import torch
import torch.nn as nn
import torch.nn.functional as F


class Recalibrator(nn.Module):
    def __init__(self, K: int, input_dim: int):
        """
        Args:
            num_classes (int): Number of output classes.
            K (int): Number of correction directions (actions).
            input_dim (int): Dimensionality of the input feature vector x.
        """
        super(Recalibrator, self).__init__()
        self.K = K
        self.input_dim = input_dim

        # Learnable weight vectors for softmax-based direction selection
        self.w = nn.Parameter(torch.randn(K, input_dim))  # [K, D]

    def forward(self, x: torch.Tensor, p_prev: torch.Tensor, y_true: torch.Tensor):
        """
        Perform recalibration on model predictions.

        Args:
            x: [B, D] input features (before classifier)
            p_prev: [B, C] model's predicted distribution (softmax)
            y_true: [B, C] one-hot ground truth

        Returns:
            p_calibrated: [B, C] calibrated prediction
            info_dict: dict of intermediate values
        """
        B, C = p_prev.shape
        K = self.K

        # 1. Residual: true - predicted
        residual = y_true - p_prev  # [B, C]

        # 2. Compute soft attention b(x, a) = softmax(<x, w_a>)
        logits_b = x @ self.w.T  # [B, K]
        b_xa = torch.softmax(logits_b, dim=-1)  # [B, K]

        # 3. Compute global matrices R (K×C) and D (K×K)
        R = torch.zeros(K, C, device=x.device)
        D = torch.zeros(K, K, device=x.device)

        for a in range(K):
            b_a = b_xa[:, a].unsqueeze(1)  # [B, 1]
            R[a] = torch.mean(b_a * residual, dim=0)  # [C]
            for a_p in range(K):
                b_ap = b_xa[:, a_p].unsqueeze(1)
                D[a, a_p] = torch.mean((b_a * b_ap).squeeze())  # scalar

        # 4. Compute correction vector for each input
        D_inv = torch.linalg.pinv(
            D + 1e-6 * torch.eye(K, device=D.device)
        )  # regularized inverse
        b_proj = b_xa @ D_inv.T  # [B, K]
        correction = b_proj @ R  # [B, C]

        # 5. Calibrate logits: logits + correction --> softmax
        logits_adjusted = torch.log(p_prev + 1e-6) + correction  # [B, C]
        p_calibrated = torch.softmax(logits_adjusted, dim=-1)

        return p_calibrated, {"R": R, "D": D, "b": b_xa}
