import functools
from typing import Callable, Optional, Tuple

import torch

from latent_invariances.utils.wasserstein import wasserstein_distance_vec


def abs_to_rel(
    anchors: torch.Tensor,
    points: torch.Tensor,
    normalizing_func: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]],
    dist_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
    relative_points = []

    if normalizing_func is not None:
        anchors, points = normalizing_func(anchors=anchors, points=points)

    for point in points:
        current_rel_point = []
        for anchor in anchors:
            current_rel_point.append(dist_func(point=point, anchor=anchor))
        relative_points.append(current_rel_point)
    return torch.as_tensor(relative_points, dtype=anchors.dtype)


def basis_change_lstsq(anchors: torch.Tensor, points: torch.Tensor) -> torch.Tensor:
    # Center the data
    anchor_mean = torch.mean(anchors, dim=0)
    centered_points = points - anchor_mean
    centered_anchors = anchors - anchor_mean

    # Normalize the centered points and anchors
    normalized_points = centered_points / torch.norm(centered_points, dim=1, keepdim=True)
    normalized_anchors = centered_anchors / torch.norm(centered_anchors, dim=1, keepdim=True)

    try:
        return torch.linalg.lstsq(normalized_anchors.T, normalized_points.T)[0].T
    except RuntimeError:
        return torch.zeros((points.shape[0], anchors.shape[0]))


def abs_to_rel_cosine(
    anchors: torch.Tensor,
    points: torch.Tensor,
) -> torch.Tensor:
    norm_anchors = torch.nn.functional.normalize(anchors, dim=-1)
    norm_points = torch.nn.functional.normalize(points, dim=-1)

    return norm_points @ norm_anchors.T


def abs_to_rel_center_cosine(
    anchors: torch.Tensor,
    points: torch.Tensor,
) -> torch.Tensor:
    anchors = anchors - points.mean(dim=0)
    points = points - points.mean(dim=0)

    norm_anchors = torch.nn.functional.normalize(anchors, dim=-1)
    norm_points = torch.nn.functional.normalize(points, dim=-1)

    return norm_points @ norm_anchors.T


def abs_to_rel_lp(
    anchors: torch.Tensor,
    points: torch.Tensor,
    p: int,
) -> torch.Tensor:
    return torch.cdist(points, anchors, p=p)


def abs_to_rel_normalized_euclidean(
    anchors: torch.Tensor,
    points: torch.Tensor,
) -> torch.Tensor:
    anchors = anchors - points.mean(dim=0)
    points = points - points.mean(dim=0)

    norm_anchors = torch.nn.functional.normalize(anchors, dim=-1)
    norm_points = torch.nn.functional.normalize(points, dim=-1)

    return torch.cdist(norm_points, norm_anchors, p=2)


def abs_to_rel_std_euclidean(
    anchors: torch.Tensor,
    points: torch.Tensor,
) -> torch.Tensor:
    anchors = anchors - points.mean(dim=0)
    points = points - points.mean(dim=0)

    norm_anchors = anchors / points.std(dim=0)
    norm_points = points / points.std(dim=0)

    return torch.cdist(norm_points, norm_anchors, p=2)


def abs_to_rel_wasserstein(
    anchors: torch.Tensor,
    points: torch.Tensor,
    batch_size: int = 1024,
) -> torch.Tensor:
    points = points.softmax(dim=-1)
    anchors = anchors.softmax(dim=-1)
    wass_dists = []
    for chunked_points in torch.chunk(points, chunks=batch_size, dim=0):
        repeated_points = chunked_points.repeat_interleave(repeats=anchors.shape[0], dim=0).contiguous()
        repeated_anchors = anchors.repeat(repeats=(chunked_points.shape[0], 1)).contiguous()
        wass_dists.append(wasserstein_distance_vec(repeated_points, repeated_anchors))
    return torch.cat(wass_dists, dim=0).reshape(points.shape[0], anchors.shape[0])


PROJECTION_TYPE = {
    "Wasserstein": abs_to_rel_wasserstein,
    "Cosine": abs_to_rel_cosine,
    "Center Cosine": abs_to_rel_center_cosine,
    "Euclidean": functools.partial(abs_to_rel_lp, p=2),
    "Normalized Euclidean": abs_to_rel_normalized_euclidean,
    "Standardized Euclidean": abs_to_rel_std_euclidean,
    "L1": functools.partial(abs_to_rel_lp, p=1),
    # "L3": functools.partial(abs_to_rel_lp, p=3),
    "CoB Lstsq": basis_change_lstsq,
    "Absolute": lambda points, **kwargs: points,
    "Normalized Absolute": lambda points, **kwargs: torch.nn.functional.normalize(points, dim=-1),
}
