import os
import json
import hashlib
import time
import threading
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Union
from dataclasses import dataclass, asdict, field
from contextlib import contextmanager
import numpy as np
import torch

@dataclass
class CurvatureCacheConfig:

    cache_dir: str = "./cache/curvature"
    mode: str = "readwrite"
    compress: bool = True
    max_cache_size_gb: float = 50.0

    def __post_init__(self):
        assert self.mode in ["read", "write", "readwrite", "refresh", "disabled"], \
            f"Invalid cache mode: {self.mode}"

@dataclass
class CurvatureCacheEntry:

    positions: List[int]
    kappa_values: List[float]
    config_hash: str
    text_hash: str
    model_name: str
    model_revision: Optional[str]
    tokenizer_name: str
    tokenizer_revision: Optional[str]
    timestamp: float
    computation_time_ms: float
    num_forward_passes: int
    belief_window_L: Optional[int] = None
    belief_window_R: Optional[int] = None
    stride_used: Optional[int] = None
    code_version: str = "1.0.0"
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, d: Dict) -> 'CurvatureCacheEntry':
        d.setdefault('model_revision', None)
        d.setdefault('tokenizer_revision', None)
        d.setdefault('belief_window_L', None)
        d.setdefault('belief_window_R', None)
        d.setdefault('stride_used', None)
        d.setdefault('code_version', '1.0.0')
        return cls(**d)

@dataclass
class CacheStats:

    hits: int = 0
    misses: int = 0
    writes: int = 0
    total_time_saved_ms: float = 0.0
    positions_served_from_cache: int = 0
    positions_computed: int = 0

    @property
    def hit_rate(self) -> float:
        total = self.hits + self.misses
        return self.hits / max(1, total)

    def to_dict(self) -> Dict:
        return {
            "hits": self.hits,
            "misses": self.misses,
            "writes": self.writes,
            "hit_rate": self.hit_rate,
            "total_time_saved_ms": self.total_time_saved_ms,
            "positions_served_from_cache": self.positions_served_from_cache,
            "positions_computed": self.positions_computed,
        }

class CurvatureCache:

    _lock = threading.RLock()

    def __init__(self, config: Optional[CurvatureCacheConfig] = None):
        self.config = config or CurvatureCacheConfig()
        self.cache_dir = Path(self.config.cache_dir)

        if self.config.mode != "disabled":
            self.cache_dir.mkdir(parents=True, exist_ok=True)

        self.stats = CacheStats()

        self.index_path = self.cache_dir / "index.json"
        self._load_index()

    def _load_index(self):

        if self.index_path.exists():
            try:
                with open(self.index_path, 'r') as f:
                    self.index = json.load(f)
            except:
                self.index = {"entries": {}, "stats": {}}
        else:
            self.index = {"entries": {}, "stats": {}}

    def _save_index(self):

        if self.config.mode == "disabled":
            return
        self.index["stats"] = self.stats.to_dict()
        try:
            with open(self.index_path, 'w') as f:
                json.dump(self.index, f, indent=2)
        except Exception as e:
            print(f"[Cache] Warning: Failed to save index: {e}")

    @staticmethod
    def compute_text_hash(text: str) -> str:

        return hashlib.sha256(text.encode('utf-8')).hexdigest()[:32]

    @staticmethod
    def compute_config_hash(config_dict: Dict) -> str:

        config_str = json.dumps(config_dict, sort_keys=True, default=str)
        return hashlib.sha256(config_str.encode('utf-8')).hexdigest()[:16]

    @staticmethod
    def compute_positions_hash(positions: List[int]) -> str:

        pos_str = ",".join(map(str, sorted(positions)))
        return hashlib.sha256(pos_str.encode('utf-8')).hexdigest()[:12]

    def build_cache_key(
        self,
        text: str,
        config_dict: Dict,
        model_name: str,
        tokenizer_name: str,
        model_revision: Optional[str] = None,
        tokenizer_revision: Optional[str] = None,
        positions: Optional[List[int]] = None,
    ) -> Tuple[str, str, str]:

        text_hash = self.compute_text_hash(text)

        full_config = {
            **config_dict,
            "model_name": model_name,
            "model_revision": model_revision or "default",
            "tokenizer_name": tokenizer_name,
            "tokenizer_revision": tokenizer_revision or "default",
        }
        if positions:
            full_config["positions_hash"] = self.compute_positions_hash(positions)

        config_hash = self.compute_config_hash(full_config)

        full_key = f"{config_hash}_{text_hash[:8]}"

        return text_hash, config_hash, full_key

    def _get_cache_path(
        self,
        text_hash: str,
        config_hash: str,
        dataset: str = "default",
        split: str = "default",
        example_id: str = "0"
    ) -> Path:

        return self.cache_dir / dataset / split / example_id / f"{config_hash}_{text_hash[:8]}.pt"

    def get(
        self,
        text: str,
        config_dict: Dict,
        model_name: str,
        tokenizer_name: str,
        model_revision: Optional[str] = None,
        tokenizer_revision: Optional[str] = None,
        dataset: str = "default",
        split: str = "default",
        example_id: str = "0",
        positions: Optional[List[int]] = None,
    ) -> Optional[CurvatureCacheEntry]:

        if self.config.mode not in ["read", "readwrite"]:
            return None

        text_hash, config_hash, full_key = self.build_cache_key(
            text, config_dict, model_name, tokenizer_name,
            model_revision, tokenizer_revision, positions
        )

        cache_path = self._get_cache_path(text_hash, config_hash, dataset, split, example_id)

        with self._lock:
            if cache_path.exists():
                try:
                    data = torch.load(cache_path, weights_only=False)
                    entry = CurvatureCacheEntry.from_dict(data)

                    if entry.text_hash == text_hash and entry.config_hash == config_hash:
                        self.stats.hits += 1
                        self.stats.total_time_saved_ms += entry.computation_time_ms
                        self.stats.positions_served_from_cache += len(entry.positions)
                        return entry
                except Exception as e:
                    print(f"[Cache] Failed to load {cache_path}: {e}")

            self.stats.misses += 1
            return None

    def put(
        self,
        text: str,
        config_dict: Dict,
        positions: List[int],
        kappa_values: Union[List[float], np.ndarray, torch.Tensor],
        model_name: str,
        tokenizer_name: str,
        computation_time_ms: float,
        num_forward_passes: int,
        model_revision: Optional[str] = None,
        tokenizer_revision: Optional[str] = None,
        belief_window_L: Optional[int] = None,
        belief_window_R: Optional[int] = None,
        stride_used: Optional[int] = None,
        dataset: str = "default",
        split: str = "default",
        example_id: str = "0",
        metadata: Optional[Dict] = None
    ) -> bool:

        if self.config.mode not in ["write", "readwrite", "refresh"]:
            return False

        if isinstance(kappa_values, np.ndarray):
            kappa_values = kappa_values.tolist()
        elif isinstance(kappa_values, torch.Tensor):
            kappa_values = kappa_values.cpu().tolist()

        text_hash, config_hash, full_key = self.build_cache_key(
            text, config_dict, model_name, tokenizer_name,
            model_revision, tokenizer_revision
        )

        cache_path = self._get_cache_path(text_hash, config_hash, dataset, split, example_id)

        with self._lock:
            cache_path.parent.mkdir(parents=True, exist_ok=True)

            entry = CurvatureCacheEntry(
                positions=positions,
                kappa_values=kappa_values,
                config_hash=config_hash,
                text_hash=text_hash,
                model_name=model_name,
                model_revision=model_revision,
                tokenizer_name=tokenizer_name,
                tokenizer_revision=tokenizer_revision,
                timestamp=time.time(),
                computation_time_ms=computation_time_ms,
                num_forward_passes=num_forward_passes,
                belief_window_L=belief_window_L,
                belief_window_R=belief_window_R,
                stride_used=stride_used,
                metadata=metadata or {}
            )

            try:
                torch.save(entry.to_dict(), cache_path)
                self.stats.writes += 1
                self.stats.positions_computed += len(positions)

                key = f"{dataset}/{split}/{example_id}/{config_hash}"
                self.index["entries"][key] = {
                    "path": str(cache_path),
                    "timestamp": entry.timestamp,
                    "num_positions": len(positions),
                    "text_hash": text_hash[:8]
                }

                return True
            except Exception as e:
                print(f"[Cache] Failed to save {cache_path}: {e}")
                return False

    def get_stats(self) -> Dict:

        return self.stats.to_dict()

    def print_stats(self, prefix: str = "[Cache]"):

        stats = self.get_stats()
        print(f"{prefix} Hits: {stats['hits']}, Misses: {stats['misses']}, "
              f"Hit Rate: {stats['hit_rate']*100:.1f}%, "
              f"Time Saved: {stats['total_time_saved_ms']/1000:.1f}s, "
              f"Positions Cached: {stats['positions_served_from_cache']}, "
              f"Positions Computed: {stats['positions_computed']}")

    def clear(self, dataset: Optional[str] = None):

        import shutil

        if dataset:
            target = self.cache_dir / dataset
            if target.exists():
                shutil.rmtree(target)
        else:
            for item in self.cache_dir.iterdir():
                if item.is_dir():
                    shutil.rmtree(item)
                elif item.name != "index.json":
                    item.unlink()

        self.index = {"entries": {}, "stats": {}}
        self._save_index()

    def get_cache_size_mb(self) -> float:

        total_size = 0
        for path in self.cache_dir.rglob("*.pt"):
            total_size += path.stat().st_size
        return total_size / (1024 * 1024)

    def finalize(self):

        self._save_index()
        self.print_stats()

_global_cache: Optional[CurvatureCache] = None
_global_cache_lock = threading.Lock()

def get_global_cache(config: Optional[CurvatureCacheConfig] = None) -> CurvatureCache:

    global _global_cache
    with _global_cache_lock:
        if _global_cache is None:
            _global_cache = CurvatureCache(config)
        return _global_cache

def reset_global_cache():

    global _global_cache
    with _global_cache_lock:
        if _global_cache is not None:
            _global_cache.finalize()
        _global_cache = None

@contextmanager
def cache_context(config: Optional[CurvatureCacheConfig] = None):

    cache = CurvatureCache(config)
    try:
        yield cache
    finally:
        cache.finalize()
