from gsw_extended import GSW_extended
import numpy as np
import torch
from torch import optim

class WGSB(GSW_extended):

    def sliced_bures(self,X,Y,theta=None):
        '''
        Calculates the generalized Bures distance between two empirical distributions.
        Note that the number of samples is assumed to be equal
        (This is however not necessary and could be easily extended
        for empirical distributions with different number of samples)
        '''
        N,dn = X.shape
        M,dm = Y.shape
        assert dn==dm
        if theta is None:
            theta=self.random_slice(dn)

        Xslices=self.get_slice(X,theta)
        Yslices=self.get_slice(Y,theta)
        return torch.square(torch.sqrt(torch.mean(Xslices**2))-torch.sqrt(torch.mean(Yslices**2)))

    def max_sliced_bures(self,X,Y,iterations=50,lr=1e-4):
        N,dn = X.shape
        M,dm = Y.shape
        device = self.device
        assert dn==dm
        if self.ftype=='linear':
            theta=torch.randn((1,dn),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        elif self.ftype=='poly':
            dpoly=self.homopoly(dn,self.degree)
            theta=torch.randn((1,dpoly),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        elif self.ftype=='circular':
            theta=torch.randn((1,dn),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
            theta.data*=self.radius
        self.theta=theta
        optimizer=optim.Adam([self.theta],lr=lr)
        total_loss=np.zeros((iterations,))
        for i in range(iterations):
            optimizer.zero_grad()
            loss=-self.sliced_bures(X.to(self.device),Y.to(self.device),self.theta.to(self.device))
            total_loss[i]=loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            self.theta.data/=torch.sqrt(torch.sum(self.theta.data**2))
#             print(loss.item())
#         plt.plot(-total_loss)
#         plt.show()
#         print(theta.data)
        return self.sliced_bures(X.to(self.device),Y.to(self.device),self.theta.to(self.device))

    def sliced_kernel_bures(self,X,Y,theta=None):
        '''
        '''
        N,dn = X.shape
        M,dm = Y.shape
        dp = self.dp
        assert dn==dm
        if self.weights is None:
            self.weights = torch.randn((dn,dp),device=self.device,requires_grad=False)/self.sigma
            self.centers = torch.rand((dp),device=self.device,requires_grad=False)*2*3.14159274101

        if theta is None:
            theta=torch.randn((1,dp))
            theta=torch.stack([th/torch.sqrt((th**2).sum()) for th in theta])

        Xslices=self.get_slice(X,theta)
        Yslices=self.get_slice(Y,theta)
        return torch.square(torch.sqrt(torch.mean(Xslices**2))-torch.sqrt(torch.mean(Yslices**2)))

    def max_sliced_kernel_bures(self,X,Y,iterations=50,lr=1e-4):
        N,dn = X.shape
        M,dm = Y.shape
        device = self.device
        assert dn==dm
        theta=torch.randn((1,self.dp),device=device,requires_grad=True)
        theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        self.theta=theta
        optimizer=optim.Adam([self.theta],lr=lr)
        total_loss=np.zeros((iterations,))
        for i in range(iterations):
            optimizer.zero_grad()
            loss=-self.sliced_kernel_bures(X.to(self.device),Y.to(self.device),self.theta.to(self.device))
            total_loss[i]=loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            self.theta.data/=torch.sqrt(torch.sum(self.theta.data**2))
#             print(loss.item())
#         plt.plot(-total_loss)
#         plt.show()
#         print(theta.data)
        return self.sliced_bures(X.to(self.device),Y.to(self.device),self.theta.to(self.device))


    def sliced_bures_weighted(self,X,Y,beta,theta=None):
        '''
        Calculates the weighted sliced Bures distance between two empirical distributions.
        Note that the number of samples is assumed to be equal
        (This is however not necessary and could be easily extended
        for empirical distributions with different number of samples)
        '''
        N,dn = X.shape
        M,dm = Y.shape
        assert dn==dm
        if theta is None:
            theta=self.random_slice(dn)

        Xslices=self.get_slice(X,theta)
        Yslices=self.get_slice(Y,theta)
        nu = self.simplex_norm(beta)
        epsilon = 0.01
        return torch.square(torch.sqrt(epsilon+torch.mean(Xslices**2))-torch.sqrt(epsilon+torch.matmul(nu.t(),Yslices**2)))

    def max_sliced_bures_weighted(self,X,Y,beta,iterations=50,lr=1e-4):
        N,dn = X.shape
        M,dm = Y.shape
        device = self.device
        assert dn==dm
        if self.ftype=='linear':
            theta=torch.randn((1,dn),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        elif self.ftype=='poly':
            dpoly=self.homopoly(dn,self.degree)
            theta=torch.randn((1,dpoly),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        elif self.ftype=='circular':
            theta=torch.randn((1,dn),device=device,requires_grad=True)
            theta.data/=torch.sqrt(torch.sum((theta.data)**2))
            theta.data*=self.radius
        self.theta=theta
        optimizer=optim.Adam([self.theta],lr=lr)
        total_loss=np.zeros((iterations,))
        for i in range(iterations):
            optimizer.zero_grad()
            loss=-self.sliced_bures_weighted(X.to(self.device),Y.to(self.device),beta.to(self.device),self.theta.to(self.device))
            total_loss[i]=loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            self.theta.data/=torch.sqrt(torch.sum(self.theta.data**2))
        return self.sliced_bures_weighted(X.to(self.device),Y.to(self.device),beta.to(self.device),self.theta.to(self.device))
    
    def sliced_kernel_bures_weighted(self,X,Y,beta,theta=None):
        '''
        '''
        N,dn = X.shape
        M,dm = Y.shape
        dp = self.dp
        assert dn==dm
        if self.weights is None:
            self.weights = torch.randn((dn,dp),device=self.device,requires_grad=False)/self.sigma
            self.centers = torch.rand((dp),device=self.device,requires_grad=False)*2*3.14159274101

        if theta is None:
            theta=torch.randn((1,dp))
            theta=torch.stack([th/torch.sqrt((th**2).sum()) for th in theta])

        Xslices=self.get_slice(X,theta)
        Yslices=self.get_slice(Y,theta)
        nu = self.simplex_norm(beta)
        epsilon = 0.01
        return torch.square(torch.sqrt(epsilon+torch.mean(Xslices**2))-torch.sqrt(epsilon+torch.matmul(nu.t(),Yslices**2)))


    def max_sliced_kernel_bures_weighted(self,X,Y,beta,iterations=50,lr=1e-4):
        N,dn = X.shape
        M,dm = Y.shape
        device = self.device
        assert dn==dm
        theta=torch.randn((1,self.dp),device=device,requires_grad=True)
        theta.data/=torch.sqrt(torch.sum((theta.data)**2))
        self.theta=theta
        optimizer=optim.Adam([self.theta],lr=lr)
        total_loss=np.zeros((iterations,))
        for i in range(iterations):
            optimizer.zero_grad()
            loss=-self.sliced_kernel_bures_weighted(X.to(self.device),Y.to(self.device),beta.to(self.device),self.theta.to(self.device))
            total_loss[i]=loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            self.theta.data/=torch.sqrt(torch.sum(self.theta.data**2))
        return self.sliced_bures_weighted(X.to(self.device),Y.to(self.device),beta.to(self.device),self.theta.to(self.device))
        
    