import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import torch.nn.functional as F
import math
import copy


#################################################################################################
# Class definition for layer-loss hyperparameter
class nmseargsstructCNN_LC:
    def __init__(self, field1,field2,field3,field4,field5,field6,field7,field8,field9,field10,field11,field12,field13):
        self.layer_dim=field1
        self.out_dim=field2
        self.layerinp_dim=field3
        self.la_R=field4
        self.R_ini=field5
        self.R_eps_weight=field6
        self.input_channel=field7
        self.output_channel=field8
        self.kernel_size=field9
        self.batch_size=field10
        self.stride=field11
        self.Reh_gain=field12
        self.method=field13
        
class nmseargsstruct:
    def __init__(self, field1,field2,field3,field4,field5,field6,field7,field8,field9):
        self.layer_dim=field1
        self.out_dim=field2
        self.layerinp_dim=field3
        self.la_R=field4
        self.R_ini=field5
        self.R_eps_weight=field6
        self.Reh_gain=field7
        self.la_R2=field8
        self.method=field9
    
    
class LossFullyConnected(nn.Module):
    def __init__(self, args):
        super(LossFullyConnected, self).__init__()
        self.Rehk = torch.zeros((args.out_dim,args.layer_dim), dtype=torch.float64, device='cuda', requires_grad=False)
        self.Rhk = args.R_ini*torch.eye(args.layer_dim , dtype=torch.float64, device='cuda', requires_grad=False)
        self.new_Rhk = torch.zeros((args.layer_dim, args.layer_dim), dtype=torch.float64, device='cuda', requires_grad=False)
        self.new_Rehk = torch.zeros(( args.out_dim,args.layer_dim ), dtype=torch.float64, device='cuda', requires_grad=False)     
        nn.init.normal_(self.Rehk, mean=0.0, std=args.Reh_gain)
        self.Rehk0 = copy.deepcopy(self.Rehk)
        self.la_R = args.la_R
        self.la_R2 = args.la_R2
        self.R_eps_weight = args.R_eps_weight
        self.R_eps = self.R_eps_weight*torch.eye(args.layer_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        self.method = args.method

    def forward(self, hk: torch.Tensor, uk:torch.Tensor, er:torch.Tensor,hkm1:torch.Tensor,layer:torch.Tensor):
        
        # Covariance Update Convex Combination
        la_R = self.la_R
        la_R2 = self.la_R2
        # Transpose of hk: contains hk vectors in its caolumns
        hkT=hk.T
        # Current Layer Output Dimensions: B is batch size,  D is the dimension of the current layer
        B, D = hk.size()
        # all ones vectors with previous layer dimensions
        Bhkm1,Dhkm1 = hkm1.size()
        #oneshkm1=torch.ones(Bhkm1, 1, dtype=torch.float64, device='cuda', requires_grad=False)
        
        if(self.method == "ebd" or self.method == "dfa2"):
            # Current layer covariance update
            Rhk_update = (hk.T @ hk) / B
            self.new_Rhk = la_R2*(self.Rhk) + (1-la_R2)*(Rhk_update)
            del Rhk_update
        
        if(self.method == "ebd"):
            # Error vs current layer  cross covariance update
            Rehk_update = (er.T @ hk) / B #/ math.sqrt(D)
            self.new_Rehk = la_R*(self.Rehk) + (1-la_R)*(Rehk_update)
            del Rehk_update
                
        # Transformed-centralized error
        if (layer>1.0):
            Q=er
        else:
            if(self.method == "dfa1" or self.method == "dfa2"):
                Q=er@self.Rehk0
            else:
                Q=er@self.new_Rehk

        subhk=(torch.sign(hk))
        Fd=(uk>0) # assuming bounded relu
        if (layer>1.0):
            Fd=torch.ones(Fd.shape,dtype=torch.float64, device='cuda', requires_grad=False)
        # Transformed error scaled by the preactivation derivative and nonlinearity derivative
        Z=Fd*Q

        # Gradient for the broadcast-error decorrelation loss
        gradWmse=Z.T@hkm1/B
        gradbmse = 0
        #gradbmse=Z.T@oneshkm1/B
        
        # preactivation scaled
        #Fdsubhk=Fd*subhk
        # Gradients for the layer entropy
        if(self.method == "ebd" or self.method == "dfa2"):
            hkRhkinvFd=Fd*torch.linalg.solve(self.new_Rhk+self.R_eps, hk+1e-8, left=False)
            gradWcov=-2*hkRhkinvFd.T@hkm1/B/D
        else:
            gradWcov = 0
        gradbcov = 0
        #gradbcov=-2*hkRhkinvFd.T@oneshkm1/B/D
        
        # Subgradient of layer sparsifying l_1-loss
        if(self.method == "ebd"):
            gradWl1out = subhk.T@hkm1/B/D
        else:
            gradWl1out = 0
        gradbl1out = 0
        #Fdsubhkd=subhk.double()
        #gradbl1out=Fdsubhkd.T@oneshkm1/B/D
        
        # update the first and second order statistics
        self.Rhk = self.new_Rhk.detach()
        self.Rehk = self.new_Rehk.detach()
        
        #calculate losses
        NMSEloss=torch.norm(self.Rehk, p='fro')
        Cov_loss = - (torch.logdet(self.new_Rhk + self.R_eps) ) / D
        angle = 180*torch.acos(torch.sum(self.Rehk * self.Rehk0) / (torch.norm(self.Rehk, p='fro') * torch.norm(self.Rehk0, p='fro')))/math.pi 
        return NMSEloss,Cov_loss,self.Rehk,self.Rhk,gradWmse,gradbmse, gradWcov,gradbcov,gradWl1out,gradbl1out,angle
    
    
class LossConvolutive(nn.Module):
    def __init__(self, args):
        super(LossConvolutive, self).__init__()
        self.ReHk = torch.randn((args.out_dim,args.output_channel,args.layer_dim[0],args.layer_dim[1]), dtype=torch.float64, device='cuda', requires_grad=False)*args.Reh_gain
        self.RHk = args.R_ini*torch.eye(args.layer_dim[0]*args.layer_dim[1], dtype=torch.float64, device='cuda', requires_grad=False).unsqueeze(0).expand(args.output_channel,-1,-1)
        self.new_ReHk = torch.zeros((args.out_dim,args.output_channel,args.layer_dim[0],args.layer_dim[1]), dtype=torch.float64, device='cuda', requires_grad=False)
        self.ReHk0 = copy.deepcopy(self.ReHk)
        self.la_R = args.la_R
        self.batch_size = args.batch_size
        self.out_dim = args.out_dim
        self.kernel_size = args.kernel_size
        self.eyemat = 1e-5 * torch.eye(args.output_channel,dtype=torch.float64, device='cuda', requires_grad=False)
        self.stride = args.stride
        self.method = args.method

    def forward(self,Wk: torch.Tensor, Hk: torch.Tensor, Uk:torch.Tensor, er:torch.Tensor,Hkm1:torch.Tensor,poolmask:torch.Tensor) -> torch.Tensor:
        # Covariance Update Convex Combination
        la_R = self.la_R
        # Batch size
        B = self.batch_size
        kk = self.kernel_size # k
        PO,PI,Mf,Nf=Wk.shape
        st = self.stride
        pad = int((kk-1)/2)
        
        if(self.method == "ebd"):
            Rehk_update = torch.einsum('bq,bpmn->qpmn', er, Hk) / B # p,b,m,n
            self.new_ReHk = la_R*(self.ReHk) + (1-la_R)*(Rehk_update) # q,p,m,n
            del Rehk_update
        ##########################################################
        dfU = (Uk>0)
        
        if(self.method == "dfa1" or self.method == "dfa2"):  
            e_ReH_dfU = torch.einsum('bpmn,qpmn,bq->pbmn', dfU, self.ReHk0, er) # p,b,m,n
        else:
            e_ReH_dfU = torch.einsum('bpmn,qpmn,bq->pbmn', dfU, self.new_ReHk, er) # p,b,m,n

        (bH,pH,mH,nH)=Hk.shape
        hk = Hk.view(B, -1)
        
        if(self.method == "ebd"):
            subghk = (torch.diag(1/(torch.norm(hk,2,dim=1)**3+1e-10)))@(((torch.diag(torch.norm(hk,2,dim=1)))**2)@(torch.sign(hk))-torch.diag(torch.norm(hk,1,dim=1))@hk)
            Hkd = subghk.view(bH,pH,mH,nH)
            Hkdp = Hkd.permute(1,0,2,3)
            del subghk
        del hk
        ##########################################################
        if st>1:
            e_ReH_dfU_expand = torch.zeros((e_ReH_dfU.shape[0],e_ReH_dfU.shape[1],Hkm1.shape[2]+2*pad-kk+1,Hkm1.shape[3]+2*pad-kk+1), dtype=torch.float64, device='cuda', requires_grad=False)
            e_ReH_dfU_expand[:,:,0::st,0::st] = e_ReH_dfU
            e_ReH_dfU = e_ReH_dfU_expand
            
            if(self.method == "ebd"):
                Hkdp_expand = torch.zeros((Hkdp.shape[0],Hkdp.shape[1],Hkm1.shape[2]+2*pad-kk+1,Hkm1.shape[3]+2*pad-kk+1), dtype=torch.float64, device='cuda', requires_grad=False)
                Hkdp_expand[:,:,0::st,0::st] = Hkdp
                Hkdp = Hkdp_expand
        ##########################################################
                    
        dW = F.conv2d(Hkm1.permute(1,0,2,3), e_ReH_dfU, bias=None, stride=1, padding=pad)/B
        #db = torch.sum(e_ReH_dfU, dim=(1,2,3))/B
        db = 0
        
        if(self.method == "ebd"):
            dW0l1=F.conv2d(Hkm1.permute(1,0,2,3), Hkdp, bias=None, stride=1, padding=pad)/B
            #db0l1 = torch.sum(Hkd, dim=(0,2,3))/B
        else:
            dW0l1 = 0
        db0l1 = 0
        
        if(self.method == "ebd" or self.method == "dfa2"):
            Wkm=Wk.view(PO,PI*Mf*Nf)
            if (PO>PI*Mf*Nf):
                RW=Wkm.T@Wkm
                Idm=1e-5*torch.eye(RW.shape[0],dtype=torch.float64, device='cuda', requires_grad=False)
                RWi=torch.inverse(RW+Idm)
                dW_covm=-Wkm@RWi
                dW_cov = dW_covm.view(PO,PI,Mf,Nf)
                db_cov=db*0
            else:
                RW=Wkm@Wkm.T
                Idm=1e-5*torch.eye(RW.shape[0],dtype=torch.float64, device='cuda', requires_grad=False)
                RWi=torch.inverse(RW+Idm)
                dW_covm=-RWi@Wkm
                dW_cov = dW_covm.view(PO,PI,Mf,Nf) 
                db_cov=db*0
            COV_Loss=-torch.logdet(RW+Idm)
        else:
            dW_cov= 0
            db_cov = 0
            COV_Loss = 0
            RW = torch.zeros((1,1,1))
            
        self.ReHk = self.new_ReHk.detach()
        NMSE_Loss = torch.norm(self.ReHk, p='fro')/torch.norm(Hk, p='fro')
        angle = 180*torch.acos(torch.sum(self.ReHk * self.ReHk0) / (torch.norm(self.ReHk, p='fro') * torch.norm(self.ReHk0, p='fro')))/math.pi   
        return NMSE_Loss, COV_Loss, dW, db, dW_cov, db_cov, dW0l1, db0l1, RW, angle