from typing import Dict
import torch
from src.measure_maps.base import MeasureMap


class CalibratedWeightedPCAMap(MeasureMap):
    def __init__(self, measure_dim: int, embedding_dim: int, device):
        super().__init__(measure_dim, embedding_dim, device)
        self.A = self.directions.clone().float()
        self.b = torch.zeros(self.measure_dim, device=self.device).float()

    @torch.no_grad()
    def adapt(self, embeddings: torch.Tensor, scores: torch.Tensor) -> Dict[str, float]:
        prev_origin = self.origin.clone()
        prev_directions = self.directions.clone()

        embeddings = embeddings.to(self.device)
        scores = scores.to(self.device)

        # Normalize scores to be positive and sum to 1.
        scores = scores - scores.min() + 1e-3
        scores /= scores.max()
        min_contrib = 1.0 / scores.shape[0]
        scores = torch.maximum(scores, torch.tensor(min_contrib, device=self.device))
        scores /= scores.sum()

        # Compute weighted mean (origin) in the embedding space.
        self.origin = (scores.unsqueeze(1) * embeddings).sum(dim=0).float()

        # Center embeddings.
        centered = embeddings - self.origin

        # Weighted PCA: scale each centered sample by sqrt(score).
        weighted_centered = centered * torch.sqrt(scores).unsqueeze(1)
        U, S, V = torch.svd(weighted_centered)
        new_directions = V[:, : self.measure_dim].T
        # (If there are fewer than measure_dim dimensions, fill remaining with previous directions.)
        if new_directions.shape[0] < self.measure_dim:
            extra = self.measure_dim - new_directions.shape[0]
            new_directions = torch.cat(
                [new_directions, prev_directions[-extra:]], dim=0
            )
        self.directions = new_directions.float()

        # Uncalibrated projections
        projections = (self.directions @ centered.T).T  # shape: (N, measure_dim)

        # Calibration using regular quantiles (not weighted) along each measure dimension.
        q_low_prob, q_high_prob = 0.05, 0.95
        q_low = torch.quantile(projections, q_low_prob, dim=0)  # shape: (measure_dim,)
        q_high = torch.quantile(projections, q_high_prob, dim=0)

        # Compute scaling factors to map [q_low, q_high] to [-1, 1]:
        scale = 2.0 / (q_high - q_low).clamp(min=1e-6)
        # Compute calibration constant so that: scale * q_low + c = -1.
        calib_c = -1.0 - scale * q_low

        # Compose the final affine mapping:
        # T(x) = diag(scale) * directions * (x - origin) + calib_c,
        # which can be written as T(x) = (diag(scale)*directions)*x + (calib_c - diag(scale)*directions*origin).
        D = torch.diag(scale)
        A_final = D @ self.directions  # shape: (measure_dim, embedding_dim)
        b_final = calib_c - A_final @ self.origin

        self.A = A_final.float()  # Final mapping matrix (measure_dim, embedding_dim)
        self.b = b_final.float()  # Final bias vector (measure_dim,)

        metrics = self.compute_change_metrics(prev_directions, prev_origin)
        metrics["calibration_quantile_range"] = q_high_prob - q_low_prob
        return metrics

    @torch.no_grad()
    def __call__(self, embedding: torch.Tensor) -> torch.Tensor:
        embedding = embedding.to(self.device).float()
        if embedding.dim() == 1:
            return torch.matmul(self.A, embedding) + self.b
        elif embedding.dim() == 2:
            # batched
            return torch.matmul(embedding, self.A.T) + self.b
