import random
import torch
import numpy as np

from transformers import AutoTokenizer, AutoModelForCausalLM

from cache_utils import *


def load_kv_cache(
    method,
    num_recent: int = None,
    num_heavy: int = None,
    recent_ratio: float = None,
    heavy_ratio: float = None,
    decode_evict: bool = True,  # whether to evict during decoding
    fix_recent_token: bool = False,  # whether to fix the number of recent tokens in the ratio scenario
    cache_ratio: float = None,  # used when fix_recent is True
):
    if method == 'sink':
        past_key_values = SinkCache(num_recent=num_recent,
                                    num_heavy=num_heavy,
                                    recent_ratio=recent_ratio,
                                    sink_ratio=heavy_ratio)

    elif method in ['h2o', 'h2o+ptb_window', 'h2o+maxpool', 'h2o+avgpool', 'snapkv', 'snapkv_avgpool', 'tova',
                    'obc_value_p1', 'obc_value_p1+maxpool', 'obc_value_p1+avgpool', 'obc_value_p1+tova',
                    'obc_value_p2', 'obc_value_p2+maxpool', 'obc_value_p2+avgpool', 'obc_value_p2+tova',
                    'obc_key_p2', 'obc_key_p2+maxpool', 'obc_key_p2+avgpool', 'obc_key_p2+tova',
                    'obc_value_key_p2', 'obc_value_key_p2+maxpool', 'obc_value_key_p2+avgpool', 'obc_value_key_p2+tova',
                    'obc_value_key_p2_wo_cross']:
        
        if method.startswith('h2o'):
            use_v_score = True; use_k_score = False; use_cross = False; ptb_is_recent = False
            p = 1; use_act = False
            pool_fn = "maxpool" if "maxpool" in method else "avgpool" if "avgpool" in method else None
            ptb_window = num_recent if "ptb_window" in method else None
        
        elif method.startswith('snapkv'):
            use_v_score = True; use_k_score = False; use_cross = False
            ptb_window = num_recent; ptb_is_recent = False
            p = 1; use_act = False
            pool_fn = "maxpool" if "avgpool" not in method else "avgpool"

        elif method.startswith('obc_value'):
            use_v_score = True; use_act = True
            ptb_window = num_recent; ptb_is_recent = False
            pool_fn = "maxpool" if "maxpool" in method else "avgpool" if "avgpool" in method else None
            p = 1 if 'p1' in method else 2

            if "key" in method:
                use_k_score = True
                use_cross = True if "wo_cross" not in method else False
            else:
                use_k_score = False
                use_cross = False
            
            if "+tova" in method:
                ptb_window = 1
                ptb_is_recent = True

        elif method.startswith('obc_key'):
            use_v_score = False; use_k_score = True; use_cross = False
            ptb_window = num_recent; ptb_is_recent = False
            pool_fn = "maxpool" if "maxpool" in method else "avgpool" if "avgpool" in method else None
            p = 2; use_act = True

            if "+tova" in method:
                ptb_window = 1
                ptb_is_recent = True

        elif method == 'tova':  
            use_v_score = True; use_k_score = False; use_cross = False
            p = 1; use_act = False; pool_fn = None
            ptb_window = 1
            ptb_is_recent = True

        if "tova" in method:
            if num_recent is not None and num_recent > 0:
                if num_heavy is not None:
                    num_heavy += num_recent
                num_recent = 0
                print("Warning: For tova, num_recent should be 0. Setting num_recent to 0 and num_heavy to ", num_heavy)
            if recent_ratio is not None and recent_ratio > 0:
                heavy_ratio += recent_ratio
                recent_ratio = 0
                print("Warning: For tova, recent_ratio should be 0.0. Setting recent_ratio to 0.0 and heavy_ratio to ", heavy_ratio)

        past_key_values = OBCache(
            num_recent=num_recent,
            num_heavy=num_heavy,
            recent_ratio=recent_ratio,
            heavy_ratio=heavy_ratio,
            decode_evict=decode_evict,
            fix_recent_token=fix_recent_token,
            cache_ratio=cache_ratio,
            use_v_score=use_v_score,
            use_k_score=use_k_score,
            use_cross=use_cross,
            p=p,
            use_act=use_act,
            ptb_window=ptb_window,
            pool_fn=pool_fn,
            ptb_is_recent=ptb_is_recent,
            num_sink=0,
        )
        past_key_values.method = method

    elif method == 'full':
        past_key_values = None

    else:
        raise ValueError(f"Unknown kv_cache method: {method}")

    return past_key_values



def load_model_and_tokenizer( 
                model_name_or_path,
                precision="fp16",
                hf_cache_dir=None,
                flash_attn=False,
                ):

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, cache_dir=hf_cache_dir)
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, 
                                                 attn_implementation="eager" if not flash_attn else "flash_attention_2",
                                                 torch_dtype=torch.float16 if precision == "fp16" else torch.bfloat16 if precision == "bf16" else torch.float32,
                                                 trust_remote_code=True,
                                                 cache_dir=hf_cache_dir)
    return model, tokenizer


def seed_everything(seed):
    random.seed(seed) 
    np.random.seed(seed) 
    torch.manual_seed(seed) 
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False