import torch
from .util import norm_logits, sample


class KVCacheModel():
    def __init__(self, model : torch.nn.Module, temperature : float = 1, top_k : int = 0, top_p : float = 0) -> None:
        self._model = model
        self._past_key_values = None
        self._prob_history = None

        self._temperature = temperature
        self._top_k = top_k
        self._top_p = top_p

    def _forward_with_kvcache(self, input_ids : torch.Tensor) -> torch.Tensor:
        if self._past_key_values is None:
            outputs = self._model(input_ids)
            self._prob_history = outputs.logits[:, :, :self.vocab_size]
            for i in range(self._prob_history.shape[-2]):   
                self._prob_history[:, i, :] = norm_logits(self._prob_history[:, i, :], self._temperature, self._top_k, self._top_p)
            self._past_key_values = outputs.past_key_values
            last_q = self._prob_history[:, -1, :]
        else:
            # return the last token's logits
            cached_len = self._past_key_values.get_seq_length()
                
            last_input_id = input_ids[:, cached_len:]
            if last_input_id.dim() == 1:
                last_input_id = torch.unsqueeze(last_input_id, 0)
            
            outputs = self._model(last_input_id, past_key_values=self._past_key_values, use_cache=True)
            
            not_cached_q = outputs.logits[:, :, :self.vocab_size]
            
            if not_cached_q.dim() == 2:
                not_cached_q = torch.unsqueeze(not_cached_q, 0)
                
            for i in range(not_cached_q.shape[-2]):   
                not_cached_q[:, i, :] = norm_logits(not_cached_q[:, i, :], self._temperature, self._top_k, self._top_p)    
                
            self._prob_history = torch.cat([self._prob_history, not_cached_q], dim=1)
            
            last_q = not_cached_q[:, -1, :]
            self._past_key_values = outputs.past_key_values
        
        return last_q


    def _generate_with_kvcache(self, prefix : torch.Tensor, 
                                    gamma : int) -> torch.Tensor:
        """ forward the model gamma times

        Args:
            prefix (torch.Tensor): the prefix
            gamma (int): how many times approx guesses

        Returns:
            Torch.Tensor: prefix+generated tokens
        """
        x = prefix

        for _ in range(gamma):
            q = self._forward_with_kvcache(x)
            next_tok = sample(q)
            x = torch.cat((x, next_tok), dim=1)
        return x

    @torch.no_grad()
    def generate(self, input : torch.Tensor, gamma : int) -> torch.Tensor:
        output = self._generate_with_kvcache(input, gamma)
        return output
    
    @torch.no_grad()
    def rollback(self, end_pos : int):
        self._past_key_values.crop(end_pos)
        self._prob_history = self._prob_history[:, :end_pos, :]

    @torch.no_grad()
    def set_cache_to_jth_sample(self, j: int):
        if self._past_key_values is None:
            raise RuntimeError("Cache is not initialized. Run a forward pass first.")
        
        current_batch_size = self._prob_history.shape[0]
        if not (0 <= j < current_batch_size):
            raise ValueError(f"Index j ({j}) out of range [0, {current_batch_size})")
        
        # 1. 提取第j个样本的概率历史
        jth_history = self._prob_history[j:j+1]  # (1, seq_len, vocab_size)
        
        # 2. 获取原始缓存对象类型
        orig_cache_type = type(self._past_key_values)
        
        # 3. 提取第j个样本的KV缓存
        jth_kv_cache = []
        for layer_cache in self._past_key_values:
            k_cache, v_cache = layer_cache
            # 提取第j个样本的缓存
            jth_k_cache = k_cache[j:j+1]  # (1, num_heads, seq_len, head_dim)
            jth_v_cache = v_cache[j:j+1]
            jth_kv_cache.append((jth_k_cache, jth_v_cache))
        
        # 4. 复制扩展到原始批次大小
        self._prob_history = jth_history.expand(current_batch_size, -1, -1)
        
        # 5. 复制KV缓存到原始批次大小，保持原始类型
        new_kv_cache = []
        for layer_cache in jth_kv_cache:
            jth_k_cache, jth_v_cache = layer_cache
            # 复制缓存到原始批次大小
            new_k_cache = jth_k_cache.repeat(current_batch_size, 1, 1, 1)
            new_v_cache = jth_v_cache.repeat(current_batch_size, 1, 1, 1)
            new_kv_cache.append((new_k_cache, new_v_cache))
        
        # 6. 重建原始缓存对象类型
        self._past_key_values = orig_cache_type(new_kv_cache)