昨天看一篇文章说保证线程安全有两种方法,一种是加锁,另一种是使用ThreadLocal,所以今天看一下ThreadLocal是怎么样实现的。
打开ThreadLocal源码,发现映入眼帘的是AtomicInteger,虽然早有耳闻,但仍不知道AtomicXXX系列是怎样实现的。于是又打开了AtomicInteger的源码,发现它使用了Unsafe,好吧,我们从Unsafe开始。

Unsafe

首先这个类是个单例模式,构造方法用了private,只能通过Unsafe.getUnsafe()拿到实例,然后通过一个static初始了许多OFFSET,其余本类中的大部分方法都是通过native修饰,说明这些方法都不是java实现,而是用其他语言(C/C++)实现。从网上资料查到,Unsafe主要的用途是操作内存,可以直接将值写入内存中。

而且其中compareAndSwapXXX(传说中的CAS)方法通过本地方法实现了原子的操作。例如compareAndSwapInt这个方法,这个方法有四个参数,其中第一个参数为需要改变的对象,第二个为偏移量,第三个参数为期待的值,第四个为更新后的值。整个方法的作用即为若调用该方法时,value的值与expect这个值相等,那么则将value修改为update这个值,并返回一个true,如果调用该方法时,value的值与expect这个值不相等,那么不做任何操作,并返回一个false,这其实是乐观锁的一种实现。

还有getXXXVolatile与putXXXVolatile方法,与volatile关键字类似,就是每次写都刷新主内存,每次读都从主内存中读。

AtomicInteger

AtomicInteger中有一个volatile修饰的value,还有个valueOffset,它表示的是value在AtomicInteger对象中的偏移量,主要是为了直接写内存用的,剩下很多原子操作的方法,先来看看compareAndSet:

1
2
3
public final boolean compareAndSet(int expect, int update) {
return unsafe.compareAndSwapInt(this, valueOffset, expect, update);
}

直接调用unsafe的compareAndSwapInt

1
public final native boolean compareAndSwapInt(Object var1, long var2, int var4, int var5);

这个是拿expect与内存中的值比较,如果一致则更新,不一致则返回false,全程原子操作,典型的乐观锁。
再看一个getAndAdd

1
2
3
public final int getAndAdd(int delta) {
return unsafe.getAndAddInt(this, valueOffset, delta);
}

调用unsafe的getAndAddInt

1
2
3
4
5
6
7
8
public final int getAndAddInt(Object var1, long var2, int var4) {
int var5;
do {
var5 = this.getIntVolatile(var1, var2);
} while(!this.compareAndSwapInt(var1, var2, var5, var5 + var4));
return var5;
}

循环获取偏移量指向内存的值作为返回,直到拿到最新的内存中的值且执行了add操作,有点难理解,我们分两种情况看:

1、没有并发,从内存中指定偏移量取出为var5,while中compareAndSwapInt会将内存中相应偏移量的值更新为 var5+var4,并且返回true,退出while循环并return,此时正常更新内存中的值,也正常返回。
2、假设存在并发,从内存中指定偏移量取出为var5,此时其他线程更新了内存中的值为T,此时while中compareAndSwapInt执行会失败,while会再次执行,重新取var5为T,此时没有并发,再更新内存中的值为T+var4,成功,退出循环并return,此时拿到的return值是新的T而不是var5,且add后的值为T+var4。
可以看出getAndAddInt保证拿到的是最新的值,且增加都是在最新值的基础上增加。

ThreadLocal

最后回到今天的主角ThreadLocal,首先ThreadLocal内部有一个静态的ThreadLocalMap类,ThreadLocalMap类里面有个Entry数组。很关键的是ThreadLocal有一个

1
private final int threadLocalHashCode = nextHashCode();

threadLocalHashCode 是ThreadLocal对象的属性,且不可变,看他是如何初始化的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private static AtomicInteger nextHashCode =
new AtomicInteger();
/**
* The difference between successively generated hash codes - turns
* implicit sequential thread-local IDs into near-optimally spread
* multiplicative hash values for power-of-two-sized tables.
*/
private static final int HASH_INCREMENT = 0x61c88647;
/**
* Returns the next hash code.
*/
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}

这里利用了AtomicInteger的getAndAdd,也就是说在多线程并发的情况下,每个ThreadLocal对象拿到的threadLocalHashCode的值都不一样。
看一下set

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

这里set的时候会根据当前的ThreadLocal的threadLocalHashCode算出区别于其他ThreadLocal的hash值,然后将对象塞进这个hash对应的数组中,get也是类似。也就是说区分线程的是不同的ThreadLocalMap,而同一个线程内则是靠不同ThreadLocal区分开,听上去很难懂,而且跟想象中的不太一样,看下图:
1.png

每个Thread会挂一个ThreadLocalMap对象,每个ThreadLocal对象有一个唯一的key,当我们需要在当前线程存储一个对象时,拿到一个ThreadLocal生成一个key,放到当前线程的ThreadlLocalMap的对应位置就可以了(似乎网上有一些说的是错误的)。

这个我写了一个程序验证了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
package thread;
import java.lang.reflect.Field;
import java.util.Random;
public class ThreadLocalTest implements Runnable {
public static java.lang.ThreadLocal<Object> threadLocal = new java.lang.ThreadLocal<Object>();
public static java.lang.ThreadLocal<Object> threadLocal2 = new java.lang.ThreadLocal<Object>();
@Override
public void run() {
Random random = new Random();
int i = random.nextInt(10);
System.out.println(Thread.currentThread().getName() + " 生成 " + i);
threadLocal.set(i);
threadLocal2.set(i*2);
try {
Thread.sleep(1000L);
} catch (Exception e) {
}
System.out.println(Thread.currentThread().getName() + " threadLocal取出 " + threadLocal.get());
System.out.println(Thread.currentThread().getName() + " threadLocal2取出 " + threadLocal2.get());
try {
Field f = ThreadLocal.class.getDeclaredField("threadLocalHashCode");
f.setAccessible(true);
System.out.println(Thread.currentThread().getName() + " threadLocal.threadLocalHashCode = " + f.getInt(threadLocal));
} catch (Exception e) {
}
try {
Field f1 = Thread.currentThread().getClass().getDeclaredField("threadLocals");
f1.setAccessible(true);
System.out.println(Thread.currentThread().getName() + " threadLocal.threadLocalHashCode = " + f1.get(Thread.currentThread()));
} catch (Exception e) {
}
}
public static void main(String[] args) {
ThreadLocalTest t = new ThreadLocalTest();
Thread thread1 = new Thread(t);
Thread thread2 = new Thread(t);
thread1.start();
thread2.start();
}
}

输出:

1
2
3
4
5
6
7
8
9
10
Thread-0 生成 3
Thread-1 生成 7
Thread-0 threadLocal取出 3
Thread-0 threadLocal2取出 6
Thread-1 threadLocal取出 7
Thread-1 threadLocal2取出 14
Thread-0 threadLocal.threadLocalHashCode = -387276957
Thread-1 threadLocal.threadLocalHashCode = -387276957
Thread-0 threadLocal.threadLocalHashCode = java.lang.ThreadLocal$ThreadLocalMap@2682ba92
Thread-1 threadLocal.threadLocalHashCode = java.lang.ThreadLocal$ThreadLocalMap@6e740a76

证明 两个ThreadLocal在两个线程分别存入对象时,取出来也是正确的,而且两个Thread对同一个ThreadLocal对象拿出来的threadLocalHashCode也是一样的,但是两个Thread的ThreadLocalMap却是不同的。