from typing import Dict

import torch

from src.measure_maps.base import MeasureMap


class PCAMap(MeasureMap):
    """Maps embeddings to measures using PCA projection.

    The algorithm:
    1. Centers the embeddings by subtracting their mean
    2. Computes the SVD of the centered embeddings
    3. Uses the top k right singular vectors as direction vectors

    This ensures:
    1. The measure space captures the directions of maximum variance
    2. The direction vectors are orthonormal
    3. The measures are uncorrelated

    Mathematical guarantees:
    1. Minimizes reconstruction error among all k-dimensional linear projections
    2. Preserves pairwise distances as much as possible in k dimensions
    3. Captures maximum variance in the data
    """

    @torch.no_grad()
    def adapt(self, embeddings: torch.Tensor) -> Dict[str, float]:
        """Adapt mapping using PCA.

        Args:
            embeddings: Tensor of shape (n_policies, embedding_dim)
        """

        prev_origin, prev_directions = self.origin.clone(), self.directions.clone()
        embeddings = embeddings.to(self.device)

        # Center embeddings
        self.origin = embeddings.mean(0)
        centered = embeddings - self.origin.unsqueeze(0)

        # Compute SVD
        U, S, V = torch.svd(centered)

        # Handle case where there are fewer singular vectors than measure_dim
        available_dims = min(V.shape[1], self.measure_dim)

        # Create new directions tensor with the same shape as before
        new_directions = torch.zeros_like(self.directions)

        # Use available right singular vectors
        new_directions[:available_dims] = V[:, :available_dims].T

        # For any missing dimensions keep previous directions for the remaining dimensions
        if available_dims < self.measure_dim:
            new_directions[available_dims:] = prev_directions[available_dims:]

        self.directions = new_directions

        metrics = self.compute_change_metrics(prev_directions, prev_origin)
        return metrics
