查看原文
其他

动手实现一个 LRU cache

crossoverJie crossoverJie 2018-10-25

前言

LRU 是 LeastRecentlyUsed 的简写,字面意思则是 最近最少使用

通常用于缓存的淘汰策略实现,由于缓存的内存非常宝贵,所以需要根据某种规则来剔除数据保证内存不被撑满。

如常用的 Redis 就有以下几种策略:

策略描述
volatile-lru从已设置过期时间的数据集中挑选最近最少使用的数据淘汰
volatile-ttl从已设置过期时间的数据集中挑选将要过期的数据淘汰
volatile-random从已设置过期时间的数据集中任意选择数据淘汰
allkeys-lru从所有数据集中挑选最近最少使用的数据淘汰
allkeys-random从所有数据集中任意选择数据进行淘汰
no-envicition禁止驱逐数据

摘抄自:https://github.com/CyC2018/Interview-Notebook/blob/master/notes/Redis.md#%E5%8D%81%E4%B8%89%E6%95%B0%E6%8D%AE%E6%B7%98%E6%B1%B0%E7%AD%96%E7%95%A5

实现一

之前也有接触过一道面试题,大概需求是:

  • 实现一个 LRU 缓存,当缓存数据达到 N 之后需要淘汰掉最近最少使用的数据。

  • N 小时之内没有被访问的数据也需要淘汰掉。

以下是我的实现:

  1. public class LRUAbstractMap extends java.util.AbstractMap {

  2.    private final static Logger LOGGER = LoggerFactory.getLogger(LRUAbstractMap.class);

  3.    /**

  4.     * 检查是否超期线程

  5.     */

  6.    private ExecutorService checkTimePool ;

  7.    /**

  8.     * map 最大size

  9.     */

  10.    private final static int MAX_SIZE = 1024 ;

  11.    private final static ArrayBlockingQueue<Node> QUEUE = new ArrayBlockingQueue<>(MAX_SIZE) ;

  12.    /**

  13.     * 默认大小

  14.     */

  15.    private final static int DEFAULT_ARRAY_SIZE =1024 ;

  16.    /**

  17.     * 数组长度

  18.     */

  19.    private int arraySize ;

  20.    /**

  21.     * 数组

  22.     */

  23.    private Object[] arrays ;

  24.    /**

  25.     * 判断是否停止 flag

  26.     */

  27.    private volatile boolean flag = true ;

  28.    /**

  29.     * 超时时间

  30.     */

  31.    private final static Long EXPIRE_TIME = 60 * 60 * 1000L ;

  32.    /**

  33.     * 整个 Map 的大小

  34.     */

  35.    private volatile AtomicInteger size  ;

  36.    public LRUAbstractMap() {

  37.        arraySize = DEFAULT_ARRAY_SIZE;

  38.        arrays = new Object[arraySize] ;

  39.        //开启一个线程检查最先放入队列的值是否超期

  40.        executeCheckTime();

  41.    }

  42.    /**

  43.     * 开启一个线程检查最先放入队列的值是否超期 设置为守护线程

  44.     */

  45.    private void executeCheckTime() {

  46.        ThreadFactory namedThreadFactory = new ThreadFactoryBuilder()

  47.                .setNameFormat("check-thread-%d")

  48.                .setDaemon(true)

  49.                .build();

  50.        checkTimePool = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS,

  51.                new ArrayBlockingQueue<>(1),namedThreadFactory,new ThreadPoolExecutor.AbortPolicy());

  52.        checkTimePool.execute(new CheckTimeThread()) ;

  53.    }

  54.    @Override

  55.    public Set<Entry> entrySet() {

  56.        return super.keySet();

  57.    }

  58.    @Override

  59.    public Object put(Object key, Object value) {

  60.        int hash = hash(key);

  61.        int index = hash % arraySize ;

  62.        Node currentNode = (Node) arrays[index] ;

  63.        if (currentNode == null){

  64.            arrays[index] = new Node(null,null, key, value);

  65.            //写入队列

  66.            QUEUE.offer((Node) arrays[index]) ;

  67.            sizeUp();

  68.        }else {

  69.            Node cNode = currentNode ;

  70.            Node nNode = cNode ;

  71.            //存在就覆盖

  72.            if (nNode.key == key){

  73.                cNode.val = value ;

  74.            }

  75.            while (nNode.next != null){

  76.                //key 存在 就覆盖 简单判断

  77.                if (nNode.key == key){

  78.                    nNode.val = value ;

  79.                    break ;

  80.                }else {

  81.                    //不存在就新增链表

  82.                    sizeUp();

  83.                    Node node = new Node(nNode,null,key,value) ;

  84.                    //写入队列

  85.                    QUEUE.offer(currentNode) ;

  86.                    cNode.next = node ;

  87.                }

  88.                nNode = nNode.next ;

  89.            }

  90.        }

  91.        return null ;

  92.    }

  93.    @Override

  94.    public Object get(Object key) {

  95.        int hash = hash(key) ;

  96.        int index = hash % arraySize ;

  97.        Node currentNode = (Node) arrays[index] ;

  98.        if (currentNode == null){

  99.            return null ;

  100.        }

  101.        if (currentNode.next == null){

  102.            //更新时间

  103.            currentNode.setUpdateTime(System.currentTimeMillis());

  104.            //没有冲突

  105.            return currentNode ;

  106.        }

  107.        Node nNode = currentNode ;

  108.        while (nNode.next != null){

  109.            if (nNode.key == key){

  110.                //更新时间

  111.                currentNode.setUpdateTime(System.currentTimeMillis());

  112.                return nNode ;

  113.            }

  114.            nNode = nNode.next ;

  115.        }

  116.        return super.get(key);

  117.    }

  118.    @Override

  119.    public Object remove(Object key) {

  120.        int hash = hash(key) ;

  121.        int index = hash % arraySize ;

  122.        Node currentNode = (Node) arrays[index] ;

  123.        if (currentNode == null){

  124.            return null ;

  125.        }

  126.        if (currentNode.key == key){

  127.            sizeDown();

  128.            arrays[index] = null ;

  129.            //移除队列

  130.            QUEUE.poll();

  131.            return currentNode ;

  132.        }

  133.        Node nNode = currentNode ;

  134.        while (nNode.next != null){

  135.            if (nNode.key == key){

  136.                sizeDown();

  137.                //在链表中找到了 把上一个节点的 next 指向当前节点的下一个节点

  138.                nNode.pre.next = nNode.next ;

  139.                nNode = null ;

  140.                //移除队列

  141.                QUEUE.poll();

  142.                return nNode;

  143.            }

  144.            nNode = nNode.next ;

  145.        }

  146.        return super.remove(key);

  147.    }

  148.    /**

  149.     * 增加size

  150.     */

  151.    private void sizeUp(){

  152.        //在put值时候认为里边已经有数据了

  153.        flag = true ;

  154.        if (size == null){

  155.            size = new AtomicInteger() ;

  156.        }

  157.        int size = this.size.incrementAndGet();

  158.        if (size >= MAX_SIZE) {

  159.            //找到队列头的数据

  160.            Node node = QUEUE.poll() ;

  161.            if (node == null){

  162.                throw new RuntimeException("data error") ;

  163.            }

  164.            //移除该 key

  165.            Object key = node.key ;

  166.            remove(key) ;

  167.            lruCallback() ;

  168.        }

  169.    }

  170.    /**

  171.     * 数量减小

  172.     */

  173.    private void sizeDown(){

  174.        if (QUEUE.size() == 0){

  175.            flag = false ;

  176.        }

  177.        this.size.decrementAndGet() ;

  178.    }

  179.    @Override

  180.    public int size() {

  181.        return size.get() ;

  182.    }

  183.    /**

  184.     * 链表

  185.     */

  186.    private class Node{

  187.        private Node next ;

  188.        private Node pre ;

  189.        private Object key ;

  190.        private Object val ;

  191.        private Long updateTime ;

  192.        public Node(Node pre,Node next, Object key, Object val) {

  193.            this.pre = pre ;

  194.            this.next = next;

  195.            this.key = key;

  196.            this.val = val;

  197.            this.updateTime = System.currentTimeMillis() ;

  198.        }

  199.        public void setUpdateTime(Long updateTime) {

  200.            this.updateTime = updateTime;

  201.        }

  202.        public Long getUpdateTime() {

  203.            return updateTime;

  204.        }

  205.        @Override

  206.        public String toString() {

  207.            return "Node{" +

  208.                    "key=" + key +

  209.                    ", val=" + val +

  210.                    '}';

  211.        }

  212.    }

  213.    /**

  214.     * copy HashMap 的 hash 实现

  215.     * @param key

  216.     * @return

  217.     */

  218.    public int hash(Object key) {

  219.        int h;

  220.        return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);

  221.    }

  222.    private void lruCallback(){

  223.        LOGGER.debug("lruCallback");

  224.    }

  225.    private class CheckTimeThread implements Runnable{

  226.        @Override

  227.        public void run() {

  228.            while (flag){

  229.                try {

  230.                    Node node = QUEUE.poll();

  231.                    if (node == null){

  232.                        continue ;

  233.                    }

  234.                    Long updateTime = node.getUpdateTime() ;

  235.                    if ((updateTime - System.currentTimeMillis()) >= EXPIRE_TIME){

  236.                        remove(node.key) ;

  237.                    }

  238.                } catch (Exception e) {

  239.                    LOGGER.error("InterruptedException");

  240.                }

  241.            }

  242.        }

  243.    }

  244. }

感兴趣的朋友可以直接从:

https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRUAbstractMap.java

下载代码本地运行。

代码看着比较多,其实实现的思路还是比较简单:

  • 采用了与 HashMap 一样的保存数据方式,只是自己手动实现了一个简易版。

  • 内部采用了一个队列来保存每次写入的数据。

  • 写入的时候判断缓存是否大于了阈值 N,如果满足则根据队列的 FIFO 特性将队列头的数据删除。因为队列头的数据肯定是最先放进去的。

  • 再开启了一个守护线程用于判断最先放进去的数据是否超期(因为就算超期也是最先放进去的数据最有可能满足超期条件。)

  • 设置为守护线程可以更好的表明其目的(最坏的情况下,如果是一个用户线程最终有可能导致程序不能正常退出,因为该线程一直在运行,守护线程则不会有这个情况。)

以上代码大体功能满足了,但是有一个致命问题。

就是最近最少使用没有满足,删除的数据都是最先放入的数据。

不过其中的 putget 流程算是一个简易的 HashMap 实现,可以对 HashMap 加深一些理解。

实现二

因此如何来实现一个完整的 LRU 缓存呢,这次不考虑过期时间的问题。

其实从上一个实现也能想到一些思路:

  • 要记录最近最少使用,那至少需要一个有序的集合来保证写入的顺序。

  • 在使用了数据之后能够更新它的顺序。

基于以上两点很容易想到一个常用的数据结构:链表

  1. 每次写入数据时将数据放入链表头结点。

  2. 使用数据时候将数据移动到头结点

  3. 缓存数量超过阈值时移除链表尾部数据。

因此有了以下实现:

  1. public class LRUMap<K, V> {

  2.    private final Map<K, V> cacheMap = new HashMap<>();

  3.    /**

  4.     * 最大缓存大小

  5.     */

  6.    private int cacheSize;

  7.    /**

  8.     * 节点大小

  9.     */

  10.    private int nodeCount;

  11.    /**

  12.     * 头结点

  13.     */

  14.    private Node<K, V> header;

  15.    /**

  16.     * 尾结点

  17.     */

  18.    private Node<K, V> tailer;

  19.    public LRUMap(int cacheSize) {

  20.        this.cacheSize = cacheSize;

  21.        //头结点的下一个结点为空

  22.        header = new Node<>();

  23.        header.next = null;

  24.        //尾结点的上一个结点为空

  25.        tailer = new Node<>();

  26.        tailer.tail = null;

  27.        //双向链表 头结点的上结点指向尾结点

  28.        header.tail = tailer;

  29.        //尾结点的下结点指向头结点

  30.        tailer.next = header;

  31.    }

  32.    public void put(K key, V value) {

  33.        cacheMap.put(key, value);

  34.        //双向链表中添加结点

  35.        addNode(key, value);

  36.    }

  37.    public V get(K key){

  38.        Node<K, V> node = getNode(key);

  39.        //移动到头结点

  40.        moveToHead(node) ;

  41.        return cacheMap.get(key);

  42.    }

  43.    private void moveToHead(Node<K,V> node){

  44.        //如果是最后的一个节点

  45.        if (node.tail == null){

  46.            node.next.tail = null ;

  47.            tailer = node.next ;

  48.            nodeCount -- ;

  49.        }

  50.        //如果是本来就是头节点 不作处理

  51.        if (node.next == null){

  52.            return ;

  53.        }

  54.        //如果处于中间节点

  55.        if (node.tail != null && node.next != null){

  56.            //它的上一节点指向它的下一节点 也就删除当前节点

  57.            node.tail.next = node.next ;

  58.            nodeCount -- ;

  59.        }

  60.        //最后在头部增加当前节点

  61.        //注意这里需要重新 new 一个对象,不然原本的node 还有着下面的引用,会造成内存溢出。

  62.        node = new Node<>(node.getKey(),node.getValue()) ;

  63.        addHead(node) ;

  64.    }

  65.    /**

  66.     * 链表查询 效率较低

  67.     * @param key

  68.     * @return

  69.     */

  70.    private Node<K,V> getNode(K key){

  71.        Node<K,V> node = tailer ;

  72.        while (node != null){

  73.            if (node.getKey().equals(key)){

  74.                return node ;

  75.            }

  76.            node = node.next ;

  77.        }

  78.        return null ;

  79.    }

  80.    /**

  81.     * 写入头结点

  82.     * @param key

  83.     * @param value

  84.     */

  85.    private void addNode(K key, V value) {

  86.        Node<K, V> node = new Node<>(key, value);

  87.        //容量满了删除最后一个

  88.        if (cacheSize == nodeCount) {

  89.            //删除尾结点

  90.            delTail();

  91.        }

  92.        //写入头结点

  93.        addHead(node);

  94.    }

  95.    /**

  96.     * 添加头结点

  97.     *

  98.     * @param node

  99.     */

  100.    private void addHead(Node<K, V> node) {

  101.        //写入头结点

  102.        header.next = node;

  103.        node.tail = header;

  104.        header = node;

  105.        nodeCount++;

  106.        //如果写入的数据大于2个 就将初始化的头尾结点删除

  107.        if (nodeCount == 2) {

  108.            tailer.next.next.tail = null;

  109.            tailer = tailer.next.next;

  110.        }

  111.    }    

  112.    private void delTail() {

  113.        //把尾结点从缓存中删除

  114.        cacheMap.remove(tailer.getKey());

  115.        //删除尾结点

  116.        tailer.next.tail = null;

  117.        tailer = tailer.next;

  118.        nodeCount--;

  119.    }

  120.    private class Node<K, V> {

  121.        private K key;

  122.        private V value;

  123.        Node<K, V> tail;

  124.        Node<K, V> next;

  125.        public Node(K key, V value) {

  126.            this.key = key;

  127.            this.value = value;

  128.        }

  129.        public Node() {

  130.        }

  131.        public K getKey() {

  132.            return key;

  133.        }

  134.        public void setKey(K key) {

  135.            this.key = key;

  136.        }

  137.        public V getValue() {

  138.            return value;

  139.        }

  140.        public void setValue(V value) {

  141.            this.value = value;

  142.        }

  143.    }

  144.    @Override

  145.    public String toString() {

  146.        StringBuilder sb = new StringBuilder() ;

  147.        Node<K,V> node = tailer ;

  148.        while (node != null){

  149.            sb.append(node.getKey()).append(":")

  150.                    .append(node.getValue())

  151.                    .append("-->") ;

  152.            node = node.next ;

  153.        }

  154.        return sb.toString();

  155.    }

  156. }

源码: https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRUMap.java

实际效果,写入时:

  1.    @Test

  2.    public void put() throws Exception {

  3.        LRUMap<String,Integer> lruMap = new LRUMap(3) ;

  4.        lruMap.put("1",1) ;

  5.        lruMap.put("2",2) ;

  6.        lruMap.put("3",3) ;

  7.        System.out.println(lruMap.toString());

  8.        lruMap.put("4",4) ;

  9.        System.out.println(lruMap.toString());

  10.        lruMap.put("5",5) ;

  11.        System.out.println(lruMap.toString());

  12.    }

  13. //输出:

  14. 1:1-->2:2-->3:3-->

  15. 2:2-->3:3-->4:4-->

  16. 3:3-->4:4-->5:5-->

使用时:

  1.    @Test

  2.    public void get() throws Exception {

  3.        LRUMap<String,Integer> lruMap = new LRUMap(3) ;

  4.        lruMap.put("1",1) ;

  5.        lruMap.put("2",2) ;

  6.        lruMap.put("3",3) ;

  7.        System.out.println(lruMap.toString());

  8.        System.out.println("==============");

  9.        Integer integer = lruMap.get("1");

  10.        System.out.println(integer);

  11.        System.out.println("==============");

  12.        System.out.println(lruMap.toString());

  13.    }

  14. //输出

  15. 1:1-->2:2-->3:3-->

  16. ==============

  17. 1

  18. ==============

  19. 2:2-->3:3-->1:1-->

实现思路和上文提到的一致,说下重点:

  • 数据是直接利用 HashMap 来存放的。

  • 内部使用了一个双向链表来存放数据,所以有一个头结点 header,以及尾结点 tailer。

  • 每次写入头结点,删除尾结点时都是依赖于 header tailer,如果看着比较懵建议自己实现一个链表熟悉下,或结合下文的对象关系图一起理解。

  • 使用数据移动到链表头时,第一步是需要在双向链表中找到该节点。这里就体现出链表的问题了。查找效率很低,最差需要 O(N)。之后依赖于当前节点进行移动。

  • 在写入头结点时有判断链表大小等于 2 时需要删除初始化的头尾结点。这是因为初始化时候生成了两个双向节点,没有数据只是为了形成一个数据结构。当真实数据进来之后需要删除以方便后续的操作(这点可以继续优化)。

  • 以上的所有操作都是线程不安全的,需要使用者自行控制。

下面是对象关系图:

初始化时

写入数据时

  1. LRUMap<String,Integer> lruMap = new LRUMap(3) ;

  2. lruMap.put("1",1) ;

  1. lruMap.put("2",2) ;

  1. lruMap.put("3",3) ;

  1. lruMap.put("4",4) ;

获取数据时

数据和上文一样:

  1. Integer integer = lruMap.get("2");

通过以上几张图应该是很好理解数据是如何存放的了。

实现三

其实如果对 Java 的集合比较熟悉的话,会发现上文的结构和 LinkedHashMap 非常类似。

对此不太熟悉的朋友可以先了解下 LinkedHashMap 底层分析 。

所以我们完全可以借助于它来实现:

  1. public class LRULinkedMap<K,V> {

  2.    /**

  3.     * 最大缓存大小

  4.     */

  5.    private int cacheSize;

  6.    private LinkedHashMap<K,V> cacheMap ;

  7.    public LRULinkedMap(int cacheSize) {

  8.        this.cacheSize = cacheSize;

  9.        cacheMap = new LinkedHashMap(16,0.75F,true){

  10.            @Override

  11.            protected boolean removeEldestEntry(Map.Entry eldest) {

  12.                if (cacheSize + 1 == cacheMap.size()){

  13.                    return true ;

  14.                }else {

  15.                    return false ;

  16.                }

  17.            }

  18.        };

  19.    }

  20.    public void put(K key,V value){

  21.        cacheMap.put(key,value) ;

  22.    }

  23.    public V get(K key){

  24.        return cacheMap.get(key) ;

  25.    }

  26.    public Collection<Map.Entry<K, V>> getAll() {

  27.        return new ArrayList<Map.Entry<K, V>>(cacheMap.entrySet());

  28.    }

  29. }

源码: https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRULinkedMap.java

这次就比较简洁了,也就几行代码(具体的逻辑 LinkedHashMap 已经帮我们实现好了)

实际效果:

  1.    @Test

  2.    public void put() throws Exception {

  3.        LRULinkedMap<String,Integer> map = new LRULinkedMap(3) ;

  4.        map.put("1",1);

  5.        map.put("2",2);

  6.        map.put("3",3);

  7.        for (Map.Entry<String, Integer> e : map.getAll()){

  8.            System.out.print(e.getKey() + " : " + e.getValue() + "\t");

  9.        }

  10.        System.out.println("");

  11.        map.put("4",4);

  12.        for (Map.Entry<String, Integer> e : map.getAll()){

  13.            System.out.print(e.getKey() + " : " + e.getValue() + "\t");

  14.        }

  15.    }

  16. //输出

  17. 1 : 1    2 : 2   3 : 3  

  18. 2 : 2    3 : 3   4 : 4      

使用时:

  1.    @Test

  2.    public void get() throws Exception {

  3.        LRULinkedMap<String,Integer> map = new LRULinkedMap(4) ;

  4.        map.put("1",1);

  5.        map.put("2",2);

  6.        map.put("3",3);

  7.        map.put("4",4);

  8.        for (Map.Entry<String, Integer> e : map.getAll()){

  9.            System.out.print(e.getKey() + " : " + e.getValue() + "\t");

  10.        }

  11.        System.out.println("");

  12.        map.get("1") ;

  13.        for (Map.Entry<String, Integer> e : map.getAll()){

  14.            System.out.print(e.getKey() + " : " + e.getValue() + "\t");

  15.        }

  16.    }

  17. }

  18. //输出

  19. 1 : 1    2 : 2   3 : 3   4 : 4  

  20. 2 : 2    3 : 3   4 : 4   1 : 1

LinkedHashMap 内部也有维护一个双向队列,在初始化时也会给定一个缓存大小的阈值。初始化时自定义是否需要删除最近不常使用的数据,如果是则会按照实现二中的方式管理数据。

其实主要代码就是重写了 LinkedHashMap 的 removeEldestEntry 方法:

  1.    protected boolean removeEldestEntry(Map.Entry<K,V> eldest) {

  2.        return false;

  3.    }

它默认是返回 false,也就是不会管有没有超过阈值。

所以我们自定义大于了阈值时返回 true,这样 LinkedHashMap 就会帮我们删除最近最少使用的数据。

总结

以上就是对 LRU 缓存的实现,了解了这些至少在平时使用时可以知其所以然。

当然业界使用较多的还有 guava 的实现,并且它还支持多种过期策略。

号外

最近在总结一些 Java 相关的知识点,感兴趣的朋友可以一起维护。

地址: https://github.com/crossoverJie/Java-Interview


    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存