from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

from collections import OrderedDict

import torch


class LRUCache:
    def __init__(
        self, 

        num_gpu_resident_lim: int, 

        device: Optional[str] = "cpu", 
    ):
        self._cache = OrderedDict()
        self._gpu_resident_node_set = set()

        self._num_gpu_resident_lim = num_gpu_resident_lim

        if device == "cpu":
            raise ValueError(
                f"`device` should be GPU. "
            )
        self._device = device

        # `__init__()` done
        pass


    @property
    def num_gpu_resident(
        self
    ) -> int:
        return len(self._gpu_resident_node_set)


    def _move_to_cpu(
        self, 

        node: "MCTSNode", 

        sample_idx: int
    ):
        """
        Func:
            Ensure the `node._state` on CPU. 
        """

        state = node.info_list[sample_idx]._state

        if state.is_cuda:
            state = state.cpu()
            self._gpu_resident_node_set.remove(node)

        # `_move_to_cpu()` done
        pass


    def _offload_least_recently_used(
        self, 

        sample_idx: int
    ):
        """
        Func:
            Offload the least recently used state to CPU.
        """

        for node in self.cache:
            state = node.info_list[sample_idx]._state

            if state.is_cuda:
                self._move_to_cpu(node)
                
                return

        # `_offload_least_recently_used()` done
        pass


    def _move_to_gpu(
        self, 

        node: "MCTSNode", 

        sample_idx: int
    ):
        """
        Func:
            Ensure the `node.info_list[sample_idx]._state` on GPU. 
        """

        state = node.info_list[sample_idx]._state
        
        if not state.is_cuda:
            while len(self._gpu_resident_node_set) >= self._num_gpu_resident_lim:
                self._offload_least_recently_used()

            state = state.to(self._device)
            
            del node.info_list[sample_idx]._state
            node.info_list[sample_idx]._state = state

        self._gpu_resident_node_set.add(node)

        # `_move_to_gpu()` done
        pass


    def get(
        self, 

        node: "MCTSNode", 

        sample_idx: int
    ) -> torch.Tensor:
        """
        Func:
            Get the state of the MCSTNode `key` from the cache. 
            Move the access key to the end, i.e., most recently used. 

        Ret:
            `state` (`torch.Tensor`): The queried state. 
        """

        if node not in self._cache:
            return None

        self._cache.move_to_end(
            key = node, 
            last = True
        )

        self._move_to_gpu(
            node = node, 

            sample_idx = sample_idx
        )

        state = self._cache[node]

        # `get()` done
        return state


    def push(
        self, 

        node: "MCTSNode", 

        sample_idx: int
    ):
        """
        Func:
            Push the key to the cache, and ensure the `state` on GPU. 
            If the key is already in the cache, move it to the end, i.e., most recently used. 
            If the number of GPU-resident tensors exceeds `self._num_gpu_resident_lim`, 
                offload the least recently used one to CPU.
        """

        if node not in self._cache:
            self._cache[node] = node.info_list[sample_idx]._state

        self._cache.move_to_end(
            key = node, 
            last = True
        )

        self._move_to_gpu(
            node = node, 

            sample_idx = sample_idx
        )
        
        # `push()` done
        pass

    
    def access(
        self, 

        node: "MCTSNode", 

        sample_idx: int
    ):
        """
        Func:
            Ensure the `node.info_list[sample_idx]._state` is on GPU, and set to be most recently used. 
        """
        
        if node in self._cache:
            self.get(
                node = node, 
                sample_idx = sample_idx
            )
        else:
            self.push(
                node = node, 
                sample_idx = sample_idx
            )

        # `access()` done
        pass
