import torch
import torch.nn as nn
import nvtx

from .base import DraftModelBase

    
class ClassicSDDraftModel(DraftModelBase):
    def forward(self, input_ids, with_softmax=False, *model_args, **kwargs):
        logits = self.model(input_ids, *model_args, **kwargs).logits
        if with_softmax:
            if self.draft_params.temperature == 1.0:
                logits = torch.softmax(logits, dim=-1)
            else:
                logits = torch.softmax(logits/self.draft_params.temperature, dim=-1)
            
        return logits
    
    @torch.no_grad()
    def speculate(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
        # 1) Obtain necessary parameters
        device = input_ids.device
        dtype = self.model.lm_head.weight.dtype
        batch_size, input_len = input_ids.shape
        max_cache_len = getattr(self.past_key_values, "max_cache_len", None)
        
        # 2) Initialize kv_len
        with nvtx.annotate("Initialize kv_len"):
            kv_len = self.past_key_values.get_seq_length()
            # convert kv_len to int if it is a tensor
            if isinstance(kv_len, torch.Tensor):
                kv_len = kv_len.item()
        
        # 3) First forward pass
        # Note: There will be bugs for dynamic cache if we do not crop here.
        self.past_key_values.crop(kv_len)

        # Setup draft prefill timing if enabled by generator
        record_prefill = hasattr(self, '_prefill_events')
        if record_prefill:
            pf_start = torch.cuda.Event(enable_timing=True)
            pf_end = torch.cuda.Event(enable_timing=True)
            pf_start.record()

        with nvtx.annotate("ssm first forward", color="red"):
            # chunked prefill
            prefill_tokens = input_ids[:, kv_len:]
            prefill_length = prefill_tokens.size(1)
            prefill_chunk_size = (
                self.draft_params.generator_kwargs.get("prefill_chunk_size", None)
                if hasattr(self.draft_params, "generator_kwargs") else None
            )
            chunk_size = prefill_length if prefill_chunk_size is None else min(prefill_length, prefill_chunk_size)

            sampled_probs = None
            for start in range(0, prefill_length, chunk_size):
                chunk = prefill_tokens[:, start:start + chunk_size]
                current_kv_len = self.past_key_values.get_seq_length()
                am_so_far = attention_mask[:, :current_kv_len + chunk.size(1)]  # (B, cur_len_so_far)
                pos_ids = position_ids[:, start:start + chunk.size(1)]
                cache_position = torch.arange(
                    current_kv_len, current_kv_len + chunk.size(1),
                    dtype=torch.long, device=input_ids.device
                )
                # last iteration
                if start + chunk_size < prefill_length:
                    # does not need output logits, just update kv-cache
                    self.model(
                        chunk,
                        past_key_values=self.past_key_values,
                        attention_mask=am_so_far,
                        position_ids=pos_ids,
                        cache_position=cache_position,
                    )
                else:
                    sampled_probs = self.prefill_forward(
                        chunk,
                        with_softmax=True,
                        past_key_values=self.past_key_values,
                        logits_to_keep=1,
                        attention_mask=am_so_far,
                        position_ids=pos_ids,
                        cache_position=cache_position,
                    )

                self.past_key_values.seq_len += chunk.size(1)
            kv_len = input_len
            self.past_key_values.seq_len = input_len

        if record_prefill:
            pf_end.record()
            self._prefill_events.append((pf_start, pf_end))

        with nvtx.annotate("sample nodes", color="green"):
            self.token_ids = torch.argmax(sampled_probs[:, -1, :], dim=-1, keepdim=True)

        # 4) Main loop
        root_id = input_ids[:, -1]
        total_tokens = 1 + self.draft_params.max_depth
        all_token_ids = torch.empty((batch_size, total_tokens), device=device, dtype=torch.long)
        all_token_ids[:, 0] = root_id
        all_token_ids[:, 1] = self.token_ids[:, 0]
        current_idx = 2

        pos_ids = pos_ids[:, -1:] + 1
        cache_position = cache_position[-1:] + 1
        attn_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=device, dtype=attention_mask.dtype)], dim=1)
        for depth_i in range(self.draft_params.max_depth-1):
            with nvtx.annotate("ssm forward", color="red"):
                sampled_probs = self(
                    self.token_ids,
                    with_softmax=True,
                    past_key_values=self.past_key_values,
                    position_ids=pos_ids,
                    cache_position=cache_position,
                    attention_mask=attn_mask,
                )

            with nvtx.annotate("sample nodes", color="green"):
                self.token_ids = torch.argmax(sampled_probs[:, -1, :], dim=-1, keepdim=True)
                # Update internal state
            self.past_key_values.seq_len += 1
            all_token_ids[:, current_idx] = self.token_ids[:, 0]
            current_idx += 1
            pos_ids = pos_ids[:, -1:] + 1
            cache_position = cache_position[-1:] + 1
            attn_mask = torch.cat([attn_mask, torch.ones((batch_size, 1), device=device, dtype=attention_mask.dtype)], dim=1)
        # 5) Return all speculated tokens and draft cumulative probabilities
        return all_token_ids, None