from typing import Dict, List

import torch

from src.measure_maps.base import MeasureMap


class GreedyAnchorMap(MeasureMap):
    """Maps embeddings to measures using greedily selected anchor points.

    The algorithm works by:
    1. Selecting an initial anchor that is furthest from the mean of all embeddings
    2. Iteratively selecting k additional anchors, where each new anchor maximizes
       its minimum distance to all previously selected anchors
    3. Using the first anchor as origin and creating k normalized direction vectors
       from the origin to each remaining anchor

    This ensures that:
    1. Anchors are well-spread in the embedding space
    2. Direction vectors capture major axes of variation
    3. Measures are interpretable as relative distances to anchor points
    """

    @torch.no_grad()
    def adapt(self, embeddings: torch.Tensor) -> Dict[str, float]:
        """Adapt mapping using greedy anchor selection.

        Args:
            embeddings: Tensor of shape (n_policies, embedding_dim)
        """
        prev_origin, prev_directions = self.origin.clone(), self.directions.clone()
        embeddings = embeddings.to(self.device)
        # Compute mean embedding
        mean = embeddings.mean(0)

        # Select first anchor as furthest point from mean
        dists = torch.norm(embeddings - mean.unsqueeze(0), dim=1)
        anchors = [embeddings[torch.argmax(dists)]]

        # Greedily select remaining anchors
        for _ in range(self.measure_dim):
            # For each candidate, compute minimum distance to existing anchors
            dists = torch.min(
                torch.stack(
                    [torch.norm(embeddings - a.unsqueeze(0), dim=1) for a in anchors]
                ),
                dim=0,
            )[0]

            # Select candidate with maximum minimum distance
            next_anchor = embeddings[torch.argmax(dists)]
            anchors.append(next_anchor)

        # Set up coordinate system
        self.origin = anchors[0]
        self.directions = []

        # Create normalized direction vectors from origin to other anchors
        for i in range(1, len(anchors)):
            v = anchors[i] - self.origin
            v = v / torch.norm(v)  # normalize
            self.directions.append(v)

        self.directions = torch.stack(self.directions)

        metrics = self.compute_change_metrics(prev_directions, prev_origin)
        return metrics


class SubspaceAnchorMap(MeasureMap):
    """Maps embeddings to measures using anchors maximizing distance to subspace.

    The algorithm works by:
    1. Selecting an initial anchor furthest from the mean
    2. Iteratively selecting k additional anchors, where each new anchor maximizes
       its distance to the subspace spanned by previous direction vectors
    3. Using the first anchor as origin and creating k normalized direction vectors

    This ensures that:
    1. Each new direction is as orthogonal as possible to previous directions
    2. The measure space captures independent dimensions of variation
    """

    @torch.no_grad()
    def distance_to_subspace(
        self, point: torch.Tensor, basis: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute the distance from a point to the linear subspace spanned by a set of basis vectors.

        The algorithm works by:
        1. Converting the basis vectors into a matrix V where each column is a basis vector
        2. Computing the orthogonal projection matrix P = V(V^TV)^+V^T using the pseudoinverse
        3. The projection onto the orthogonal complement is (I-P)x where x is the point
        4. The distance is the norm of this projection

        The pseudoinverse is used instead of the regular inverse to handle cases where
        the basis vectors may be linearly dependent.

        Args:
            point (torch.Tensor): The point to compute distance from. Shape: (n,)
            basis (List[torch.Tensor]): List of basis vectors spanning the subspace.
                Each vector should have shape (n,). The vectors need not be linearly independent.

        Returns:
            torch.Tensor: The shortest distance from the point to any point in the subspace.
                Returns a scalar tensor.

        Note:
            This implementation assumes the subspace passes through the origin. This is
            ensured by subtracting self.origin from the vectors in adapt().


        Mathematical background:
            Given a subspace W = span{v₁, ..., vₖ}, the distance from a point p to W is:
            dist(p, W) = ‖p - proj_W(p)‖ = ‖(I - P)p‖
            where P = V(V^TV)^+V^T is the orthogonal projection matrix onto W.
        """
        if len(basis) == 0:
            return torch.norm(point)
        # Stack basis vectors into a matrix where each column is a basis vector
        V = torch.stack(
            basis, dim=1
        )  # Shape: (n, k) where n is dimension, k is num vectors

        # Compute projection matrix using pseudoinverse (handles linear dependence)
        # V V^+ = V (V^T V)^+ V^T
        # TODO: This may be numerically unstable (also I'm not sure about the theory)
        #   See if we could do this with QR decomposition
        P = V @ torch.linalg.pinv(V)  # Shape: (n, n)

        # Project the point onto the orthogonal complement using (I - P)
        # The distance is the norm of this projection
        I = torch.eye(V.size(0), device=point.device)
        distance = torch.norm((I - P) @ point)

        return distance

    @torch.no_grad()
    def adapt(self, embeddings: torch.Tensor) -> None:
        """Adapt mapping using subspace distance maximization.

        Args:
            embeddings: Tensor of shape (n_policies, embedding_dim)
        """

        prev_origin, prev_directions = self.origin.clone(), self.directions.clone()
        embeddings = embeddings.to(self.device)
        # Select first anchor furthest from mean
        mean = embeddings.mean(0)
        dists = torch.norm(embeddings - mean.unsqueeze(0), dim=1)
        anchors = [embeddings[torch.argmax(dists)]]
        self.origin = anchors[0]

        # Initialize basis for computing distances
        basis = []

        # Select remaining anchors
        for _ in range(self.measure_dim):
            # Compute distances to current subspace
            # TODO: pseudoinverse is being computed for each embedding from scratch.
            # We should somehow cache it to avoid recomputation.
            dists = torch.tensor(
                [
                    self.distance_to_subspace(emb - self.origin, basis)
                    for emb in embeddings
                ]
            )

            # Select point with maximum distance
            next_anchor = embeddings[torch.argmax(dists)]
            anchors.append(next_anchor)

            # Add normalized direction to basis
            v = next_anchor - self.origin
            v = v / torch.norm(v)
            basis.append(v)

        self.directions = torch.stack(basis)

        metrics = self.compute_change_metrics(prev_directions, prev_origin)
        return metrics
