import threading
from collections import OrderedDict
from queue import Queue, Empty
from typing import Dict, List, Set, Optional, Tuple
import torch


class TieredCacheManager:
    def __init__(
        self,
        storage_manager,
        max_ram_gb: float = 16.0,
        block_size: int = 4096,
        point_dim: int = 59,
        eviction_threshold: float = 0.8
    ):
        self.storage = storage_manager
        self.max_ram_bytes = int(max_ram_gb * 1024 * 1024 * 1024)
        self.block_size = block_size
        self.point_dim = point_dim
        self.eviction_threshold = eviction_threshold
        self.bytes_per_block = block_size * point_dim * 4

        self.cache_data: OrderedDict[int, torch.Tensor] = OrderedDict()
        self.dirty_set: Set[int] = set()
        self.cache_lock = threading.RLock()

        self.flushing_buffer: Dict[int, Tuple[torch.Tensor, float]] = {}
        self.flushing_lock = threading.RLock()

        self.sync_queue: Queue = Queue(maxsize=100)
        self.sync_thread = None
        self.sync_running = False

        self._start_sync_thread()

    def _start_sync_thread(self):
        self.sync_running = True
        self.sync_thread = threading.Thread(target=self._sync_worker, daemon=True)
        self.sync_thread.start()

    def _sync_worker(self):
        while self.sync_running:
            try:
                item = self.sync_queue.get(timeout=0.1)
                block_id, gpu_tensor, cuda_stream = item

                if cuda_stream is not None:
                    cuda_stream.synchronize()

                cpu_tensor = gpu_tensor.cpu().clone()

                with self.cache_lock:
                    self.cache_data[block_id] = cpu_tensor
                    self.cache_data.move_to_end(block_id)
                    self.dirty_set.add(block_id)

                self._maybe_evict()

            except Empty:
                continue
            except Exception as e:
                print(f"[TieredCache] Sync worker error: {e}")

    def sync_from_gpu(
        self,
        block_ids: List[int],
        gpu_tensors: List[torch.Tensor],
        cuda_stream: Optional[torch.cuda.Stream] = None
    ):
        for block_id, gpu_tensor in zip(block_ids, gpu_tensors):
            self.sync_queue.put((block_id, gpu_tensor, cuda_stream))

    def prefetch(self, needed_block_ids: List[int]) -> Dict[int, torch.Tensor]:
        result = {}
        cache_hits = []
        cache_misses = []

        with self.cache_lock:
            for block_id in needed_block_ids:
                if block_id in self.cache_data:
                    result[block_id] = self.cache_data[block_id]
                    self.cache_data.move_to_end(block_id)
                    cache_hits.append(block_id)
                else:
                    cache_misses.append(block_id)

        still_missing = []
        with self.flushing_lock:
            for block_id in cache_misses:
                if block_id in self.flushing_buffer:
                    result[block_id] = self.flushing_buffer[block_id][0]
                else:
                    still_missing.append(block_id)

        if still_missing:
            ssd_blocks = self.storage.read_blocks(still_missing)
            with self.cache_lock:
                for block_id, tensor in ssd_blocks.items():
                    self.cache_data[block_id] = tensor
                    self.cache_data.move_to_end(block_id)
                    result[block_id] = tensor
            self._maybe_evict()

        return result

    def _maybe_evict(self):
        with self.cache_lock:
            current_bytes = len(self.cache_data) * self.bytes_per_block
            threshold = self.max_ram_bytes * self.eviction_threshold

            while current_bytes > threshold and len(self.cache_data) > 0:
                evicted_id, evicted_tensor = self.cache_data.popitem(last=False)

                if evicted_id in self.dirty_set:
                    with self.flushing_lock:
                        self.flushing_buffer[evicted_id] = (evicted_tensor, 0)
                    self.storage.write_patch({evicted_id: evicted_tensor})
                    self.dirty_set.discard(evicted_id)
                    with self.flushing_lock:
                        self.flushing_buffer.pop(evicted_id, None)

                current_bytes -= self.bytes_per_block

    def flush_all_dirty(self):
        with self.cache_lock:
            dirty_blocks = {bid: self.cache_data[bid] for bid in self.dirty_set if bid in self.cache_data}
        if dirty_blocks:
            self.storage.write_patch(dirty_blocks)
        with self.cache_lock:
            self.dirty_set.clear()

    def shutdown(self):
        self.sync_running = False
        if self.sync_thread:
            self.sync_thread.join(timeout=2.0)
        self.flush_all_dirty()
