import copy
import json
import time

from safetensors import safe_open
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
import os
from transformers import PreTrainedModel, PretrainedConfig,AutoConfig

from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from .modeling_mixtral_kv import MixtralForCausalLM as KVMixtralForCausalLM
from .modeling_qwen2_kv import LlamaForCausalLM as KVQwen2ForCausalLM
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, AutoProcessor
from .modeling_llava_next import LlavaNextForConditionalGeneration

from .utils import *
from .kv_cache import initialize_past_key_values

from .cnets import Model
from .configs import EConfig





class EaModel(nn.Module):

    def __init__(
            self,
            base_model,
            base_model_name_or_path,
            ea_model_path,
            total_token,
            depth,
            top_k,
            threshold,
            ea_layer_state_dict
    ):

        super().__init__()
        self.embed_model = base_model
        base_model = base_model.language_model
        self.base_model = base_model
        self.config = base_model.config
        self.hidden_size = base_model.lm_head.weight.shape[-1]
        self.vocab_size = base_model.lm_head.weight.shape[0]
        self.base_model_name_or_path = base_model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path,use_fast=False)
        self.processor = AutoProcessor.from_pretrained(self.base_model_name_or_path)
        config = EConfig.from_pretrained(ea_model_path)
        with open(ea_model_path,"r") as f:
            con=json.loads(f.read())
        try:
            bias=con["bias"]
        except:
            bias=True
        self.ea_layer = Model(config,bias=bias,total_tokens=total_token,depth=depth,top_k=top_k,threshold=threshold)

        low_memory=False

        device = base_model.model.layers[-1].self_attn.q_proj.weight.device
        if device!=base_model.lm_head.weight.device:
            self.ea_layer.diff_device = True
            if not low_memory:
                self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device)
            else:
                self.ea_layer.layer_device = device

        else:
            self.ea_layer.diff_device = False

        print(base_model_name_or_path)
        try:
            with open(os.path.join(base_model_name_or_path, "model.safetensors.index.json"), "r") as f:
                index_json = json.loads(f.read())
                head_path = index_json["weight_map"]["language_model.lm_head.weight"]
                print(head_path)
            with safe_open(os.path.join(base_model_name_or_path, head_path),
                           framework="pt",
                           device="cpu") as f:
                tensor_slice = f.get_slice("language_model.lm_head.weight")
                vocab_size, hidden_dim = tensor_slice.get_shape()
                tensor = tensor_slice[:, :hidden_dim].to(torch.bfloat16)
        except Exception as e:
            with open(os.path.join(base_model_name_or_path, "pytorch_model.bin.index.json"), "r") as f:
                index_json = json.loads(f.read())
                head_path = index_json["weight_map"]["lm_head.weight"]
                weights = torch.load(os.path.join(base_model_name_or_path, head_path), weights_only=False)
                tensor = weights["lm_head.weight"].to(torch.bfloat16)
        head = torch.nn.Linear(tensor.shape[1], tensor.shape[0], bias=False)
        head.weight.data = tensor
        self.ea_layer.head_weight = head

        self.ea_layer.load_state_dict(ea_layer_state_dict, strict=False)
        self.ea_layer.to(self.base_model.dtype).to(device)
        self.ea_layer.init_tree()
        self.ea_layer.embed_model = self.embed_model

    def get_tokenizer(self):
        """Get the tokenizer of the base model.

        Returns:
            Tokenizer: The tokenizer of the base model.
        """
        return self.tokenizer

    @classmethod
    def from_pretrained(
            cls,
            Type="LLaMA",
            base_model_path=None,
            ea_model_path=None,
            total_token=32,
            depth=6,
            top_k=4,
            threshold=1.0,
            **kwargs,
    ):
        Type=AutoConfig.from_pretrained(base_model_path).architectures[0]
        if Type=='LlamaForCausalLM':
            base_model = KVLlamaForCausalLM.from_pretrained(
                base_model_path, **kwargs
            )
        elif Type=='Qwen2ForCausalLM':
            base_model=KVQwen2ForCausalLM.from_pretrained(
                base_model_path, **kwargs
            )
        elif Type=='MixtralForCausalLM':
            base_model = KVMixtralForCausalLM.from_pretrained(
                base_model_path, **kwargs
            )
        elif Type=='LlavaNextForConditionalGeneration':
            base_model = LlavaNextForConditionalGeneration.from_pretrained(
                base_model_path, **kwargs
            )

        configpath=os.path.join(base_model_path,"config.json")
        if not os.path.exists(configpath):
            configpath = hf_hub_download(ea_model_path, "config.json")

        try:
            load_model_path=os.path.join(ea_model_path, "pytorch_model.bin")
            if not os.path.exists(load_model_path):
                load_model_path=hf_hub_download(ea_model_path, "pytorch_model.bin")
            ea_layer_state_dict = torch.load(load_model_path,
                                             map_location=base_model.device)
        except:
            from safetensors.torch import load_file
            load_model_path = os.path.join(ea_model_path, "model.safetensors")
            if not os.path.exists(load_model_path):
                load_model_path = hf_hub_download(ea_model_path, "model.safetensors")
            ea_layer_state_dict = load_file(load_model_path)
        model = cls(
            base_model,
            base_model_path,
            configpath,
            total_token,
            depth,
            top_k,
            threshold,
            ea_layer_state_dict
        )



        if total_token==-1:
            device = model.base_model.model.layers[0].self_attn.q_proj.weight.device
            cans=[40,48,50,56,60]
            x=[1,1.05,1.07,1.1,1.13]
            times=[]

            for i in range(len(cans)):
                length = cans[i]
                input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device)
                torch.cuda.synchronize()
                start_time = time.time()
                for _ in range(20):
                    torch.cuda.synchronize()
                    with torch.no_grad():
                        outputs = model.base_model(input_ids)
                    torch.cuda.synchronize()
                torch.cuda.synchronize()
                end_time = time.time()
                times.append((end_time - start_time) / x[i])
            total_token=cans[times.index(min(times))]
            model.ea_layer.total_tokens=total_token-1

        return model

    def forward(
            self,
            input_ids=None,
            inputs_embeds=None,
            attention_mask=None,
            past_key_values=None,
            output_orig=False,
            position_ids=None,
            output_attan_score=False,
            output_draft_attention_scores=False, 
    ):

        with torch.inference_mode():
            # Pass input through the base model
            outputs = self.base_model.model(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
                output_attan_score=output_attan_score,
            )
            if output_orig:
                orig = self.base_model.lm_head(outputs[0])
            hidden_states = outputs[0]

        if output_orig:
            if output_attan_score:
                return outputs, orig, hidden_states, outputs[-1]
            return outputs, orig, hidden_states
        else:
            if output_attan_score:
                return outputs, hidden_states, outputs[-1]
            return outputs, hidden_states

    @torch.no_grad()
    def eagenerate(
            self,
            inputs,
            temperature=0.0,
            top_p=0.0,
            top_k=0.0,
            max_new_tokens=512,
            max_length=2048,
            log=False,
            is_llama3=False,

    ):
        input_ids = inputs.input_ids.to(self.base_model.device)
        pixel_values = inputs.pixel_values.to(self.base_model.device).to(torch.bfloat16) if hasattr(inputs, "pixel_values") else None
        image_sizes = inputs.image_sizes.to(self.base_model.device) if hasattr(inputs, "image_sizes") else None

        if is_llama3:
            stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        max_length=max_length-self.ea_layer.total_tokens-10

        if temperature > 1e-5:
            logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
        else:
            logits_processor = None
        # Avoid modifying the input_ids in-place
        padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
        input_ids = input_ids.clone()
        self.ea_layer.reset_kv()



        # Initialize the past key and value states
        if hasattr(self, "past_key_values"):
            past_key_values = self.past_key_values
            past_key_values_data = self.past_key_values_data
            current_length_data = self.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(self.base_model)
            self.past_key_values = past_key_values
            self.past_key_values_data = past_key_values_data
            self.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_tree_mode(self)

        start = time.time()
        draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
            input_ids, self, past_key_values, logits_processor, self.embed_model, pixel_values, image_sizes
        )
        t0 = time.time() - start

        new_token = 0

        for idx in range(max_length):
            self.base_model.model.tree_mask = tree_mask

            draft_tokens=draft_tokens.to(input_ids.device)

            start = time.time()
            logits, hidden_state_new, outputs = tree_decoding(
                self,
                draft_tokens,
                past_key_values,
                tree_position_ids,
                input_ids,
                retrieve_indices,
            )
            t1 += time.time() - start
            draft_tokens=torch.cat((draft_tokens,padding),dim=1)
            candidates=draft_tokens[0,retrieve_indices]
            best_candidate, accept_length, sample_p = evaluate_posterior(
                logits, candidates, logits_processor
            )

            start = time.time()
            input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs(
                input_ids,
                candidates,
                best_candidate,
                accept_length,
                retrieve_indices,
                logits_processor,
                new_token,
                past_key_values_data,
                current_length_data,
                self,
                hidden_state_new,
                sample_p,
            )
            t2 += time.time() - start
            print(f"init time {t0}, verify time {t1}, draft time {t2}")
            if is_llama3:
                if stop_token_id in input_ids[0, input_len:].tolist():
                    print(f"init time {t0}, verify time {t1}, draft time {t2}")
                    break

            if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                print(f"init time {t0}, verify time {t1}, draft time {t2}")
                break
            if new_token > max_new_tokens:
                print(f"init time {t0}, verify time {t1}, draft time {t2}")
                break
            if input_ids.shape[1] > max_length:
                print(f"init time {t0}, verify time {t1}, draft time {t2}")
                break
        if not log:
            print(f"init time {t0}, verify time {t1}, draft time {t2}")
            return input_ids
        else:
            print(f"init time {t0}, verify time {t1}, draft time {t2}")
            return input_ids, new_token, idx


    @torch.no_grad()
    def naivegenerate(
            self,
            inputs,
            temperature=0.0,
            top_p=0.0,
            top_k=0.0,
            max_new_tokens=512,
            max_length=2048,
            log=False,
            is_llama3=False,

    ):
        inputs.to(self.base_model.device)
        input_ids = inputs.input_ids
        pixel_values = inputs.pixel_values.to(self.base_model.device) if hasattr(inputs, "pixel_values") else None
        image_sizes = inputs.image_sizes.to(self.base_model.device) if hasattr(inputs, "image_sizes") else None
        if is_llama3:
            stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        max_length = max_length - self.ea_layer.total_tokens - 10

        if temperature > 1e-5:
            logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
        else:
            logits_processor = None

        padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
        input_ids = input_ids.clone()
        self.ea_layer.reset_kv()

        # Initialize the past key and value states
        if hasattr(self, "past_key_values"):
            past_key_values = self.past_key_values
            past_key_values_data = self.past_key_values_data
            current_length_data = self.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(self.base_model)
            self.past_key_values = past_key_values
            self.past_key_values_data = past_key_values_data
            self.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_tree_mode(self)
        input_embeds = self.embed_model(**inputs)
        outputs = self.base_model(inputs_embeds=input_embeds, past_key_values=past_key_values, use_cache=True)
        new_token = 0

        for idx in range(max_length):
            if logits_processor is not None:
                logits = outputs.logits[:, -1]
                logits = logits_processor(None, logits)
                probabilities = torch.nn.functional.softmax(logits, dim=-1)
                input_id = torch.multinomial(probabilities, 1)
            else:
                input_id = outputs.logits[:, -1:].argmax(dim=-1)
            outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
            input_ids = torch.cat([input_ids, input_id], dim=-1)
            new_token+=1

            if is_llama3:
                if stop_token_id in input_ids[0, input_len:].tolist():
                    break

            if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                break
            if new_token > max_new_tokens:
                break
            if input_ids.shape[1] > max_length:
                break
        if not log:
            return input_ids
        else:
            return input_ids, new_token, idx

    @torch.no_grad()
    def ea_generate(
            self,
            inputs,
            temperature=0.0,
            top_p=0.0,
            top_k=0.0,
            max_new_tokens=512,
            max_length=4096,
            log=False,
            is_llama3=False,
            output_attention_scores=False,

    ):
        input_ids = inputs.input_ids.to(self.base_model.device)
        pixel_values = inputs.pixel_values.to(self.base_model.device).to(torch.bfloat16) if hasattr(inputs, "pixel_values") else None
        image_sizes = inputs.image_sizes.to(self.base_model.device) if hasattr(inputs, "image_sizes") else None

        original_prompt_length = int(input_ids.shape[1] * 1)

        if is_llama3:
            stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        max_length=max_length-self.ea_layer.total_tokens-10

        if temperature > 1e-5:
            logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
        else:
            logits_processor = None

        padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
        input_ids = input_ids.clone()
        self.ea_layer.reset_kv()



        # Initialize the past key and value states
        if hasattr(self, "past_key_values"):
            past_key_values = self.past_key_values
            past_key_values_data = self.past_key_values_data
            current_length_data = self.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(self.base_model)
            self.past_key_values = past_key_values
            self.past_key_values_data = past_key_values_data
            self.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_tree_mode(self)

        start = time.time()

        if output_attention_scores:
            draft_input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token, target_score, draft_score, image_start, image_end, text_start, text_end  = initialize_tree(
                input_ids, self, past_key_values, logits_processor, self.embed_model, pixel_values, image_sizes, output_draft_attention_scores=True, original_prompt_length=original_prompt_length
            )
            all_attention_scores = [draft_score] 
        else:
            draft_input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token,  target_score, draft_score = initialize_tree(
                input_ids, self, past_key_values, logits_processor, self.embed_model, pixel_values, image_sizes
            )
            all_attention_scores = []


        t0 = time.time() - start
        t1 = t2 = 0
        new_token = 0

        for idx in range(1, max_length+1):
            self.base_model.model.tree_mask = tree_mask

            draft_tokens=draft_tokens.to(input_ids.device)
            start = time.time()
            logits, hidden_state_new, outputs = tree_decoding(
                self,
                draft_tokens,
                past_key_values,
                tree_position_ids,
                input_ids,
                retrieve_indices,
                output_draft_attention_scores=output_attention_scores
            )
            t1 += time.time() - start

            draft_tokens=torch.cat((draft_tokens,padding),dim=1)
            candidates=draft_tokens[0,retrieve_indices]
            if True:
                for i, candidate in enumerate(candidates):
                    # Filter out padding tokens (-1)
                    valid_tokens = [t.item() for t in candidate if t.item() != -1]
                    # Convert tokens to text if you have access to tokenizer
                    text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
            best_candidate, accept_length, sample_p = evaluate_posterior(
                logits, candidates, logits_processor
            )

            start = time.time()
     
            if output_attention_scores:
                input_ids, draft_input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token, draft_attention_scores = update_inference_inputs(
                    input_ids,
                    draft_input_ids,
                    candidates,
                    best_candidate,
                    accept_length,
                    retrieve_indices,
                    logits_processor,
                    new_token,
                    past_key_values_data,
                    current_length_data,
                    self,
                    hidden_state_new,
                    sample_p,
                    output_draft_attention_scores=True,
                    original_prompt_length=original_prompt_length,
                    image_start=image_start, 
                    image_end=image_end, 
                    text_start=text_start, 
                    text_end=text_end,
                )
                if draft_attention_scores:
                    all_attention_scores.append(draft_attention_scores)
            else:
                input_ids, draft_input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token, draft_attention_scores = update_inference_inputs(
                    input_ids,
                    draft_input_ids,
                    candidates,
                    best_candidate,
                    accept_length,
                    retrieve_indices,
                    logits_processor,
                    new_token,
                    past_key_values_data,
                    current_length_data,
                    self,
                    hidden_state_new,
                    sample_p,
                )
            t2+=time.time()-start

            if output_attention_scores:
                yield input_ids, target_score, all_attention_scores
            else:
                yield input_ids

            if is_llama3:
                if stop_token_id in input_ids[0, input_len:].tolist():
                    print(f"init time {t0}, verify time {t1/idx}, draft time {t2/idx}")
                    break

            if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                print(f"init time {t0}, verify time {t1/idx}, draft time {t2/idx}")
                break
            if new_token > max_new_tokens:
                print(f"init time {t0}, verify time {t1/idx}, draft time {t2/idx}")
                break
            if input_ids.shape[1] > max_length:
                print(f"init time {t0}, verify time {t1/idx}, draft time {t2/idx}")
                break


    @torch.no_grad()
    def naive_generate(
            self,
            inputs,
            temperature=0.0,
            top_p=0.0,
            top_k=0.0,
            max_new_tokens=512,
            max_length=4096,
            log=False,
            is_llama3=False,

    ):
        input_ids = inputs.input_ids.to(self.base_model.device)
        if is_llama3:
            stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        max_length = max_length - self.ea_layer.total_tokens - 10

        if temperature > 1e-5:
            logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
        else:
            logits_processor = None
        # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
        # Avoid modifying the input_ids in-place
        padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
        input_ids = input_ids.clone()
        self.ea_layer.reset_kv()

        # Initialize the past key and value states
        if hasattr(self, "past_key_values"):
            past_key_values = self.past_key_values
            past_key_values_data = self.past_key_values_data
            current_length_data = self.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(self.base_model)
            self.past_key_values = past_key_values
            self.past_key_values_data = past_key_values_data
            self.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_tree_mode(self)

        input_embeds = self.embed_model(**inputs)
        outputs = self.base_model(inputs_embeds=input_embeds, past_key_values=past_key_values, use_cache=True)
        new_token = 0


        for idx in range(max_length):
            if logits_processor is not None:
                logits = outputs.logits[:, -1]
                logits = logits_processor(None, logits)
                probabilities = torch.nn.functional.softmax(logits, dim=-1)
                input_id = torch.multinomial(probabilities, 1)
            else:
                input_id = outputs.logits[:, -1:].argmax(dim=-1)

            outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
            input_ids = torch.cat([input_ids, input_id], dim=-1)
            new_token += 1

            yield input_ids



            if is_llama3:
                if stop_token_id in input_ids[0, input_len:].tolist():
                    break

            if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                break
            if new_token > max_new_tokens:
                break
            if input_ids.shape[1] > max_length:
                break
