from typing import Tuple

import torch
from scipy.stats import pearsonr, wasserstein_distance
from torch.nn.functional import cosine_similarity, mse_loss, softmax


def custom_cosine(point, anchor) -> float:
    return cosine_similarity(anchor.unsqueeze(0), point.unsqueeze(0))


def custom_euclidean(point, anchor) -> float:
    return torch.norm(point - anchor)


def custom_pearson(point, anchor) -> float:
    correlation, _ = pearsonr(point, anchor)
    return torch.nan_to_num(torch.as_tensor(correlation), nan=0)


def custom_mse(point, anchor) -> float:
    return mse_loss(point, anchor)


def custom_wasserstein(point, anchor) -> float:
    point = softmax(point, dim=0)
    anchor = softmax(anchor, dim=0)
    return wasserstein_distance(point, anchor)


def custom_l1(point, anchor) -> float:
    return torch.sum(torch.abs(point - anchor))


def custom_lp_distance(x, y, p=2):
    return (x - y).norm(p=p, dim=-1).mean()


def custom_normalize(anchors, points) -> Tuple[torch.Tensor, torch.Tensor]:
    anchors = anchors - points.mean(dim=0)
    points = points - points.mean(dim=0)

    return anchors, points
