import torch
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteria
import logging
import nvtx
import json
import os

from .base import GeneratorBase
from ..utils.mixin import MBSDProfilingMixin
from ..utils.compressKV import compressKV_mb
# Import CaptureAttentionContext to capture Q from target model
from ..utils.monkey_patch import CaptureAttentionContext

class TargetKVSDGeneratorBase(GeneratorBase):
    def __init__(self, generator_kwargs, *model_args, **kwargs):
        super().__init__(*model_args, **kwargs)
        self.prefill_chunk_size = generator_kwargs.get("prefill_chunk_size", None)
        self.compressKV_size = generator_kwargs.get("Target_KV_size", None)
        self.compressKV_draft_select_size = generator_kwargs.get("Draft_Select_size", 128)
        print(f"Target_KV_size: {self.compressKV_size}, Draft_Select_size: {self.compressKV_draft_select_size}")

        self.window_size = generator_kwargs.get("window_size", 8)
        self.sink_size = generator_kwargs.get("sink_size", 4)
        self.target_key_values = kwargs.get("target_key_values")
        self.full_pos_ids = None

        self.analysis_mode = os.getenv("ANALYSIS_MODE", "0") == "1"
        self.all_attention_latencies = []
        self.all_compresskv_latencies = []
        self.all_criticality_estimation = []

    def set_important_head_idx(self, filename):
        # check file existence
        if not os.path.exists(filename):
            print(f"Important heads file not found: {filename}, selecting all heads as important heads.")
            self.target_model.important_heads = torch.arange(self.target_model.config.num_attention_heads, device=self.device).unsqueeze(0).expand(self.target_model.config.num_hidden_layers, -1)
        else:
            with open(filename, 'r') as f:
                important_heads = json.load(f)
            self.target_model.important_heads = [important_heads[str(i)] for i in range(len(important_heads))]
            # if SRH_head_num is None, select all heads. if not, select top SRH_head_num heads
            if self.draft_params.generator_kwargs['SRH_head_num'] is None:
                self.target_model.important_heads = torch.tensor(self.target_model.important_heads, device=self.device)
            else:
                self.target_model.important_heads = torch.tensor(self.target_model.important_heads, device=self.device)[:,:self.draft_params.generator_kwargs['SRH_head_num']]
        self.target_model.important_heads = self.target_model.important_heads[self.target_model.important_layers.to("cpu")]
        print(f"self.target_model.important_heads:\n{self.target_model.important_heads}")
    def set_important_layers(self, important_layer):
        if type(important_layer) is list:
            self.target_model.important_layers = torch.tensor(important_layer, device=self.device)
        else:
            self.target_model.important_layers = torch.tensor([important_layer], device=self.device)
        print(f"self.target_model.important_layers:\n{self.target_model.important_layers}")

    def _speculate(self, input_ids, attention_mask=None, position_ids=None):
        return self.draft_model.speculate(input_ids, attention_mask=attention_mask, position_ids=position_ids)

    def compressKV_target_model_forward(self, draft_input_ids, tree_position_ids, target_attn_mask):
        target_attn_mask = torch.cat([target_attn_mask, torch.ones((target_attn_mask.size(0), draft_input_ids.size(1)), device=target_attn_mask.device, dtype=target_attn_mask.dtype)], dim=1)
        cache_position = torch.arange(target_attn_mask.size(1)-draft_input_ids.size(1), target_attn_mask.size(1), dtype=torch.long, device=target_attn_mask.device)
        enable_analysis = self.analysis_mode and bool(getattr(self, "profiling", False))
        # When running the target model forward for speculative tokens, capture the queries of the last token.
        # Use CaptureAttentionContext to monkey patch the target model's attention layers.
        capture_ctx = CaptureAttentionContext(
            self.target_model,
            capture_all_queries=True,
            measure_latency=enable_analysis,
            capture_queries=True,
        )
        # Enable capturing
        self.target_model._capture_enabled = True

        if enable_analysis:
            tf_start = torch.cuda.Event(enable_timing=True)
            tf_end = torch.cuda.Event(enable_timing=True)
            tf_start.record()

        with nvtx.annotate("target_model forward", color="orange"):
            with capture_ctx:
                outputs = self.target_model(
                    draft_input_ids,
                    past_key_values=self.target_key_values,
                    position_ids=tree_position_ids,
                    cache_position=cache_position,
                    attention_mask=target_attn_mask,
                )
            

        if enable_analysis:
            tf_end.record()
            tf_end.synchronize()
            self._last_target_forward_ms = tf_start.elapsed_time(tf_end)

        # Reset capture flag (CaptureAttentionContext __exit__ resets it)
        self.target_model._capture_enabled = False

        if enable_analysis:
            attn_latencies = getattr(self.target_model, "latest_attention_latencies", None)
            if attn_latencies:
                total_ms = sum(attn_latencies)
                per_layer = ",".join(f"{v:.3f}" for v in attn_latencies)
                logging.debug(f"Target self-attention latency: total={total_ms:.3f} ms; per-layer={per_layer}")
                self.all_attention_latencies.append(attn_latencies)

            logging.debug(f"target_model forward latency: {self._last_target_forward_ms:.3f} ms")

        return outputs
    
    def _tree_decoding(self, draft_input_ids, past_key_values, position_offset, cache_position, position_ids=None, device="cuda", attention_mask=None):
        # Compress KV for target_model
        with nvtx.annotate("compress KV", color="blue"):
            enable_analysis = self.analysis_mode and bool(getattr(self, "profiling", False))
            # Compress KV based on both draft and target model attention scores
            if enable_analysis:
                ck_start = torch.cuda.Event(enable_timing=True)
                ck_end = torch.cuda.Event(enable_timing=True)
                ck_start.record()

            _, target_attn_mask, criticality_estimation = compressKV_mb(
                past_key_values,
                position_offset,
                self.window_size,
                self.sink_size,
                self.compressKV_size,
                self.compressKV_draft_select_size,
                self.target_key_values,
                self.draft_model,
                self.full_pos_ids[:, :-1],
                attention_mask=attention_mask[:, :position_offset],
                target_model=self.target_model,
                enable_analysis=enable_analysis,
            )

            if enable_analysis:
                ck_end.record()
                ck_end.synchronize()
                self._last_compresskv_ms = ck_start.elapsed_time(ck_end)
                self.all_compresskv_latencies.append(self._last_compresskv_ms)
                logging.debug(f"compressKV_mb latency: {self._last_compresskv_ms:.3f} ms")
                self.all_criticality_estimation.append(criticality_estimation)
                logging.debug(f"Criticality Estimation Latency: {criticality_estimation:.3f} ms")

        outputs = self.compressKV_target_model_forward(
            draft_input_ids, position_ids, target_attn_mask
        )
        return outputs

    def _verify(self, draft_input_ids, logits, logits_processor, do_sample, finished):
        # Sequential verification in GPUs.
        global_p = torch.argmax(self._sample_token(logits, logits_processor, do_sample, return_probs=True), dim=-1)

        # Initialize variables

        verified_mask = draft_input_ids[:, 1:] == global_p[:, :-1]
        eos_mask = global_p[:, :-1] == self.draft_model.eos_token_id
        verified_mask = torch.where(eos_mask, False, verified_mask)
        verified_mask = torch.cummin(verified_mask, dim=1).values
        accept_len_arr = torch.sum(verified_mask, dim=1)

        # Set the values to -1 for finished sequences
        accept_len_arr[finished] = -1

        max_accept = torch.max(accept_len_arr)
        accepted_tokens = global_p[:, :max_accept].clone()
        # Set the rejected tokens to pad_token_id
        accepted_tokens[verified_mask[:, :max_accept] == False] = self.tokenizer.pad_token_id
        
        # Add bonus token
        bonus_token = global_p[torch.arange(global_p.size(0), device=global_p.device), accept_len_arr]
        sampled_tokens = torch.cat([accepted_tokens, bonus_token.unsqueeze(1)], dim=1)

        # total len = accept_len + 1 when 1. reject 2. got EOS
        total_len_arr = accept_len_arr + torch.logical_or(
            sampled_tokens[:, -1] == self.draft_model.eos_token_id,
            torch.logical_not(accept_len_arr == draft_input_ids.size(1)-1)
        ).long()

        total_len_arr[finished] = -1

        # Select target queries based on acc_len
        S, L, B, H, Q, D = self.target_model.latest_captured_rope_queries.shape
        acc_q_index = accept_len_arr.clamp(min=0, max=Q-1)
        acc_q_index = acc_q_index.view(1, 1, B, 1, 1, 1)
        acc_q_index = acc_q_index.expand(S, L, B, H, 1, D)
        self.target_model.latest_captured_rope_queries = torch.gather(self.target_model.latest_captured_rope_queries, dim=4, index=acc_q_index)
        self.target_model.latest_captured_rope_queries = self.target_model.latest_captured_rope_queries.squeeze(4)

        return sampled_tokens, None, (total_len_arr, accept_len_arr)

    def _generate(
        self,
        input_ids: torch.LongTensor,
        stopping_criteria: StoppingCriteria,
        logits_processor: LogitsProcessorList,
        do_sample: bool,
        **model_kwargs,
    ):
        """
        Generate sequence of tokens with speculative decoding.

        This method consists of two main stages: prefill and decode.

        Prefill Stage:
        - Perform the model's initial forward pass.
        - Sample a token and append it to the input_ids.

        Decode Stage (with speculative decoding):
        - Iterate through the following steps:
            1. Perform SSM speculative sampling, returns sampled tokens in tree form.
            2. Decode the sampled tokens in parallel with the language model (LLM), generating probabilities for each token.
            3. Verify the sampled tokens by accepting or rejecting them, corresponding to the probabilities.
            4. Update the key-value cache and input_ids accordingly.

        Args:
            input_ids (torch.LongTensor): The input token IDs. 
            stopping_criteria (StoppingCriteria): The criteria to stop the generation.
            logits_processor (LogitsProcessor): The processor to modify the logits.
            do_sample (bool): Whether to sample tokens during generation. If False, the generation will be deterministic.

        Returns:
            input_ids (torch.LongTensor): The generated token IDs.
        """
        assert self.target_model is not None, "target_model must be provided"
        assert self.draft_model is not None, "draft_model must be provided"
        assert self.tokenizer is not None, "tokenizer must be provided"

        # * initialize variables
        finished_arr = torch.zeros(input_ids.size(0), dtype=torch.bool, device=input_ids.device)
        # Check a sequence is completely filled with eos tokens or not
        finished_arr = torch.all(input_ids == self.draft_model.eos_token_id, dim=1)

        finished_step = torch.zeros(input_ids.size(0), dtype=torch.long, device=input_ids.device)
        finished_step += input_ids.shape[1]

        # * clone input_ids 
        input_ids = input_ids.clone()
        batch_size, org_input_len = input_ids.shape

        # * prepare kv-cache
        # Raise error if max_length not set while using static cache
        if stopping_criteria.max_length is None:
            if self.cache_implementation == "static":
                raise ValueError(
                    "max_length is not set. Only 'dynamic' kv-cache is supported when max_length is unspecified."
                )
            
        if model_kwargs.get("past_key_values") is not None and model_kwargs.get("draft_past_key_values") is not None:
            past_key_values = model_kwargs["past_key_values"]
            max_cache_len = getattr(past_key_values, "max_cache_len", None)

            draft_past_key_values = model_kwargs["draft_past_key_values"]
            self.draft_model.set_past_key_values(draft_past_key_values)
            
            # Initialize draft prefill profiling list
            self.draft_model._prefill_events = []
        else:
            raise ValueError("past_key_values and draft_past_key_values should both be provided")
        
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(self.device)

        position_ids = attention_mask.long().cumsum(-1) - 1
        
        # * prefill stage
        target_prefill_start = torch.cuda.Event(enable_timing=True)
        target_prefill_end = torch.cuda.Event(enable_timing=True)
        
        target_prefill_start.record()
        with nvtx.annotate("chunked prefill", color="orange"):
            current_kv_len = past_key_values.get_seq_length()
            prefill_tokens = input_ids[:, current_kv_len:]
            prefill_length = prefill_tokens.size(1)
            chunk_size = prefill_length if self.prefill_chunk_size is None else min(prefill_length, self.prefill_chunk_size)
            next_token_logits = None
            for start in range(0, prefill_length, chunk_size):
                chunk = prefill_tokens[:, start:start + chunk_size]
                current_kv_len = past_key_values.get_seq_length()
                cache_position = torch.arange(
                    current_kv_len, current_kv_len + chunk.size(1),
                    dtype=torch.long, device=input_ids.device
                )
                am_so_far = attention_mask[:, :current_kv_len + chunk.size(1)]  # (B, cur_len_so_far)
                pos_ids = position_ids[:, current_kv_len:current_kv_len + chunk.size(1)]
                # last iteration
                if start + chunk_size < prefill_length:
                    # does not need output logits, just update kv-cache
                    self.target_model.model(
                        chunk,
                        past_key_values=past_key_values,
                        position_ids=pos_ids,
                        cache_position=cache_position,
                        attention_mask=am_so_far,
                    )
                else:
                    # TODO: there will be bugs if chunk size is smaller than draft_window_size
                    # For the last chunk, capture the target model's query for the previous token.
                    capture_ctx = CaptureAttentionContext(self.target_model)
                    # Enable capturing for the target model
                    self.target_model._capture_enabled = True
                    with capture_ctx:
                        outputs = self.target_model.prefill_forward(
                            chunk,
                            past_key_values=past_key_values,
                            position_ids=pos_ids,
                            cache_position=cache_position,
                            logits_to_keep=1,
                            attention_mask=am_so_far,
                        )

                    # Reset capture flag (CaptureAttentionContext __exit__ resets it)
                    self.target_model._capture_enabled = False

                    next_token_logits = outputs.logits
                    del outputs
                
                past_key_values.seq_len += chunk.size(1)
        target_prefill_end.record()
        # Store for profiling mixin
        self._target_prefill_event = (target_prefill_start, target_prefill_end)

        with nvtx.annotate("sample tokens"):
            sampled_tokens = self._sample_token(next_token_logits, logits_processor, do_sample)

        with nvtx.annotate("update data"):
            input_ids = torch.cat([input_ids, sampled_tokens], dim=-1)
            cache_position = torch.arange(org_input_len, org_input_len+self.draft_params.max_depth+1, dtype=torch.long, device=input_ids.device)
            draft_pos_ids = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=1)
            pos_ids = draft_pos_ids[:, -1:] + torch.arange(self.draft_params.max_depth+1, device=pos_ids.device)
            attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1, device=input_ids.device, dtype=attention_mask.dtype)], dim=1)
            self.full_pos_ids = draft_pos_ids.clone()

        with nvtx.annotate("decoding"):
            finished = False
            while not finished:
                cache_position = torch.arange(min(self.compressKV_size, input_ids.shape[1]-1), min(self.compressKV_size, input_ids.shape[1]-1)+self.draft_params.max_verify_tokens, dtype=torch.long, device=input_ids.device)
                # print(f"input_ids.shape[1]-1: {input_ids.shape[1]-1}")
                # debug_mask = input_ids.ne(self.tokenizer.pad_token_id).to(self.device)
                # check if debug mask is the same as attention mask
                # if torch.equal(debug_mask, attention_mask) == False:
                #     raise ValueError("debug_mask is not equal to attention_mask")
                # debug_pos_ids = (attention_mask.long().cumsum(-1) - 1)[:,-draft_pos_ids.size(1):]
                # if torch.equal(debug_pos_ids, draft_pos_ids) == False:
                #     raise ValueError("debug_pos_ids is not equal to draft_pos_ids")
                # for i in range(batch_size):
                #     print(f"batch {i} pos_ids:\n{debug_pos_ids[i, attention_mask[i]]}")
                # print(f"debug_pos_ids[attention_mask]:\n{debug_pos_ids[attention_mask]}")
                # for i in range(batch_size):
                #     print(f"batch {i}")
                #     print(f"input_ids:\n{self.tokenizer.decode(input_ids[i, attention_mask[i]])}")
                # print(f"attention_mask:\n{attention_mask}")
                # debug_attn_mask = input_ids.ne(self.tokenizer.pad_token_id).to(self.device)
                # print(f"debug attn mask:\n{debug_attn_mask}")
                # print('--'*20)
                # * speculate
                with nvtx.annotate("speculate", color="cyan"):
                    input_ids = input_ids.clone(memory_format=torch.contiguous_format)
                    draft_input_ids, draft_probs = self._speculate(input_ids, attention_mask=attention_mask, position_ids=draft_pos_ids)

                # * tree decoding
                # target_forward_mask = torch.cat([attention_mask, torch.ones(batch_size, self.draft_params.max_depth, device=input_ids.device, dtype=attention_mask.dtype)], dim=1)

                with nvtx.annotate("tree_decoding", color="orange"):
                    prev_kv_len = past_key_values.get_seq_length()
                    outputs = self._tree_decoding(draft_input_ids, past_key_values, position_offset=input_ids.shape[1]-1, cache_position=cache_position, position_ids=pos_ids, device=input_ids.device, attention_mask=attention_mask)
                    next_token_logits = outputs.logits
                    del outputs
                # * verify
                with nvtx.annotate("verify"):
                    sampled_tokens, hidden_indices, (total_len, accept_len_arr) = self._verify(
                                                        draft_input_ids, next_token_logits,
                                                        logits_processor,
                                                        do_sample,
                                                        finished_arr
                                                    )
                    del next_token_logits
                
                with nvtx.annotate("reorder kv"):
                    accept_len = torch.max(accept_len_arr).item()
                    new_kv_len = prev_kv_len + accept_len + 1
                    past_key_values.reorder_full_cache_with_offset_seq(self.target_key_values, start=self.target_key_values.get_seq_length(), end=self.target_key_values.get_seq_length()+accept_len+1) 

                    past_key_values.seq_len = new_kv_len
                    self.draft_model.past_key_values.seq_len = new_kv_len - (self.draft_params.max_depth == accept_len)

                # * update input_ids and cache_position
                with nvtx.annotate("update data"):
                    # 1. set attention mask based on pad tokens 3. set pos ids accordingly
                    input_ids = torch.cat([input_ids, sampled_tokens], dim=-1)
                    cache_position += sampled_tokens.shape[1]                    
                    part_attn = sampled_tokens.ne(self.tokenizer.pad_token_id).to(self.device)
                    attention_mask = torch.cat([attention_mask, part_attn], dim=1)

                    # accept_len == max_depth means the second last token is not forwarded by draft model
                    pos = draft_pos_ids[:, -1:] + accept_len_arr[:, None]
                    if accept_len == self.draft_params.max_depth:
                        draft_pos_ids = torch.cat([pos, pos+1], dim=1)
                    else:
                        draft_pos_ids = pos+1
                    pos_offsets = torch.arange(self.draft_params.max_depth+1, device=draft_pos_ids.device)
                    pos_ids = draft_pos_ids[:, -1:] + pos_offsets[None, :]

                    to_cat = part_attn.long().cumsum(-1)
                    to_cat += self.full_pos_ids[:, -1:].clone()

                    self.full_pos_ids = torch.cat([self.full_pos_ids, to_cat], dim=1)

                # check stopping criteria
                with nvtx.annotate("stopping criteria"):
                    finished_step += torch.logical_not(finished_arr).long() * sampled_tokens.size(1)
                    finished_arr = torch.logical_or(stopping_criteria(input_ids, None), finished_arr)
                    finished = torch.all(finished_arr).item()

                # print(f"Decoding progress: {input_ids.size(1)}/{stopping_criteria.max_length}", end='\r')
        # Turn every tokens after finished steps to eos token
        for i in range(batch_size):
            input_ids[i, min(finished_step[i]+1, input_ids.size(1)):] = self.tokenizer.eos_token_id

        return input_ids
    
class TargetKVSDGenerator(MBSDProfilingMixin, TargetKVSDGeneratorBase):
    pass