import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import torch.nn.functional as F
import math
import copy


#################################################################################################
# Class definition for layer-loss hyperparameter
class nmseargsstructCNN_LC:
    def __init__(self, field1,field2,field3,field4,field5,field6,field7,field8,field9,field10,field11,field12,field13):
        self.layer_dim=field1
        self.out_dim=field2
        self.layerinp_dim=field3
        self.la_R=field4
        self.R_ini=field5
        self.R_eps_weight=field6
        self.input_channel=field7
        self.output_channel=field8
        self.kernel_size=field9
        self.batch_size=field10
        self.stride=field11
        self.Reh_gain=field12
        self.method=field13
        
class nmseargsstruct:
    def __init__(self, field1,field2,field3,field4,field5,field6,field7,field8,field9):
        self.layer_dim=field1
        self.out_dim=field2
        self.layerinp_dim=field3
        self.la_R=field4
        self.R_ini=field5
        self.R_eps_weight=field6
        self.Reh_gain=field7
        self.la_R2=field8
        self.method=field9
    
    
class LossFullyConnected(nn.Module):
    def __init__(self, args):
        super(LossFullyConnected, self).__init__()
        self.Rehk = torch.zeros((args.out_dim,args.layer_dim), dtype=torch.float64, device='cuda', requires_grad=False)
        self.Rhk = args.R_ini*torch.eye(args.layer_dim , dtype=torch.float64, device='cuda', requires_grad=False)
        self.new_Rhk = torch.zeros((args.layer_dim, args.layer_dim), dtype=torch.float64, device='cuda', requires_grad=False)
        self.new_Rehk = torch.zeros(( args.out_dim,args.layer_dim ), dtype=torch.float64, device='cuda', requires_grad=False)     
        nn.init.normal_(self.Rehk, mean=0.0, std=args.Reh_gain)
        self.Rehk0 = copy.deepcopy(self.Rehk)
        self.la_R = args.la_R
        self.la_R2 = args.la_R2
        self.R_eps_weight = args.R_eps_weight
        self.R_eps = self.R_eps_weight*torch.eye(args.layer_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        self.method = args.method

    def forward(self, hk: torch.Tensor, uk:torch.Tensor, er:torch.Tensor,hkm1:torch.Tensor,layer:torch.Tensor):
        
        # Covariance Update Convex Combination
        la_R = self.la_R
        la_R2 = self.la_R2
        # Transpose of hk: contains hk vectors in its caolumns
        hkT=hk.T
        # Current Layer Output Dimensions: B is batch size,  D is the dimension of the current layer
        B, D = hk.size()
        # all ones vectors with previous layer dimensions
        Bhkm1,Dhkm1 = hkm1.size()
        #oneshkm1=torch.ones(Bhkm1, 1, dtype=torch.float64, device='cuda', requires_grad=False)
        
        if(self.method == "ebd" or self.method == "dfa2"):
            # Current layer covariance update
            Rhk_update = (hk.T @ hk) / B
            self.new_Rhk = la_R2*(self.Rhk) + (1-la_R2)*(Rhk_update)
            del Rhk_update
        
        if(self.method == "ebd"):
            # Error vs current layer  cross covariance update
            Rehk_update = (er.T @ hk) / B #/ math.sqrt(D)
            self.new_Rehk = la_R*(self.Rehk) + (1-la_R)*(Rehk_update)
            del Rehk_update
                
        # Transformed-centralized error
        if (layer>1.0):
            Q=er
        else:
            if(self.method == "dfa1" or self.method == "dfa2"):
                Q=er@self.Rehk0
            else:
                Q=er@self.new_Rehk

        subhk=(torch.sign(hk))
        Fd=(uk>0) # assuming bounded relu
        if (layer>1.0):
            Fd=torch.ones(Fd.shape,dtype=torch.float64, device='cuda', requires_grad=False)
        # Transformed error scaled by the preactivation derivative and nonlinearity derivative
        Z=Fd*Q

        # Gradient for the broadcast-error decorrelation loss
        gradWmse=Z.T@hkm1/B
        gradbmse = 0
        #gradbmse=Z.T@oneshkm1/B
        
        # preactivation scaled
        #Fdsubhk=Fd*subhk
        # Gradients for the layer entropy
        if(self.method == "ebd" or self.method == "dfa2"):
            hkRhkinvFd=Fd*torch.linalg.solve(self.new_Rhk+self.R_eps, hk+1e-8, left=False)
            gradWcov=-2*hkRhkinvFd.T@hkm1/B/D
        else:
            gradWcov = 0
        gradbcov = 0
        #gradbcov=-2*hkRhkinvFd.T@oneshkm1/B/D
        
        # Subgradient of layer sparsifying l_1-loss
        if(self.method == "ebd"):
            gradWl1out = subhk.T@hkm1/B/D
        else:
            gradWl1out = 0
        gradbl1out = 0
        #Fdsubhkd=subhk.double()
        #gradbl1out=Fdsubhkd.T@oneshkm1/B/D
        
        # update the first and second order statistics
        self.Rhk = self.new_Rhk.detach()
        self.Rehk = self.new_Rehk.detach()
        
        #calculate losses
        NMSEloss=torch.norm(self.Rehk, p='fro')
        Cov_loss = - (torch.logdet(self.new_Rhk + self.R_eps) ) / D
        angle = 180*torch.acos(torch.sum(self.Rehk * self.Rehk0) / (torch.norm(self.Rehk, p='fro') * torch.norm(self.Rehk0, p='fro')))/math.pi 
        return NMSEloss,Cov_loss,self.Rehk,self.Rhk,gradWmse,gradbmse, gradWcov,gradbcov,gradWl1out,gradbl1out,angle
    
    
class LossLocallyConnected(nn.Module):
    def __init__(self, args):
        super(LossLocallyConnected, self).__init__()
        self.ReHk = torch.zeros((args.out_dim,args.output_channel,args.layer_dim[0],args.layer_dim[1]), dtype=torch.float64, device='cuda', requires_grad=False)
        self.RHk = args.R_ini*torch.eye(args.layer_dim[0]*args.layer_dim[1], dtype=torch.float64, device='cuda', requires_grad=False).unsqueeze(0).expand(args.output_channel,-1,-1)
        self.new_ReHk = torch.zeros((args.out_dim,args.output_channel,args.layer_dim[0],args.layer_dim[1]), dtype=torch.float64, device='cuda', requires_grad=False) 
        nn.init.xavier_uniform_(self.ReHk, gain=args.Reh_gain)
        self.ReHk0 = copy.deepcopy(self.ReHk)
        self.la_R = args.la_R
        self.batch_size = args.batch_size
        self.kernel_size = args.kernel_size
        self.eyemat = identity_matrix = torch.eye(args.output_channel,dtype=torch.float64, device='cuda', requires_grad=False) * 1e-5
        self.stride = args.stride
        self.method = args.method

    def forward(self, Wk: torch.Tensor, Hk: torch.Tensor, Uk:torch.Tensor, er:torch.Tensor,Hkm1:torch.Tensor,poolmask:torch.Tensor) -> torch.Tensor:
        # Covariance Update Convex Combination
        la_R = self.la_R
        # Batch size
        B = self.batch_size
        kk = self.kernel_size # k
        ss = self.stride
        _,PO,PI,Mf,Nf,Wf2=Wk.shape
        
        if(self.method == "ebd"):
            Rehk_update = torch.einsum('bq,bpmn->qpmn', er, Hk) / B
            self.new_ReHk = la_R*(self.ReHk) + (1-la_R)*(Rehk_update) # q,p,m,n
            del Rehk_update
        ##########################################################
        dfU = (Uk>0) # b,p,m,n
        
        if(self.method == "dfa1" or self.method == "dfa2"):  
            e_ReH_dfU = torch.einsum('bpmn,qpmn,bq->pbmn', dfU, self.ReHk0, er) # p,b,m,n
        else:
            e_ReH_dfU = torch.einsum('bpmn,qpmn,bq->pbmn', dfU, self.new_ReHk, er) # p,b,m,n
            
        padding = int((kk-1)/2)
        self.padding_dims = (padding, padding, padding, padding)
        Hkm1_pad = F.pad(Hkm1, self.padding_dims, mode='constant', value=0)
        
        
        b, c, h, w = Hkm1_pad.size()
        x = Hkm1_pad.unfold(2, kk, ss).unfold(3, kk, ss)
        dW = torch.einsum('bimnw,obmn->oimnw', x.reshape(b,c,x.shape[2],x.shape[3],-1), e_ReH_dfU)/B
        db = 0
        del Hkm1_pad, e_ReH_dfU
        
        hk = Hk.reshape(Hk.size(0), -1)
        
        if(self.method == "ebd"):
            subghk=(torch.diag(1/(torch.norm(hk,2,dim=1)**3+1e-10)))@(((torch.diag(torch.norm(hk,2,dim=1)))**2)@(torch.sign(hk))-torch.diag(torch.norm(hk,1,dim=1))@hk)
            Hkd=subghk.view(B,PO,Mf,Nf)
            dWl1 = torch.einsum('bimnw,bomn->oimnw', x.reshape(b,c,x.shape[2],x.shape[3],-1), Hkd)/B
            del subghk
        else:
            dWl1 = 0
        dbl1 = 0
        del x, hk
        
        if(self.method == "ebd" or self.method == "dfa2"):
            Wkm=Wk.view(PO,PI*Mf*Nf*Wf2)
            RW=Wkm@Wkm.T
            Idm=1e-6*torch.eye(RW.shape[0],dtype=torch.float64, device='cuda', requires_grad=False)
            RWi=torch.inverse(RW+Idm)
            dW_covm=-RWi@Wkm
            dW_cov = dW_covm.view(1,PO,PI,Mf,Nf,Wf2) 
            COV_Loss=-torch.logdet(RW+Idm)
            del Wkm
        else:
            dW_cov = 0
            db_cov = 0
            COV_Loss = 0
            RW = torch.zeros((1,1,1))
        db_cov = 0
        
        self.ReHk = self.new_ReHk.detach()
        NMSELoss = torch.norm(self.ReHk, p='fro')
        angle = 180*torch.acos(torch.sum(self.ReHk * self.ReHk0) / (torch.norm(self.ReHk, p='fro') * torch.norm(self.ReHk0, p='fro')))/math.pi  
        return NMSELoss, COV_Loss, dW, db, dW_cov, db_cov, dWl1, dbl1, RW, angle