import torch
import torch.nn as nn
import torch.nn.functional as F


NUM_SLICES = 256


class LpNorm(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.p = p
       
    def forward(self, x):
        cost = torch.norm(x, p=self.p, dim=-1)
        return cost.mean()


class ProjectedErrorFunction(nn.Module):
    def __init__(self):
        super().__init__()
        self.per_module = PerModule.apply
       
    def forward(self, x):
        N, C = x.shape
        samples = torch.randn_like(x)

        unit_vector = torch.randn(x.shape[1], NUM_SLICES, device=x.device)
        unit_vector = torch.nn.functional.normalize(  # each col has a norm 1
            unit_vector, p=2, dim=0)
        x = torch.matmul(x, unit_vector)

        cost = self.per_module(x)

        return cost
    
class PerModule(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        
        # pass dummy loss
        return torch.ones(1, device='cuda')

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        xs, = ctx.saved_tensors
        xs = xs / (2**0.5)

        batch_size = xs.size(0)
        gauss = torch.distributions.normal.Normal(0, 1)
       
        cdf = gauss.cdf(xs)
       
        grad = (2*cdf - 1)/ float(batch_size)
        grad = grad.to('cuda')
       
        cloned = grad_output.clone()
        return cloned*grad / float(NUM_SLICES)

    
class WassDist(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        samples = torch.randn_like(x)
        wdist = sliced_wasserstein_distance(x, samples, NUM_SLICES, 1)
        return wdist


def sliced_wasserstein_distance(x, y, sample_cnt, p):
    """Calculated a stochastic sliced wasserstein distance between x and y.

    Arguments:
        x {torch.Tensor} -- A tensor of shape [N,*]. Samples from the distribution p(X)
        y {torch.Tensor} -- A tensor of shape [N,*]. Samples from the distribution p(Y)
        sample_cnt {int} -- A number of samples to estimate the distance
        p {int} -- L_p is used to calculate sliced w-dist
        weight_x {torch.Tensor} -- A tensor of shape [N] or None
        weight_y {torch.Tensor} -- A tensor of shape [N] or None
    Returns:
        scalar -- The sliced wasserstein distance (with gradient)
    """

    unit_vector = torch.randn(x.shape[1], sample_cnt, device=x.device)
    unit_vector = torch.nn.functional.normalize(  # each col has a norm 1
        unit_vector, p=2, dim=0
    )
    x = torch.matmul(x, unit_vector)  #  [N,D] * [D, samples] = [N,samples]
    y = torch.matmul(y, unit_vector)

    sorted_x, sort_index_x = x.sort(dim=0)  # [N,samples]
    sorted_y, sort_index_y = y.sort(dim=0)  # [N,samples]

    w_dist = (sorted_x-sorted_y).norm(p=p, dim=0).mean()

    return w_dist
   