# recommended by
# https://arxiv.org/pdf/1511.08861.pdf

# requires pytorch-mssim
# https://github.com/VainF/pytorch-msssim

from pytorch_msssim import SSIM, MS_SSIM
import torch
import torch.nn as nn

# from utils.dataloaders import Cifar10_Inverse_Covariances


class Distance(nn.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, *args, **kwargs):
        return self.dist(*args, **kwargs)

    def dist(self, X, Y, *args, **kwargs):
        raise NotImplementedError()

    def get_config(self):
        raise NotImplementedError()

    def compute_distance_matrix(self, X, Y, *args, **kwargs):
        D = X.new_zeros((X.shape[0], Y.shape[0]))
        for y_idx in range(Y.shape[0]):
            y = Y[y_idx, :]
            D[:, y_idx] = self.dist(X, y.expand_as(X))
        return D


class SquaredEuclideanDistance(Distance):
    def dist(self, X, Y, *args, **kwargs):
        N = X.shape[0]
        diff = X.view(N, -1) - Y.view(N, -1)
        return torch.sum(diff**2, dim=1)


class LPDistance(Distance):
    def __init__(self, p=2.0):
        super().__init__()
        self.p = float(p)

    def dist(self, X, Y, *args, **kwargs):
        # X, Y batches [N, C, H, W]
        N = X.shape[0]
        return (X.view(N, -1) - Y.view(N, -1)).norm(p=self.p, dim=1)

    def get_config(self):
        return {"Distance": "LP", "p": self.p}


class SSIMDistance(Distance):
    def __init__(self, kernel_w=3, sigma=1.5, channels=1):
        super().__init__()
        self.kernel_w = kernel_w
        self.ssim_d = SSIM(
            win_size=kernel_w,
            win_sigma=sigma,
            data_range=1.0,
            size_average=False,
            channel=channels,
        )
        self.config = {"Distance": "SSIM", "win_size": kernel_w, "win_sigma": sigma}

    def dist(self, X, Y, *args, **kwargs):
        # X, Y batches [N, C, H, W]
        N = X.shape[0]
        d1 = torch.clamp_min(1 - (self.ssim_d(X, Y)).view(N, -1).mean(dim=1), 0.0)
        return d1


class MSSSIMDistance(Distance):
    def __init__(self, kernel_w=3, sigma=1.5, channels=1, weights=None):
        super().__init__()
        self.kernel_w = kernel_w
        self.sigma = sigma

        # number of weights determines the depth of the pyramid
        # standard are 5, too deep for MNist resolution
        if weights is None:
            self.weights = [0.0516, 0.32949, 0.34622, 0.27261]
        else:
            self.weights = weights
        self.ssim_d = MS_SSIM(
            win_size=kernel_w,
            win_sigma=sigma,
            data_range=1.0,
            channel=channels,
            weights=self.weights,
            size_average=False,
        )

        self.config = {"Distance": "SSIM", "win_size": kernel_w, "win_sigma": sigma}

    def dist(self, X, Y, *args, **kwargs):
        # X, Y batches [N, C, H, W]
        N = X.shape[0]
        d1 = torch.clamp_min(1 - (self.ssim_d(X, Y)).view(N, -1).mean(dim=1), 0.0)
        return d1


class ReconstructionLoss(Distance):
    def __init__(self, alpha=0.84, kernel_w=3, sigma=1.5, channels=1):
        super.__init__()
        self.alpha = alpha
        self.kernel_w = kernel_w
        self.ssim_d = SSIM(
            win_size=kernel_w,
            win_sigma=sigma,
            data_range=1.0,
            size_average=False,
            channel=channels,
        )
        self.config = {
            "Distance": "L1 + SSIM",
            "alpha": alpha,
            "win_size": kernel_w,
            "win_sigma": sigma,
        }

    def dist(self, X, Y, *args, **kwargs):
        # X, Y batches [N, C, H, W]
        N = X.shape[0]
        numel = X.shape[1] * X.shape[2] * X.shape[3]
        d1 = torch.clamp_min(1 - (self.ssim_d(X, Y)).view(N, -1).mean(dim=1), 0.0)
        d2 = (self.kernel_w / numel) * torch.sum(
            torch.abs(X.view(N, -1) - Y.view(N, -1)), dim=1
        )
        return self.alpha * d1 + (1.0 - self.alpha) * d2

    def get_config(self):
        return self.config


# multiscale
class MSReconstructionLoss(Distance):
    def __init__(self, alpha=0.84, kernel_w=3, sigma=1.5, channels=1, weights=None):
        super().__init__()
        self.alpha = alpha
        self.kernel_w = kernel_w
        self.sigma = sigma

        # number of weights determines the depth of the pyramid
        # standard are 5, too deep for MNist resolution
        if weights is None:
            self.weights = [0.0516, 0.32949, 0.34622, 0.27261]
        else:
            self.weights = weights

        self.config = {
            "Distance": "L1 + MS_SSIM",
            "alpha": alpha,
            "win_size": kernel_w,
            "win_sigma": sigma,
        }

    def dist(self, X, Y, *args, **kwargs):
        # X, Y batches [N, C, H, W]

        N = X.shape[0]
        numel = X.shape[1] * X.shape[2] * X.shape[3]

        weights = torch.FloatTensor(self.weights).to(X.device, dtype=X.dtype)
        d1 = torch.clamp_min(
            ms_ssim(
                X,
                Y,
                win_size=self.kernel_w,
                win_sigma=self.sigma,
                data_range=1.0,
                size_average=False,
                weights=weights,
            ),
            0.0,
        )
        d1 = 1.0 - d1.view(N, -1).mean(dim=1)

        d2 = (self.kernel_w / numel) * torch.sum(
            torch.abs(X.view(N, -1) - Y.view(N, -1))
        )

        return self.alpha * d1 + (1.0 - self.alpha) * d2

    def get_config(self):
        return self.config
