import torch
import torch.nn as nn
from typing import Any, List

class TokenwiseTemperatureScaling(nn.Module):
    def __init__(
        self, 
        base_model: nn.Module, 
        vocab_size: int = None, 
        temp: float = 1.0, 
        top_p_token_ids: list[int] = None
    ):
        super().__init__()
        self.base_model = base_model
        self.top_p_token_ids = top_p_token_ids
        
        if vocab_size is None:
            # Try getting from HuggingFace model config
            if hasattr(base_model, 'config') and hasattr(base_model.config, 'vocab_size'):
                vocab_size = base_model.config.vocab_size
            
            # Try getting from output embeddings (usually LM head)
            elif hasattr(base_model, 'get_output_embeddings'):
                output_embeddings = base_model.get_output_embeddings()
                if output_embeddings is not None and hasattr(output_embeddings, 'out_features'):
                    vocab_size = output_embeddings.out_features
            
            # Try direct vocab_size attribute
            elif hasattr(base_model, 'vocab_size'):
                vocab_size = base_model.vocab_size
            
            # Try getting from input embeddings
            elif hasattr(base_model, 'get_input_embeddings'):  
                input_embeddings = base_model.get_input_embeddings()
                if input_embeddings is not None:
                    if hasattr(input_embeddings, 'num_embeddings'):
                        vocab_size = input_embeddings.num_embeddings
                    elif hasattr(input_embeddings, 'weight') and input_embeddings.weight is not None:
                        vocab_size = input_embeddings.weight.shape[0]
            if vocab_size is None:
                raise ValueError(
                    "Cannot automatically obtain vocab_size from base_model. "
                    "Please ensure base_model has 'config.vocab_size', 'get_output_embeddings().out_features', "
                    "'vocab_size' attribute, or 'get_input_embeddings().num_embeddings' / 'get_input_embeddings().weight.shape[0]', "
                    "or explicitly provide the vocab_size parameter."
                )
        
        self.vocab_size = vocab_size
        # Initialize temperature vector with the specified temp value
        if top_p_token_ids is not None:
            self.top_p_token_ids = torch.tensor(top_p_token_ids, dtype=torch.long)
            self.register_buffer("fixed_temp", torch.ones(vocab_size))
            self._temp = nn.Parameter(torch.full((len(top_p_token_ids),), temp))
        else:
            self.top_p_token_ids = None
            self._temp = nn.Parameter(torch.full((vocab_size,), temp))
    
    @property
    def temp(self):
        if self.top_p_token_ids is not None:
            temp_full = self.fixed_temp.clone()
            temp_full[self.top_p_token_ids] = self._temp
            return temp_full
        else:
            return self._temp

    def get_trainable_temp(self):
        return [self._temp]

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        temp_vec = self.temp.to(next(self.base_model.parameters()).device)

        if (temp_vec.data <= 0).any().item():
            non_positive_temps = temp_vec.data[temp_vec.data <= 0]
            non_positive_indices = torch.where(temp_vec.data <= 0)[0]
            raise ValueError(
                "All elements in the temperature vector must be positive. "
                f"Found non-positive values: {non_positive_temps} at indices: {non_positive_indices}."
            )

        with torch.no_grad():
            out = self.base_model(*args, **kwargs)

        if not hasattr(out, 'logits'):
            raise AttributeError("The output of base_model does not have a 'logits' attribute.")

        logits = out.logits

        if logits.shape[-1] != temp_vec.shape[0]:
            raise ValueError(
                f"The last dimension of logits ({logits.shape[-1]}) must match "
                f"the length of temperature vector ({temp_vec.shape[0]}) "
                f"(i.e., vocab_size: {self.vocab_size})."
            )

        scaled_logits = logits / temp_vec

        try:
            other_attrs = {k: v for k, v in out.items() if k != 'logits'}
            return out.__class__(logits=scaled_logits, **other_attrs)
        except AttributeError:
            raise RuntimeError(
                f"Cannot reconstruct output object of type {type(out)}. Please ensure it supports .items() "
                f"and its constructor accepts the form (logits=..., **other_attributes)."
            )