from abc import ABC, abstractmethod
from typing import Dict

import torch


class MeasureMap(ABC):
    """Base class for mapping high-dimensional embeddings to low-dimensional measures.

    A measure map defines a transformation from a high-dimensional embedding space to
    a lower-dimensional measure space. The transformation is defined by:
    1. An origin point in the embedding space
    2. A set of direction vectors that span the measure space

    For an embedding e, the measure is computed as:
        m_i = <e - origin, direction_i>
    where <,> denotes inner product.

    The measure map can be adapted using a set of example embeddings to better capture
    the structure of the policy space.
    """

    def __init__(self, measure_dim: int, embedding_dim: int, device):
        """Initialize measure map.

        Args:
            measure_dim: Dimension of the measure space (k)
            embedding_dim: Dimension of the embedding space (n)
            device: cuda or cpu
        """
        self.measure_dim = measure_dim
        self.embedding_dim = embedding_dim
        self.device = device
        self._initialize_random_mapping()

    @torch.no_grad()
    def _initialize_random_mapping(self):
        """Initialize a random orthogonal mapping.

        The random mapping:
        1. Uses origin at zero
        2. Creates k random orthogonal directions by:
           a. Sampling from standard normal distribution
           b. Using QR decomposition to orthogonalize
        This ensures:
        1. All embedding dimensions are considered equally
        2. Direction vectors are orthonormal
        3. Initial mapping is random but well-conditioned
        """
        self.origin = torch.zeros(self.embedding_dim).to(self.device)

        # Sample random matrix and orthogonalize
        random_matrix = torch.randn(self.embedding_dim, self.measure_dim).to(
            self.device
        )
        Q, R = torch.linalg.qr(random_matrix)

        # Use first k columns of Q as directions
        # Q is already orthonormal by construction
        self.directions = Q[:, : self.measure_dim].T

    @abstractmethod
    def adapt(self, embeddings: torch.Tensor) -> Dict[str, float]:
        """Adapt the mapping using current policy embeddings.

        Args:
            embeddings: Tensor of shape (n_policies, embedding_dim)
                      Each row is the embedding of a policy

        Returns:
            metrics: dictionary containing useful info about the change in the affine map
        """
        pass

    @torch.no_grad()
    def __call__(self, embedding: torch.Tensor) -> torch.Tensor:
        """Map an embedding to measure space.

        The measure is computed by:
        1. Centering the embedding with respect to the origin
        2. Computing inner products with each direction vector

        Args:
            embedding: Tensor of shape (embedding_dim,)

        Returns:
            measures: Tensor of shape (measure_dim,)

        Raises:
            ValueError: If the map hasn't been initialized by calling adapt
        """
        if self.origin is None or self.directions is None:
            raise ValueError("Measure map not initialized.")

        centered = embedding - self.origin
        return torch.matmul(self.directions, centered)

    @torch.no_grad()
    def compute_change_metrics(
        self,
        prev_directions: torch.Tensor,
        prev_origin: torch.Tensor,
    ) -> Dict[str, float]:
        """Compute metrics describing how the map has changed.

        Args:
            prev_directions: Previous direction vectors
            prev_origin: Previous origin vector

        Returns:
            Dictionary containing metrics:
            - origin_displacement: Distance origin has moved

            - rotation_angle: Average principal angle between old and new subspaces (radians)
                Near 0: stable subspace
                Near π/2: major rotation/change in subspace

            - direction_orthogonality: Average absolute off-diagonal element in Gram matrix
                0: perfectly orthogonal directions
                >0: directions becoming correlated

            - singular_value_ratio: Ratio of largest to smallest singular value
                Near 1: well-balanced directions
                Large values: some directions dominate others

            - effective_rank: Number of singular values above threshold
                Should be close to measure_dim
                Lower values indicate some directions aren't being used
        """
        metrics = {}

        # Origin displacement (in embedding space)
        metrics["mmap_origin_displacement"] = torch.norm(
            self.origin - prev_origin
        ).item()

        # Subspace rotation via principal angles
        product = prev_directions @ self.directions.T
        singular_values = torch.linalg.svd(product)[1]
        principal_angles = torch.arccos(torch.clamp(singular_values, -1, 1))
        metrics["mmap_rotation_angle"] = principal_angles.mean().item()

        # Direction orthogonality via Gram matrix
        gram = self.directions @ self.directions.T
        off_diag = torch.abs(gram - torch.eye(gram.shape[0], device=gram.device))
        metrics["mmap_direction_orthogonality"] = off_diag.mean().item()

        # Singular value analysis
        singular_values = torch.linalg.svd(self.directions)[1]
        metrics["mmap_singular_value_ratio"] = (
            singular_values[0] / singular_values[-1]
        ).item()

        # Effective rank (threshold-based)
        threshold = 1e-2 * singular_values[0]  # Relative to largest singular value
        metrics["mmap_effective_rank"] = torch.sum(singular_values > threshold).item()

        return metrics
