import torch
import torch.nn as nn
import nvtx
import json
import os 

from .base import DraftModelBase

from ..utils.monkey_patch import CaptureAttentionContext

class TargetKVSDDraftModel(DraftModelBase):
    def forward(self, input_ids, with_softmax=False, *model_args, **kwargs):
        logits = self.model(input_ids, *model_args, **kwargs).logits
        if with_softmax:
            if self.draft_params.temperature == 1.0:
                logits = torch.softmax(logits, dim=-1)
            else:
                logits = torch.softmax(logits/self.draft_params.temperature, dim=-1)
            
        return logits
    
    def set_important_head_idx(self, filename):
        # check file existence
        if not os.path.exists(filename):
            print(f"Important heads file not found: {filename}, selecting all heads as important heads.")
            self.important_heads = torch.arange(self.model.config.num_attention_heads, device=self.device).unsqueeze(0).expand(self.model.config.num_hidden_layers, -1)
        else:
            with open(filename, 'r') as f:
                important_heads = json.load(f)
            self.important_heads = [important_heads[str(i)] for i in range(len(important_heads))]
            # if SRH_head_num is None, select all heads. if not, select top SRH_head_num heads
            if self.draft_params.generator_kwargs['SRH_head_num'] is None:
                self.important_heads = torch.tensor(self.important_heads, device=self.device)
            else:
                self.important_heads = torch.tensor(self.important_heads, device=self.device)[:,:self.draft_params.generator_kwargs['SRH_head_num']]
            self.important_heads = self.important_heads[self.important_layers.to("cpu")]
        print(f"self.important_heads:\n{self.important_heads}")

    def set_important_layers(self, important_layer):
        if type(important_layer) is list:
            self.important_layers = torch.tensor(important_layer, device=self.device)
        else:
            self.important_layers = torch.tensor([important_layer], device=self.device)
        print(f"self.important_layers:\n{self.important_layers}")
        # # select all layers as important layers
        # self.important_layers = torch.tensor(list(range(self.model.config.num_hidden_layers)), device=self.device)

    @torch.no_grad()
    def speculate(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
        # 1) Obtain necessary parameters
        device = input_ids.device
        dtype = self.model.lm_head.weight.dtype
        batch_size, input_len = input_ids.shape
        max_cache_len = getattr(self.past_key_values, "max_cache_len", None)
        capture_ctx = CaptureAttentionContext(self)

        # 2) Initialize kv_len
        with nvtx.annotate("Initialize kv_len"):
            kv_len = self.past_key_values.get_seq_length()
            # convert kv_len to int if it is a tensor
            if isinstance(kv_len, torch.Tensor):
                kv_len = kv_len.item()
        
        with capture_ctx:
            self._capture_enabled = False
            # 3) First forward pass
            # Note: There will be bugs for dynamic cache if we do not crop here.
            self.past_key_values.crop(kv_len)
            
            # Setup draft prefill timing if enabled by generator
            record_prefill = hasattr(self, '_prefill_events')
            if record_prefill:
                pf_start = torch.cuda.Event(enable_timing=True)
                pf_end = torch.cuda.Event(enable_timing=True)
                pf_start.record()
                
            with nvtx.annotate("ssm first forward", color="red"):
                # chunked prefill
                prefill_tokens = input_ids[:, kv_len:]
                prefill_length = prefill_tokens.size(1)
                prefill_chunk_size = (
                    self.draft_params.generator_kwargs.get("prefill_chunk_size", None)
                    if hasattr(self.draft_params, "generator_kwargs") else None
                )
                chunk_size = prefill_length if prefill_chunk_size is None else min(prefill_length, prefill_chunk_size)

                sampled_probs = None
                for start in range(0, prefill_length, chunk_size):
                    chunk = prefill_tokens[:, start:start + chunk_size]
                    current_kv_len = self.past_key_values.get_seq_length()
                    am_so_far = attention_mask[:, :current_kv_len + chunk.size(1)]  # (B, cur_len_so_far)
                    pos_ids = position_ids[:, start:start + chunk.size(1)]
                    cache_position = torch.arange(
                        current_kv_len, current_kv_len + chunk.size(1),
                        dtype=torch.long, device=input_ids.device
                    )
                    # last iteration
                    if start + chunk_size < prefill_length:
                        # does not need output logits, just update kv-cache
                        self.model(
                            chunk,
                            past_key_values=self.past_key_values,
                            attention_mask=am_so_far,
                            position_ids=pos_ids,
                            cache_position=cache_position,
                        )
                    else:
                        self._capture_enabled = True
                        sampled_probs = self.prefill_forward(
                            chunk,
                            with_softmax=True,
                            past_key_values=self.past_key_values,
                            logits_to_keep=1,
                            attention_mask=am_so_far,
                            position_ids=pos_ids,
                            cache_position=cache_position,
                        )

                    self.past_key_values.seq_len += chunk.size(1)
                kv_len = input_len
                self.past_key_values.seq_len = input_len

            if record_prefill:
                pf_end.record()
                self._prefill_events.append((pf_start, pf_end))

            with nvtx.annotate("sample nodes", color="green"):
                self.token_ids = torch.argmax(sampled_probs[:, -1, :], dim=-1, keepdim=True)

            # 4) Main loop
            root_id = input_ids[:, -1]
            total_tokens = 1 + self.draft_params.max_depth
            all_token_ids = torch.empty((batch_size, total_tokens), device=device, dtype=torch.long)
            all_token_ids[:, 0] = root_id
            all_token_ids[:, 1] = self.token_ids[:, 0]
            current_idx = 2

            pos_ids = pos_ids[:, -1:] + 1
            cache_position = cache_position[-1:] + 1
            
            attn_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=device, dtype=attention_mask.dtype)], dim=1)
            for depth_i in range(self.draft_params.max_depth-1):
                with nvtx.annotate("ssm forward", color="red"):
                    sampled_probs = self(
                        self.token_ids,
                        with_softmax=True,
                        past_key_values=self.past_key_values,
                        position_ids=pos_ids,
                        cache_position=cache_position,
                        attention_mask=attn_mask,
                    )

                with nvtx.annotate("sample nodes", color="green"):
                    self.token_ids = torch.argmax(sampled_probs[:, -1, :], dim=-1, keepdim=True)
                    # Update internal state
                self.past_key_values.seq_len += 1
                all_token_ids[:, current_idx] = self.token_ids[:, 0]
                current_idx += 1
                pos_ids = pos_ids[:, -1:] + 1
                cache_position = cache_position[-1:] + 1
                attn_mask = torch.cat([attn_mask, torch.ones((batch_size, 1), device=device, dtype=attention_mask.dtype)], dim=1)
        # 5) Return all speculated tokens and draft cumulative probabilities
        return all_token_ids, None