import copy
import json
import time
import os
import inspect
from collections import deque

import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import PreTrainedModel, PretrainedConfig

from utils import Timer, prepare_logits_processor, pad_path, generate_tree_buffers, initialize_tree, reset_tree_mode, reset_past_key_values, generate_candidates, tree_decoding, evaluate_posterior, update_inference_inputs, update_inputs_only
from kv_cache import initialize_past_key_values

from cnets_infer import Model
# from .cnets1 import Model as Model1
from configs import EConfig


class EaModel(nn.Module):

    def __init__(
        self,
        use_eagle3,
        base_model,
        base_model_name_or_path,
        ea_model_path,
        total_token,
        depth,
        top_k,
        threshold,
        ea_layer_state_dict,
        confidence_loss_type
    ):

        super().__init__()
        self.base_model = base_model
        self.config = base_model.config
        self.hidden_size = base_model.lm_head.weight.shape[-1]
        self.vocab_size = base_model.lm_head.weight.shape[0]
        self.base_model_name_or_path = base_model_name_or_path
        self.use_eagle3 = use_eagle3
        self.confidence_loss_type = confidence_loss_type

        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path, use_fast=False)
        self.model_type = AutoConfig.from_pretrained(self.base_model_name_or_path).architectures[0]

        config = EConfig.from_pretrained(ea_model_path)
        with open(ea_model_path, "r") as f:
            con = json.loads(f.read())
        try:
            bias = con["bias"]
        except:
            bias = True
        if use_eagle3:
            self.ea_layer = Model(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k,
                                  threshold=threshold, path=base_model_name_or_path, load_emb=True, confidence_loss_type=confidence_loss_type)
        else:
            # self.ea_layer = Model1(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k,
            #                       threshold=threshold, path=base_model_name_or_path,load_emb=True)
            raise NotImplementedError("Model1 not implemented")
        
        print(f"self.model_type: {self.model_type}")
        if self.model_type == "LlamaForCausalLM":
            self.tokenizer.start_think_id = self.tokenizer.encode('<think>', add_special_tokens=False)[0]
            self.tokenizer.stop_think_id = self.tokenizer.encode('</think>', add_special_tokens=False)[0]
        elif self.model_type == "Qwen3ForCausalLM":
            self.tokenizer.start_think_id = self.tokenizer.encode('<think>', add_special_tokens=False)[0]
            self.tokenizer.stop_think_id = self.tokenizer.encode('</think>', add_special_tokens=False)[0]
        else:
            raise NotImplementedError(f"Model type {self.model_type} not implemented")

        low_memory = False

        # self.device = base_model.lm_head.weight.device
        self.device = base_model.device
        device = base_model.model.layers[-1].self_attn.q_proj.weight.device
        if device != self.device:
            self.ea_layer.diff_device = True
            if not low_memory:
                self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device)
            else:
                self.ea_layer.layer_device = device
        else:
            self.ea_layer.diff_device = False
        if self.use_eagle3 and config.vocab_size==config.draft_vocab_size:
            del self.ea_layer.d2t,self.ea_layer.t2d
        load_=self.ea_layer.load_state_dict(ea_layer_state_dict, strict=False)
        self.ea_layer.to(self.base_model.dtype).to(device)
        self.ea_layer.init_tree()

    def get_tokenizer(self):
        """Get the tokenizer of the base model.

        Returns:
            Tokenizer: The tokenizer of the base model.
        """
        return self.tokenizer

    @classmethod
    def from_pretrained(
            cls,
            use_eagle3=True,
            base_model_path=None,
            ea_model_path=None,
            total_token=60,
            depth=7,
            top_k=10,
            threshold=1.0,
            confidence_loss_type=None,
            **kwargs,
    ):
        print(f"EaModel.from_pretrained kwargs={kwargs} base_model_path={base_model_path}")
        base_model_config = AutoConfig.from_pretrained(base_model_path)
        Type = base_model_config.architectures[0]
        print(f"Model type: {Type}")

        if Type == 'LlamaForCausalLM':
            from models.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
            base_model = KVLlamaForCausalLM.from_pretrained(
                base_model_path, **kwargs
            )
        elif Type == 'Qwen3ForCausalLM':
            from models.modeling_qwen3_kv import Qwen3ForCausalLM as KVQwen3ForCausalLM
            base_model = KVQwen3ForCausalLM.from_pretrained(
                base_model_path, **kwargs
            )
        else:
            raise NotImplementedError(f"Type {Type} not supported now")

        configpath = os.path.join(ea_model_path, "config.json")
        if not os.path.exists(configpath):
            configpath = hf_hub_download(ea_model_path, "config.json")

        try:
            load_model_path = os.path.join(ea_model_path, "pytorch_model.bin")
            if not os.path.exists(load_model_path):
                load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin")
            ea_layer_state_dict = torch.load(load_model_path,
                                             map_location=base_model.device)
            for k, v in ea_layer_state_dict.items():
                print(f"load {k}: {v.shape}")
        except:
            from safetensors.torch import load_file
            load_model_path = os.path.join(ea_model_path, "model.safetensors")
            if not os.path.exists(load_model_path):
                load_model_path = hf_hub_download(ea_model_path, "model.safetensors")
            ea_layer_state_dict = load_file(load_model_path)
        model = cls(
            use_eagle3,
            base_model,
            base_model_path,
            configpath,
            total_token,
            depth,
            top_k,
            threshold,
            ea_layer_state_dict,
            confidence_loss_type
        )

        if total_token == -1:
            device = model.base_model.model.layers[0].self_attn.q_proj.weight.device
            cans = [40, 48, 50, 56, 60]
            x = [1, 1.05, 1.07, 1.1, 1.13]
            times = []

            for i in range(len(cans)):
                length = cans[i]
                input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device)
                torch.cuda.synchronize()
                start_time = time.time()
                for _ in range(20):
                    torch.cuda.synchronize()
                    with torch.no_grad():
                        outputs = model.base_model(input_ids)
                    torch.cuda.synchronize()
                torch.cuda.synchronize()
                end_time = time.time()
                times.append((end_time - start_time) / x[i])
            total_token = cans[times.index(min(times))]
            model.ea_layer.total_tokens = total_token - 1

        return model

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            past_key_values=None,
            output_orig=False,
            position_ids=None,
    ):

        with torch.inference_mode():
            # Pass input through the base model
            outputs = self.base_model.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
            )
            if output_orig:
                orig = self.base_model.lm_head(outputs[0])
            hidden_states = outputs[0]

        if output_orig:
            return outputs, orig, hidden_states
        else:
            return outputs, hidden_states

    @torch.no_grad()
    def eagenerate(
        self,
        input_ids,
        temperature=0.0,
        top_p=0.0,
        top_k=0.0,
        repetition_penalty=1.0,
        max_new_tokens=8192,
        max_length=8192,
        log=False,
        is_llama3=False,
        think_step_split_map_ids={},
        input_text=None,
        enable_think_exit=False,
        exit_threshold=[50],
        min_think=100,
        min_paragraph=0,
        window_size=100,
        pool_type="ewma",
        stop_think_prompt_ids=None,
        stop_prob_threshold=0,
        split_type="paragraph"
    ):
        if is_llama3:
            stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")


        if temperature > 1e-5:
            logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
        else:
            logits_processor = None
        # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
        # Avoid modifying the input_ids in-place

        padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
        input_ids = input_ids.clone()
        self.ea_layer.reset_kv()

        # Initialize the past key and value states
        if hasattr(self, "past_key_values"):
            past_key_values = self.past_key_values
            past_key_values_data = self.past_key_values_data
            current_length_data = self.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(self.base_model,max_length=max_length)
            self.past_key_values = past_key_values
            self.past_key_values_data = past_key_values_data
            self.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_tree_mode(self)
        # prefill
        draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token, confidence_score = initialize_tree(
            input_ids, self, past_key_values, logits_processor
        )
        print(f"prefill confidence_score={confidence_score}")
        confidence_score_cpu = confidence_score = None

        new_token = 0
        paragraph_tokens = 0
        accept_length_list = []
        max_length = max_length - self.ea_layer.total_tokens - 10

        confidence_list, split_step_list, paragraph_states = [], [], []
        think_end = False
        exit_score = None
        force_input_ids = []

        class MeanPredictor:
            def __init__(self, window_size=100):
                self.scores = deque(maxlen=window_size)
            
            def add_score(self, score):
                self.scores.append(score)

            def predict_next_score(self):
                if len(self.scores) == 0:
                    return None
                return sum(self.scores) / len(self.scores)

            def clear_before(self):
                if len(self.scores) == 0:
                    return
                score = self.scores[-1]
                self.scores.clear()
                self.scores.append(score)

        class ScorePredictor:
            def __init__(self, window_size=10):
                if window_size < 2:
                    raise ValueError("window_size must be at least 2")
                self.scores = deque(maxlen=window_size) 
                self.delta_window_size = window_size - 1
                self.deltas = deque(maxlen=self.delta_window_size)
                self.last_score = None

            def add_score(self, score):
                self.scores.append(score)
                if len(self.scores) > 1:
                    delta = score - self.scores[-2]
                    self.deltas.append(delta)
                self.last_score = score

            def predict_next_score(self):
                if len(self.deltas) < self.delta_window_size:
                    return None
                average_delta = sum(self.deltas) / len(self.deltas)
                return self.last_score + average_delta

        class EWMAScore:
            def __init__(self, alpha=0.2, initial_score=None):
                if not 0 < alpha <= 1:
                    raise ValueError("alpha must be between 0 and 1")
                self.alpha = alpha
                self.ewma = initial_score

            def add_score(self, score):
                if self.ewma is None:
                    self.ewma = score
                else:
                    self.ewma = self.alpha * score + (1 - self.alpha) * self.ewma
                
            def predict_next_score(self):
                return self.ewma

        if pool_type == "ewma":
            predict_class = EWMAScore
            kwargs = {"alpha": 0.1}
        elif pool_type == "predict":
            predict_class = ScorePredictor
            kwargs = {"window_size": window_size}
        elif pool_type == "paragraph_mean":
            predict_class = MeanPredictor
            kwargs = {"window_size": 1000}
        else:
            predict_class = MeanPredictor
            kwargs = {"window_size": window_size}

        if self.confidence_loss_type == "confidence_progress_remain":
            predictor = [predict_class(**kwargs) for _ in range(3)]
        else:
            predictor = predict_class(**kwargs)

        for idx in range(max_length):
            # with Timer("all"):
            paragraph_tokens += 1
            self.base_model.model.tree_mask = tree_mask

            draft_tokens = draft_tokens.to(input_ids.device)
            # Target model forward, get logits
            logits, hidden_state_new, outputs = tree_decoding(
                self,
                draft_tokens,
                past_key_values,
                tree_position_ids,
                input_ids,
                retrieve_indices,
                tree_mask,
                self.tokenizer
            )

            draft_tokens = torch.cat((draft_tokens, padding), dim=1)
            candidates = draft_tokens[0, retrieve_indices]
            # verification
            best_candidate, accept_length, sample_p = evaluate_posterior(
                input_ids, logits, candidates, logits_processor, self.tokenizer
            )

            if logits_processor is not None:
                sample_token = torch.multinomial(sample_p, 1)
                sample_token = sample_token[None]
            else:
                sample_token = torch.argmax(sample_p)
                sample_token = sample_token[None, None]

            new_token_ids = candidates[None, best_candidate, : accept_length + 1].view(-1).tolist() + [sample_token.item()]

            if (not think_end) and self.tokenizer.stop_think_id in new_token_ids:
                think_end = True
                print(f"Normal stop thinking")

            if (not think_end) and confidence_score is not None:
                confidence_list.append(confidence_score_cpu)
                if self.confidence_loss_type == "confidence_progress_remain":
                    for p_i, score_i in zip(predictor, confidence_score_cpu):
                        p_i.add_score(score_i)
                else:
                    predictor.add_score(confidence_score_cpu)
                can_stop = False
                pos_offset = 0
                if "marker" in split_type:
                    pos_offset = -1
                    for cand_i in range(1, len(new_token_ids)):
                        if new_token_ids[cand_i] in think_step_split_map_ids:
                            can_stop = True
                            split_step_pos = cand_i
                            break
                else:
                    for cand_i in range(len(new_token_ids) - 1):
                        if new_token_ids[cand_i] in think_step_split_map_ids:
                            can_stop = True
                            split_step_pos = cand_i
                            break
                if can_stop:
                    empty_score = not confidence_score_cpu
                    if paragraph_tokens < min_paragraph:
                        print(f"split think step! paragraph_tokens={paragraph_tokens} < min_paragraph={min_paragraph} confidence_score={confidence_score_cpu}")
                    else:
                        if pool_type in ["ewma", "predict", "mean", "paragraph_mean"]:
                            if self.confidence_loss_type == "confidence_progress_remain":
                                pred_next_score = [p_i.predict_next_score() for p_i in predictor]
                                empty_score = any([not s for s in pred_next_score])
                            else:
                                pred_next_score = predictor.predict_next_score()
                                empty_score = (not pred_next_score)

                            if pool_type == "paragraph_mean":
                                if self.confidence_loss_type == "confidence_progress_remain":
                                    for p_i in predictor:
                                        p_i.clear_before()
                                else:
                                    predictor.clear_before()
                        else:
                            pred_next_score = confidence_score_cpu
                        new_token_strs = [self.tokenizer.decode([t]) for t in new_token_ids]
                        split_step_list.append(idx)

                        if enable_think_exit and not empty_score and new_token > min_think:

                            stop_think_logit = logits[best_candidate, split_step_pos + pos_offset, self.tokenizer.stop_think_id].float().item()
                            origin_prob = torch.softmax(logits[best_candidate, split_step_pos + pos_offset].float(), dim=-1)
                            stop_think_prob = origin_prob[self.tokenizer.stop_think_id].item()

                            origin_token_id = origin_prob.argmax(dim=-1).squeeze().item()
                            origin_token = self.tokenizer.decode(origin_token_id)
                            origin_token_prob = origin_prob[origin_token_id].item()

                            entropy = -torch.sum(origin_prob * torch.log(origin_prob + 1e-10)).item()

                            paragraph_states.append({
                                "split": new_token_strs[split_step_pos],
                                "confidence": confidence_score_cpu,
                                "stop_think_prob": stop_think_prob,
                                "stop_think_logit": stop_think_logit,
                                "origin_token": origin_token,
                                "entropy": entropy
                            })
                            
                            if (self.confidence_loss_type == "mse" and pred_next_score > exit_threshold[0]) or \
                            (self.confidence_loss_type == "rmsle" and pred_next_score < exit_threshold[0]) or \
                            (self.confidence_loss_type == "confidence_progress_remain" and pred_next_score[0] > exit_threshold[0] and pred_next_score[1] > exit_threshold[1] and pred_next_score[2] < exit_threshold[2]):
                                
                                map_value_id = think_step_split_map_ids[new_token_ids[split_step_pos]]

                                if split_type == "paragraph":
                                    if map_value_id == new_token_ids[split_step_pos]:
                                        print(f"Note: pred score = {pred_next_score}, threshlod = {exit_threshold}: stop thinking, new_tokens={new_token_strs!r}")
                                        sample_token[..., -1] = self.tokenizer.stop_think_id
                                        accept_length = split_step_pos
                                        exit_score = pred_next_score
                                        think_end = True
                                        new_token_ids = new_token_ids[:split_step_pos + 1] + [self.tokenizer.stop_think_id]
                                    else:
                                        print(f"Note: pred score = {pred_next_score}, threshlod = {exit_threshold}: map {self.tokenizer.decode(new_token_ids[split_step_pos])!r} => {self.tokenizer.decode(map_value_id)!r}, new_tokens={new_token_strs!r}")
                                        sample_token[..., -1] = map_value_id
                                        accept_length = split_step_pos - 1
                                        new_token_ids = new_token_ids[:split_step_pos] + [map_value_id]

                                elif "marker" in split_type:
                                    print(f"Note: pred score = {pred_next_score}, threshlod = {exit_threshold}: replace {self.tokenizer.decode(new_token_ids[split_step_pos])!r} => {self.tokenizer.decode(map_value_id)!r}, new_tokens={new_token_strs!r}")
                                    sample_token[..., -1] = map_value_id
                                    accept_length = split_step_pos - 1
                                    think_end = True
                                    new_token_ids = new_token_ids[:split_step_pos] + [map_value_id]
                                
                                else:
                                    raise ValueError(f"split_type={split_type} not supported")

                            else:
                                print(f"Note: confidence_loss_type={self.confidence_loss_type} pred_next_score={pred_next_score}, exit_threshold={exit_threshold}: continue thinking, new_tokens={new_token_strs!r}")
                    paragraph_tokens = 0

            try:
                accept_length_list.append(accept_length.item())
            except:
                accept_length_list.append(accept_length)

            input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, confidence_score = update_inference_inputs(
                input_ids,
                candidates,
                best_candidate,
                accept_length,
                retrieve_indices,
                logits_processor,
                new_token,
                past_key_values_data,
                current_length_data,
                self,
                hidden_state_new,
                sample_p,
                sample_token,
            )

            if confidence_score is None:
                confidence_score_cpu = None
            elif self.confidence_loss_type == "confidence_progress_remain":
                confidence_score_cpu = confidence_score.view(-1).tolist()
                confidence_score_cpu = [round(c, 4) for c in confidence_score_cpu]
            else:
                confidence_score_cpu = confidence_score.item()
            candidates = draft_tokens[0, retrieve_indices]

            if is_llama3:
                if stop_token_id in input_ids[0, input_len:].tolist():
                    break

            if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                break
            if new_token > max_new_tokens:
                break
            if input_ids.shape[1] > max_length:
                break

        if not log:
            return input_ids
        else:
            return input_ids, new_token, idx, accept_length_list, exit_score, confidence_list, paragraph_states


    @torch.no_grad()
    def naivegenerate(
            self,
            input_ids,
            temperature=0.0,
            top_p=0.0,
            top_k=0.0,
            repetition_penalty=1.0,
            max_new_tokens=512,
            max_length=2048,
            log=False,
            is_llama3=False,
    ):
        if is_llama3:
            stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")


        if temperature > 1e-5:
            logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
        else:
            logits_processor = None
        # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
        # Avoid modifying the input_ids in-place

        padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
        input_ids = input_ids.clone()
        self.ea_layer.reset_kv()

        # Initialize the past key and value states
        if hasattr(self, "past_key_values"):
            past_key_values = self.past_key_values
            past_key_values_data = self.past_key_values_data
            current_length_data = self.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(self.base_model,max_length=max_length)
            self.past_key_values = past_key_values
            self.past_key_values_data = past_key_values_data
            self.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_tree_mode(self)
        outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
        new_token = 0
        max_length = max_length - self.ea_layer.total_tokens - 10
        for idx in range(max_length):
            if logits_processor is not None:
                logits = outputs.logits[:, -1]
                logits = logits_processor(None, logits)
                probabilities = torch.nn.functional.softmax(logits, dim=-1)
                input_id = torch.multinomial(probabilities, 1)
            else:
                input_id = outputs.logits[:, -1:].argmax(dim=-1)
            outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
            input_ids = torch.cat([input_ids, input_id], dim=-1)
            new_token += 1

            if is_llama3:
                if stop_token_id in input_ids[0, input_len:].tolist():
                    break

            if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                break
            if new_token > max_new_tokens:
                break
            if input_ids.shape[1] > max_length:
                break
        if not log:
            return input_ids
        else:
            return input_ids, new_token, idx

    @torch.no_grad()
    def naive_generate(
            self,
            input_ids,
            temperature=0.0,
            top_p=0.0,
            top_k=0.0,
            max_new_tokens=512,
            max_length=2048,
            log=False,
            is_llama3=False,

    ):
        if is_llama3:
            stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")


        if temperature > 1e-5:
            logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
        else:
            logits_processor = None
        # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
        # Avoid modifying the input_ids in-place

        padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
        input_ids = input_ids.clone()
        self.ea_layer.reset_kv()

        # Initialize the past key and value states
        if hasattr(self, "past_key_values"):
            past_key_values = self.past_key_values
            past_key_values_data = self.past_key_values_data
            current_length_data = self.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(self.base_model,max_length=max_length)
            self.past_key_values = past_key_values
            self.past_key_values_data = past_key_values_data
            self.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_tree_mode(self)
        outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
        new_token = 0
        max_length = max_length - self.ea_layer.total_tokens - 10
        for idx in range(max_length):
            if logits_processor is not None:
                logits = outputs.logits[:, -1]
                logits = logits_processor(None, logits)
                probabilities = torch.nn.functional.softmax(logits, dim=-1)
                input_id = torch.multinomial(probabilities, 1)
            else:
                input_id = outputs.logits[:, -1:].argmax(dim=-1)

            outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
            input_ids = torch.cat([input_ids, input_id], dim=-1)
            new_token += 1

            yield input_ids

            if is_llama3:
                if stop_token_id in input_ids[0, input_len:].tolist():
                    break

            if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                break
            if new_token > max_new_tokens:
                break
            if input_ids.shape[1] > max_length:
                break
