import torch
import torch.nn as nn

class TemperatureScaling(nn.Module):
    def __init__(self, base_model, temp: float = 1.0):  # default temperature = 1
        super().__init__()
        self.base_model = base_model
        self.temp = nn.Parameter(torch.tensor(float(temp)))

        self.base_model.eval()
        for param in self.base_model.parameters():
            param.requires_grad = False

    def forward(self, *args, **kwargs):
        if self.temp.item() <= 0:
            raise ValueError(
                f"Temperature should be positive. temp={self.temp.item()}。"
                f"Please check the model or the temperature scaling module."
                )
        
        with torch.no_grad():
            out = self.base_model(*args, **kwargs)

        if not hasattr(out, 'logits'):
            raise AttributeError("There is no 'logits' in base_model.")
        
        logits = out.logits / self.temp  # temperature scaling

        try:
            other_attrs = {k: v for k, v in out.items() if k != 'logits'}
            return out.__class__(logits=logits, **other_attrs)
        except AttributeError:
            raise RuntimeError(
                f"Cannot reconstruct output object of type {type(out)}. Please ensure it supports .items() "
                f"and its constructor accepts the form (logits=..., **other_attributes)."
            )
        