
1 ThreadLocal 是什么?
ThreadLocal是指线程的本地变量,我们可以通过ThreadLocal去设计只有线程内部才可以访问的变量,该变量是与其他线程所隔离的。
2 ThreadLocal 可以干什么?
ThreadLocal最大的特点就是线程隔离,其他线程无法获取,当前线程ThreadLocal存放的数据,然后就是在ThreadLocal中存放的信息,无论其程序的调用链路有多深,只要是同一个线程,无需参数传递,可以直接获取。
3 ThreadLocal 典型应用
一个比较典型的应用场景就是,MyBatis分页插件中的应用,需要分页的接口,需要先将分页信息放入ThreadLocal中,需要用的时候直接通过ThreadLocal获取,这样就不需要每个接口都传入分页信息,然后再传给分页插件。

Spring的声明式事物中,也应用了ThreadLocal来存储事物信息,因为我们只有使用同一个数据库连接,设置事物才会生效,Spring的事物管理器在获取到数据库连接后就会将其与ThreadLocal绑定,事物完成后解除绑定。
4 ThreadLocal 源码解析
ThreadLocal 的源码是相对容易看懂的,我们以ThreadLocal的set方法为入口,开始看其原理
- 首先set方法会先获取当前线程 t
- 通过当前线程 t 获取当前线程的ThreadLocalMap对象
- 判断ThreadLocalMap是否为空 
  - 不为空时 为map设置值,key为当前的ThreadLocal对象,value为传入的参数
- 为空时 调用ThreadLocalMap构造函数初始化对象
 
public class ThreadLocal<T> {
    ......
        
    // ThreadLocal设置值
	public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    // 获取当前线程的ThreadLocalMap对象
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    // 当前线程ThreadLocalMap对象为空时,调用ThreadLocalMap构造函数初始化对象
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    // ThreadLocalMap为ThreadLocal的静态内部类,为Thread的属性
    static class ThreadLocalMap {
        ......
        
    	ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
    	}
    }
    
}
复制代码通过这一段代码,我们就可以很清楚的分析到ThreadLocal为什么能做到线程间的隔离,因为ThreadLocalMap是Thread的属性,也就是说这些数据是存储在Thread上的,这也是对其名字最好的理解。
然后我们再来看一下这个静态内部类ThreadLocalMap的数据结构,我们可以看到这个Map比HashMap简单许多,它是一个Entry数组,Entry是继承自WeakReference对象的(弱应用,垃圾回收时会清理),并且该Map并没有采用拉链的方式来解决hash冲突,而是拿当前下标,判断寻找下一个符合条件的位置存放。
Entry的类继承关系如下图,其中key是由弱应用,保存在Reference对象中的,通过get()方法获取。

public class ThreadLocal<T> {
    ......
    
    static class ThreadLocalMap {
        
        ......
        
       static class Entry extends WeakReference<ThreadLocal<?>> {
            Object value;
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];
            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }
        
        private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                if (k == key) {
                    e.value = value;
                    return;
                }
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;
                    
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == key) {
                    e.value = value;
                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }
            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
        
        /**
         * Remove the entry for key.
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
    }
}   
复制代码ThreadLocal子类InheritableThreadLocal的设计,它主要是从写其getMap,createMap,childValue方法,使其调用Thread中inheritableThreadLocals属性,inheritableThreadLocals在Thread会继承父线程的inheritableThreadLocals属性,从而实现子父线程间变量的传递。
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    protected T childValue(T parentValue) {
        return parentValue;
    }
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}
复制代码5 ThreadLocal 精妙设计

-  线程隔离 ThreadLocal线程本地变量如其名字一样,并没有将数据放在ThreadLocal对象中,而将ThreadLocalMap放在Thread对象上,这样ThreadLocal只是提供操作数据的入口,并不具备实际存储能力,这样就可以做到线程隔离。 
-  弱引用解决key的内存释放 ThreadLocal通过弱引用的方式,使得ThreadLocal在外部强引用消失时,可以自动被垃圾回收收集,而不需要去释放ThreadLocalMap中key的引用。 重点:key可以通过弱引用被释放,那么value如何处理呢? 尤其是在使用线程池,线程会被复用,如果不能有效的回收ThreadLocalMap,那么很容易出现脏数据,甚至内存溢出,当我们在使用完毕ThreadLocal后,一定要去调用其remove方法。 public static void main(String[] args) { new Thread(()-> { try { threadLocal.set("hello"); threadLocal.get(); } finally { threadLocal.remove(); } }).start(); } 复制代码
-  父子线程传递ThreadLocal 虽然说ThreadLocal是线程隔离的,但是理论上子线程想要获取父线程中设置的值是可以的,这时我们就可以通过InheritableThreadLocal来实现。 InheritableThreadLocal是ThreadLocal的子类,其重写了创建createMap、getMap、childValue等方法,在Thread类中有两个成员变量一个是threadLocals,一个是inheritableThreadLocals,在线程创建的时候回判断父线程是否具有inheritableThreadLocals,有的化会继承下来,从而实现父线程向子线程ThreadLocalMap的传递。 
6 ThreadLocal 测试案例
下面是ThreadLocal最简单的使用方法,大家可以自行跟踪源码进行验证。
package com.zhj.interview;
public class Test18 {
    private static ThreadLocal<String> threadLocal = new ThreadLocal<>();
    private static ThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();
    public static void main(String[] args) {
        inheritableThreadLocal.set("hello main");
        new Thread(()-> {
            try {
                threadLocal.set("hello");
                System.out.println( Thread.currentThread().getName() + threadLocal.get());
                System.out.println( Thread.currentThread().getName() + inheritableThreadLocal.get());
                inheritableThreadLocal.set("hello thread1");
                System.out.println( Thread.currentThread().getName() + inheritableThreadLocal.get());
            } finally {
                threadLocal.remove();
            }
        },"thread1 - ").start();
        new Thread(()-> {
            try {
                Thread.sleep(10);
                System.out.println( Thread.currentThread().getName() + inheritableThreadLocal.get());
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                threadLocal.remove();
            }
        }, "thread2 - ").start();
    }
}
复制代码感谢大家的阅读,如果感觉有帮助,帮点个赞,万分感谢!!!










