import torch
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteria
import logging
import nvtx
import os

from .base import GeneratorBase
from ..utils.mixin import ProfilingMixin


class NaiveGeneratorBase(GeneratorBase):
    def __init__(self, generator_kwargs, *model_args, **kwargs):
        super().__init__(*model_args, **kwargs)
        self.prefill_chunk_size = generator_kwargs.get("prefill_chunk_size", None)

        self.analysis_mode = os.getenv("ANALYSIS_MODE", "0") == "1"
        self.all_attention_latencies = []

    def _generate(
        self,
        input_ids: torch.LongTensor,
        stopping_criteria: StoppingCriteria,
        logits_processor: LogitsProcessorList,
        do_sample: bool,
        **model_kwargs,
    ):
        assert self.target_model is not None, "target_model must be provided"

        # * initialize variables
        finished_arr = torch.zeros(input_ids.size(0), dtype=torch.bool, device=input_ids.device)
        # Check a sequence is completely filled with eos tokens or not
        finished_arr = torch.all(input_ids == self.tokenizer.eos_token_id, dim=1)

        finished_step = torch.zeros(input_ids.size(0), dtype=torch.long, device=input_ids.device)
        finished_step += input_ids.shape[1]

        # Clone input_ids
        input_ids = input_ids.clone()
        batch_size, input_len = input_ids.shape

        # Prepare kv-cache and cache position
        if stopping_criteria.max_length is None:
            if self.cache_implementation == "static":
                raise ValueError(
                    "max_length is not set. Only 'dynamic' kv-cache is supported when max_length is unspecified."
                )

        if model_kwargs.get("past_key_values") is not None:
            past_key_values = model_kwargs["past_key_values"]
            max_cache_len = getattr(past_key_values, "max_cache_len", None)
        else:
            raise ValueError("past_key_values should be provided")

        kv_len = past_key_values.get_seq_length()
        cache_position = torch.arange(kv_len, input_len, dtype=torch.long, device=input_ids.device)
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(self.device)
        position_ids = attention_mask.long().cumsum(-1) - 1

        # Prefill stage
        target_prefill_start = torch.cuda.Event(enable_timing=True)
        target_prefill_end = torch.cuda.Event(enable_timing=True)
        
        target_prefill_start.record()
        with nvtx.annotate("prefill", color="orange"):
            current_kv_len = past_key_values.get_seq_length()
            prefill_tokens = input_ids[:, current_kv_len:]
            prefill_length = prefill_tokens.size(1)
            chunk_size = prefill_length if self.prefill_chunk_size is None else min(prefill_length, self.prefill_chunk_size)
            next_token_logits = None
            for start in range(0, prefill_length, chunk_size):
                chunk = prefill_tokens[:, start:start + chunk_size]
                current_kv_len = past_key_values.get_seq_length()
                cache_position = torch.arange(
                    current_kv_len, current_kv_len + chunk.size(1),
                    dtype=torch.long, device=input_ids.device
                )
                am_so_far = attention_mask[:, :current_kv_len + chunk.size(1)]  # (B, cur_len_so_far)
                # last iteration
                if start + chunk_size < prefill_length:
                    # does not need output logits, just update kv-cache
                    self.target_model.model(
                        chunk,
                        past_key_values=past_key_values,
                        position_ids=position_ids[:, current_kv_len:current_kv_len + chunk.size(1)],
                        cache_position=cache_position,
                        attention_mask=am_so_far,
                    )
                else:
                    outputs = self.target_model.prefill_forward(
                        chunk,
                        past_key_values=past_key_values,
                        position_ids=position_ids[:, current_kv_len:current_kv_len + chunk.size(1)],
                        cache_position=cache_position,
                        logits_to_keep=1,
                        attention_mask=am_so_far,
                    )
                    next_token_logits = outputs.logits
                    del outputs
                past_key_values.seq_len += chunk.size(1)
        
        target_prefill_end.record()
        # Store for profiling mixin
        self._target_prefill_event = (target_prefill_start, target_prefill_end)

        with nvtx.annotate("sample tokens"):
            next_tokens = self._sample_token(next_token_logits, logits_processor, do_sample)
        with nvtx.annotate("update data"):
            input_ids = torch.cat([input_ids, next_tokens], dim=-1)
            cache_position = cache_position[-1:] + 1
        position_ids = position_ids[:, -1:] + 1
        attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1, device=input_ids.device, dtype=attention_mask.dtype)], dim=1)
        enable_analysis = self.analysis_mode and bool(getattr(self, "profiling", False))
        # Decoding loop
        with nvtx.annotate("decoding"):
            finished = False
            while not finished:
                if enable_analysis:
                    from ..utils.monkey_patch import CaptureAttentionContext
                    tf_start = torch.cuda.Event(enable_timing=True)
                    tf_end = torch.cuda.Event(enable_timing=True)
                    tf_start.record()
                    capture_ctx = CaptureAttentionContext(
                        self.target_model,
                        capture_queries=False,
                        measure_latency=True,
                    )
                    with capture_ctx:
                        with nvtx.annotate("llm forward", color="orange"):
                            outputs = self.target_model(
                                next_tokens,
                                past_key_values=past_key_values,
                                position_ids=position_ids,
                                cache_position=cache_position,
                                attention_mask=attention_mask,
                            )
                            next_token_logits = outputs.logits
                    tf_end.record()
                    tf_end.synchronize()
                    self._last_target_forward_ms = tf_start.elapsed_time(tf_end)
                    
                    attn_latencies = getattr(self.target_model, "latest_attention_latencies", None)
                    if attn_latencies:
                        total_ms = sum(attn_latencies)
                        per_layer = ",".join(f"{v:.3f}" for v in attn_latencies)
                        logging.debug(f"Target self-attention latency: total={total_ms:.3f} ms; per-layer={per_layer}")
                        self.all_attention_latencies.append(attn_latencies)
                    logging.debug(f"target_model forward latency: {self._last_target_forward_ms:.3f} ms")
                else:
                    with nvtx.annotate("llm forward", color="orange"):
                        outputs = self.target_model(
                            next_tokens,
                            past_key_values=past_key_values,
                            position_ids=position_ids,
                            cache_position=cache_position,
                            attention_mask=attention_mask,
                        )
                        next_token_logits = outputs.logits

                with nvtx.annotate("sample tokens"):
                    next_tokens = self._sample_token(next_token_logits, logits_processor, do_sample)

                with nvtx.annotate("update data"):
                    input_ids = torch.cat([input_ids, next_tokens], dim=-1)
                    position_ids = position_ids[:, -1:] + 1
                    cache_position += 1
                    past_key_values.seq_len += 1
                    attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1, device=input_ids.device, dtype=attention_mask.dtype)], dim=1)

                # check stopping criteria
                with nvtx.annotate("stopping criteria"):
                    finished_step += torch.logical_not(finished_arr).long()
                    finished_arr = torch.logical_or(stopping_criteria(input_ids, None), finished_arr)
                    finished = torch.all(finished_arr).item()
                    # When 0 turns into 1, that means that sequence is finished at this step

        # Turn every tokens after finished steps to eos token
        for i in range(batch_size):
            input_ids[i, min(finished_step[i]+1, input_ids.size(1)):] = self.tokenizer.eos_token_id
        return input_ids

class NaiveGenerator(ProfilingMixin, NaiveGeneratorBase):
    pass