from typing import Dict

import torch

from src.measure_maps.base import MeasureMap


class WeightedPCAMap(MeasureMap):
    """
    Maps embeddings to measures using weighted PCA projection.
    """

    @torch.no_grad()
    def adapt(self, embeddings: torch.Tensor, scores: torch.Tensor) -> Dict[str, float]:
        """Adapt mapping using Weighted PCA with adaptive clipping.

        Args:
            embeddings: Tensor of shape (n_policies, embedding_dim)
            scores: Tensor of shape (n_policies,)
        """
        prev_origin, prev_directions = self.origin.clone(), self.directions.clone()
        embeddings = embeddings.to(self.device)
        scores = scores.to(self.device)

        scores -= scores.min() - 1e-3  # [0, max]
        scores /= scores.max()  # [0, 1]

        # Ensure each point gets at least the minimum contribution
        scores = torch.maximum(
            scores, torch.tensor(1.0 / scores.shape[0], device=scores.device)
        )

        # Re-normalize the scores after clipping to sum to 1
        scores /= scores.sum()

        # Compute weighted mean
        self.origin = (scores.unsqueeze(1) * embeddings).sum(dim=0).float()

        centered_embeddings = embeddings - self.origin

        # Apply sqrt(weights) scaling for weighted covariance computation
        weighted_centered = centered_embeddings * scores.sqrt().unsqueeze(1)

        # Compute SVD on the weighted data
        U, S, V = torch.svd(weighted_centered)

        available_dims = min(V.shape[1], self.measure_dim)
        new_directions = torch.zeros_like(self.directions)
        new_directions[:available_dims] = V[:, :available_dims].T

        # For any missing dimensions keep previous directions for the remaining dimensions
        if available_dims < self.measure_dim:
            new_directions[available_dims:] = prev_directions[available_dims:]

        self.directions = new_directions.float()

        metrics = self.compute_change_metrics(prev_directions, prev_origin)
        return metrics
