昨天看一篇文章说保证线程安全有两种方法,一种是加锁,另一种是使用ThreadLocal,所以今天看一下ThreadLocal是怎么样实现的。
打开ThreadLocal源码,发现映入眼帘的是AtomicInteger,虽然早有耳闻,但仍不知道AtomicXXX系列是怎样实现的。于是又打开了AtomicInteger的源码,发现它使用了Unsafe,好吧,我们从Unsafe开始。
Unsafe
首先这个类是个单例模式,构造方法用了private,只能通过Unsafe.getUnsafe()拿到实例,然后通过一个static初始了许多OFFSET,其余本类中的大部分方法都是通过native修饰,说明这些方法都不是java实现,而是用其他语言(C/C++)实现。从网上资料查到,Unsafe主要的用途是操作内存,可以直接将值写入内存中。
而且其中compareAndSwapXXX(传说中的CAS)方法通过本地方法实现了原子的操作。例如compareAndSwapInt这个方法,这个方法有四个参数,其中第一个参数为需要改变的对象,第二个为偏移量,第三个参数为期待的值,第四个为更新后的值。整个方法的作用即为若调用该方法时,value的值与expect这个值相等,那么则将value修改为update这个值,并返回一个true,如果调用该方法时,value的值与expect这个值不相等,那么不做任何操作,并返回一个false,这其实是乐观锁的一种实现。
还有getXXXVolatile与putXXXVolatile方法,与volatile关键字类似,就是每次写都刷新主内存,每次读都从主内存中读。
AtomicInteger
AtomicInteger中有一个volatile修饰的value,还有个valueOffset,它表示的是value在AtomicInteger对象中的偏移量,主要是为了直接写内存用的,剩下很多原子操作的方法,先来看看compareAndSet:
1 2 3
| public final boolean compareAndSet(int expect, int update) { return unsafe.compareAndSwapInt(this, valueOffset, expect, update); }
|
直接调用unsafe的compareAndSwapInt
1
| public final native boolean compareAndSwapInt(Object var1, long var2, int var4, int var5);
|
这个是拿expect与内存中的值比较,如果一致则更新,不一致则返回false,全程原子操作,典型的乐观锁。
再看一个getAndAdd
1 2 3
| public final int getAndAdd(int delta) { return unsafe.getAndAddInt(this, valueOffset, delta); }
|
调用unsafe的getAndAddInt
1 2 3 4 5 6 7 8
| public final int getAndAddInt(Object var1, long var2, int var4) { int var5; do { var5 = this.getIntVolatile(var1, var2); } while(!this.compareAndSwapInt(var1, var2, var5, var5 + var4)); return var5; }
|
循环获取偏移量指向内存的值作为返回,直到拿到最新的内存中的值且执行了add操作,有点难理解,我们分两种情况看:
1、没有并发,从内存中指定偏移量取出为var5,while中compareAndSwapInt会将内存中相应偏移量的值更新为 var5+var4,并且返回true,退出while循环并return,此时正常更新内存中的值,也正常返回。
2、假设存在并发,从内存中指定偏移量取出为var5,此时其他线程更新了内存中的值为T,此时while中compareAndSwapInt执行会失败,while会再次执行,重新取var5为T,此时没有并发,再更新内存中的值为T+var4,成功,退出循环并return,此时拿到的return值是新的T而不是var5,且add后的值为T+var4。
可以看出getAndAddInt保证拿到的是最新的值,且增加都是在最新值的基础上增加。
ThreadLocal
最后回到今天的主角ThreadLocal,首先ThreadLocal内部有一个静态的ThreadLocalMap类,ThreadLocalMap类里面有个Entry数组。很关键的是ThreadLocal有一个
1
| private final int threadLocalHashCode = nextHashCode();
|
threadLocalHashCode 是ThreadLocal对象的属性,且不可变,看他是如何初始化的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| private static AtomicInteger nextHashCode = new AtomicInteger(); /** * The difference between successively generated hash codes - turns * implicit sequential thread-local IDs into near-optimally spread * multiplicative hash values for power-of-two-sized tables. */ private static final int HASH_INCREMENT = 0x61c88647; /** * Returns the next hash code. */ private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT); }
|
这里利用了AtomicInteger的getAndAdd,也就是说在多线程并发的情况下,每个ThreadLocal对象拿到的threadLocalHashCode的值都不一样。
看一下set
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| public void set(T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); else createMap(t, value); } ThreadLocalMap getMap(Thread t) { return t.threadLocals; } private void set(ThreadLocal<?> key, Object value) { // We don't use a fast path as with get() because it is at // least as common to use set() to create new entries as // it is to replace existing ones, in which case, a fast // path would fail more often than not. Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get(); if (k == key) { e.value = value; return; } if (k == null) { replaceStaleEntry(key, value, i); return; } } tab[i] = new Entry(key, value); int sz = ++size; if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); }
|
这里set的时候会根据当前的ThreadLocal的threadLocalHashCode算出区别于其他ThreadLocal的hash值,然后将对象塞进这个hash对应的数组中,get也是类似。也就是说区分线程的是不同的ThreadLocalMap,而同一个线程内则是靠不同ThreadLocal区分开,听上去很难懂,而且跟想象中的不太一样,看下图:
每个Thread会挂一个ThreadLocalMap对象,每个ThreadLocal对象有一个唯一的key,当我们需要在当前线程存储一个对象时,拿到一个ThreadLocal生成一个key,放到当前线程的ThreadlLocalMap的对应位置就可以了(似乎网上有一些说的是错误的)。
这个我写了一个程序验证了:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
| package thread; import java.lang.reflect.Field; import java.util.Random; public class ThreadLocalTest implements Runnable { public static java.lang.ThreadLocal<Object> threadLocal = new java.lang.ThreadLocal<Object>(); public static java.lang.ThreadLocal<Object> threadLocal2 = new java.lang.ThreadLocal<Object>(); @Override public void run() { Random random = new Random(); int i = random.nextInt(10); System.out.println(Thread.currentThread().getName() + " 生成 " + i); threadLocal.set(i); threadLocal2.set(i*2); try { Thread.sleep(1000L); } catch (Exception e) { } System.out.println(Thread.currentThread().getName() + " threadLocal取出 " + threadLocal.get()); System.out.println(Thread.currentThread().getName() + " threadLocal2取出 " + threadLocal2.get()); try { Field f = ThreadLocal.class.getDeclaredField("threadLocalHashCode"); f.setAccessible(true); System.out.println(Thread.currentThread().getName() + " threadLocal.threadLocalHashCode = " + f.getInt(threadLocal)); } catch (Exception e) { } try { Field f1 = Thread.currentThread().getClass().getDeclaredField("threadLocals"); f1.setAccessible(true); System.out.println(Thread.currentThread().getName() + " threadLocal.threadLocalHashCode = " + f1.get(Thread.currentThread())); } catch (Exception e) { } } public static void main(String[] args) { ThreadLocalTest t = new ThreadLocalTest(); Thread thread1 = new Thread(t); Thread thread2 = new Thread(t); thread1.start(); thread2.start(); } }
|
输出:
1 2 3 4 5 6 7 8 9 10
| Thread-0 生成 3 Thread-1 生成 7 Thread-0 threadLocal取出 3 Thread-0 threadLocal2取出 6 Thread-1 threadLocal取出 7 Thread-1 threadLocal2取出 14 Thread-0 threadLocal.threadLocalHashCode = -387276957 Thread-1 threadLocal.threadLocalHashCode = -387276957 Thread-0 threadLocal.threadLocalHashCode = java.lang.ThreadLocal$ThreadLocalMap@2682ba92 Thread-1 threadLocal.threadLocalHashCode = java.lang.ThreadLocal$ThreadLocalMap@6e740a76
|
证明 两个ThreadLocal在两个线程分别存入对象时,取出来也是正确的,而且两个Thread对同一个ThreadLocal对象拿出来的threadLocalHashCode也是一样的,但是两个Thread的ThreadLocalMap却是不同的。