from gsw import GSW
import numpy as np
import torch
from torch import optim

class GSW_extended(GSW):

    def __init__(self, ftype, nofbases=2000, sigma=.2):
        super().__init__(ftype, nofbases, [], sigma)
        if self.ftype is 'kernel':
            self.weights = None
            self.centers = None
            self.dp = nofbases
            self.sigma = sigma

    def random_fourier(self,X):
        return torch.cos(torch.matmul(X,self.weights)+self.centers)*np.sqrt(2.0/self.dp)

    def get_slice(self,X,theta):
        ''' 
        Slices samples from distribution X~P_X
        Inputs:
        X:  Nxd matrix of N data samples
        theta: parameters of g (e.g., a d vector in the linear case)
        '''
        if self.ftype=='kernel':
            return self.linear(self.random_fourier(X),theta)
        else:
            return super().get_slice(X,theta)

    def simplex_norm(self,beta):
        '''
        Transform beta to a discrete distribution (nonnegative and sums to 1)
        '''
        return  torch.nn.functional.softplus(beta)/torch.sum(torch.nn.functional.softplus(beta))

    def kernel_gsw(self,X,Y,theta=None):
        '''
        '''
        N,dn = X.shape
        M,dm = Y.shape
        dp = self.dp
        assert dn==dm
        dim = self.nofprojections
        if self.weights is None:
            self.weights = torch.randn((dn,dp),device=self.device,requires_grad=False)/self.sigma
            self.centers = torch.rand((dp),device=self.device,requires_grad=False)*2*3.14159274101
        if theta is None:
            theta=torch.randn((1,dp))
            theta=torch.stack([th/torch.sqrt((th**2).sum()) for th in theta])
        Xslices=self.get_slice(X,theta)
        Yslices=self.get_slice(Y,theta)
        Xslices_sorted=torch.sort(Xslices,dim=0)[0]
        Yslices_sorted= torch.sort(Yslices,dim=0)[0]
        return torch.sqrt(torch.sum((Xslices_sorted-Yslices_sorted)**2))

    def max_kernel_gsw(self,X,Y,iterations=50,lr=1e-4):
        N,dn = X.shape
        M,dm = Y.shape
        device = self.device
        assert dn==dm
        theta=torch.randn((1,self.dp),device=device,requires_grad=True)
        theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        self.theta=theta
        optimizer=optim.Adam([self.theta],lr=lr)
        total_loss=np.zeros((iterations,))
        for i in range(iterations):
            optimizer.zero_grad()
            loss=-self.kernel_gsw(X.to(self.device),Y.to(self.device),self.theta.to(self.device))
            total_loss[i]=loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            self.theta.data/=torch.sqrt(torch.sum(self.theta.data**2))
        return self.kernel_gsw(X.to(self.device),Y.to(self.device),self.theta.to(self.device))

    def gsw_weighted(self,X,Y,beta,theta=None): 
        ''' 
        Calculates GSW between two empirical distributions. 
        Y is weighted by nu which is proportional to beta through a softplus. 
        Note that the number of samples is assumed to be equal 
        (This is however not necessary and could be easily extended 
        for empirical distributions with different number of samples) 
        ''' 
        N,dn = X.shape 
        M,dm = Y.shape 
        P = beta.shape[0]
        assert dn==dm and M==N and P==N     
        if theta is None: 
            theta=self.random_slice(dn) 

        Xslices=self.get_slice(X,theta) 
        Yslices=self.get_slice(Y,theta) 

        Xslices_sorted=torch.sort(Xslices,dim=0)[0] 
        Yslices_sorted, indices = torch.sort(Yslices,dim=0) 
        nu = self.simplex_norm(beta)
        return torch.sqrt(torch.matmul(nu[indices].t(),(Xslices_sorted-Yslices_sorted)**2)) 


    def max_gsw_weighted(self,X,Y,beta,iterations=50,lr=1e-4):
        '''
        See gsw_weighted function
        '''
        N,dn = X.shape
        M,dm = Y.shape
        P = beta.shape[0]
        device = self.device
        assert dn==dm and M==N and P==N 
    #         if self.theta is None:
        if self.ftype=='linear':
            theta=torch.randn((1,dn),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        elif self.ftype=='poly':
            dpoly=self.homopoly(dn,self.degree)
            theta=torch.randn((1,dpoly),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        elif self.ftype=='circular':
            theta=torch.randn((1,dn),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
            theta.data*=self.radius
        self.theta=theta
        optimizer=optim.Adam([self.theta],lr=lr)
        total_loss=np.zeros((iterations,))
        for i in range(iterations):
            optimizer.zero_grad()
            loss=-self.gsw_weighted(X.to(self.device),Y.to(self.device),beta.to(self.device),self.theta.to(self.device))
            total_loss[i]=loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            self.theta.data/=torch.sqrt(torch.sum(self.theta.data**2))
        return self.gsw_weighted(X.to(self.device),Y.to(self.device),beta.to(self.device),self.theta.to(self.device))

    def kernel_gsw_weighted(self,X,Y,beta,theta=None):
        '''
        Special case of slicing with random Fourier basis (RFB) embedding.
        'dp' is the number of bases 
        '''
        N,dn = X.shape
        M,dm = Y.shape
        dp = self.dp
        assert dn==dm
        
        if self.weights is None:
            self.weights = torch.randn((dn,dp),device=self.device,requires_grad=False)/self.sigma
            self.centers = torch.rand((dp),device=self.device,requires_grad=False)*2*3.14159274101

        if theta is None:
            theta=torch.randn((1,dp))
            theta=torch.stack([th/torch.sqrt((th**2).sum()) for th in theta])

        Xslices=self.get_slice(X,theta)
        Yslices=self.get_slice(Y,theta)
        Xslices_sorted=torch.sort(Xslices,dim=0)[0] 
        Yslices_sorted, indices = torch.sort(Yslices,dim=0) 
        nu = self.simplex_norm(beta)
        return torch.sqrt(torch.matmul(nu[indices].t(),(Xslices_sorted-Yslices_sorted)**2)) 

    def max_kernel_gsw_weighted(self,X,Y,beta,iterations=50,lr=1e-4):
        '''
        See kernel_gsw_weighted function
        '''
        N,dn = X.shape
        M,dm = Y.shape
        device = self.device
        assert dn==dm
        theta=torch.randn((1,self.dp),device=device,requires_grad=True)
        theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        self.theta=theta
        optimizer=optim.Adam([self.theta],lr=lr)
        total_loss=np.zeros((iterations,))
        for i in range(iterations):
            optimizer.zero_grad()
            loss=-self.kernel_gsw_weighted(X.to(self.device),Y.to(self.device),beta.to(self.device),self.theta.to(self.device))
            total_loss[i]=loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            self.theta.data/=torch.sqrt(torch.sum(self.theta.data**2))
        return self.kernel_gsw_weighted(X.to(self.device),Y.to(self.device),beta.to(self.device),self.theta.to(self.device))