import torch
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteria
import logging
import nvtx
import os

from .base import GeneratorBase
from ..utils.mixin import MBSDProfilingMixin
from ..utils.utils import DraftParams, invert_mask

class ClassicSDGeneratorBase(GeneratorBase):
    def __init__(self, generator_kwargs, *model_args, **kwargs):
        super().__init__(*model_args, **kwargs)
        self.prefill_chunk_size = generator_kwargs.get("prefill_chunk_size", None)
        
        self.analysis_mode = os.getenv("ANALYSIS_MODE", "0") == "1"
        self.all_attention_latencies = []
        
    def _speculate(self, input_ids, attention_mask=None, position_ids=None):
        return self.draft_model.speculate(input_ids, attention_mask=attention_mask, position_ids=position_ids)
 
    def _tree_decoding(self, draft_input_ids, past_key_values, position_ids=None, attention_mask=None):
        enable_analysis = self.analysis_mode and bool(getattr(self, "profiling", False))
        
        if enable_analysis:
            from ..utils.monkey_patch import CaptureAttentionContext
            tf_start = torch.cuda.Event(enable_timing=True)
            tf_end = torch.cuda.Event(enable_timing=True)
            tf_start.record()
            
            capture_ctx = CaptureAttentionContext(
                self.target_model,
                capture_queries=False,
                measure_latency=True,
            )
            with capture_ctx:
                with nvtx.annotate("target_model forward", color="orange"):
                    outputs = self.target_model(
                        draft_input_ids,
                        past_key_values=past_key_values,
                        position_ids=position_ids,
                        attention_mask=attention_mask,
                    )
            
            tf_end.record()
            tf_end.synchronize()
            self._last_target_forward_ms = tf_start.elapsed_time(tf_end)
            
            attn_latencies = getattr(self.target_model, "latest_attention_latencies", None)
            if attn_latencies:
                total_ms = sum(attn_latencies)
                per_layer = ",".join(f"{v:.3f}" for v in attn_latencies)
                logging.debug(f"Target self-attention latency: total={total_ms:.3f} ms; per-layer={per_layer}")
                self.all_attention_latencies.append(attn_latencies)
            logging.debug(f"target_model forward latency: {self._last_target_forward_ms:.3f} ms")
            return outputs
        else:
            return self.target_model(
                draft_input_ids,
                past_key_values=past_key_values,
                position_ids=position_ids,
                attention_mask=attention_mask,
            )
    
    def _verify(self, draft_input_ids, logits, logits_processor, do_sample, finished):
        # Sequential verification in GPUs.
        global_p = torch.argmax(self._sample_token(logits, logits_processor, do_sample, return_probs=True), dim=-1)
        # Initialize variables

        verified_mask = draft_input_ids[:, 1:] == global_p[:, :-1]
        eos_mask = global_p[:, :-1] == self.draft_model.eos_token_id
        verified_mask = torch.where(eos_mask, False, verified_mask)
        verified_mask = torch.cummin(verified_mask, dim=1).values
        accept_len_arr = torch.sum(verified_mask, dim=1)
        # Set the values to -1 for finished sequences
        accept_len_arr[finished] = -1
        max_accept = torch.max(accept_len_arr)
        accepted_tokens = global_p[:, :max_accept].clone()
        # Set the rejected tokens to pad_token_id
        accepted_tokens[verified_mask[:, :max_accept] == False] = self.tokenizer.pad_token_id
        
        # Add bonus token
        bonus_token = global_p[torch.arange(global_p.size(0), device=global_p.device), accept_len_arr]
        sampled_tokens = torch.cat([accepted_tokens, bonus_token.unsqueeze(1)], dim=1)

        # total len = accept_len + 1 when 1. reject 2. got EOS
        total_len_arr = accept_len_arr + torch.logical_or(
            sampled_tokens[:, -1] == self.draft_model.eos_token_id,
            torch.logical_not(accept_len_arr == draft_input_ids.size(1)-1)
        ).long()

        total_len_arr[finished] = -1

        return sampled_tokens, None, (total_len_arr, accept_len_arr)

    def _generate(
        self,
        input_ids: torch.LongTensor,
        stopping_criteria: StoppingCriteria,
        logits_processor: LogitsProcessorList,
        do_sample: bool,
        **model_kwargs,
    ):
        """
        Generate sequence of tokens with speculative decoding.

        This method consists of two main stages: prefill and decode.

        Prefill Stage:
        - Perform the model's initial forward pass.
        - Sample a token and append it to the input_ids.

        Decode Stage (with speculative decoding):
        - Iterate through the following steps:
            1. Perform SSM speculative sampling, returns sampled tokens in tree form.
            2. Decode the sampled tokens in parallel with the language model (LLM), generating probabilities for each token.
            3. Verify the sampled tokens by accepting or rejecting them, corresponding to the probabilities.
            4. Update the key-value cache and input_ids accordingly.

        Args:
            input_ids (torch.LongTensor): The input token IDs. 
            stopping_criteria (StoppingCriteria): The criteria to stop the generation.
            logits_processor (LogitsProcessor): The processor to modify the logits.
            do_sample (bool): Whether to sample tokens during generation. If False, the generation will be deterministic.

        Returns:
            input_ids (torch.LongTensor): The generated token IDs.
        """
        assert self.target_model is not None, "target_model must be provided"
        assert self.draft_model is not None, "draft_model must be provided"
        assert self.tokenizer is not None, "tokenizer must be provided"

        # * initialize variables
        finished_arr = torch.zeros(input_ids.size(0), dtype=torch.bool, device=input_ids.device)
        # Check a sequence is completely filled with eos tokens or not
        finished_arr = torch.all(input_ids == self.draft_model.eos_token_id, dim=1)

        finished_step = torch.zeros(input_ids.size(0), dtype=torch.long, device=input_ids.device)
        finished_step += input_ids.shape[1]

        # * clone input_ids 
        input_ids = input_ids.clone()
        batch_size, org_input_len = input_ids.shape

        # * prepare kv-cache
        # Raise error if max_length not set while using static cache
        if stopping_criteria.max_length is None:
            if self.cache_implementation == "static":
                raise ValueError(
                    "max_length is not set. Only 'dynamic' kv-cache is supported when max_length is unspecified."
                )
            
        if model_kwargs.get("past_key_values") is not None and model_kwargs.get("draft_past_key_values") is not None:
            past_key_values = model_kwargs["past_key_values"]
            max_cache_len = getattr(past_key_values, "max_cache_len", None)

            draft_past_key_values = model_kwargs["draft_past_key_values"]
            self.draft_model.set_past_key_values(draft_past_key_values)

            # Initialize draft prefill profiling list
            self.draft_model._prefill_events = []
        else:
            raise ValueError("past_key_values and draft_past_key_values should both be provided")
        
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(self.device)
        position_ids = attention_mask.long().cumsum(-1) - 1
        # * prefill stage
        target_prefill_start = torch.cuda.Event(enable_timing=True)
        target_prefill_end = torch.cuda.Event(enable_timing=True)
        
        target_prefill_start.record()
        with nvtx.annotate("chunked prefill", color="orange"):
            current_kv_len = past_key_values.get_seq_length()
            prefill_tokens = input_ids[:, current_kv_len:]
            prefill_length = prefill_tokens.size(1)
            chunk_size = prefill_length if self.prefill_chunk_size is None else min(prefill_length, self.prefill_chunk_size)
            next_token_logits = None
            for start in range(0, prefill_length, chunk_size):
                chunk = prefill_tokens[:, start:start + chunk_size]
                current_kv_len = past_key_values.get_seq_length()
                cache_position = torch.arange(
                    current_kv_len, current_kv_len + chunk.size(1),
                    dtype=torch.long, device=input_ids.device
                )
                am_so_far = attention_mask[:, :current_kv_len + chunk.size(1)]  # (B, cur_len_so_far)
                pos_ids = position_ids[:, current_kv_len:current_kv_len + chunk.size(1)]
                # last iteration
                if start + chunk_size < prefill_length:
                    # does not need output logits, just update kv-cache
                    self.target_model.model(
                        chunk,
                        past_key_values=past_key_values,
                        position_ids=pos_ids,
                        cache_position=cache_position,
                        attention_mask=am_so_far,
                    )
                else:
                    outputs = self.target_model.prefill_forward(
                        chunk,
                        past_key_values=past_key_values,
                        position_ids=pos_ids,
                        cache_position=cache_position,
                        logits_to_keep=1,
                        attention_mask=am_so_far,
                    )
                    next_token_logits = outputs.logits
                    del outputs
                
                past_key_values.seq_len += chunk.size(1)
        target_prefill_end.record()
        # Store for profiling mixin
        self._target_prefill_event = (target_prefill_start, target_prefill_end)

        with nvtx.annotate("sample tokens"):
            sampled_tokens = self._sample_token(next_token_logits, logits_processor, do_sample)

        with nvtx.annotate("update data"):
            input_ids = torch.cat([input_ids, sampled_tokens], dim=-1)
            cache_position = torch.arange(org_input_len, org_input_len+self.draft_params.max_depth+1, dtype=torch.long, device=input_ids.device)
            draft_pos_ids = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=1)
            pos_ids = draft_pos_ids[:, -1:] + torch.arange(self.draft_params.max_depth+1, device=pos_ids.device)
            attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1, device=input_ids.device, dtype=attention_mask.dtype)], dim=1)

        with nvtx.annotate("decoding"):
            finished = False
            while not finished:
                # debug_mask = input_ids.ne(self.tokenizer.pad_token_id).to(self.device)
                # check if debug mask is the same as attention mask
                # if torch.equal(debug_mask, attention_mask) == False:
                #     raise ValueError("debug_mask is not equal to attention_mask")
                # debug_pos_ids = (attention_mask.long().cumsum(-1) - 1)[:,-draft_pos_ids.size(1):]
                # if torch.equal(debug_pos_ids, draft_pos_ids) == False:
                #     raise ValueError("debug_pos_ids is not equal to draft_pos_ids")
                # for i in range(batch_size):
                #     print(f"batch {i} pos_ids:\n{debug_pos_ids[i, attention_mask[i]]}")
                # print(f"debug_pos_ids[attention_mask]:\n{debug_pos_ids[attention_mask]}")
                # for i in range(batch_size):
                #     print(f"batch {i}")
                #     print(f"input_ids:\n{self.tokenizer.decode(input_ids[i, attention_mask[i]])}")
                # print(f"attention_mask:\n{attention_mask}")
                # debug_attn_mask = input_ids.ne(self.tokenizer.pad_token_id).to(self.device)
                # print(f"debug attn mask:\n{debug_attn_mask}")
                # print('--'*20)
                # * speculate
                with nvtx.annotate("speculate", color="cyan"):
                    input_ids = input_ids.clone(memory_format=torch.contiguous_format)
                    draft_input_ids, draft_probs = self._speculate(input_ids, attention_mask=attention_mask, position_ids=draft_pos_ids)

                # * tree decoding
                target_forward_mask = torch.cat([attention_mask, torch.ones(batch_size, self.draft_params.max_depth, device=input_ids.device, dtype=attention_mask.dtype)], dim=1)

                with nvtx.annotate("tree_decoding", color="orange"):
                    prev_kv_len = past_key_values.get_seq_length()
                    outputs = self._tree_decoding(draft_input_ids, past_key_values, position_ids=pos_ids, attention_mask=target_forward_mask)
                    next_token_logits = outputs.logits
                    del outputs
                # * verify
                with nvtx.annotate("verify"):
                    # print("--------------------------------------------")
                    sampled_tokens, hidden_indices, (total_len, accept_len_arr) = self._verify(
                                                        draft_input_ids, next_token_logits,
                                                        logits_processor,
                                                        do_sample,
                                                        finished_arr
                                                    )
                    del next_token_logits
                
                with nvtx.annotate("reorder kv"):
                    accept_len = torch.max(accept_len_arr).item()
                    new_kv_len = prev_kv_len + accept_len + 1
                    past_key_values.crop(new_kv_len)
                    past_key_values.seq_len = new_kv_len
                    self.draft_model.past_key_values.seq_len = new_kv_len - (self.draft_params.max_depth == accept_len)

                # * update input_ids and cache_position
                with nvtx.annotate("update data"):
                    # 1. set attention mask based on pad tokens 3. set pos ids accordingly
                    input_ids = torch.cat([input_ids, sampled_tokens], dim=-1)
                    cache_position += sampled_tokens.shape[1]                    
                    part_attn = sampled_tokens.ne(self.tokenizer.pad_token_id).to(self.device)
                    attention_mask = torch.cat([attention_mask, part_attn], dim=1)

                    # accept_len == max_depth means the second last token is not forwarded by draft model
                    pos = draft_pos_ids[:, -1:] + accept_len_arr[:, None]
                    if accept_len == self.draft_params.max_depth:
                        draft_pos_ids = torch.cat([pos, pos+1], dim=1)
                    else:
                        draft_pos_ids = pos+1
                    pos_offsets = torch.arange(self.draft_params.max_depth+1, device=draft_pos_ids.device)
                    pos_ids = draft_pos_ids[:, -1:] + pos_offsets[None, :]

                # check stopping criteria
                with nvtx.annotate("stopping criteria"):
                    finished_step += torch.logical_not(finished_arr).long() * sampled_tokens.size(1)
                    finished_arr = torch.logical_or(stopping_criteria(input_ids, None), finished_arr)
                    finished = torch.all(finished_arr).item()

                # print(f"Decoding progress: {input_ids.size(1)}/{stopping_criteria.max_length}", end='\r')
        # Turn every tokens after finished steps to eos token
        for i in range(batch_size):
            input_ids[i, min(finished_step[i]+1, input_ids.size(1)):] = self.tokenizer.eos_token_id

        return input_ids
    
class ClassicSDGenerator(MBSDProfilingMixin, ClassicSDGeneratorBase):
    pass