import hashlib
import numpy as np
import torch
import diskcache
import os
import functools
import inspect
import json
from transformers import HfArgumentParser, AutoModelForCausalLM
import __main__

def printif(condition, message):
    if condition:
        print(message)

def parse_args(Arguments):
    if hasattr(__main__, '__file__'):
        parser = HfArgumentParser((Arguments,))
        args = Arguments(**vars(parser.parse_args()))
    else:
        args = Arguments()
    return args

def update_args(args, **kwargs):
    for key, value in kwargs.items():
        if hasattr(args, key):
            setattr(args, key, value)
        else:
            raise ValueError(f"Invalid argument: {key} for {args.__class__.__name__}")
    return args

def quick_hash(instance, depth=0):
    
    # For models, we can just hash the name_or_path
    if hasattr(instance, "name_or_path"):
        hashed = instance.name_or_path

    elif isinstance(instance, np.ndarray):
        hashed = hashlib.sha256(instance.tobytes()).hexdigest()

    elif isinstance(instance, torch.Tensor):
        hashed = hashlib.sha256(instance.cpu().numpy().tobytes()).hexdigest()

    elif isinstance(instance, dict):
        hashed = "".join([f"{k}:{quick_hash(v, depth + 1)}" for k, v in instance.items()])

    elif isinstance(instance, list):
        hashed = "".join([f"{quick_hash(i, depth + 1)}" for i in instance])

    else:
        try:
            hashed = json.dumps(instance)

        except Exception as e:
            raise ValueError(f"Instance: {instance} is not JSON serializable")

    if depth > 5:
        raise ValueError("Depth limit reached for hash")
    
    if len(hashed) > 1000:
        hashed = hashlib.sha256(hashed.encode()).hexdigest()

    return hashed


def prompt_cache(param_names: list = None, cache_dir: str = None, cache_name: str = 'prompt_cache', verbose: bool = False):
    """
    Decorator to cache function output based on input parameters.
    param_names: list of parameter names to include in key (defaults to all).
    cache_dir: directory to store cache (defaults to ~/.cache).
    cache_name: subdirectory name for cache.
    verbose: if True, prints cache operations (init, hit, miss, save).
    """
    # Determine cache directory
    if cache_dir is None:
        cache_dir = os.path.join(os.path.expanduser('~'), '.cache', cache_name)
    printif(verbose, f"Initializing cache at: {cache_dir}")
    cache = diskcache.Cache(cache_dir)

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Check if cache is completely disabled first
            if os.environ.get("PROMPT_CACHE_DISABLED"):
                printif(verbose, "Cache disabled")
                return func(*args, **kwargs)
            
            # Bind arguments to parameter names and calculate key
            bound = inspect.signature(func).bind_partial(*args, **kwargs)
            bound.apply_defaults()
            # Determine which parameters to include
            names = param_names or list(bound.arguments.keys())
            # Build key from parameter names and hashed values
            key_parts = []
            for name in names:
                if name in bound.arguments:
                    value = bound.arguments[name]
                    hashed = quick_hash(value)
                    key_parts.append(f"{name}:{hashed}")
            key = '|'.join(key_parts)
            
            # Check if write-only mode is enabled
            if os.environ.get("PROMPT_CACHE_WRITE_ONLY"):
                result = func(*args, **kwargs)
                printif(verbose, f"Saving result to cache for key: {key}")
                cache[key] = result
                return result
            # Return cached result if available
            if key in cache:
                printif(verbose, f"Cache hit for key: {key}")
                return cache[key]
            # Compute and cache result
            printif(verbose, f"Cache miss for key: {key}")
            try:
                result = func(*args, **kwargs)
                
                # Only cache if result is valid (not None/empty and no errors)
                if result is not None and result != "" and not isinstance(result, Exception):
                    printif(verbose, f"Saving result to cache for key: {key}")
                    cache[key] = result
                else:
                    printif(verbose, f"Skipping cache save for invalid result (None/empty): {key}")
                
                return result
            except Exception as e:
                printif(verbose, f"Skipping cache save due to exception for key: {key}")
                # Re-raise the exception without caching
                raise

        return wrapper
    return decorator
