#Discriminator in PyTorch
import torch
import torch.nn as nn
import numpy as np

#custom layer
class GradRBFLayer(nn.Module):

    def __init__(self, num_centres, input_dim, output_dim, order_m, batch_size, device, const = 0, rbf_pow = None,  **kwargs):
        
        super(GradRBFLayer, self).__init__()                            
        self.device = device
        #self.device = "cuda:0" #for GPU 0
        #self.device = "cpu" #for CPU
                                                    
        self.m = order_m
        self.const = const
        self.input_dim = input_dim 
        self.output_dim = output_dim
        self.num_hidden = num_centres
        self.rbf_pow = rbf_pow

        self.n = self.input_dim[1]

        centres = torch.Tensor(size = (self.num_hidden, self.n))
        rbf_weights = torch.Tensor(size = (self.num_hidden,))

        self.centres = nn.Parameter(centres)
        self.rbf_weights = nn.Parameter(rbf_weights)

        #initialize weights

        #random uniform distribution for centres
        torch.nn.init.uniform_(self.centres)

        #ones for rbf_weights
        torch.nn.init.ones_(self.rbf_weights)

    def forward(self, X):
                             
        #X is the input
        X = torch.reshape(X, self.input_dim)

        X = torch.unsqueeze(X, 2)                                

        C = torch.unsqueeze(self.centres, 2).to(self.device)

        C = torch.unsqueeze(C,0)
        C_tiled = torch.tile(C, (X.size()[0], 1,1,1))

        X = torch.unsqueeze(X,1)
        X_tiled = torch.tile(X, (1, self.num_hidden,1,1))

        #for device debugging
        """
        print("X_tiled",X_tiled.device)
        print("C_tiled",C_tiled.device)

        exit(0)
        """
        '''
        print("X_tiled",X_tiled.shape)
        print("C_tiled",C_tiled.shape)
        exit(0)
        '''

        Tau = X_tiled - C_tiled

        if self.rbf_pow == None:
            order = (2*self.m) - self.n # order = k
        else:
            order = self.rbf_pow

        if order < 0:
            sign = -1.
        else:
            sign = 1.

        if order < 2:
            epsilon = 1e-1
        else:
            epsilon = 0.

        if self.n%2 == 1 or (self.n%2 == 0 and (order)<0) or (self.rbf_pow != None):
            
            norm_tau = torch.norm(Tau, p=2, dim = 2) 
            ord_tensor = (order-2)*torch.ones_like(norm_tau)
            Phi = sign*order*torch.pow(norm_tau + epsilon, ord_tensor)

            '''
            print(torch.max(Phi))
            print(torch.min(Phi))
            '''

        else:
            
            norm_tau = torch.norm(Tau, p=2, dim = 2)
            ord_tensor = (order-2)*torch.ones_like(norm_tau)
            log_term = order*torch.log(norm_tau+10.0**(-5)) + 1
            Phi = sign*torch.multiply(torch.pow(norm_tau + epsilon, ord_tensor),log_term)

        RepPhi = torch.unsqueeze(Phi, axis = 2)

        RepPhi = torch.tile(RepPhi, (1,1,self.n,1))

        GradD = torch.multiply(Tau,RepPhi).float()

        '''
        print(torch.max(GradD))
        print(torch.min(GradD))
        '''

        W = torch.unsqueeze(self.rbf_weights, axis = 1).float().to(self.device)

        WGradD =  torch.squeeze(torch.einsum('bNno,No->bno',GradD,W))

        torch.cuda.empty_cache()

        return WGradD
    

    #TODO. 
    def compute_output_shape(self):
        return (n, self.output_dim)

    #TODO. Look into nn containers.
    def get_config(self):

        config = {
            'output_dim': self.output_dim
        }

#model that uses the custom layer
class discriminator_model_RBF(nn.Module):

    def __init__(self, num_centres, input_dim, device):

        super(discriminator_model_RBF, self).__init__()

        self.num_centres = num_centres
        self.input_dim = input_dim #excluding batch_size
        self.input_dim_layer = (num_centres, np.prod(self.input_dim)) #including batch_size. Flattened.

        #initialize the custom layer
        self.rbf_real = GradRBFLayer(num_centres = self.num_centres, input_dim = self.input_dim_layer, output_dim = 1, order_m = 0, batch_size = num_centres, device = device, rbf_pow = 1)
        self.rbf_fake = GradRBFLayer(num_centres = self.num_centres, input_dim = self.input_dim_layer, output_dim = 1, order_m = 0, batch_size = num_centres, device = device, rbf_pow = 1)

    def set_cw(self, cw):

        #unpack the list. Target(centre, weight) and Generated(centre, weight)
        C_d, D_d, C_g, D_g = cw 

        with torch.no_grad():
            #rbf_real
            self.rbf_real.centres = nn.Parameter(C_d)
            self.rbf_real.rbf_weights = nn.Parameter(D_d)

            #rbf_fake
            self.rbf_fake.centres = nn.Parameter(C_g)
            self.rbf_fake.rbf_weights = nn.Parameter(D_g)

            torch.cuda.empty_cache()

    def forward(self, x):

         with torch.no_grad():
        
            real_grad = self.rbf_real(x)
            fake_grad = self.rbf_fake(x)

            torch.cuda.empty_cache()

            return real_grad, fake_grad 
            
def find_rbf_centres_weights(target_data, generator_data, N_centres):

    C_d = target_data[0:N_centres] 
    C_g = generator_data[0:N_centres]

    d_size = torch.prod(torch.tensor(C_d.size()[1:]))
    g_size = torch.prod(torch.tensor(C_g.size()[1:]))

    C_d = torch.reshape(C_d, [C_d.shape[0], d_size]) #target data centres
    C_g = torch.reshape(C_g, [C_g.shape[0], g_size]) #generated data centres

    #passing as numpy array
    #D_d = (1/C_d.shape[0])*np.ones([C_d.shape[0]]) #target data weights
    #D_g = (1/(C_g.shape[0]))*np.ones([C_g.shape[0]]) #generated data weights

    #passing as torch tensor
    D_d = (1/C_d.shape[0])*torch.ones(C_d.shape[0]) #target data weights
    D_g = (1/(C_g.shape[0]))*torch.ones(C_g.shape[0]) #generated data weights

    torch.cuda.empty_cache()
    
    return C_d, C_g, D_d, D_g 

################################################ CHECKS for DIMENSIONS ################################################
'''
x1 = torch.rand(128,32, 32, 3).to("cuda:0") #assume target data
x2 = torch.rand(128,32, 32, 3).to("cuda:0") #assume prev data
x3 = torch.rand(128,32, 32, 3).to("cuda:0") #assume curr data

#step:1 -> instantiate the model
#the model in-turn calls the custom layer.
discriminator_rbf = discriminator_model_RBF(128, (3072,), 'cuda:0')

#step:2 -> find centers and weights
C_d, C_g, D_d, D_g = find_rbf_centres_weights(x1, x2, 1280) #find centre and weights of RBF x_prev and real_batch
#print("shapes of centres and weights: ",C_d.size(), D_d.size(), C_g.size(), D_g.size())
#print(D_d)

#step:3 -> manually set the centres and weights
discriminator_rbf.set_cw([C_d, D_d,C_g, D_g]) #set centre and weight of RBF
#print(discriminator_rbf.rbf_real.rbf_weights)

#step:4 -> obtain the gradients
real_grad, fake_grad = discriminator_rbf(x3) 
grad_disc = (fake_grad - real_grad)
grad_disc = torch.reshape(grad_disc, (128, 32, 32, 3))
#print(grad_disc)
'''
################################################ CHECKS for DIMENSIONS ################################################
"""
#from plot import get_plot
device = "cpu"

# check if gradients are correct. Plots stored in grad_checks folder
x_r = torch.normal(5,1, size = (1280,2)).to(device) #real
x_f = torch.normal(-5,1, size = (1280,2)).to(device) #fake

discriminator_rbf = discriminator_model_RBF(1280, (2,))
C_d, C_g, D_d, D_g = find_rbf_centres_weights(x_r, x_f, 1280)
discriminator_rbf.set_cw([C_d, D_d,C_g, D_g])

get_plot(discriminator_rbf, x_r, x_f, "grad_checks/")
"""