import os
import json
import threading
from functools import wraps
import wandb
class CacheManager:
    """
    CacheManager for file-based caching with flexible key-based subdirectory structure.
    
    Args:
        cache_dir (str): Root directory for all caches.
        metadata_fields (list): List of key names to use for cache keys and directory structure.
    
    The cache and meta files will be stored at:
        cache_dir/<key1>/cache.json
    """
    def __init__(self, cache_dir, metadata_fields):
        self.cache_dir = cache_dir
        self.metadata_fields = metadata_fields
        # Build subdirectory path
        # subdir = self._build_subdir()
        self.full_cache_dir = self.cache_dir
        os.makedirs(self.full_cache_dir, exist_ok=True)
        print("Cache directory: ", self.full_cache_dir)
        self.cache_file = os.path.join(self.full_cache_dir, "cache.json")
        self.meta_file = os.path.join(self.full_cache_dir, "cache_meta.json")
        self.lock = threading.Lock()
        self.cache = self._load_cache()
        print("Cache initialized with entries: ", len(self.cache))  
        self.metadata = self._load_metadata(metadata_fields)
        self.hits = 0
        self.misses = 0
        self.new_entries_since_save = 0
        self.auto_save_threshold = 10

    # def _build_subdir(self):
    #     # Use only the field names for the subdirectory structure
    #     return os.path.join(*self.metadata_fields)

    def _load_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, 'r') as f:
                try:
                    return json.load(f)
                except Exception:
                    print("Error loading cache file: ", self.cache_file)
                    return {}
        return {}

    def _save_cache(self):
        with open(self.cache_file, 'w') as f:
            json.dump(self.cache, f)

    def _load_metadata(self, metadata_fields):
        if os.path.exists(self.meta_file):
            with open(self.meta_file, 'r') as f:
                return json.load(f)
        elif metadata_fields is not None:
            meta = {"key_fields": metadata_fields}
            with open(self.meta_file, 'w') as f:
                json.dump(meta, f)
            return meta
        else:
            raise ValueError("Metadata file not found and no metadata_fields provided.")

    def _make_key(self, kwargs):
        # Only use the fields specified in metadata
        key_fields = self.metadata["key_fields"]
        key_tuple = tuple(str(kwargs.get(field, '')) for field in key_fields)
        return json.dumps(key_tuple, sort_keys=True)

    def get(self, kwargs):
        key = self._make_key(kwargs)
        with self.lock:
            value = self.cache.get(key)
            if value is not None:
                self.hits += 1
            else:
                self.misses += 1

            if wandb.run is not None:
                hit_rate = self.report_hit_rate()
                wandb.log({"cache_hit_rate": hit_rate})

            return value
        # if wandb is initialized, report hit rate

    def set(self, kwargs, value):
        key = self._make_key(kwargs)
        with self.lock:
            # Check if this is a new entry (not just updating existing)
            if key not in self.cache:
                self.new_entries_since_save += 1
                # Auto-save if threshold reached
                if self.new_entries_since_save >= self.auto_save_threshold:
                    self._save_cache()
                    self.new_entries_since_save = 0
            self.cache[key] = value
            # Do not save here; defer to save() method

    def save(self):
        with self.lock:
            self._save_cache()
            self.new_entries_since_save = 0

    def report_hit_rate(self):
        total = self.hits + self.misses
        rate = self.hits / total if total > 0 else 0.0
        print(f"Cache hit rate: {self.hits}/{total} = {rate:.2%}")
        return rate


def cache_response(cache_manager):
    """
    Decorator for caching responses based on cache_manager.
    All key fields (as defined in metadata_fields) must be passed as keyword arguments to the decorated function.
    If cache_manager is None, the decorator does nothing and just returns the original function.
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if cache_manager is None:
                print("Cache manager is None, skipping cache")
                return func(*args, **kwargs)
            cached = cache_manager.get(kwargs)
            if cached is not None:
                return cached
            result = func(*args, **kwargs)
            if result is not None:
                cache_manager.set(kwargs, result)
            return result
        return wrapper
    return decorator 
