import math

import torch

EPSILON = 1e-8


def unsqueeze(x):
    if x.dim() == 1:
        return x.unsqueeze(1)
    else:
        return x


def similarity(x, y, kernel='linear'):
    gamma = 1 / x.size(1)
    if kernel == 'linear':
        out = x.matmul(y.t())
    elif kernel == 'rbf':
        out = torch.exp(-gamma * torch.cdist(x, y))
    elif kernel == 'poly':
        c = 1
        d = 3
        out = (gamma * x.matmul(y.t()) + c) ** d
    else:
        raise ValueError(kernel)
    return out.mean()


def squared_mmd(x, y, kernel='linear'):
    out1 = similarity(x, x, kernel)
    out2 = similarity(y, y, kernel)
    out3 = similarity(x, y, kernel)
    return torch.clamp_min(out1 + out2 - 2 * out3, 0)


def distance(x, y, metric='euclidean', kernel='linear', aggregation='mean'):
    if metric == 'mmd':
        out = squared_mmd(x, y, kernel).sqrt()
    elif metric == 'manhattan':
        out = torch.cdist(x, y, p=1)
    elif metric == 'euclidean':
        out = torch.cdist(x, y, p=2)
    elif metric == 'squared':
        out = torch.cdist(x, y, p=2) ** 2
    elif metric == 'chebyshev':
        out = torch.cdist(x, y, p=torch.inf)
    elif metric == 'cosine':
        out = 1 - x.matmul(y.t())
    else:
        raise ValueError(metric)

    if aggregation == 'mean':
        return out.mean()
    elif aggregation == 'none':
        assert metric != 'mmd'
        return out
    else:
        raise ValueError(aggregation)


class EmbNormalizer:
    def __init__(self, mode='tpsd'):
        self.mode = mode

    def __call__(self, emb_x, emb_y, emb_z):
        if self.mode == 'tpsd':
            emb_all = torch.cat([emb_x, emb_y, emb_z])
            mean = emb_all.mean(0)
            std = torch.norm(emb_all - mean) / math.sqrt(emb_all.size(0))
            emb_x = (emb_x - mean) / std
            emb_y = (emb_y - mean) / std
            emb_z = (emb_z - mean) / std
            return emb_x, emb_y, emb_z
        else:
            raise ValueError(self.mode)


class BaseLoss:
    def __init__(self, metric='mmd', kernel='linear', average='micro'):
        super().__init__()
        self.metric = metric
        self.kernel = kernel
        self.average = average

    def __call__(self, emb_x, emb_y, emb_z):
        if self.average == 'micro':
            emb_xy = torch.cat([emb_x, emb_y])
            return distance(emb_xy, emb_z, self.metric, self.kernel)
        elif self.average == 'macro':
            d1 = distance(emb_x, emb_z, self.metric, self.kernel)
            d2 = distance(emb_y, emb_z, self.metric, self.kernel)
            return (d1 + d2) / 2
        else:
            raise ValueError(self.average)


class MeanLoss:
    def __init__(self, metric, kernel=None):
        super().__init__()
        assert metric != 'mmd'
        self.metric = metric
        self.kernel = kernel

    def __call__(self, emb_x, emb_y, emb_z):
        mean_x = emb_x.mean(0, keepdim=True)
        mean_y = emb_y.mean(0, keepdim=True)
        d1 = distance(mean_x, emb_z, self.metric, self.kernel)
        d2 = distance(mean_y, emb_z, self.metric, self.kernel)
        return (d1 + d2) / 2
