总结Python的并发编程~

以前写得并发编程常用于网络资源获取、数据库查询上,例如网络爬虫、并发同步MongoDB和MySQL数据。并发在I/O上有极大的好处。通常情况下,网络请求、磁盘读写等I/O的时间周期是CPU时钟周期的一百万倍,程序在执行时大部分时间都浪费在I/O等待上。CPU、内存、网络的时间对比如图:

由于项目原因,觉得有必要总结下。于是,就在博客上写下线程并发编程,未来可能会总结协程并发(好吧,我挖的坑)

背景

本文的重点不是Python线程库threading的使用(具体的使用可以看官方文档),而是关注线程的原理、使用技巧和与线程并发编程相关的各类话题。

线程的相关概念

本节内容:

  1. 线程id、线程的常量、竞争条件、临界区、线程安全和非线程安全、死锁
  2. 内核级线程和用户级线程
  3. 线程的内存模型
  4. 线程的底层实现
  5. 线程的有限状态机
  6. 不同语言的线程实现的差异
  7. Future、Callable

Python、Golang、Java、C

线程的创建和使用

Python线程的创建有三种方法:(1)通过继承threading.Thread类,重写run方法(2)实例化threading.Thread(3)使用底层库_thread。下面具体说明。

  • 通过继承的方法

假定我们要实现一个线程,能定时打印出当前时间。我们只需要两点:(1)把类初始化的参数在init方法中传入(2)重写run方法—线程的运行逻辑

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

import threading
import datetime
import time

class Timer(threading.Thread):

def __init__(self, interval): # 该类要初始化的参数
self.interval = interval
super().__init__() # 初始化父类

def run(self): # 线程的运行逻辑在这里
while True:
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), flush=True) # flush是指定清空缓冲,让字符直接在终端显示
time.sleep(self.interval) # 每次循环指定休眠时间

if __name__ == '__main__':
timer = Timer(1)
timer.start() # 调用start方法触发上面重写的run方法的逻辑

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
2017-08-19 23:25:35
2017-08-19 23:25:36
2017-08-19 23:25:37
2017-08-19 23:25:38
2017-08-19 23:25:39
2017-08-19 23:25:40
2017-08-19 23:25:41
2017-08-19 23:25:42
2017-08-19 23:25:43
2017-08-19 23:25:44
2017-08-19 23:25:45
2017-08-19 23:25:46
2017-08-19 23:25:47
2017-08-19 23:25:48

上面实现的类的一个明显缺点是无法让启动了的线程停下来,文章后面会讲到如何处理这个问题的。

  • 通过实例化threading.Thread来创建线程

功能依旧是上面的一个心跳时钟

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

import time
import datetime
import threading

def timer(interval):
while True:
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), flush=True)
time.sleep(interval)

if __name__ == '__main__':
interval = 1
task = threading.Thread(target=timer, args=(interval,))
task.start()

target是指定要在线程上执行的代码逻辑。args是传入target所指定的函数的参数即timer函数args要以元组的方式表示,即便是只有一个参数的情况。

有时候我们并不想通过args传入参数,我们还可以通过偏函数的方法提前给函数指定参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import time
import datetime
import threading
import functools

def timer(interval):
while True:
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), flush=True)
time.sleep(interval)

if __name__ == '__main__':
timer_with_args = functools.partial(timer, interval=1) # 通过偏函数指定timer的参数
task = threading.Thread(target=timer_with_args) # 参数已经传入了,不同通过args传入
task.start()

通过threading库的源码我们可以看看Thread还有哪些参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

class Thread:
def __init__(self, group=None, target=None, name=None,
args=(), kwargs=None, *, daemon=None):
assert group is None, "group argument must be None for now"
if kwargs is None:
kwargs = {}
self._target = target
self._name = str(name or _newname())
self._args = args
self._kwargs = kwargs
if daemon is not None:
self._daemonic = daemon
else:
self._daemonic = current_thread().daemon
self._ident = None
self._tstate_lock = None
self._started = Event()
self._is_stopped = False
self._initialized = True
# sys.stderr is not stored in the class like
# sys.exc_info since it can be changed between instances
self._stderr = _sys.stderr
# For debugging and _after_fork()
_dangling.add(self)

从源码可以知道,我们还可以通过关键字参数给我们的target函数传入参数。另外还可指定线程是否为守护线程。守护线程可以到达的效果是:当有多个线程在并发执行时,所有的非守护线程都执行完毕退出了,那么解析器不管有多小守护线程,不管它们在执行什么,都直接退出。这会即将讲到。

以上就是实现线程的两种方法:实例化Thread类、继承Thread重写run方法。那么问题来了,到底什么时候使用前者,什么时候使用后者。一般惯例如下:

当我们要实现的线程的执行流程很复杂,需要分解为小的函数,我们就可以:采用继承Thread重写run方法。这样我们把复杂的功能分解为该类的方法,有run方法调用这些方法从而构成复杂的多线程调用链。模板如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

class Task(threading.Thread):

def __init__(self, *args, **kwargs):
# deal with args kwargs
super().__init__()

def _task1(self):
# task2 function

def _task2(self):
# task2 function

...

def run(self):
self._task1()
self._task2()

而如果多线程执行的函数相对简单,一个函数就可以实现,通过实例化Thread的方法就很方便了。例如上面的timer例子。

标准库threading中有一个Timer类,其实现的功能和心跳时钟相似。Timer类会在启动后的指定时间到达的一刻运行传入的func函数。通常可以把这个类用于一次性的定时任务,例如定时清理日志。例子:

1
2
3
4
5
6
7
8
9
10
11
12
import threading

def task(*args, **kwargs):
print("args params", args)
print("kwargs params", kwargs)
print("I quit now!")

if __name__ == '__main__':
interval = 1
timer = threading.Timer(interval, task, args=(1, 2, 3))
timer.start()

另外,还有一种不常用的方法。使用_thread(Python3)中的start_new_thread函数。但这个函数来自底层库_thread,并不常用。在实际开发中应该优先使用threading库中的函数。

1
2
3
4
5
6
7
8
9
10
11

import _thread

def timer(interval):
while True:
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), flush=True)
time.sleep(interval)

if __name__ == '__main__':
thread_id = _thread.start_new_thread(timer, (1,))

如果我们需要更细粒度控制线程以及并发原语,我们可以使用_thread这个库。

结合Python的装饰器语法糖,我们甚至可以创建一线程装饰器,被装饰的函数或类在新的线程中运行。这一实现在技巧部分出现。

守护线程

关于守护进程参考历史文章:Linux下创建守护进程

守护线程和守护进程的概念类似。后者是进程在后台(没有终端)运行,前者是协助其他线程运行。当进程中的所有非守护线程已经运行完成(回想线程状态图),进程就直接退出,而不等待守护线程。使用守护线程的情景是:业务需要多个线程执行,其中一部分线程并不参与实际业务流程只是协助执行业务流程的线程,当执行业务流程的线程退出了,协助业务流程的线程就没有运行的意义了,它不应该让进程(主线程)等待它们退出。那么可以把这类协助目的的线程设置为守护线程。例如:先前的项目文件系统搜索引擎,上面运行着两类线程,一类是web服务线程,一类是文件系统扫描线程。

线程本地

本节包括两部分:

  1. threading.local
  2. Flask

threading.local为每个线程创建一个对其他线程不可见的对象,用以存储当前线程的数据。threading.local本质上是字典数据结构,字典的键为线程唯一的id,可以通过`threading.current_thread().ident获取,字段的值为线程id对应的线程存储的对象空间。

下面举一个例子,两个线程分别保存以自己命名的数据,然后试图读取自己和另外一线程的数据,以此考察threading.local的效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import threading
import time
import random

def task(_id, local):
# tid = threading.current_thread().ident // 获取当前线程id
thread_id = "thread_{}".format(_id)
setattr(local, thread_id, _id)

time.sleep(random.random())

try:
value = local.thread_1
print(thread_id, "got thread_id", value)
except AttributeError as err:
print(thread_id, "can't see local.thread_1", flush=True)

try:
value = local.thread_2
print(thread_id, "got thread_id", value)
except AttributeError as err:
print(thread_id, "can't see local.thread_2", flush=True)

def main():

local = threading.local()
thread_1 = threading.Thread(target=task, args=(1, local))
thread_2 = threading.Thread(target=task, args=(2, local))

thread_1.start()
thread_2.start()

thread_1.join()
thread_2.join()

if __name__ == '__main__':
main()

运行程序输出:

1
2
3
4
thread_2 can't see local.thread_1
thread_2 got thread_id 2
thread_1 got thread_id 1
thread_1 can't see local.thread_2

可以验证,线程无法看到其他线程在threading.local上保存的变量,起到隔离作用。

threading.local本身是一个类,该类源自_thread._local。可以通过继承threading.local来实现更丰富的控制或接口。下面例子通过继承threading.local为其添加部分dict接口。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import threading

def _is_dunder(name):
"""Returns True if a __dunder__ name, False otherwise."""
return (name[:2] == name[-2:] == '__' and
name[2:3] != '_' and
name[-3:-2] != '_' and
len(name) > 4)

def _is_sunder(name):
"""Returns True if a _sunder_ name, False otherwise."""
return (name[0] == name[-1] == '_' and
name[1:2] != '_' and
name[-2:-1] != '_' and
len(name) > 2)

class DictLocal(threading.local):

def __init__(self, **kwargs):
super().__init__()
for key, value in kwargs.items():
self[key] = value

def __getitem__(self, key):
return getattr(self, key)

def __setitem__(self, key, value):
if _is_dunder(key) or _is_sunder(key):
raise ValueError("key is dunder or sunder")
setattr(self, key, value)

def __delitem__(self, key):
try:
delattr(self, key)
except AttributeError:
pass

def items(self):
data = {}
for item in dir(self):
if _is_dunder(item) or _is_sunder(item) or item == 'items':
continue
data[item] = self[item]
return data

def __contains__(self, key):
return hasattr(self, key)

这个类比threading.local的使用方便多了。_is_dunder_is_sunder方法用于检测私有方法或被保护方法。通常我们不会以这两类名字命名线程local变量。

上面讨论了threading.local的实现原理,下面实现一个简单的local。注意递归问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

import threading
from collections import defaultdict

class SimpleLocal:

def __init__(self, **kwargs):
self._data = defaultdict(dict)
_id = threading.current_thread().ident
for key, value in kwargs.items():
self._data[_id][key] = value

def __getitem__(self, key):
_id = threading.current_thread().ident
return self._data[_id][key] # 这里并没有做一次处理,和threading.local一样

def __setitem__(self, key, value):
_id = threading.current_thread().ident
self._data[_id][key] = value

def __contains__(self, key):
_id = threading.current_thread().ident
return key in self._data[_id]

def __delitem__(self, key):
if key in self:
_id = threading.current_thread().ident
del self._data[_id][key]

我们依旧使用上面的例子,但把threading.local改为SimpleLocal。由于数据结构上的差异,修改了部分细节。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import threading
import time
import random

def task(_id, local):
# tid = threading.current_thread().ident // 获取当前线程id
thread_id = "thread_{}".format(_id)
local[thread_id] = _id # setattr(local, thread_id, _id)

time.sleep(random.random())

try:
value = local["thread_1"]
print(thread_id, "got thread_id", value)
except KeyError as err:
print(thread_id, "can't see local.thread_1", flush=True)

try:
value = local["thread_2"]
print(thread_id, "got thread_id", value)
except KeyError as err:
print(thread_id, "can't see local.thread_2", flush=True)

def main():

local = SimpleLocal() # 类似threading.local()
thread_1 = threading.Thread(target=task, args=(1, local))
thread_2 = threading.Thread(target=task, args=(2, local))

thread_1.start()
thread_2.start()

thread_1.join()
thread_2.join()

if __name__ == '__main__':
main()

运行结果依旧不变。

那么threading.local有什么用呢?下面通过Flask中的上下文Context来举例说明。

线程间通信

Python进行线程间通信有两类方法:1. 共享变量,通过同步原语实现并发访问控制。2. 线程安全的队列,把数据发送到队列中,另一方线程取出。前者在下一节详述。后者在实现上也有两种方法:(1)标准库中的queue模块的实现方法,使用面向对象方式,在线程不安全的队列上添加同步原语。(2)通过CAS实现无锁队列。从底层角度,有锁无锁的实现方式本质上都是一样的,只是层次、粒度不一样。在不同层次上实现锁导致不同的性能差异。

本节详述使用线程安全队列进行的方式线程间通信。然后剖析queue模块,根据该模块实现实现线程安全的优先队列。最后实现基于CAS的无锁队列。

Pythonqueue模块有三个线程安全的队列,根据其命名可以知道其作用:

  • queue.Queue
  • queue.LifoQueue
  • queue.PriorityQueue

它们的接口都一致的。下面以生产者/消费者为例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import threading
import queue
import time
import random
import string

_sentinel = object()

class Consumer(threading.Thread):
def __init__(self, q):
self.q = q
super().__init__()

def run(self):
while True:
e = self.q.get()
if e is _sentinel:
self.q.put(e)
break
print('consume:{} by thread<{}>'.format(e, threading.current_thread().ident))
time.sleep(random.random())

class Producer(threading.Thread):
def __init__(self, q):
self.q = q
super().__init__()

def run(self):
loop = 0
while True:
e = random.choice(string.ascii_letters)
self.q.put(e)
print('produce:{} by thread<{}>'.format(e, threading.current_thread().ident))
time.sleep(random.random())
loop += 1
if loop == 10:
break
self.q.put(_sentinel)

def main():
q = queue.Queue(maxsize=10)
producer = Producer(q)
consumer = Consumer(q)
producer.start()
consumer.start()

producer.join()
consumer.join()

if __name__ == '__main__':
main()

上面的实现有一个特殊的处理,通过_sentinel对象告知消费者退出,当消费者收到这个对象后重新把它放到队列中,然后退出循环。

另外有一种方法可以不使用_sentinel,把消费者和生产者设置为守护线程,当队列为空时,main函数(主线程)不等待消费者和生产者进程而直接退出,进而程序退出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import threading
import queue
import time
import random
import string

class Consumer(threading.Thread):
def __init__(self, q):
self.q = q
super().__init__(daemon=True)

def run(self):
while True:
e = self.q.get()
print('consume:{} by thread<{}>'.format(e, threading.current_thread().ident))
self.q.task_done()
time.sleep(random.random())

class Producer(threading.Thread):
def __init__(self, q):
self.q = q
super().__init__(daemon=True)

def run(self):
while True:
e = random.choice(string.ascii_letters)
self.q.put(e)
print('produce:{} by thread<{}>'.format(e, threading.current_thread().ident))
time.sleep(random.random())

def main():
q = queue.Queue(maxsize=10)
producer = Producer(q)
consumer = Consumer(q)
producer.start()
consumer.start()

time.sleep(3) # 避免队列一开始为空而直接退出
q.join()
print('queue is empty, so quit', flush=True)

if __name__ == '__main__':
main()

在不采取主动终止的情况,上面的代码不确定何时终止,就像网络爬虫并不能确定何时检索完整个网络。这个问题融入一定的技巧可以解决,文章后面会提及。

关于线程通信的深入探讨见旧文线程安全的优先队列的实现

转载请包括本文地址:https://allenwind.github.io/blog/4871
更多文章请参考:https://allenwind.github.io/blog/archives/