import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torchvision
import torch.nn.functional as F

import os
from datetime import datetime
import time
import math

from itertools import repeat
from torch.nn.parameter import Parameter
import collections
from torch_utilsn import *
import matplotlib


class EBDCorInfoMaxHopfield():

    def __init__(self, architecture, lambda_, lambda_eb_,epsilon, lr_fr, lr_er, lr_er2=[0,0,0], lr_weight_l1=[0,0,0], lr_weight_fb_l1=[0,0,0],lr_weight_bb_l1=[0,0,0], lr_weight_l2=[0,0,0], lr_weight_fb_l2=[0,0,0], lr_weight_bb_l2=[0,0,0], include_forw=1,include_back=0, use_preact=0, subt_mean=0, br_update_lateral=0, activation = hard_sigmoid, act_l1_lr_ff=[0,0,0],act_l1_lr_fb=[0,0,0],momentum_ff=0, momentum_fb=0, lateral_init_scale=1.0, lateral_scale=[1.0,1.0,1.0], Wff_init_scale=[1.0,1.0,1.0], act_pow_lr_ff=[0.0,0.0,0.0], layer_pow_target=[0.1,0.1,0.1], non_fun_args=[1,0], activation_power=1.0, device = None):
        
        self.architecture = architecture
        self.lambda_ = lambda_
        self.lambda_eb_=lambda_eb_
        self.gam_ = (1 - lambda_) / lambda_
        self.epsilon = epsilon
        self.one_over_epsilon = 1 / epsilon
        self.lr_er=lr_er
        self.lr_er2=lr_er2
        self.lr_fr=lr_fr
        self.include_forw=include_forw
        self.include_back=include_back
        self.use_preact=use_preact
        self.subt_mean=subt_mean
        self.br_update_lateral=br_update_lateral
        self.act_l1_lr_ff=act_l1_lr_ff
        self.act_l1_lr_fb=act_l1_lr_fb
        self.momentum_ff=momentum_ff
        self.momentum_fb=momentum_fb
        self.lateral_scale=lateral_scale
        self.act_pow_lr_ff=act_pow_lr_ff
        self.layer_pow_target=layer_pow_target
        self.args_w=non_fun_args[0]#*torch.ones((1,1), requires_grad = False).to(self.device)
        self.args_ph=non_fun_args[1]#*torch.ones((1,1), requires_grad = False).to(self.device)
        self.non_fun_der_scale=5.0
        self.activation_power=activation_power=1.0
        self.lr_weight_l1=lr_weight_l1
        self.lr_weight_fb_l1=lr_weight_fb_l1
        self.lr_weight_bb_l1=lr_weight_bb_l1
        self.lr_weight_l2=lr_weight_l2
        self.lr_weight_fb_l2=lr_weight_fb_l2
        self.lr_weight_bb_l2=lr_weight_bb_l2
        
        
        self.activation = activation
        if device is None:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        
        # Feedforward Synapses Initialization
        Wff = []
        Pow = []
        Vff = []
        layer_mu=[]
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx + 1], architecture[idx], requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)
            Vm = torch.zeros((architecture[idx + 1], architecture[idx]), requires_grad = False).to(self.device)
            powv=torch.zeros((architecture[idx + 1], 1), requires_grad = False).to(self.device)
            
            weight=Wff_init_scale[idx]*weight
            Wff.append({'weight': weight})
            Vff.append({'weight': Vm})
            Pow.append(powv)
            
            
        Wff = np.array(Wff)
        Vff = np.array(Vff)
    

        for idx in range(len(architecture)):
            lmu=torch.zeros((architecture[idx],1),requires_grad = False).to(self.device)
            layer_mu.append(lmu)
            
        
        # Feedback Synapses Initialization
        Wfb = []
        Vfb = []
        for idx in range(len(architecture)-1):
            weight = torch.eye(architecture[idx], architecture[idx + 1], requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)
            Vm = torch.zeros((architecture[idx], architecture[idx+1]), requires_grad = False).to(self.device)

            Wfb.append({'weight': weight})
            Vfb.append({'weight': Vm})
            
        Wfb = np.array(Wfb)
        Vfb = np.array(Vfb)
        
        # Lateral Synapses Initialization
        B = []
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)
            weight = lateral_init_scale* weight @ weight.T
            # weight = 1.0*torch.eye(architecture[idx + 1] + 1, architecture[idx + 1] + 1, requires_grad = False).to(self.device)
            B.append({'weight': weight})
        B = np.array(B)


        # Error Broadcast Synapses Initialization
                # Feedforward Synapses Initialization
        Web = []
        for idx in range(len(architecture)-1):
            weight = torch.zeros((architecture[idx + 1], architecture[-1]), requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)

            Web.append({'weight': weight})
        Web = np.array(Web)

        Web2 = []
        for idx in range(len(architecture)-1):
            weight = torch.zeros((architecture[idx + 1], architecture[-1]), requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)

            Web2.append({'weight': weight})
        Web2 = np.array(Web2)

        # Correlation Matrices (Only for debugging)
        Rfree = []
        for idx in range(len(architecture) - 1):
            weight = 1.0*torch.eye(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(self.device)
            Rfree.append({'weight': weight})

        Rfree = np.array(Rfree)

        # Correlation Matrices (Only for debugging)
        Rnudged = []
        for idx in range(len(architecture) - 1):
            weight = 1.0*torch.eye(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(self.device)
            Rnudged.append({'weight': weight})

        Rnudged = np.array(Rnudged)
        err_out_mu=torch.zeros((architecture[-1],1), requires_grad=False).to(self.device)
        self.Wff = Wff
        self.Vff = Vff
        self.Wfb = Wfb
        self.Vfb = Vfb
        self.B = B
        self.Pow=Pow
        self.Web = Web
        self.Web2 = Web2
        self.Rfree = Rfree
        self.Rnudged = Rnudged
        self.layer_mu=layer_mu
        self.err_out_mu=err_out_mu
        
        ############ Some Debugging Logs ##########################
        self.forward_backward_angles = []
        self.layerwise_forward_corinfo_list_free = []
        self.layerwise_backward_corinfo_list_free = []
        self.layerwise_forward_corinfo_list_nudged = []
        self.layerwise_backward_corinfo_list_nudged = []

        self.neural_dynamics_free_forward_info_list = []
        self.neural_dynamics_free_backward_info_list = []
        self.neural_dynamics_nudged_forward_info_list = []
        self.neural_dynamics_nudged_backward_info_list = []

    ###############################################################
    ############### HELPER METHODS ################################
    ###############################################################
    def activation2(self,neurons):
        out=F.hardtanh(torch.pow(F.relu(neurons),self.activation_power)/50.0)*50.0
        return out
    def activation_der(self,neurons,check_out):
        if (check_out==True):
            ader=torch.ones_like(neurons)
        else:
            nn=torch.pow(neurons,self.activation_power)
            ader=(self.activation_power)*(neurons>0.0)*(nn<=50.0)*torch.pow(neurons,(self.activation_power-1))
            
        return ader
    def copy_neurons(self, neurons):
        copy = []
        for n in neurons:
            copy.append(torch.empty_like(n).copy_(n.data))#.requires_grad_())
        return copy

    def non_fun(self,inp):
        out=torch.sin(self.args_w*inp+self.args_ph)
        #out = inp * inp 
        #out=F.hardtanh(F.relu(self.linear_layers[-1](x))/50.0)*50.0
        return out

    def non_fun_der(self,inp):
        #out=(inp>0)*(inp<50.0)*1.0 
        out=self.args_w*torch.cos(self.args_w*inp+self.args_ph)/self.non_fun_der_scale
        #out = 2* inp 
        return out
        
        
    def init_neurons(self, mbs, random_initialize = False, device = 'cuda'):
        # Initializing the neurons
        if random_initialize:
            neurons = []
            append = neurons.append
            for size in self.architecture[1:]:  
                append(torch.randn((mbs, size), requires_grad=False, device=device).T)       
        else:
            neurons = []
            append = neurons.append
            for size in self.architecture[1:]:  
                append(torch.zeros((mbs, size), requires_grad=False, device=device).T)
        return neurons

    ###############################################################
    ############### REQUIRED FUNCTIONS FOR DEBUGGING ##############
    ###############################################################
    def angle_between_two_matrices(self, A, B):
        """Computes the angle between two matrices A and B.

        Args:
            A (torch.Tensor): Pytorch tensor of size m times n
            B (torch.Tensor): Pytorch tensor of size m times n

        Returns:
            angle: angle between the matrices A and B. The formula is given by the following:
                (180/pi) * acos[ Tr(A @ B.T) / sqrt(Tr(A @ A.T) * Tr(B @ B.T))] 
        """

        angle = (180 / torch.pi) * torch.acos(torch.trace(A @ B.T) / torch.sqrt(torch.trace(A @ A.T) * torch.trace(B @ B.T)))
        return angle

    def layerwise_forward_and_backward_correlative_information(self, layers, phase = "free"):
        Wff = self.Wff
        Wfb = self.Wfb
        if phase == "free":
            R = self.Rfree 
        elif phase == "nudged":
            R = self.Rnudged
        epsilon = self.epsilon
        one_over_epsilon = self.one_over_epsilon
        device = self.device
        architecture = self.architecture

        # epsilon_tensor = torch.Tensor([epsilon]).to(device)
        batch_size = layers[0].shape[1]
        batch_size_sqrt_root = np.sqrt(batch_size)
        log_epsilon = np.log(epsilon)

        forward_info_list = []
        backward_info_list = []

        for jj in range(len(architecture) - 2):
            Identity_Matrix = epsilon * torch.eye(*R[jj + 1]['weight'].shape).to(device)
            forward_info_jj= (torch.logdet(R[jj + 1]['weight'] + Identity_Matrix) - (1 / batch_size) * (one_over_epsilon * torch.norm(layers[jj + 2] - Wff[jj + 1]['weight'] @ layers[jj + 1]) ** 2 - layers[jj + 2].shape[0] * log_epsilon)).item()

            forward_info_list.append(forward_info_jj)

        for jj in range(len(architecture) - 2):
            Identity_Matrix = epsilon * torch.eye(*R[jj]['weight'].shape).to(device)
            backward_info_jj = (torch.logdet(R[jj]['weight'] + Identity_Matrix) - (1 / batch_size) * (one_over_epsilon * torch.norm((layers[jj + 1]) - Wfb[jj + 1]['weight'] @ layers[jj + 2]) ** 2 - (layers[jj + 1].shape[0] + 1) * log_epsilon)).item()

            backward_info_list.append(backward_info_jj)

            return forward_info_list, backward_info_list

    ###############################################################
    ############### NEURAL DYNAMICS ALGORITHMS ####################
    ###############################################################
    def run_neural_dynamics_hopfield(self, x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, 
                                     lr_rule = "constant", lr_decay_multiplier = 0.1, 
                                     neural_dynamic_iterations = 10, beta = 1, take_debug_logs = False):

        include_forw=self.include_forw
        include_back=self.include_back
        lateral_scale=self.lateral_scale
        layer_mu=self.layer_mu
        # if take_debug_logs:
        if beta != 0:
            phase = "free"
        else:
            phase = "nudged"
        forward_info = []
        backward_info = []
            
        Wff = self.Wff
        Wfb = self.Wfb
        B = self.B
        gam_ = self.gam_
        epsilon = self.epsilon
        one_over_epsilon = self.one_over_epsilon

        neurons_intermediate = self.copy_neurons(neurons)
        layers = [x] + neurons  # concatenate the input to other layers
        for iter_count in range(neural_dynamic_iterations):

            if lr_rule == "constant":
                neural_lr = neural_lr_start
            elif lr_rule == "divide_by_loop_index":
                neural_lr = max(neural_lr_start / (iter_count + 1), neural_lr_stop)
            elif lr_rule == "divide_by_slow_loop_index":
                neural_lr = max(neural_lr_start / (iter_count * lr_decay_multiplier + 1), neural_lr_stop)

            with torch.no_grad():       
                for jj in range(len(neurons)):
                    if jj == len(neurons) - 1:
                        
                        basal_voltage = include_forw*Wff[jj]['weight'] @ (layers[jj]-layer_mu[jj]) #+ Wff[jj]['bias']
                        apical_voltage = (lateral_scale[jj]*gam_ * B[jj]['weight'] @ ( layers[jj + 1]-layer_mu[jj+1]) + hopfield_g * layers[jj + 1]) - beta * (layers[jj + 1] - y)
                        gradient_neurons = -hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + (apical_voltage - neurons_intermediate[jj]) 
                        neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons
                        #neurons[jj] = self.activation(neurons_intermediate[jj])
                        neurons[jj] = (neurons_intermediate[jj])
                        
                    else:
                        
                        basal_voltage = include_forw*Wff[jj]['weight'] @ (layers[jj]-layer_mu[jj]) #+ Wff[jj]['bias']
                        apical_voltage = epsilon * (2 *lateral_scale[jj]* gam_ * B[jj]['weight'] @ (layers[jj + 1]-layer_mu[jj+1]) + hopfield_g * layers[jj + 1])+ include_back*(Wfb[jj + 1]['weight'] @ (layers[jj + 2]-layer_mu[jj+2])) 
                        gradient_neurons = - hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + one_over_epsilon * (apical_voltage - neurons_intermediate[jj])
                        neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons
                        neurons[jj] = self.activation2(neurons_intermediate[jj])
                    layers = [x] + neurons  # concatenate the input to other layers

            if take_debug_logs:
                info_measures = self.layerwise_forward_and_backward_correlative_information(layers, phase)
                forward_info.append(np.sum(info_measures[0]))
                backward_info.append(np.sum(info_measures[1]))

        err_out=layers[len(neurons)]-y
        return neurons, forward_info, backward_info, err_out,neurons_intermediate

    ###############################################################
    ############### BATCH STEP ALGORITHMS #########################
    ###############################################################
    def batch_step_hopfield(self, x, y, hopfield_g, lr, neural_lr_start, neural_lr_stop, neural_lr_rule = "constant", 
                            neural_lr_decay_multiplier = 0.1, neural_dynamic_iterations_free = 20, 
                            neural_dynamic_iterations_nudged = 10, beta = 1, use_three_phase = False, 
                            take_debug_logs = False, weight_decay = False,lr_weight_l1=[0,0,0],lr_weight_fb_l1=[0,0,0],lr_weight_bb_l1=[0,0,0],lr_weight_l2=[0,0,0],lr_weight_fb_l2=[0,0,0],lr_weight_bb_l2=[0,0,0]):

        Wff, Vff, Wfb,Vfb, B, Web,Web2,layer_mu, err_out_mu,lr_fr,lr_er,lr_er2 = self.Wff, self.Vff,  self.Wfb, self.Vfb, self.B,self.Web,self.Web2,self.layer_mu, self.err_out_mu,self.lr_fr,self.lr_er,self.lr_er2
        lambda_ = self.lambda_
        lambda_eb_=self.lambda_eb_
        gam_ = self.gam_
        epsilon = self.epsilon
        use_preact=self.use_preact
        subt_mean=self.subt_mean
        br_update_lateral=self.br_update_lateral
        act_l1_lr_ff=self.act_l1_lr_ff
        act_l1_lr_fb=self.act_l1_lr_fb
        momentum_ff=self.momentum_ff
        momentum_fb=self.momentum_fb
        Pow=self.Pow
        act_pow_lr_ff=self.act_pow_lr_ff
        layer_pow_target=self.layer_pow_target

        Rfree = self.Rfree # For debugging to check the correlation matrices vs inverse correlation matrices
        Rnudged = self.Rnudged # For debugging to check the correlation matrices vs inverse correlation matrices

        # neurons = self.init_neurons(x.size(1), device = self.device)
        neurons = self.init_neurons(x.size(1), device = self.device)

        (neurons,
         free_forward_info,
         free_backward_info, 
         err_out,neurons_intermediate
        ) = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                             neural_lr_decay_multiplier, neural_dynamic_iterations_free, 0, take_debug_logs)

        
        neurons1 = neurons.copy()
        neurons1_intermediate=neurons_intermediate.copy()
        layers_free_ = [x] + neurons1
        layers_intermediate_free=[x]+neurons1_intermediate
        neurons2 = neurons.copy()

        for jj in range(len(B)):

            Rfree[jj]['weight'] = lambda_ * Rfree[jj]['weight'] + (1 - lambda_) * torch.mean(outer_prod_broadcasting(neurons1[jj].T, neurons1[jj].T), axis = 0)


        layers_free = [x] + neurons1
        ## Compute forward errors
        forward_errors_free = [layers_free[jj + 1] - (Wff[jj]['weight'] @ layers_free[jj]) for jj in range(len(Wff))]
        ##forward_errors_nudged = [layers_nudged[jj + 1] - (Wff[jj]['weight'] @ layers_nudged[jj]) for jj in range(len(Wff))]
        ## Compute backward errors
        backward_errors_free = [(layers_free[jj]) - (Wfb[jj]['weight'] @ layers_free[jj + 1]) for jj in range(1, len(Wfb))]
        ##backward_errors_nudged = [(layers_nudged[jj]) - (Wfb[jj]['weight'] @ layers_nudged[jj + 1]) for jj in range(1, len(Wfb))]

        ### Learning updates for feed-forward and backward weights
        
        if (subt_mean==0):
            err_out_mu=lambda_eb_*err_out_mu+(1-lambda_eb_)*torch.mean(err_out,axis=1).unsqueeze(1)
        repeated_err_out_mu = err_out_mu.repeat(1,err_out.shape[1])
        
        err_hat=err_out-repeated_err_out_mu


    #################################################################################
    #                     ERROR BROADCAST WEIGHT-UPDATES                            #
    #################################################################################
        
        
        for jj in range(1, len(Web)):
            if (use_preact==1):
                Web[jj-1]['weight']=lambda_eb_*Web[jj-1]['weight'] + (1-lambda_eb_) * torch.mean(outer_prod_broadcasting(layers_intermediate_free[jj].T ,err_out.T),axis=0)
            else:
                Web[jj-1]['weight']=lambda_eb_*Web[jj-1]['weight'] + (1-lambda_eb_) * torch.mean(outer_prod_broadcasting((layers_free[jj]-layer_mu[jj]).T ,err_out.T),axis=0)
                #Web2[jj-1]['weight']=lambda_eb_*Web2[jj-1]['weight'] + (1-lambda_eb_) * torch.mean(outer_prod_broadcasting(self.non_fun(layers_free[jj]-layer_mu[jj]).T ,err_out.T),axis=0)
                Web2[jj-1]['weight']=lambda_eb_*Web2[jj-1]['weight'] + (1-lambda_eb_) * torch.mean(outer_prod_broadcasting(self.non_fun(layers_free[jj]).T ,err_out.T),axis=0)

    #################################################################################
    #                  FEEDFORWARD-UPDATES                                          #
    #################################################################################
        
        
        for jj in range(len(Wff)):
            if (subt_mean==1):
                layer_mu[jj]=lambda_eb_*layer_mu[jj]+(1-lambda_eb_)*torch.mean(layers_free[jj],axis=1).unsqueeze(1)
            #print(err_out.shape)
            #print((Web[jj]['weight']).shape)
            if (jj<(len(Wff)-1)):
                err_out_jp1=(Web[jj]['weight'])@err_hat 
            else:
                err_out_jp1=err_hat*0.03

            if (jj<(len(Wff)-1)):
                #err_out_jp2=((Web2[jj]['weight'])@err_hat )*self.non_fun_der(layers_free[jj+1]-layer_mu[jj+1])
                err_out_jp2=((Web2[jj]['weight'])@err_hat )*self.non_fun_der(layers_free[jj+1])
            else:
                err_out_jp2=err_hat*0.03
            check_out=(jj==(len(Wff)-1))
            if (use_preact==1):
                Wff[jj]['weight'] += +(1/(beta * (int(use_three_phase) + 1))) * lr['ff'][jj] *  \
            torch.mean(lr_fr*outer_prod_broadcasting(forward_errors_free[jj].T, layers_free[jj].T)  \
                       -lr_er[jj]*outer_prod_broadcasting(((layers_free[jj+1])*err_out_jp1).T, (layers_free[jj]-layer_mu[jj]).T),axis=0) 
            else:
                Vff[jj]['weight'] = momentum_ff*Vff[jj]['weight'] +(1-momentum_ff) *  \
            torch.mean(lr_fr*outer_prod_broadcasting(forward_errors_free[jj].T, layers_free[jj].T)  \
                       -lr_er[jj]* outer_prod_broadcasting((self.activation_der(layers_intermediate_free[jj+1],check_out)*err_out_jp1).T, (layers_free[jj]-layer_mu[jj]).T)\
                       -lr_er2[jj]*outer_prod_broadcasting((self.activation_der(layers_intermediate_free[jj+1],check_out)*err_out_jp2).T, (layers_free[jj]-layer_mu[jj]).T),axis=0)
                Wff[jj]['weight'] += (1/(beta * (int(use_three_phase) + 1))) * lr['ff'][jj] * Vff[jj]['weight']
            if weight_decay:
                Wff[jj]['weight'] -= lr_weight_l2[jj]* Wff[jj]['weight']
            Wff[jj]['weight'] -= lr_weight_l1[jj] * epsilon * torch.sign(Wff[jj]['weight'])
           
            Wff[jj]['weight'] -= lr['ff'][jj]*act_l1_lr_ff[jj]* torch.mean(outer_prod_broadcasting(torch.sign(layers_free[jj+1]).T, layers_free[jj].T),axis=0) 
            Wff[jj]['weight'] -= lr['ff'][jj]*act_pow_lr_ff[jj]* torch.mean(outer_prod_broadcasting(((layers_free[jj+1]*layers_free[jj+1]-layer_pow_target[jj])*layers_free[jj+1]).T, layers_free[jj].T),axis=0) 

    #################################################################################
    #                     FEEDBACK-UPDATES                                          #
    #################################################################################

        
        for jj in range(1, len(Wfb)):
            err_out_j=(Web[jj-1]['weight'])@err_hat
            if (use_preact==1):
                Wfb[jj]['weight'] += +(1/(beta * (int(use_three_phase) + 1))) * lr['fb'][jj] * torch.mean(lr_fr*outer_prod_broadcasting(backward_errors_free[jj - 1].T, layers_free[jj + 1].T)-lr_er[jj-1]*outer_prod_broadcasting(((layers_free[jj])*err_out_j).T, layers_free[jj+1].T),axis=0)
            else:
                Vfb[jj]['weight'] = momentum_fb* Vfb[jj]['weight']+(1-momentum_fb) * torch.mean(lr_fr*outer_prod_broadcasting(backward_errors_free[jj - 1].T, layers_free[jj + 1].T)-lr_er[jj-1]*outer_prod_broadcasting((self.activation_der(layers_intermediate_free[jj],check_out)*err_out_j).T, layers_free[jj+1].T),axis=0)
                Wfb[jj]['weight']+= +(1/(beta * (int(use_three_phase) + 1))) * lr['fb'][jj]* Vfb[jj]['weight']
            
            if weight_decay:
                Wfb[jj]['weight'] -=lr_weight_fb_l2[jj] * Wfb[jj]['weight']

            Wfb[jj]['weight'] -= lr_weight_fb_l1[jj] * epsilon * torch.sign(Wfb[jj]['weight'])
            Wfb[jj]['weight'] -= lr['fb'][jj] * act_l1_lr_fb[jj] * torch.mean(outer_prod_broadcasting( torch.sign(layers_free[jj]).T, layers_free[jj+1].T),axis=0)




    #################################################################################
    #                             LATERAL WEIGHT-UPDATES                            #
    #################################################################################
        
        for jj in range(len(B)):
            z = B[jj]['weight'] @ (neurons2[jj])
            B_update = torch.mean(outer_prod_broadcasting(z.T, z.T), axis = 0)
            B[jj]['weight'] = (1 / lambda_) * (B[jj]['weight'] - gam_ * B_update)
            if weight_decay:
                B[jj]['weight']-= lr_weight_bb_l2[jj]*B[jj]['weight']
            B[jj]['weight']-= lr_weight_bb_l1[jj]*torch.sign(B[jj]['weight'])
                
            if (br_update_lateral>0):
                if (jj<(len(Wff)-1)):
                    err_out_jp1=(Web[jj]['weight'])@err_hat 
                else:
                    err_out_jp1=err_hat*0.07
                    
                    
                B[jj]['weight']=B[jj]['weight'] -br_update_lateral * torch.mean(outer_prod_broadcasting((self.activation_der(layers_intermediate_free[jj+1],check_out)*err_out_jp1).T, (layers_free[jj+1]-layer_mu[jj+1]).T)+outer_prod_broadcasting((layers_free[jj+1]-layer_mu[jj+1]).T,(self.activation_der(layers_intermediate_free[jj+1],check_out)*err_out_jp1).T),axis=0)

           # Rnudged[jj]['weight'] = lambda_ * Rnudged[jj]['weight'] + (1 - lambda_) * torch.mean(outer_prod_broadcasting(neurons2[jj].T, neurons2[jj].T), axis = 0)
                 
        self.B = B
        self.Wff = Wff
        self.Vff = Vff
        self.Wfb = Wfb
        self.Vfb = Vfb
        self.Web = Web

        self.Pow=Pow
        self.Rfree = Rfree
        self.Rnudged = Rnudged
        self.layer_mu=layer_mu
        self.err_out_mu=err_out_mu

        if take_debug_logs:
            instant_forward_backward_angles = []
            for jj in range(1, len(Wff)):
                instant_forward_backward_angles.append(self.angle_between_two_matrices(self.Wff[jj]['weight'], self.Wfb[jj]['weight'].T).item())
            
            self.forward_backward_angles.append(instant_forward_backward_angles)

            (forward_info_list_free, 
             backward_info_list_free, 
            ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, "free")

            #(forward_info_list_nudged, 
            # backward_info_list_nudged, 
            #) = self.layerwise_forward_and_backward_correlative_information(layers_free_, "nudged")

            self.layerwise_forward_corinfo_list_free.append(forward_info_list_free)
            self.layerwise_backward_corinfo_list_free.append(backward_info_list_free)
            #self.layerwise_forward_corinfo_list_nudged.append(forward_info_list_nudged)
            #self.layerwise_backward_corinfo_list_nudged.append(backward_info_list_nudged)

            self.neural_dynamics_free_forward_info_list.append(free_forward_info)
            self.neural_dynamics_free_backward_info_list.append(free_backward_info)
            #self.neural_dynamics_nudged_forward_info_list.append(nudged_forward_info)
            #self.neural_dynamics_nudged_backward_info_list.append(nudged_backward_info)
        return neurons

    def save_model_weights(self, pickle_name = "EBDCorInfoWeights"):
        Wff_save = []
        for idx in range(len(self.Wff)):
            weight = torch2numpy(self.Wff[idx]['weight'])
            Wff_save.append({'weight': weight})
            
        Wfb_save = []
        for idx in range(len(self.Wfb)):
            weight = torch2numpy(self.Wfb[idx]['weight'])
            Wfb_save.append({'weight': weight})
            
        B_save = []
        for idx in range(len(self.B)):
            weight = torch2numpy(self.B[idx]['weight'])
            B_save.append({'weight': weight})
            
        model_params = pd.DataFrame(columns = ['Wff', 'Wfb', 'B'])

        model_params['Wff'] = Wff_save
        model_params['Wfb'] = Wfb_save
        model_params['B'] = B_save

        model_params.to_pickle(pickle_name + ".pkl")

    def load_model_weights(self, pickle_name):
        model_params_load = pd.read_pickle(pickle_name + ".pkl")
        for idx in range(len(self.Wff)):
            self.Wff[idx]['weight'] = torch.tensor(model_params_load['Wff'].iloc[idx]['weight'], requires_grad = False).to(self.device)
            
        for idx in range(len(self.Wfb)):
            self.Wfb[idx]['weight'] = torch.tensor(model_params_load['Wfb'].iloc[idx]['weight'], requires_grad = False).to(self.device)
            
        for idx in range(len(self.B)):
            self.B[idx]['weight'] = torch.tensor(model_params_load['B'].iloc[idx]['weight'], requires_grad = False).to(self.device)
                 
class ContrastiveCorInfoMaxHopfieldSparse(EBDCorInfoMaxHopfield):
    def __init__(self, architecture, lambda_, epsilon, activation = hard_sigmoid, sparse_layers = [], device = None):
        self.sparse_layers = sparse_layers
        super().__init__(architecture, lambda_, epsilon, activation, device)
        
    ###############################################################
    ############### NEURAL DYNAMICS ALGORITHMS ####################
    ###############################################################
    def run_neural_dynamics_hopfield(self, x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, STlambda_lr_list,
                                     lr_rule = "constant", lr_decay_multiplier = 0.1, 
                                     neural_dynamic_iterations = 10, beta = 1, take_debug_logs = False):
        mbs = x.size(1)
        if beta == 0:
            STLAMBD_list = [torch.zeros(1, mbs).to(self.device) for _ in range(len(neurons))]
        else:
            STLAMBD_list = self.STLAMBD_list
        STLAMBD_list_intermediate = self.copy_neurons(STLAMBD_list)
        # if take_debug_logs:
        if beta != 0:
            phase = "free"
        else:
            phase = "nudged"
        forward_info = []
        backward_info = []
            
        Wff = self.Wff
        Wfb = self.Wfb
        B = self.B
        gam_ = self.gam_
        epsilon = self.epsilon
        one_over_epsilon = self.one_over_epsilon

        neurons_intermediate = self.copy_neurons(neurons)
        layers = [x] + neurons  # concatenate the input to other layers
        for iter_count in range(neural_dynamic_iterations):

            if lr_rule == "constant":
                neural_lr = neural_lr_start
            elif lr_rule == "divide_by_loop_index":
                neural_lr = max(neural_lr_start / (iter_count + 1), neural_lr_stop)
            elif lr_rule == "divide_by_slow_loop_index":
                neural_lr = max(neural_lr_start / (iter_count * lr_decay_multiplier + 1), neural_lr_stop)

            with torch.no_grad():       
                for jj in range(len(neurons)):
                    if jj == len(neurons) - 1:
                        basal_voltage = Wff[jj]['weight'] @ layers[jj]
                        # apical_voltage = (gam_ * B[jj]['weight'] @ ( layers[jj + 1]) + hopfield_g * layers[jj + 1]) - beta * (layers[jj + 1] - y)
                        # gradient_neurons = -hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + (apical_voltage - neurons_intermediate[jj]) #+ 2 * beta * (y - layers[jj + 1])
                        
                        if (jj + 1) in self.sparse_layers:
                            apical_voltage = (gam_ * B[jj]['weight'] @ ( layers[jj + 1]) + hopfield_g * layers[jj + 1]) - STLAMBD_list[jj] - beta * (layers[jj + 1] - y)
                            gradient_neurons = -hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + (apical_voltage - neurons_intermediate[jj]) #+ 2 * beta * (y - layers[jj + 1])
                            neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons 
                            neurons[jj] = F.relu(neurons_intermediate[jj])
                            
                            STLAMBD_list_intermediate[jj] = STLAMBD_list_intermediate[jj] + STlambda_lr_list[jj] * (-STLAMBD_list_intermediate[jj] + (torch.sum(neurons[jj], 0).view(1, -1) - 1) + STLAMBD_list[jj])
                            STLAMBD_list[jj] = F.relu(STLAMBD_list_intermediate[jj])
                        else:
                            apical_voltage = (gam_ * B[jj]['weight'] @ ( layers[jj + 1]) + hopfield_g * layers[jj + 1]) - beta * (layers[jj + 1] - y)
                            gradient_neurons = -hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + (apical_voltage - neurons_intermediate[jj]) #+ 2 * beta * (y - layers[jj + 1])
                            neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons 
                            neurons[jj] = self.activation(neurons_intermediate[jj])
                    else:
                        basal_voltage = Wff[jj]['weight'] @ layers[jj] 
                        # apical_voltage = epsilon * (2 * gam_ * B[jj]['weight'] @ (layers[jj + 1]) + hopfield_g * layers[jj + 1]) + (Wfb[jj + 1]['weight'] @ layers[jj + 2]) #+ Wfb[jj + 1]['bias']
                        # gradient_neurons = - hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + one_over_epsilon * (apical_voltage - neurons_intermediate[jj])
                        
                        if (jj + 1) in self.sparse_layers:
                            apical_voltage = epsilon * (2 * gam_ * B[jj]['weight'] @ (layers[jj + 1]) + hopfield_g * layers[jj + 1]) + (Wfb[jj + 1]['weight'] @ layers[jj + 2]) - epsilon * STLAMBD_list[jj]#+ Wfb[jj + 1]['bias']
                            gradient_neurons = - hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + one_over_epsilon * (apical_voltage - neurons_intermediate[jj])
                            neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons
                            neurons[jj] = F.relu(neurons_intermediate[jj])
                            STLAMBD_list_intermediate[jj] = STLAMBD_list_intermediate[jj] + STlambda_lr_list[jj] * (-STLAMBD_list_intermediate[jj] + (torch.sum(neurons[jj], 0).view(1, -1) - 1) + STLAMBD_list[jj])
                            STLAMBD_list[jj] = F.relu(STLAMBD_list_intermediate[jj])
                        else:
                            apical_voltage = epsilon * (2 * gam_ * B[jj]['weight'] @ (layers[jj + 1]) + hopfield_g * layers[jj + 1]) + (Wfb[jj + 1]['weight'] @ layers[jj + 2]) #+ Wfb[jj + 1]['bias']
                            gradient_neurons = - hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + one_over_epsilon * (apical_voltage - neurons_intermediate[jj])
                            neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons 
                            neurons[jj] = self.activation(neurons_intermediate[jj])
                        # neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons
                        # neurons[jj] = self.activation(neurons_intermediate[jj])
                    layers = [x] + neurons  # concatenate the input to other layers

            if take_debug_logs:
                info_measures = self.layerwise_forward_and_backward_correlative_information(layers, phase)
                forward_info.append(np.sum(info_measures[0]))
                backward_info.append(np.sum(info_measures[1]))
                    # neurons[neuron_iter] = self.activation(neurons[neuron_iter] + neural_lr * neuron_grads[neuron_iter])
        self.STLAMBD_list = STLAMBD_list
        return neurons, forward_info, backward_info

    ###############################################################
    ############### BATCH STEP ALGORITHMS #########################
    ###############################################################
    def batch_step_hopfield(self, x, y, hopfield_g, lr, neural_lr_start, neural_lr_stop, STlambda_lr_list, neural_lr_rule = "constant", 
                            neural_lr_decay_multiplier = 0.1, neural_dynamic_iterations_free = 20, 
                            neural_dynamic_iterations_nudged = 10, beta = 1, use_three_phase = False, 
                            take_debug_logs = False, weight_decay = False):

        Wff, Wfb, B = self.Wff, self.Wfb, self.B
        lambda_ = self.lambda_
        gam_ = self.gam_
        epsilon = self.epsilon

        Rfree = self.Rfree # For debugging to check the correlation matrices vs inverse correlation matrices
        Rnudged = self.Rnudged # For debugging to check the correlation matrices vs inverse correlation matrices

        # neurons = self.init_neurons(x.size(1), device = self.device)
        neurons = self.init_neurons(x.size(1), device = self.device)

        (neurons,
         free_forward_info,
         free_backward_info
        ) = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, STlambda_lr_list, neural_lr_rule, 
                                             neural_lr_decay_multiplier, neural_dynamic_iterations_free, 0, take_debug_logs)

        
        neurons1 = neurons.copy()
        layers_free_ = [x] + neurons1

        for jj in range(len(B)):

            Rfree[jj]['weight'] = lambda_ * Rfree[jj]['weight'] + (1 - lambda_) * torch.mean(outer_prod_broadcasting(neurons1[jj].T, neurons1[jj].T), axis = 0)

        (neurons,
         nudged_forward_info,
         nudged_backward_info 
        ) = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, STlambda_lr_list, neural_lr_rule, 
                                              neural_lr_decay_multiplier, neural_dynamic_iterations_nudged, beta, take_debug_logs)


        neurons2 = neurons.copy()

        if use_three_phase:
            neurons, _, _ = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, STlambda_lr_list, neural_lr_rule, 
                                                              neural_lr_decay_multiplier, neural_dynamic_iterations_nudged, -beta, take_debug_logs)

            neurons3 = neurons.copy()

            layers_free = [x] + neurons3
        else:
            layers_free = [x] + neurons1

        layers_nudged = [x] + neurons2

        ## Compute forward errors
        forward_errors_free = [layers_free[jj + 1] - (Wff[jj]['weight'] @ layers_free[jj]) for jj in range(len(Wff))]
        forward_errors_nudged = [layers_nudged[jj + 1] - (Wff[jj]['weight'] @ layers_nudged[jj]) for jj in range(len(Wff))]
        ## Compute backward errors
        backward_errors_free = [(layers_free[jj]) - (Wfb[jj]['weight'] @ layers_free[jj + 1]) for jj in range(1, len(Wfb))]
        backward_errors_nudged = [(layers_nudged[jj]) - (Wfb[jj]['weight'] @ layers_nudged[jj + 1]) for jj in range(1, len(Wfb))]

        ### Learning updates for feed-forward and backward weights
        for jj in range(len(Wff)):
            Wff[jj]['weight'] += -(1/(beta * (int(use_three_phase) + 1))) * lr['ff'][jj] * torch.mean(outer_prod_broadcasting(forward_errors_free[jj].T, layers_free[jj].T) - outer_prod_broadcasting(forward_errors_nudged[jj].T, layers_nudged[jj].T), axis = 0)
            if weight_decay:
                Wff[jj]['weight'] -= lr['ff'][jj] * epsilon * Wff[jj]['weight']

        for jj in range(1, len(Wfb)):
            Wfb[jj]['weight'] += -(1/(beta * (int(use_three_phase) + 1))) * lr['fb'][jj] * torch.mean(outer_prod_broadcasting(backward_errors_free[jj - 1].T, layers_free[jj + 1].T) - outer_prod_broadcasting(backward_errors_nudged[jj - 1].T, layers_nudged[jj + 1].T), axis = 0)
            if weight_decay:
                Wfb[jj]['weight'] -= lr['fb'][jj] * epsilon * Wfb[jj]['weight']
        ### Lateral Weight Updates
        for jj in range(len(B)):
            z = B[jj]['weight'] @ (neurons2[jj])
            B_update = torch.mean(outer_prod_broadcasting(z.T, z.T), axis = 0)
            B[jj]['weight'] = (1 / lambda_) * (B[jj]['weight'] - gam_ * B_update)

            Rnudged[jj]['weight'] = lambda_ * Rnudged[jj]['weight'] + (1 - lambda_) * torch.mean(outer_prod_broadcasting(neurons2[jj].T, neurons2[jj].T), axis = 0)
                 
        self.B = B
        self.Wff = Wff
        self.Wfb = Wfb
        self.Rfree = Rfree
        self.Rnudged = Rnudged

        if take_debug_logs:
            instant_forward_backward_angles = []
            for jj in range(1, len(Wff)):
                instant_forward_backward_angles.append(self.angle_between_two_matrices(self.Wff[jj]['weight'], self.Wfb[jj]['weight'].T).item())
            
            self.forward_backward_angles.append(instant_forward_backward_angles)

            (forward_info_list_free, 
             backward_info_list_free, 
            ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, "free")

            (forward_info_list_nudged, 
             backward_info_list_nudged, 
            ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, "nudged")

            self.layerwise_forward_corinfo_list_free.append(forward_info_list_free)
            self.layerwise_backward_corinfo_list_free.append(backward_info_list_free)
            self.layerwise_forward_corinfo_list_nudged.append(forward_info_list_nudged)
            self.layerwise_backward_corinfo_list_nudged.append(backward_info_list_nudged)

            self.neural_dynamics_free_forward_info_list.append(free_forward_info)
            self.neural_dynamics_free_backward_info_list.append(free_backward_info)
            self.neural_dynamics_nudged_forward_info_list.append(nudged_forward_info)
            self.neural_dynamics_nudged_backward_info_list.append(nudged_backward_info)
        return neurons

class EP(torch.nn.Module):
    #TODO : Add structured docstring for understandibility
    """
    Modified from https://github.com/Laborieux-Axel/Equilibrium-Propagation/blob/master/model_utils.py
    This EP Class is a little bit different from the one taken from the above github page. The above one uses fixed point iteration in the 
    neural dynamics, i.e., s_(t+1) = sigma( dPhi/ds ), whereas in this implementation we use s_(t+1) = s(t) - neural_lr * sigma( dPhi/ds )
    """
    def __init__(self, architecture, activation = hard_sigmoid):
        super(EP, self).__init__()
        
        self.activation = activation
        self.architecture = architecture 
        self.nc = self.architecture[-1]

        # Feedforward and Feedback Synapses Initialization
        self.W = torch.nn.ModuleList()
        for idx in range(len(architecture)-1):
            m = torch.nn.Linear(architecture[idx], architecture[idx+1], bias=True)
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                m.bias.data.mul_(0)
            self.W.append(m)

    def Phi(self, x, y, neurons, beta, criterion):

        x = x.view(x.size(0),-1) 
        
        layers = [x] + neurons 
        
        phi = 0.0
        for idx in range(len(neurons)): 
            phi += 0.5*torch.sum( neurons[idx] * neurons[idx], dim=1).squeeze() 
        for idx in range(len(self.W)): 
            phi -= torch.sum( self.W[idx](layers[idx]) * layers[idx+1], dim=1).squeeze() 

        if beta!=0.0: 
            if criterion.__class__.__name__.find('MSE')!=-1:
                y = F.one_hot(y, num_classes=self.nc)
                L = criterion(layers[-1].float(), y.float()).sum(dim=1).squeeze()   
            else:
                L = criterion(layers[-1].float(), y).squeeze()     
            phi += beta*L
        
        return phi
    
    
    def forward(self, x, y, neurons, T, neural_lr = 0.5, beta=0.0, criterion=torch.nn.MSELoss(reduction='none'), check_thm=False):
        not_mse = (criterion.__class__.__name__.find('MSE')==-1)
        mbs = x.size(0)
        device = x.device

        for t in range(T):
            phi = self.Phi(x, y, neurons, beta, criterion) 
            init_grads = torch.tensor([1 for i in range(mbs)], dtype=torch.float, device=device, requires_grad=True) 
            grads = torch.autograd.grad(phi, neurons, grad_outputs=init_grads, create_graph=check_thm) 

            with torch.no_grad():
                for idx in range(len(neurons)-1):
                    neurons[idx] = self.activation(neurons[idx] - neural_lr * grads[idx] )  # s_(t+1) = s_(t) - neural_lr * sigma( dPhi/ds )
                if check_thm:
                    neurons[idx].retain_grad()
                else:
                    neurons[idx].requires_grad = True
             
            if not_mse:
                neurons[-1] = grads[-1]
            else:
                with torch.no_grad():
                    neurons[-1] = self.activation(neurons[-1] - neural_lr * grads[-1] )

            if check_thm:
                neurons[-1].retain_grad()
            else:
                neurons[-1].requires_grad = True

        return neurons


    def init_neurons(self, mbs, device):
        neurons = []
        append = neurons.append
        for size in self.architecture[1:]:  
            append(torch.zeros((mbs, size), requires_grad=True, device=device))
        return neurons


    def compute_syn_grads(self, x, y, neurons_1, neurons_2, betas, criterion, check_thm=False):
        beta_1, beta_2 = betas
        
        self.zero_grad()           
        if not(check_thm):
            phi_1 = self.Phi(x, y, neurons_1, beta_1, criterion)
        else:
            phi_1 = self.Phi(x, y, neurons_1, beta_2, criterion)
        phi_1 = phi_1.mean()
        
        phi_2 = self.Phi(x, y, neurons_2, beta_2, criterion)
        phi_2 = phi_2.mean()
        
        delta_phi = (phi_2 - phi_1)/(beta_2 - beta_1)        
        delta_phi.backward() 

class CSM(torch.nn.Module):
    """
    Contrastive Similarity Matching for Supervised Learning.
    Paper :                             https://arxiv.org/abs/2002.10378
    Published Official Theano Code :    https://github.com/Pehlevan-Group/Supervised-Similarity-Matching
    """
    def __init__(self, architecture, activation, alphas_W, alphas_M, task = "classification"):
        super(CSM, self).__init__()
        
        self.activation = activation
        self.architecture = architecture 
        self.nc = self.architecture[-1]
        self.task = task
        # Feedforward and Feedboack Synapses Initialization
        self.W = torch.nn.ModuleList()
        for idx in range(len(architecture)-1):
            m = torch.nn.Linear(architecture[idx], architecture[idx+1], bias=True)
            torch.nn.init.xavier_uniform_(m.weight)
            # m.weight.data.mul_(torch.tensor([1]))
            if m.bias is not None:
                m.bias.data.mul_(0)
            self.W.append(m)

        # Lateral Synapses Initialization
        self.M = torch.nn.ModuleList()
        for idx in range(1,len(architecture)-1):
            m = torch.nn.Linear(architecture[idx], architecture[idx], bias = False)
            torch.nn.init.xavier_uniform_(m.weight)
            m.weight.data = m.weight.data @ m.weight.data.T
            self.M.append(m)

        self.M_copy = torch.nn.ModuleList()
        for idx in range(1, len(architecture) - 1):
            m = torch.nn.Linear(architecture[idx], architecture[idx], bias = False)
            m.weight.data = self.M[idx-1].weight.data
            m.weight.data.requires_grad_(False)
            self.M_copy.append(m)

        optim_params = []
        for idx in range(len(self.W)):
            optim_params.append(  {'params': self.W[idx].parameters(), 'lr': alphas_W[idx]}  )
            
        for idx in range(len(self.M)):
            optim_params.append(  {'params': self.M[idx].parameters(), 'lr': alphas_M[idx]}  )

        optimizer = torch.optim.SGD( optim_params, momentum=0.0 )
        self.optimizer = optimizer

    def Phi(self, x, y, neurons, beta, criterion):

        x = x.view(x.size(0),-1) 
        
        layers = [x] + neurons  
        

        phi = 0.0
        for idx in range(len(neurons)): 
            phi += 0.5*torch.sum( neurons[idx] * neurons[idx], dim=1).squeeze() 
        for idx in range(len(self.W)): 
            phi -= torch.sum( self.W[idx](layers[idx]) * layers[idx+1], dim=1).squeeze() 
        for idx in range(len(self.M)): 
            if beta != 0.0:
                phi += 0.5*torch.sum( self.M[idx](layers[idx+1]) * layers[idx+1], dim=1).squeeze() 
            else:
                phi += 0.5*torch.sum( self.M_copy[idx](layers[idx+1]) * layers[idx+1], dim=1).squeeze() 

        if beta!=0.0: 
            if criterion.__class__.__name__.find('MSE')!=-1:
                if self.task == "classification":
                    y = F.one_hot(y, num_classes=self.nc)
                L = criterion(layers[-1].float(), y.float()).sum(dim=1).squeeze()   
            else:
                L = criterion(layers[-1].float(), y).squeeze()     
            phi += beta*L
        
        return phi
    
    def forward(self, x, y, neurons, T, neural_lr = 0.5, beta=0.0, criterion=torch.nn.MSELoss(reduction='none'), check_thm=False):

        not_mse = (criterion.__class__.__name__.find('MSE')==-1)
        mbs = x.size(0)
        device = x.device

        for t in range(T):
            phi = self.Phi(x, y, neurons, beta, criterion) 
            init_grads = torch.tensor([1 for i in range(mbs)], dtype=torch.float, device=device, requires_grad=True) 
            grads = torch.autograd.grad(phi, neurons, grad_outputs=init_grads, create_graph=check_thm) 
            with torch.no_grad():
                for idx in range(len(neurons)-1):
                    neurons[idx] = self.activation(neurons[idx] - neural_lr * grads[idx] )  # s_(t+1) = s_(t) - neural_lr * sigma( dPhi/ds )
                if check_thm:
                    neurons[idx].retain_grad()
                else:
                    neurons[idx].requires_grad = True
             
            if not_mse:
                neurons[-1] = grads[-1]
            else:
                with torch.no_grad():
                    neurons[-1] = self.activation(neurons[-1] - neural_lr * grads[-1] )

            if check_thm:
                neurons[-1].retain_grad()
            else:
                neurons[-1].requires_grad = True

        return neurons

    def init_neurons(self, mbs, device):
        neurons = []
        append = neurons.append
        for size in self.architecture[1:]:  
            append(torch.zeros((mbs, size), requires_grad=True, device=device))
        return neurons

    def compute_syn_grads(self, x, y, neurons_1, neurons_2, betas, alphas_M, criterion, check_thm=False):
        
        beta_1, beta_2 = betas
        
        self.zero_grad()            # p.grad is zero
        if not(check_thm):
            phi_1 = self.Phi(x, y, neurons_1, beta_1, criterion)
        else:
            phi_1 = self.Phi(x, y, neurons_1, beta_2, criterion)
        phi_1 = phi_1.mean()
        
        phi_2 = self.Phi(x, y, neurons_2, beta_2, criterion)
        phi_2 = phi_2.mean()
        
        delta_phi = (phi_2 - phi_1)/(beta_2 - beta_1)        
        delta_phi.backward() 

        self.optimizer.step()
        # Contrastive Similarity Matching Lateral Weight Update additional term is added below (before optimizer step)
        with torch.no_grad(): # Check line 306 in https://github.com/Pehlevan-Group/Supervised-Similarity-Matching/blob/master/Main/model_wlat_smep_mod.py
            for kk in range(len(self.M)):
                Mweight = self.M[kk].weight.data
                self.M[kk].weight.data = Mweight + (alphas_M[kk]) * Mweight/(2 * np.abs(beta_2))
                
        for idx in range(len(self.M)):
            self.M_copy[idx].weight.data = self.M[idx].weight.data
            self.M_copy[idx].weight.data.requires_grad_(False)

###### Debugging #####

class ContrastiveCorInfoMaxHopfieldDebug1():
    """This is the algorithm to be used in the paper. The summary will be added later.
    """

    def __init__(self, architecture, lambda_, epsilon, activation = hard_sigmoid, device = None):
        
        self.architecture = architecture
        self.lambda_ = lambda_
        self.gam_ = (1 - lambda_) / lambda_
        self.epsilon = epsilon
        self.one_over_epsilon = 1 / epsilon
        self.activation = activation
        if device is None:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        # self.run_neural_dynamics = self.run_neural_dynamics_hopfield
        # Feedforward Synapses Initialization
        Wff = []
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx + 1], architecture[idx], requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)

            Wff.append({'weight': weight})
        Wff = np.array(Wff)
        
        # Feedback Synapses Initialization
        Wfb = []
        for idx in range(len(architecture)-1):
            weight = torch.eye(architecture[idx], architecture[idx + 1], requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)

            Wfb.append({'weight': weight})
        Wfb = np.array(Wfb)
        
        # Lateral Synapses Initialization
        B = []
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)
            weight = weight @ weight.T
            # weight = 1.0*torch.eye(architecture[idx + 1] + 1, architecture[idx + 1] + 1, requires_grad = False).to(self.device)
            B.append({'weight': weight})
        B = np.array(B)

        # Correlation Matrices (Only for debugging)
        Rfree = []
        for idx in range(len(architecture) - 1):
            weight = 1.0*torch.eye(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(self.device)
            Rfree.append({'weight': weight})

        Rfree = np.array(Rfree)

        # Correlation Matrices (Only for debugging)
        Rnudged = []
        for idx in range(len(architecture) - 1):
            weight = 1.0*torch.eye(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(self.device)
            Rnudged.append({'weight': weight})

        Rnudged = np.array(Rnudged)

        self.Wff = Wff
        self.Wfb = Wfb
        self.B = B
        self.Rfree = Rfree
        self.Rnudged = Rnudged
        
        ############ Some Debugging Logs ##########################
        self.forward_backward_angles = []
        self.layerwise_forward_corinfo_list_free = []
        self.layerwise_backward_corinfo_list_free = []
        self.layerwise_forward_corinfo_list_nudged = []
        self.layerwise_backward_corinfo_list_nudged = []

        self.neural_dynamics_free_forward_info_list = []
        self.neural_dynamics_free_backward_info_list = []
        self.neural_dynamics_nudged_forward_info_list = []
        self.neural_dynamics_nudged_backward_info_list = []

    ###############################################################
    ############### HELPER METHODS ################################
    ###############################################################
    def copy_neurons(self, neurons):
        copy = []
        for n in neurons:
            copy.append(torch.empty_like(n).copy_(n.data))#.requires_grad_())
        return copy
        
    def init_neurons(self, mbs, random_initialize = False, device = 'cuda'):
        # Initializing the neurons
        if random_initialize:
            neurons = []
            append = neurons.append
            for size in self.architecture[1:]:  
                append(torch.randn((mbs, size), requires_grad=False, device=device).T)       
        else:
            neurons = []
            append = neurons.append
            for size in self.architecture[1:]:  
                append(torch.zeros((mbs, size), requires_grad=False, device=device).T)
        return neurons

    ###############################################################
    ############### REQUIRED FUNCTIONS FOR DEBUGGING ##############
    ###############################################################
    def angle_between_two_matrices(self, A, B):
        """Computes the angle between two matrices A and B.

        Args:
            A (torch.Tensor): Pytorch tensor of size m times n
            B (torch.Tensor): Pytorch tensor of size m times n

        Returns:
            angle: angle between the matrices A and B. The formula is given by the following:
                (180/pi) * acos[ Tr(A @ B.T) / sqrt(Tr(A @ A.T) * Tr(B @ B.T))] 
        """

        angle = (180 / torch.pi) * torch.acos(torch.trace(A @ B.T) / torch.sqrt(torch.trace(A @ A.T) * torch.trace(B @ B.T)))
        return angle

    def layerwise_forward_and_backward_correlative_information(self, layers, phase = "free"):
        Wff = self.Wff
        Wfb = self.Wfb
        if phase == "free":
            R = self.Rfree 
        elif phase == "nudged":
            R = self.Rnudged
        epsilon = self.epsilon
        one_over_epsilon = self.one_over_epsilon
        device = self.device
        architecture = self.architecture

        # epsilon_tensor = torch.Tensor([epsilon]).to(device)
        batch_size = layers[0].shape[1]
        batch_size_sqrt_root = np.sqrt(batch_size)
        log_epsilon = np.log(epsilon)

        forward_info_list = []
        backward_info_list = []

        for jj in range(len(architecture) - 2):
            Identity_Matrix = epsilon * torch.eye(*R[jj + 1]['weight'].shape).to(device)
            forward_info_jj= (torch.logdet(R[jj + 1]['weight'] + Identity_Matrix) - (1 / batch_size) * (one_over_epsilon * torch.norm(layers[jj + 2] - Wff[jj + 1]['weight'] @ layers[jj + 1]) ** 2 - layers[jj + 2].shape[0] * log_epsilon)).item()

            forward_info_list.append(forward_info_jj)

        for jj in range(len(architecture) - 2):
            Identity_Matrix = epsilon * torch.eye(*R[jj]['weight'].shape).to(device)
            backward_info_jj = (torch.logdet(R[jj]['weight'] + Identity_Matrix) - (1 / batch_size) * (one_over_epsilon * torch.norm((layers[jj + 1]) - Wfb[jj + 1]['weight'] @ layers[jj + 2]) ** 2 - (layers[jj + 1].shape[0] + 1) * log_epsilon)).item()

            backward_info_list.append(backward_info_jj)

            return forward_info_list, backward_info_list

    ###############################################################
    ############### NEURAL DYNAMICS ALGORITHMS ####################
    ###############################################################
    def run_neural_dynamics_hopfield(self, x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, 
                                     lr_rule = "constant", lr_decay_multiplier = 0.1, 
                                     neural_dynamic_iterations = 10, beta = 1, take_debug_logs = False):

        # if take_debug_logs:
        if beta != 0:
            phase = "free"
        else:
            phase = "nudged"
        forward_info = []
        backward_info = []
            
        Wff = self.Wff
        Wfb = self.Wfb
        B = self.B
        gam_ = self.gam_
        epsilon = self.epsilon
        one_over_epsilon = self.one_over_epsilon

        neurons_intermediate = self.copy_neurons(neurons)
        layers = [x] + neurons  # concatenate the input to other layers
        for iter_count in range(neural_dynamic_iterations):

            if lr_rule == "constant":
                neural_lr = neural_lr_start
            elif lr_rule == "divide_by_loop_index":
                neural_lr = max(neural_lr_start / (iter_count + 1), neural_lr_stop)
            elif lr_rule == "divide_by_slow_loop_index":
                neural_lr = max(neural_lr_start / (iter_count * lr_decay_multiplier + 1), neural_lr_stop)

            with torch.no_grad():       
                for jj in range(len(neurons)):
                    if jj == len(neurons) - 1:
                        
                        basal_voltage = Wff[jj]['weight'] @ layers[jj] #+ Wff[jj]['bias']
                        apical_voltage = (gam_ * B[jj]['weight'] @ ( layers[jj + 1]) + hopfield_g * layers[jj + 1]) - beta * (layers[jj + 1] - y)
                        gradient_neurons = -hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + (apical_voltage - neurons_intermediate[jj]) #+ 2 * beta * (y - layers[jj + 1])
                        neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons
                        neurons[jj] = self.activation(neurons_intermediate[jj])
                        
                    else:
                        
                        basal_voltage = Wff[jj]['weight'] @ layers[jj] #+ Wff[jj]['bias']
                        apical_voltage = epsilon * (2 * gam_ * B[jj]['weight'] @ (layers[jj + 1]) + hopfield_g * layers[jj + 1]) + (Wfb[jj + 1]['weight'] @ layers[jj + 2]) #+ Wfb[jj + 1]['bias']
                        gradient_neurons = - hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + one_over_epsilon * (apical_voltage - neurons_intermediate[jj])
                        neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons
                        neurons[jj] = self.activation(neurons_intermediate[jj])
                    layers = [x] + neurons  # concatenate the input to other layers

            if take_debug_logs:
                info_measures = self.layerwise_forward_and_backward_correlative_information(layers, phase)
                forward_info.append(np.sum(info_measures[0]))
                backward_info.append(np.sum(info_measures[1]))
                    
        return neurons, forward_info, backward_info

    ###############################################################
    ############### BATCH STEP ALGORITHMS #########################
    ###############################################################
    def batch_step_hopfield(self, x, y, hopfield_g, lr, neural_lr_start, neural_lr_stop, neural_lr_rule = "constant", 
                            neural_lr_decay_multiplier = 0.1, neural_dynamic_iterations_free = 20, 
                            neural_dynamic_iterations_nudged = 10, beta = 1, use_three_phase = False, 
                            take_debug_logs = False, weight_decay = False):

        Wff, Wfb, B = self.Wff, self.Wfb, self.B
        lambda_ = self.lambda_
        gam_ = self.gam_
        epsilon = self.epsilon

        Rfree = self.Rfree # For debugging to check the correlation matrices vs inverse correlation matrices
        Rnudged = self.Rnudged # For debugging to check the correlation matrices vs inverse correlation matrices

        # neurons = self.init_neurons(x.size(1), device = self.device)
        neurons = self.init_neurons(x.size(1), device = self.device)

        (neurons,
         free_forward_info,
         free_backward_info
        ) = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                             neural_lr_decay_multiplier, neural_dynamic_iterations_free, 0, take_debug_logs)

        
        neurons1 = neurons.copy()
        layers_free_ = [x] + neurons1

        for jj in range(len(B)):

            Rfree[jj]['weight'] = lambda_ * Rfree[jj]['weight'] + (1 - lambda_) * torch.mean(outer_prod_broadcasting(neurons1[jj].T, neurons1[jj].T), axis = 0)

        (neurons,
         nudged_forward_info,
         nudged_backward_info 
        ) = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                              neural_lr_decay_multiplier, neural_dynamic_iterations_nudged, beta, take_debug_logs)


        neurons2 = neurons.copy()

        if use_three_phase:
            neurons, _, _ = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                                              neural_lr_decay_multiplier, neural_dynamic_iterations_nudged, -beta, take_debug_logs)

            neurons3 = neurons.copy()

            layers_free = [x] + neurons3
        else:
            layers_free = [x] + neurons1

        layers_nudged = [x] + neurons2

        ## Compute forward errors
        forward_errors_free = [layers_free[jj + 1] - (Wff[jj]['weight'] @ layers_free[jj]) for jj in range(len(Wff))]
        forward_errors_nudged = [layers_nudged[jj + 1] - (Wff[jj]['weight'] @ layers_nudged[jj]) for jj in range(len(Wff))]
        ## Compute backward errors
        backward_errors_free = [(layers_free[jj]) - (Wfb[jj]['weight'] @ layers_free[jj + 1]) for jj in range(1, len(Wfb))]
        backward_errors_nudged = [(layers_nudged[jj]) - (Wfb[jj]['weight'] @ layers_nudged[jj + 1]) for jj in range(1, len(Wfb))]

        ### Learning updates for feed-forward and backward weights
        for jj in range(len(Wff)):
            Wff[jj]['weight'] += -(1/(beta * (int(use_three_phase) + 1))) * lr['ff'][jj] * torch.mean(outer_prod_broadcasting(forward_errors_free[jj].T, layers_free[jj].T) - outer_prod_broadcasting(forward_errors_nudged[jj].T, layers_nudged[jj].T), axis = 0)
            if weight_decay:
                Wff[jj]['weight'] -= lr['ff'][jj] * epsilon * Wff[jj]['weight']

        for jj in range(1, len(Wfb)):
            Wfb[jj]['weight'] += -(1/(beta * (int(use_three_phase) + 1))) * lr['fb'][jj] * torch.mean(outer_prod_broadcasting(backward_errors_free[jj - 1].T, layers_free[jj + 1].T) - outer_prod_broadcasting(backward_errors_nudged[jj - 1].T, layers_nudged[jj + 1].T), axis = 0)
            if weight_decay:
                Wfb[jj]['weight'] -= lr['fb'][jj] * epsilon * Wfb[jj]['weight']
        ### Lateral Weight Updates
        for jj in range(len(B)):
            z = B[jj]['weight'] @ (neurons2[jj])
            B_update = torch.mean(outer_prod_broadcasting(z.T, z.T), axis = 0)
            B[jj]['weight'] = (1 / lambda_) * (B[jj]['weight'] - gam_ * B_update)

            Rnudged[jj]['weight'] = lambda_ * Rnudged[jj]['weight'] + (1 - lambda_) * torch.mean(outer_prod_broadcasting(neurons2[jj].T, neurons2[jj].T), axis = 0)
                 
        self.B = B
        self.Wff = Wff
        self.Wfb = Wfb
        self.Rfree = Rfree
        self.Rnudged = Rnudged

        if take_debug_logs:
            instant_forward_backward_angles = []
            for jj in range(1, len(Wff)):
                instant_forward_backward_angles.append(self.angle_between_two_matrices(self.Wff[jj]['weight'], self.Wfb[jj]['weight'].T).item())
            
            self.forward_backward_angles.append(instant_forward_backward_angles)

            # (forward_info_list_free, 
            #  backward_info_list_free, 
            #  forward_info_list_nudged, 
            #  backward_info_list_nudged
            # ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, layers_nudged)

            (forward_info_list_free, 
             backward_info_list_free, 
            ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, "free")

            (forward_info_list_nudged, 
             backward_info_list_nudged, 
            ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, "nudged")

            self.layerwise_forward_corinfo_list_free.append(forward_info_list_free)
            self.layerwise_backward_corinfo_list_free.append(backward_info_list_free)
            self.layerwise_forward_corinfo_list_nudged.append(forward_info_list_nudged)
            self.layerwise_backward_corinfo_list_nudged.append(backward_info_list_nudged)

            self.neural_dynamics_free_forward_info_list.append(free_forward_info)
            self.neural_dynamics_free_backward_info_list.append(free_backward_info)
            self.neural_dynamics_nudged_forward_info_list.append(nudged_forward_info)
            self.neural_dynamics_nudged_backward_info_list.append(nudged_backward_info)
        return neurons

    def save_model_weights(self, pickle_name = "CorInfoWeights"):
        Wff_save = []
        for idx in range(len(self.Wff)):
            weight = torch2numpy(self.Wff[idx]['weight'])
            Wff_save.append({'weight': weight})
            
        Wfb_save = []
        for idx in range(len(self.Wfb)):
            weight = torch2numpy(self.Wfb[idx]['weight'])
            Wfb_save.append({'weight': weight})
            
        B_save = []
        for idx in range(len(self.B)):
            weight = torch2numpy(self.B[idx]['weight'])
            B_save.append({'weight': weight})
            
        model_params = pd.DataFrame(columns = ['Wff', 'Wfb', 'B'])

        model_params['Wff'] = Wff_save
        model_params['Wfb'] = Wfb_save
        model_params['B'] = B_save

        model_params.to_pickle(pickle_name + ".pkl")

    def load_model_weights(self, pickle_name):
        model_params_load = pd.read_pickle(pickle_name + ".pkl")
        for idx in range(len(self.Wff)):
            self.Wff[idx]['weight'] = torch.tensor(model_params_load['Wff'].iloc[idx]['weight'], requires_grad = False).to(self.device)
            
        for idx in range(len(self.Wfb)):
            self.Wfb[idx]['weight'] = torch.tensor(model_params_load['Wfb'].iloc[idx]['weight'], requires_grad = False).to(self.device)
            
        for idx in range(len(self.B)):
            self.B[idx]['weight'] = torch.tensor(model_params_load['B'].iloc[idx]['weight'], requires_grad = False).to(self.device)
       
class ContrastiveCorInfoMaxHopfieldDebugWithBias():
    """This is the algorithm to be used in the paper. The summary will be added later.
    """

    def __init__(self, architecture, lambda_, epsilon, activation = hard_sigmoid, output_sparsity = False, STlambda_lr = 0.01, device = None):
        
        self.architecture = architecture
        self.lambda_ = lambda_
        self.gam_ = (1 - lambda_) / lambda_
        self.epsilon = epsilon
        self.one_over_epsilon = 1 / epsilon
        self.activation = activation
        self.output_sparsity = output_sparsity
        self.STlambda_lr = STlambda_lr
        if device is None:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        # self.run_neural_dynamics = self.run_neural_dynamics_hopfield
        # Feedforward Synapses Initialization
        Wff = []
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx + 1], architecture[idx], requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)
            bias = torch.zeros(architecture[idx + 1], 1, requires_grad = False).to(self.device)

            # torch.nn.init.kaiming_uniform_(weight)
            # fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(weight)
            # bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            # torch.nn.init.uniform_(bias, -bound, bound)

            Wff.append({'weight': weight, 'bias': bias})
        Wff = np.array(Wff)
        
        # Feedback Synapses Initialization
        Wfb = []
        for idx in range(len(architecture)-1):
            # weight = torch.randn(architecture[idx] + 1, architecture[idx + 1], requires_grad = False).to(self.device)
            weight = torch.eye(architecture[idx] + 1, architecture[idx + 1], requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)
            # torch.nn.init.kaiming_uniform_(weight)
            # bias = torch.zeros(architecture[idx], 1, requires_grad = False).to(self.device)
            Wfb.append({'weight': weight})
        Wfb = np.array(Wfb)
        
        # Lateral Synapses Initialization
        B = []
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx + 1] + 1, architecture[idx + 1] + 1, requires_grad = False).to(self.device)
            torch.nn.init.xavier_uniform_(weight)
            weight = weight @ weight.T
            # weight = 1.0*torch.eye(architecture[idx + 1] + 1, architecture[idx + 1] + 1, requires_grad = False).to(self.device)
            B.append({'weight': weight})
        B = np.array(B)

        # Correlation Matrices (Only for debugging)
        Rfree = []
        for idx in range(len(architecture) - 1):
            weight = 1.0*torch.eye(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(self.device)
            Rfree.append({'weight': weight})

        Rfree = np.array(Rfree)

        # Correlation Matrices (Only for debugging)
        Rnudged = []
        for idx in range(len(architecture) - 1):
            weight = 1.0*torch.eye(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(self.device)
            Rnudged.append({'weight': weight})

        Rnudged = np.array(Rnudged)

        self.Wff = Wff
        self.Wfb = Wfb
        self.B = B
        self.Rfree = Rfree
        self.Rnudged = Rnudged
        
        ############ Some Debugging Logs ##########################
        self.forward_backward_angles = []
        self.layerwise_forward_corinfo_list_free = []
        self.layerwise_backward_corinfo_list_free = []
        self.layerwise_forward_corinfo_list_nudged = []
        self.layerwise_backward_corinfo_list_nudged = []

        self.neural_dynamics_free_forward_info_list = []
        self.neural_dynamics_free_backward_info_list = []
        self.neural_dynamics_nudged_forward_info_list = []
        self.neural_dynamics_nudged_backward_info_list = []

    ###############################################################
    ############### HELPER METHODS ################################
    ###############################################################
    def copy_neurons(self, neurons):
        copy = []
        for n in neurons:
            copy.append(torch.empty_like(n).copy_(n.data))#.requires_grad_())
        return copy
        
    def init_neurons(self, mbs, random_initialize = False, device = 'cuda'):
        # Initializing the neurons
        if random_initialize:
            neurons = []
            append = neurons.append
            for size in self.architecture[1:]:  
                append(torch.randn((mbs, size), requires_grad=False, device=device).T)       
        else:
            neurons = []
            append = neurons.append
            for size in self.architecture[1:]:  
                append(torch.zeros((mbs, size), requires_grad=False, device=device).T)
        return neurons

    def append_ones_row_vector_to_tensor(self, inp_vector):
        return torch.cat((inp_vector, torch.ones(1, inp_vector.shape[1]).to(inp_vector.device)), 0)
    
    ###############################################################
    ############### REQUIRED FUNCTIONS FOR DEBUGGING ##############
    ###############################################################
    def angle_between_two_matrices(self, A, B):
        """Computes the angle between two matrices A and B.

        Args:
            A (torch.Tensor): Pytorch tensor of size m times n
            B (torch.Tensor): Pytorch tensor of size m times n

        Returns:
            angle: angle between the matrices A and B. The formula is given by the following:
                (180/pi) * acos[ Tr(A @ B.T) / sqrt(Tr(A @ A.T) * Tr(B @ B.T))] 
        """

        angle = (180 / torch.pi) * torch.acos(torch.trace(A @ B.T) / torch.sqrt(torch.trace(A @ A.T) * torch.trace(B @ B.T)))
        return angle

    def layerwise_forward_and_backward_correlative_information(self, layers, phase = "free"):
        Wff = self.Wff
        Wfb = self.Wfb
        if phase == "free":
            R = self.Rfree 
        elif phase == "nudged":
            R = self.Rnudged
        epsilon = self.epsilon
        one_over_epsilon = self.one_over_epsilon
        device = self.device
        architecture = self.architecture

        # epsilon_tensor = torch.Tensor([epsilon]).to(device)
        batch_size = layers[0].shape[1]
        batch_size_sqrt_root = np.sqrt(batch_size)
        log_epsilon = np.log(epsilon)

        forward_info_list = []
        backward_info_list = []

        for jj in range(len(architecture) - 2):
            Identity_Matrix = epsilon * torch.eye(*R[jj + 1]['weight'].shape).to(device)
            forward_info_jj= (torch.logdet(R[jj + 1]['weight'] + Identity_Matrix) - (1 / batch_size) * (one_over_epsilon * torch.norm(layers[jj + 2] - Wff[jj + 1]['weight'] @ layers[jj + 1] - Wff[jj + 1]['bias']) ** 2 - layers[jj + 2].shape[0] * log_epsilon)).item()

            forward_info_list.append(forward_info_jj)

        for jj in range(len(architecture) - 2):
            Identity_Matrix = epsilon * torch.eye(*R[jj]['weight'].shape).to(device)
            backward_info_jj = (torch.logdet(R[jj]['weight'] + Identity_Matrix) - (1 / batch_size) * (one_over_epsilon * torch.norm(self.append_ones_row_vector_to_tensor(layers[jj + 1]) - Wfb[jj + 1]['weight'] @ layers[jj + 2]) ** 2 - (layers[jj + 1].shape[0] + 1) * log_epsilon)).item()

            backward_info_list.append(backward_info_jj)

            return forward_info_list, backward_info_list

    ###############################################################
    ############### NEURAL DYNAMICS ALGORITHMS ####################
    ###############################################################
    def run_neural_dynamics_hopfield(self, x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, 
                                     lr_rule = "constant", lr_decay_multiplier = 0.1, 
                                     neural_dynamic_iterations = 10, beta = 1, take_debug_logs = False):

        # if take_debug_logs:
        if beta != 0:
            phase = "free"
        else:
            phase = "nudged"
        forward_info = []
        backward_info = []
            
        Wff = self.Wff
        Wfb = self.Wfb
        B = self.B
        gam_ = self.gam_
        epsilon = self.epsilon
        one_over_epsilon = self.one_over_epsilon

        neurons_intermediate = self.copy_neurons(neurons)
        layers = [x] + neurons  # concatenate the input to other layers
        for iter_count in range(neural_dynamic_iterations):

            if lr_rule == "constant":
                neural_lr = neural_lr_start
            elif lr_rule == "divide_by_loop_index":
                neural_lr = max(neural_lr_start / (iter_count + 1), neural_lr_stop)
            elif lr_rule == "divide_by_slow_loop_index":
                neural_lr = max(neural_lr_start / (iter_count * lr_decay_multiplier + 1), neural_lr_stop)

            with torch.no_grad():       
                for jj in range(len(neurons)):
                    if jj == len(neurons) - 1:
                        # print("here if")
                        basal_voltage = Wff[jj]['weight'] @ layers[jj] + Wff[jj]['bias']
                        apical_voltage = (gam_ * B[jj]['weight'][:-1] @ self.append_ones_row_vector_to_tensor( layers[jj + 1]) + hopfield_g * layers[jj + 1]) - beta * (layers[jj + 1] - y)
                        gradient_neurons = -hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + (apical_voltage - neurons_intermediate[jj]) #+ 2 * beta * (y - layers[jj + 1])
                        neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons
                        neurons[jj] = self.activation(neurons_intermediate[jj])
                        # init_grads[jj] = gam_ * B[jj]['weight'] @ layers[jj + 1] - one_over_epsilon * (layers[jj + 1] - (Wff[jj]['weight'] @ layers[jj] + Wff[jj]['bias'])) + 2 * beta * (y - layers[jj + 1])
                    else:
                        # print("here else")
                        basal_voltage = Wff[jj]['weight'] @ layers[jj] + Wff[jj]['bias']
                        apical_voltage = epsilon * (2 * gam_ * B[jj]['weight'][:-1] @ self.append_ones_row_vector_to_tensor(layers[jj + 1]) + hopfield_g * layers[jj + 1]) + (Wfb[jj + 1]['weight'] @ layers[jj + 2])[:-1] #+ Wfb[jj + 1]['bias']
                        gradient_neurons = - hopfield_g * neurons_intermediate[jj] + one_over_epsilon * (basal_voltage - neurons_intermediate[jj]) + one_over_epsilon * (apical_voltage - neurons_intermediate[jj])
                        neurons_intermediate[jj] = neurons_intermediate[jj] + neural_lr * gradient_neurons
                        neurons[jj] = self.activation(neurons_intermediate[jj])
                    layers = [x] + neurons  # concatenate the input to other layers

            if take_debug_logs:
                info_measures = self.layerwise_forward_and_backward_correlative_information(layers, phase)
                forward_info.append(np.sum(info_measures[0]))
                backward_info.append(np.sum(info_measures[1]))
                    # neurons[neuron_iter] = self.activation(neurons[neuron_iter] + neural_lr * neuron_grads[neuron_iter])
        return neurons, forward_info, backward_info

    ###############################################################
    ############### BATCH STEP ALGORITHMS #########################
    ###############################################################
    def batch_step_hopfield(self, x, y, hopfield_g, lr, neural_lr_start, neural_lr_stop, neural_lr_rule = "constant", 
                            neural_lr_decay_multiplier = 0.1, neural_dynamic_iterations_free = 20, 
                            neural_dynamic_iterations_nudged = 10, beta = 1, use_three_phase = False, 
                            take_debug_logs = False, weight_decay = False):

        Wff, Wfb, B = self.Wff, self.Wfb, self.B
        lambda_ = self.lambda_
        gam_ = self.gam_
        epsilon = self.epsilon

        Rfree = self.Rfree # For debugging to check the correlation matrices vs inverse correlation matrices
        Rnudged = self.Rnudged # For debugging to check the correlation matrices vs inverse correlation matrices

        # neurons = self.init_neurons(x.size(1), device = self.device)
        neurons = self.init_neurons(x.size(1), device = self.device)

        (neurons,
         free_forward_info,
         free_backward_info
        ) = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                             neural_lr_decay_multiplier, neural_dynamic_iterations_free, 0, take_debug_logs)

        
        neurons1 = neurons.copy()
        layers_free_ = [x] + neurons1

        for jj in range(len(B)):
            # z = B[jj]['weight'] @ self.append_ones_row_vector_to_tensor(neurons1[jj])
            # B_update = torch.mean(outer_prod_broadcasting(z.T, z.T), axis = 0)
            # B[jj]['weight'] = (1 / lambda_) * (B[jj]['weight'] - gam_ * B_update)

            Rfree[jj]['weight'] = lambda_ * Rfree[jj]['weight'] + (1 - lambda_) * torch.mean(outer_prod_broadcasting(neurons1[jj].T, neurons1[jj].T), axis = 0)

        (neurons,
         nudged_forward_info,
         nudged_backward_info 
        ) = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                              neural_lr_decay_multiplier, neural_dynamic_iterations_nudged, beta, take_debug_logs)


        neurons2 = neurons.copy()

        if use_three_phase:
            neurons, _, _ = self.run_neural_dynamics_hopfield(x, y, neurons, hopfield_g, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                                              neural_lr_decay_multiplier, neural_dynamic_iterations_nudged, -beta, take_debug_logs)

            neurons3 = neurons.copy()

            layers_free = [x] + neurons3
        else:
            layers_free = [x] + neurons1

        layers_nudged = [x] + neurons2

        ## Compute forward errors
        forward_errors_free = [layers_free[jj + 1] - (Wff[jj]['weight'] @ layers_free[jj] + Wff[jj]['bias']) for jj in range(len(Wff))]
        forward_errors_nudged = [layers_nudged[jj + 1] - (Wff[jj]['weight'] @ layers_nudged[jj] + Wff[jj]['bias']) for jj in range(len(Wff))]
        ## Compute backward errors
        backward_errors_free = [self.append_ones_row_vector_to_tensor(layers_free[jj]) - (Wfb[jj]['weight'] @ layers_free[jj + 1]) for jj in range(1, len(Wfb))]
        backward_errors_nudged = [self.append_ones_row_vector_to_tensor(layers_nudged[jj]) - (Wfb[jj]['weight'] @ layers_nudged[jj + 1]) for jj in range(1, len(Wfb))]

        ### Learning updates for feed-forward and backward weights
        for jj in range(len(Wff)):
            Wff[jj]['weight'] += -(1/(beta * (int(use_three_phase) + 1))) * lr['ff'][jj] * torch.mean(outer_prod_broadcasting(forward_errors_free[jj].T, layers_free[jj].T) - outer_prod_broadcasting(forward_errors_nudged[jj].T, layers_nudged[jj].T), axis = 0)
            Wff[jj]['bias'] += -(1/(beta * (int(use_three_phase) + 1))) * lr['ff'][jj] * torch.mean(forward_errors_free[jj] - forward_errors_nudged[jj], axis = 1, keepdims = True) 
            if weight_decay:
                Wff[jj]['weight'] -= lr['ff'][jj] * epsilon * Wff[jj]['weight']
                Wff[jj]['bias'] -= lr['ff'][jj] * epsilon * Wff[jj]['bias']

        for jj in range(1, len(Wfb)):
            Wfb[jj]['weight'] += -(1/(beta * (int(use_three_phase) + 1))) * lr['fb'][jj] * torch.mean(outer_prod_broadcasting(backward_errors_free[jj - 1].T, layers_free[jj + 1].T) - outer_prod_broadcasting(backward_errors_nudged[jj - 1].T, layers_nudged[jj + 1].T), axis = 0)
            # Wfb[jj]['bias'] -= (1/beta) * lr['fb'][jj] * torch.mean(backward_errors_free[jj - 1] - backward_errors_nudged[jj - 1], axis = 1, keepdims = True)
            if weight_decay:
                Wfb[jj]['weight'] -= lr['fb'][jj] * epsilon * Wfb[jj]['weight']
        ### Lateral Weight Updates
        for jj in range(len(B)):
            z = B[jj]['weight'] @ self.append_ones_row_vector_to_tensor(neurons2[jj])
            B_update = torch.mean(outer_prod_broadcasting(z.T, z.T), axis = 0)
            B[jj]['weight'] = (1 / lambda_) * (B[jj]['weight'] - gam_ * B_update)

            Rnudged[jj]['weight'] = lambda_ * Rnudged[jj]['weight'] + (1 - lambda_) * torch.mean(outer_prod_broadcasting(neurons2[jj].T, neurons2[jj].T), axis = 0)
                 
        self.B = B
        self.Wff = Wff
        self.Wfb = Wfb
        self.Rfree = Rfree
        self.Rnudged = Rnudged

        if take_debug_logs:
            instant_forward_backward_angles = []
            for jj in range(1, len(Wff)):
                instant_forward_backward_angles.append(self.angle_between_two_matrices(torch.cat((self.Wff[jj]['weight'], self.Wff[jj]['bias']), 1), self.Wfb[jj]['weight'].T).item())
            
            self.forward_backward_angles.append(instant_forward_backward_angles)

            # (forward_info_list_free, 
            #  backward_info_list_free, 
            #  forward_info_list_nudged, 
            #  backward_info_list_nudged
            # ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, layers_nudged)

            (forward_info_list_free, 
             backward_info_list_free, 
            ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, "free")

            (forward_info_list_nudged, 
             backward_info_list_nudged, 
            ) = self.layerwise_forward_and_backward_correlative_information(layers_free_, "nudged")

            self.layerwise_forward_corinfo_list_free.append(forward_info_list_free)
            self.layerwise_backward_corinfo_list_free.append(backward_info_list_free)
            self.layerwise_forward_corinfo_list_nudged.append(forward_info_list_nudged)
            self.layerwise_backward_corinfo_list_nudged.append(backward_info_list_nudged)

            self.neural_dynamics_free_forward_info_list.append(free_forward_info)
            self.neural_dynamics_free_backward_info_list.append(free_backward_info)
            self.neural_dynamics_nudged_forward_info_list.append(nudged_forward_info)
            self.neural_dynamics_nudged_backward_info_list.append(nudged_backward_info)
        return neurons

    def save_model_weights(self, pickle_name = "CorInfoWeights"):
        Wff_save = []
        for idx in range(len(self.Wff)):
            weight, bias = torch2numpy(self.Wff[idx]['weight']), torch2numpy(self.Wff[idx]['bias'])
            Wff_save.append({'weight': weight, 'bias': bias})
            
        Wfb_save = []
        for idx in range(len(self.Wfb)):
            weight = torch2numpy(self.Wfb[idx]['weight'])
            Wfb_save.append({'weight': weight})
            
        B_save = []
        for idx in range(len(self.B)):
            weight = torch2numpy(self.B[idx]['weight'])
            B_save.append({'weight': weight})
            
        model_params = pd.DataFrame(columns = ['Wff', 'Wfb', 'B'])

        model_params['Wff'] = Wff_save
        model_params['Wfb'] = Wfb_save
        model_params['B'] = B_save

        model_params.to_pickle(pickle_name + ".pkl")

    def load_model_weights(self, pickle_name):
        model_params_load = pd.read_pickle(pickle_name + ".pkl")
        for idx in range(len(self.Wff)):
            self.Wff[idx]['weight'] = torch.tensor(model_params_load['Wff'].iloc[idx]['weight'], requires_grad = False).to(self.device)
            self.Wff[idx]['bias'] = torch.tensor(model_params_load['Wff'].iloc[idx]['bias'], requires_grad = False).to(self.device)
            
        for idx in range(len(self.Wfb)):
            self.Wfb[idx]['weight'] = torch.tensor(model_params_load['Wfb'].iloc[idx]['weight'], requires_grad = False).to(self.device)
            
        for idx in range(len(self.B)):
            self.B[idx]['weight'] = torch.tensor(model_params_load['B'].iloc[idx]['weight'], requires_grad = False).to(self.device)
        
