import copy
import json
import time

import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig,AutoConfig
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from .utils import *
from .kv_cache import initialize_past_key_values
from .choices import medusa
from transformers import AutoTokenizer
import os
from huggingface_hub import hf_hub_download
from .draft_head import Model
from .configs import setConfig
from huggingface_hub import hf_hub_download




class Model(nn.Module):

    def __init__(
            self,
            base_model,
            base_model_name_or_path,
            head_path,
    ):

        super().__init__()
        self.base_model = base_model
        self.config = base_model.config
        self.hidden_size = base_model.lm_head.weight.shape[-1]
        self.vocab_size = base_model.lm_head.weight.shape[0]
        self.base_model_name_or_path = base_model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
        config = setConfig.from_pretrained(head_path)
        with open(head_path,"r") as f:
            con=json.loads(f.read())
        try:
            bias=con["bias"]
        except:
            bias=True
        self.draft_module = Model(config,load_emb=True,path=base_model_name_or_path,bias=bias)

        low_memory=False

        device = base_model.model.layers[-1].self_attn.q_proj.weight.device
        if device!=base_model.lm_head.weight.device:
            self.draft_module.diff_device = True
            if not low_memory:
                self.draft_module.headweight = base_model.lm_head.weight.clone().to(device)
            else:
                self.draft_module.layer_device = device

        else:
            self.draft_module.diff_device = False
        self.draft_module.to(self.base_model.dtype).to(device)
        self.draft_module.init_tree()

    def get_tokenizer(self):
        """Get the tokenizer of the base model.

        Returns:
            Tokenizer: The tokenizer of the base model.
        """
        return self.tokenizer

    @classmethod
    def from_pretrained(
            cls,
            Type="LLaMA",
            base_model_path=None,
            head_path=None,
            use_safetensor_weight = False,
            **kwargs,
    ):

        Type=AutoConfig.from_pretrained(base_model_path).architectures[0]
        if Type=='LlamaForCausalLM':
            base_model = KVLlamaForCausalLM.from_pretrained(
                base_model_path, **kwargs
            )

        if use_safetensor_weight==False:
            
            configpath=os.path.join(head_path,"config.json")
            if not os.path.exists(configpath):
                configpath = hf_hub_download(head_path, "config.json")
            model = cls(
                base_model,
                base_model_path,
                configpath
            )
            load_model_path=os.path.join(head_path, "pytorch_model.bin")
            if not os.path.exists(load_model_path):
                load_model_path=hf_hub_download(head_path, "pytorch_model.bin")
            draft_module_state_dict = torch.load(load_model_path,
                                            map_location=base_model.device)
            model.draft_module.load_state_dict(draft_module_state_dict, strict=True)
            
        else:
            configpath="vicuna_13B_config.json"
            load_model_path=os.path.join(head_path, "model.safetensors")
            if not os.path.exists(load_model_path):
                load_model_path=hf_hub_download(head_path, "model.safetensors")
            model = cls(
                base_model,
                base_model_path,
                configpath
            )
            from safetensors.torch import load_file
            draft_module_state_dict = load_file(load_model_path,
                                        device="cuda")
            model.draft_module.load_state_dict(draft_module_state_dict, strict=True)
                        
        return model

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            labels=None,
            past_key_values=None,
            output_orig=False,
            position_ids=None,
            init=True,
            logits_processor=None
    ):

        with torch.inference_mode():

            outputs = self.base_model.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
            )
            if output_orig:
                orig = self.base_model.lm_head(outputs[0])
            hidden_states = outputs[0].clone()
        if init:
            if logits_processor is not None:
                logits = orig[:, -1]
                logits = logits_processor(None, logits)
                probabilities = torch.nn.functional.softmax(logits, dim=1)
                token = torch.multinomial(probabilities, 1)
            else:
                token = torch.argmax(orig[:, -1])
                token = token[None, None]
            input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)


            draft_logits,topk_time = self.draft_module.topK_genrate(hidden_states, input_ids, self.base_model.lm_head, logits_processor)
            if output_orig:
                return draft_logits, outputs, orig, hidden_states, token
            return draft_logits, hidden_states, token
        else:
            if output_orig:
                return outputs, orig, hidden_states

