#############################################
#                                           #
# Error Broadcast & Decorrelate Algorithm   #
#  MLP   & CIFAR 10                         #
# architecture = 32*32*3, 1024,1024,1024 10 #
#                                           #
#############################################


####################################
# IMPORT RELEVANT PYTHON LIBRARIES #
####################################
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
import argparse
import matplotlib
import pandas as pd
from tqdm import tqdm
import glob
import os
from datetime import datetime
from itertools import product
import time
import math
import sys
sys.path.append("./src")
import warnings
warnings.filterwarnings("ignore")
from torch.utils.data import Dataset, DataLoader
import torch.nn.utils as utils
from IPython.display import clear_output
import matplotlib.pyplot as plt
import torch.nn as nn
import math
import argparse
import pickle
from torch.autograd import Variable
from utils import *

def get_args():
    parser = argparse.ArgumentParser(description='Modify default values of the script.')
    parser.add_argument('--Reh_initialization_gain', type=float, default=1.0, help='Reh initialization gain')
    parser.add_argument('--Reh_initialization_gain2', type=float, default=1.0, help='Reh initialization gain')
    parser.add_argument('--Reh_lambda', type=float, default=0.99999, help='Reh lambda')
    parser.add_argument('--COV_SCALEV', type=float, default=0.1, help='COV SCALE ')
    parser.add_argument('--error_gain', nargs='+', type=float, default=[1.0,1.0,1.0,1.0], help='error gain')
    parser.add_argument('--CMSE_HIDDENV', nargs='+', type=float, default=[40.0,15.0,15.0], help='Decor. Loss gain for hidden layers')
    parser.add_argument('--CL1_HIDDENV', nargs='+', type=float, default=[8e-1,3e-1,3e-1], help='L1 Loss gain for hidden layers')
    parser.add_argument('--CMSE_OUTV', type=float, default=3500, help='MSE Loss Gain for Output ')
    #parser.add_argument('--error_gain', type=float, default=1.0, help='error gain')
    parser.add_argument('--pickle_name', type=str, default="NewANNMLPCIFAR10Angleseed72.pkl", help='Pickle file name for results')
    parser.add_argument('--data_directory', type=str, default="data/", help='Pickle file name for results')
    parser.add_argument('--test_file_name', type=str, default="NewANNTest.pkl", help='Pickle file name for results')
    parser.add_argument('--INPUT_SCALE', type=float, default=1.0, help='Scale Input ')
    parser.add_argument('--OUTPUT_SCALE', type=float, default=1.0, help='Scale Output ')
    parser.add_argument('--WEIGHT_SCALE', nargs='+', type=float, default=[1.0,1.0,1.0,1.0], help='Scale Weights ')
    parser.add_argument('--TARGET_POW_HIDDEN_LRV', nargs='+', type=float, default=[4e-3,6e-3,6e-3], help='Scale Weights ')
    parser.add_argument('--R_eps_weight_hiddenv', type=float, default=1e-6, help='Eps value for Entropy ')
    parser.add_argument('--bias_lr_scale', type=float, default=1.0, help='bias learning rate scale')
    parser.add_argument('--TARGET_POW_HIDDENV', nargs='+', type=float, default=[2.5e-1,2.5e-1,2.5e-1], help='Scale Weights ')
    parser.add_argument('--g_exp', type=float, default=1.0, help='g function s exponent ')
    parser.add_argument('--g_const', type=float, default=0.0, help='g functions constant')
    #TARGET_POW_HIDDEN=[2.5e-1,2.5e-1]
    parser.add_argument('--coeff_Wl2_reg0v', type=float, default=1.6e-4, help='weight decay parameter ')
    parser.add_argument('--proj_update_period', type=int, default=20, help='update period of projection matrix')
    parser.add_argument('--normalized_Rehk', type=int, default=0, help='normalize Rehks columns')
    parser.add_argument('--CMSE_FORWARD', type=float, default=1000, help='Forward Projection based scaling ')
    parser.add_argument('--momentumvv', type=float, default=0.9999, help='initial momentum for gradient descent ')
    parser.add_argument('--momentumvv2', type=float, default=-1.0, help='final momentum for gradient descent ')
    parser.add_argument('--momentumscale', type=float, default=5000, help='momentum adjust ')
    parser.add_argument('--LR_SCALE', type=float, default=1/1.5, help='learningrate time scale ')
    parser.add_argument('--DLR_SCALE', type=float, default=30000.0, help='learningrate time scale ')
    parser.add_argument('--L2LR_SCALE', type=float, default=300.0, help='learningrate time scale ')
    parser.add_argument('--WEIGHT_SPARSITYV', type=float, default=1.0, help='percentage of zero weights ')
    parser.add_argument('--out_activation', type=bool, default=False, help='include (relu) output activation ')
    parser.add_argument('--LRV_SCALE', type=float, default=1.0, help='learningrate time scale ')
    parser.add_argument('--INPUT_OUTPUT_SCALE', type=float, default=1.0, help='learningrate time scale ')
    parser.add_argument('--DBP_ON', type=int, default=0, help='decorrelation back prop included ')
    parser.add_argument('--GRAD_NORM', type=int, default=0, help='Turn on gradient normalization ')
    parser.add_argument('--GRAD_SCALE', type=float, default=1.0, help='Scaling for all gradients ')
    parser.add_argument('--SEED_V', type=int, default=7, help='seed value ')
    args = parser.parse_args()
    return args


# Define class for the Error Broadcast & Decorrelate Algorithm 
class AllLossNew(nn.Module):
    def __init__(self, args):
        super(AllLossNew, self).__init__()    
        # Error-Hidden layer activation cross-correlation matrix
        self.Rehk = torch.zeros(( args.out_dim,args.layer_dim ), dtype=torch.float64, device='cuda', requires_grad=False)
        # Error-mean 
        self.mue = torch.zeros(args.out_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        # Hidden layer activation mean
        self.muhk = torch.zeros(args.layer_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        # Mean of nonlinear function of hidden layer activations
        self.mughk = torch.zeros(args.layer_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        # Mean of previous layer activations
        self.muhkm1 = torch.zeros(args.layerinp_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        # Autocorrelation matrix for hidden layer activations
        self.Rhk = args.R_ini*torch.eye(args.layer_dim , dtype=torch.float64, device='cuda', requires_grad=False)
        # Identity matrix with dimensions same as Rhk
        self.Ihk = torch.eye(args.layer_dim , dtype=torch.float64, device='cuda', requires_grad=False)
        # Updated version of the layer autocorrelation matrix
        self.new_Rhk = torch.zeros((args.layer_dim, args.layer_dim), dtype=torch.float64, device='cuda', requires_grad=False) # Changed to requires_grad=False
        # Updater version of the error-hidden layer cross-correlation matrix
        self.new_Rehk = torch.zeros(( args.out_dim,args.layer_dim ), dtype=torch.float64, device='cuda', requires_grad=False) # Changed to requires_grad=False
        # Apply random xavier initialization to error-hidden layer cross-correlation matrix
        torch.nn.init.xavier_uniform(self.Rehk,gain =args.Reh_initialization_gain)
        # Updated version of the error mean
        self.new_mue = torch.zeros(args.out_dim, dtype=torch.float64, device='cuda', requires_grad=False) # Changed to requires_grad=False
        # updated version of the current layer activations
        self.new_muhk = torch.zeros(args.layer_dim, dtype=torch.float64, device='cuda', requires_grad=False) # Changed to requires_grad=False
        # updated version of the nonlinear-function of the current layer activations
        self.new_mughk = torch.zeros(args.layer_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        # updated version of the previous layer mean
        self.new_muhkm1 = torch.zeros(args.layerinp_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        
        # Arguments passed at the initialization stored in object memory:
        # auto-regressive forgetting factor parameter for autocorrelation matrix update
        self.la_R = args.la_R
        # autor-regressive forgetting factor parameter for mean update
        self.la_mu = args.la_mu
        # auto-regressive forgetting factor parameter  for cross-correlation matrix update
        self.la_R2 = args.la_R2
         # alternative  autor-regressive forgetting factor parameter for mean update (not used)
        self.la_mu2 = args.la_mu2
        # Diagonal perturbation parameter for the autocorrelation matrix
        self.R_eps_weight = args.R_eps_weight
        # Diagonal matrix based on the epsilon perturbation parameter
        self.R_eps = self.R_eps_weight*torch.eye(args.layer_dim, dtype=torch.float64, device='cuda', requires_grad=False)
        # Target variance level for layer power normalization 
        self.layer_power_target=args.layer_power_target
        # Logical parameter determining mean subtraction or not (correlation vs. covariance)
        self.include_mean=args.include_mean

        # NOT USED alternative correlation based objective variance  parameters
        self.sigmaT=args.sigmaT
        # Weight of th power normalization loss
        self.alph=args.alph
        # error gain
        self.error_gain=args.error_gain
        self.input_scale=args.input_scale
        self.output_scale=args.output_scale
        self.g_exp=args.g_exp*torch.ones((1,1), dtype=torch.float64, device='cuda', requires_grad=False)
        self.g_const=args.g_const*torch.ones((1,1), dtype=torch.float64, device='cuda', requires_grad=False)
        self.normalized_Rehk=args.normalized_Rehk
        self.out_activation=args.out_activation
        
    # Gradient calculations for the current layer 
    def forward(self, hk: torch.Tensor, uk:torch.Tensor, er:torch.Tensor,hkm1:torch.Tensor,layer:torch.Tensor,Phik:torch.Tensor, Wk:torch.Tensor) -> torch.Tensor:

        # Covariance/Mean Update Convex Combination parameters
        la_R = self.la_R
        la_mu = self.la_mu
        la_R2 = self.la_R2
        la_mu2 = self.la_mu2

        # Transpose of hk: contains hk vectors in its columns
        hkT=hk.T

        # Nonlinearly transformed layer activations
        # Currently no-nonlinearity
        ghk=torch.pow(hk+self.g_const,self.g_exp)
        
        # Derivative of transformed layer activations
        # currently nonlinear mapping is identity so the derivatives are all ones
        if (layer>2):
            ghkd=torch.ones(uk.shape,device='cuda')
        else:
            #ghkd=self.g_exp*torch.pow(powinp,g_exp2) #torch.ones(uk.shape,device='cuda')
            ghkd=self.g_exp*(hk+self.g_const)**(self.g_exp-1.0)
        # Current Layer Output Dimensions: B is batch size,  D is the dimension of the current layer
        B, D = hk.size()
        # Overall output (error) Dimensions
        Be,D2= er.size()
        # Previous Layer's Output Dimensions
        Bhkm1,Dhkm1= hkm1.size()

        #  mean update term for the current layer
        mu_updatehk = torch.mean(hk, 0)
        # mean update for the function of layer activations
        mu_updateghk = torch.mean(ghk, 0)
        #  mean update term for the error 
        mu_updatee = torch.mean(er, 0)
        # mean update term for the previous layer
        mu_updatehkm1 = torch.mean(hkm1, 0)

        # Calculate the mean for the current layer
        self.new_muhk = la_mu*(self.muhk) + (1-la_mu)*(mu_updatehk)

        self.new_gmuhk = la_mu*(self.mughk) + (1-la_mu)*(mu_updateghk)
        # Calculate the mean for the error layer
        self.new_mue = la_mu*(self.mue) + (1-la_mu)*(mu_updatee)
        # calculate the mean for the previous layer
        self.new_muhkm1 = la_mu*(self.muhkm1) + (1-la_mu)*(mu_updatehkm1)

        # Centralize the current layer samples
        hk_hat =  hk #- self.include_mean*self.new_muhk
        ghk_hat = ghk #- self.include_mean*self.new_mughk
        # Centralize the error samples
        e_hat =  self.error_gain*er# - self.include_mean*self.new_mue)
        # Centralize the previous layer outputs
        hkm1_hat =  hkm1 #- self.include_mean*self.new_muhkm1

        # all ones vectors with previous layer dimensions (used in gradient calculations)
        oneshkm1=torch.ones(Bhkm1,1, dtype=torch.float64, device='cuda', requires_grad=False)
        # Current layer covariance update term
        #Rhk_update = (hk_hat.T @ hk_hat) / B
        # error vs current layer cross covariance update term
        #Rehk_update = (e_hat.T @ ghk_hat) / B

        # Current layer covariance update
        #self.new_Rhk = la_R*(self.Rhk) + (1-la_R)*(Rhk_update)
        # Error vs current layer  cross covariance update
        #self.new_Rehk = la_R2*(self.Rehk) + (1-la_R2)*(Rehk_update)

        # Transformed-centralized error
        if (layer>1.0):
            Q=e_hat*3e-2 # for the output layer  (for a 3 layer network: layer=0,1 are hidden layers)
        else:
            # Error transformation by from ek dimensions to hk dimensions by Rehk
            col_norms = torch.norm(self.Rehk, p=2, dim=0, keepdim=True)
            if (self.normalized_Rehk==1):
                Q=e_hat@(self.Rehk/(col_norms+1e-12))*0.002  # for the hidden layers (layer=0,1)
            else:
                Q=e_hat@self.Rehk

        # SPARSITY LOSS
        # subgradient of l1_norm of current layer
        #subghk=(torch.sign(hk))
        
        # (sug)gradient of activation  function for the current layer
        Fd=(uk>0)/self.input_scale*self.output_scale # assuming bounded relu
        if (layer>2.0):
            # For the output layer there is no nonlinearity
            if (self.out_activation==False):
                Fd=torch.ones(Fd.shape,dtype=torch.float64, device='cuda', requires_grad=False)/self.input_scale*self.output_scale
            else:
                Fd=(uk>0)/self.input_scale*self.output_scale
        # Transformed error scaled by the preactivation derivative and nonlinearity derivative
        Z=Fd*Q*ghkd
        #hkRhkinvFd=Fd*torch.linalg.solve(self.new_Rhk+self.R_eps, hk_hat,left=False)
        # preactivation scaled
        #Fdsubghk=Fd*subghk
        
        
        ## Gradient calculations
        
        # Gradients for the layer entropy
        gradWcov=0#-2*hkRhkinvFd.T@hkm1_hat/B/D
        gradbcov=0#-2*hkRhkinvFd.T@oneshkm1/B/D

        # Gradient for the broadcast-error decorrelation loss
        gradWmse=Z.T.type(torch.float64)@hkm1.type(torch.float64)/B
        gradbmse=Z.T@oneshkm1/B
        
        # Gradient Power Target
        #pow_err=(hk_hat*hk_hat-self.layer_power_target)*hk_hat
        gradWpow=0# pow_err.T@hkm1.double()/B
        gradbpow=0# pow_err.T@oneshkm1.double()/B

        # Subgradient of layer sparsifying l_1-loss
        #Fdsubghkd2=Fdsubghk.double()
        gradWl1out=0#Fdsubghkd2.T@hkm1.double()/B/D
        
        gradbl1out=0#Fdsubghkd2.T@oneshkm1/B/D

        # Gradient calculation for the propagation component of the error decorrelation loss 
        # (to check the angle between true gradient and the approximation not used in learning)
        Phik=torch.matmul(Phik, Wk)
        Fde = Fd.unsqueeze(1)
        Phikn = Phik * Fde
        #Gtilde=ghk_hat@self.new_Rehk.T
        #G_expanded = Gtilde.unsqueeze(-1)
        #Psik = torch.sum(G_expanded * Phikn, dim=1)
        gradWmse2=0#self.error_gain*Psik.T@hkm1_hat/B
        gradbmse2=0#self.error_gain*Psik.T@oneshkm1/B

        #calculate losses
        NMSEloss=torch.norm(self.Rehk, p='fro')
        Cov_loss = - (torch.logdet(self.new_Rhk + self.R_eps) ) / D
        Outl1_loss= torch.sum(torch.abs(hk))/B/D
     
        return NMSEloss,Cov_loss, Outl1_loss,  self.Rehk,self.Rhk,gradWmse,gradbmse, gradWcov,gradbcov,gradWl1out,gradbl1out,gradWmse2,gradbmse2,Phikn,gradWpow,gradbpow


# MLP Class
class MLP(torch.nn.Module):
    """
    Multi Layer Perceptron
    """
    def __init__(self, architecture, activation = F.relu, final_layer_activation = False, weight_std=[1.0,1.0,1.0], input_scale=1.0, output_scale=1.0):
        super(MLP, self).__init__()

        self.activation = activation
        self.architecture = architecture
        self.nc = self.architecture[-1]
        self.final_layer_activation = final_layer_activation
        self.input_scale=input_scale
        self.output_scale=output_scale

        self.linear_layers = torch.nn.ModuleList()
        for idx in range(len(architecture)-2):
            m = torch.nn.Linear(architecture[idx], architecture[idx+1], bias=True)
            #torch.nn.init.normal_(m.weight, mean=0.0, std=weight_std) #NEWPART
            torch.nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
            m.weight.data*=weight_std[idx]
            #torch.nn.init.xavier_uniform_(m.weight, gain=weight_std[idx])#/self.input_scale/self.output_scale*weight_std[idx])
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
            self.linear_layers.append(m)
        m = torch.nn.Linear(architecture[idx+1], architecture[idx+2], bias=True)
            #torch.nn.init.normal_(m.weight, mean=0.0, std=weight_std) #NEWPART
        #torch.nn.init.xavier_uniform_(m.weight, gain=weight_std[idx])#/self.input_scale*weight_std[len(architecture)-2])
        torch.nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
        m.weight.data*=weight_std[idx]
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
        self.linear_layers.append(m)

    def forward(self, x):
        x = x.view(x.size(0),-1) # flattening the input
        preh=x
        # List that stores preactivations for all layers
        preh_list=[]
        # List that contains activations of hidden layers
        hidden_list=[]
        # List that contains activations of the output layer
        y_list=[]
        
        for idx in range(len(self.architecture)-2):
            # Calculate preacitvation
            preh=self.linear_layers[idx](x)
            # store in the preactivation list
            preh_list.append(preh)
            # calculate the hidden layer output: the nonlinearity is the bounded relu (0 for x<0, x for x\in(0,50), and 50 for x>50)
            #z = F.hardtanh(F.relu(self.linear_layers[idx](x))/5e9)*5e9/self.input_scale*self.output_scale
            z = (F.relu(preh))/self.input_scale*self.output_scale
            # it will be the input of the next layer
            x=z
            # store the hidden layer activation
            hidden_list.append(x)
            
        # if the final layer has activation function
        if self.final_layer_activation:
            prey=self.linear_layers[-1](x)
            z = (F.relu(self.linear_layers[-1](x)))/self.input_scale*self.output_scale
            y=z
        # if the final layer has no activation    
        else:
            z = self.linear_layers[-1](x)/self.input_scale*self.output_scale
            prey=z
            y=z
        return y,hidden_list,prey,preh_list


# Class definition for nonlinear-mse layer-loss hyperparameter
class nmseargsstruct:
    def __init__(self, field1, field2, field3,field4,field5,field6,field7,field8,field9,field10,field11,field12,field13,field14,field15,field16,field17,field18,field19,field20,field21):
        # layer dimension
        self.layer_dim=field1
        # output dimension
        self.out_dim=field2
        # previous layer dimension
        self.layerinp_dim=field3
        # autocorrelation update (convex combination) parameter
        self.la_R=field4
        # mean update (convex combination) parameter
        self.la_mu=field5
        # diagonal initialization parameter for the layer autocorrelation matrix
        self.R_ini=field6
        # epsilon I perturbation parameter for layer autocorrelation matrix
        self.R_eps_weight=field7
        # parameters for the alternative modified entropy with variance regulariation (not used)
        self.alph=field8
        self.sigmaT=field9
        # cross-correlation update (convex combination) parameter
        self.la_R2=field10
        # alternative mean update parameter
        self.la_mu2=field11
        # logical variable to determine mean subtraction 
        self.include_mean=field12
        # target power level for layer power normalization
        self.layer_power_target=field13
        # Reh initialization gain
        self.Reh_initialization_gain=field14
        # error gain
        self.error_gain=field15
        # input scale
        self.input_scale=field16
        # output scale
        self.output_scale=field17
        self.g_exp=field18
        self.normalized_Rehk=field19
        self.g_const=field20
        self.out_activation=field21


###############################################################################################
#                                                                                             #
#                 SCRIPT STARTS HERE                                                          #
#                                                                                             #
###############################################################################################

# Arguments
args = get_args()
Reh_initialization_gain=args.Reh_initialization_gain
Reh_initialization_gain2=args.Reh_initialization_gain2
Reh_lambda=args.Reh_lambda
pickle_name=args.pickle_name
error_gain=args.error_gain
COV_SCALEV=args.COV_SCALEV
CMSE_HIDDENV=args.CMSE_HIDDENV
CMSE_OUTV=args.CMSE_OUTV
INPUT_SCALE=args.INPUT_SCALE
OUTPUT_SCALE=args.OUTPUT_SCALE
WEIGHT_SCALE=args.WEIGHT_SCALE
test_file_name=args.test_file_name
CL1_HIDDENV=args.CL1_HIDDENV
TARGET_POW_HIDDEN_LRV=args.TARGET_POW_HIDDEN_LRV
R_eps_weight_hiddenv=args.R_eps_weight_hiddenv
TARGET_POW_HIDDENV=args.TARGET_POW_HIDDENV
bias_lr_scale=args.bias_lr_scale
g_exp=args.g_exp
coeff_Wl2_reg0v=args.coeff_Wl2_reg0v
proj_update_period=args.proj_update_period
normalized_Rehk=args.normalized_Rehk
CMSE_FORWARD=args.CMSE_FORWARD
momentumvv=args.momentumvv
momentumvv2=args.momentumvv2
momentumscale=args.momentumscale
LR_SCALE=args.LR_SCALE
DLR_SCALE=args.DLR_SCALE
L2LR_SCALE=args.L2LR_SCALE
WEIGHT_SPARSITYV=args.WEIGHT_SPARSITYV
g_const=args.g_const
data_directory=args.data_directory
out_activation=args.out_activation
LRV_SCALE=args.LRV_SCALE
INPUT_OUTPUT_SCALE=args.INPUT_OUTPUT_SCALE
DBP_ON=args.DBP_ON
GRAD_NORM=args.GRAD_NORM
GRAD_SCALE=args.GRAD_SCALE
SEED_V=args.SEED_V

test_file_name=data_directory+test_file_name
#pickle_name=data_directory+pickle_name


if (momentumvv2<0):
    momentumvv2=momentumvv

# Scaling Powers
bet=INPUT_SCALE
bet2=bet*bet
bet3=bet2*bet
bet4=bet3*bet
bet6=bet3*bet3

alph=OUTPUT_SCALE
alph2=alph*alph
alph3=alph2*alph
alph4=alph2*alph2
alph5=alph3*alph2
alph6=alph3*alph3

error_gain[0]=error_gain[0]/alph4
error_gain[1]=error_gain[1]/alph5

if (Reh_initialization_gain2==1.0):
    Reh_initialization_gain2=Reh_initialization_gain


Reh_initialization_gain_list=[Reh_initialization_gain*bet2, Reh_initialization_gain2*bet*alph,Reh_initialization_gain2*bet*alph]
#Reh_initialization_gain_list=[Reh_initialization_gain, Reh_initialization_gain]

CMSE_OUTV=CMSE_OUTV/alph6

# Check GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Adjust file/directory organization
current_directory = os.getcwd()
working_path = current_directory
os.chdir(working_path)
if not os.path.exists("../Results"):
    os.mkdir("../Results")
# Filename to save simulation variables
pickle_name_for_results = pickle_name

# Data structure to store simulation variables
RESULTS_DF = pd.DataFrame( columns = ['setting_number', 'seed', 'Model', 'Hyperparams', 'Trn_ACC_list', 'Tst_ACC_list'])


# MLP Architecture
#architecture = [int(32*32*3), 10024, 1024, 1024, 10]

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=20, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)



############# HYPERPARAMS  #########################
# Fixed Parameters

# number of output classes
num_classes=10
# input dimension 
inp_dim=3072 #32*32*3
# number of training epochs
n_epochs = 121 

# The script actually runs for the first seed only (we run script for several times by modifying j+? part ?=0...9 
#seed_list = [10*(j+7) for j in range(10)]
seed_list=[10*SEED_V]

# setting 
setting_number = 0

# MLP layer sizes
#architecture=[3072, 1024,512,512, 10]
architecture=[784, 1024,512, 10]

# Number of layers (discluding input)
NL=len(architecture)-1

# Flag to include propagated component of the decorrelation loss
INCLUDE_PROP_GRAD=0.0 #1.0

#############################################
# Layer-loss initialization hyperparameters #
#############################################

# Initial value for layer variance
# for the output layer
argsR_ini_out=0.3*alph6
# for the hidden layers
argsR_ini_hidden=0.03

# Lambda parameter for covariance autoregressive update
argsla_R=0.99999
argsla_R2=Reh_lambda#0.99999
# Lambda parameter for mean autoregressive update
argsla_mu=0.99999
argsla_mu2=Reh_lambda #0.99999

# Correlation(Covariance) regularizer terms
# for the output
argsR_eps_weight_out=1e-6*alph6
# for the hidden layers
argsR_eps_weight_hidden=R_eps_weight_hiddenv#1e-6


################################
#       Loss Parameters        #
################################
# Multiplier for l1-norm loss for layer weights
coeff_Wl1_reg=0
# Multiplier for l2-norm loss for layer weights
coeff_Wl2_reg0=coeff_Wl2_reg0v# 8e-4/5

SYNAPTIC_PRUNING_THRESHOLD=0.0001
WEIGHT_NON_SPARSITY=1.0-(WEIGHT_SPARSITYV)/100.0 #0.99 #0.8 #0.9

# FINAL (OUTPUT) LAYER LOSS WEIGHTS
# MSE Loss
CMSE_OUT=200*2*10
CMSE_OUT=CMSE_OUTV#3500*10/4*4/10
# Layer Entropy Objective
CCOV_OUT=5e-4/4.0*10
# Layer Activation l_1 Loss
CL1_OUT=0
#Learning rate scaling for the output layer
LR_OUT_SCALE=1.0

# Target power for the output layer
TARGET_POW_OUT=0.1*alph6
# Target power loss learning parameter for the output layer
TARGET_POW_OUT_LR=1e-10


# SECOND HIDDEN LAYER LOSS WEIGHTS
# Decorrelation loss weight for hidden layers
CMSE_HIDDEN=CMSE_HIDDENV#[40,15]
CMSE_HIDDEN=[CMSE_HIDDEN[0]/bet4, CMSE_HIDDEN[1]/bet2,CMSE_HIDDEN[2]/bet2]

# Layer Entropy Objective weight adjustment parameters (decreased as a function of iterations)
COV_SCALE=COV_SCALEV
CCOV_HIDDEN_MAX=[2.5e-3*COV_SCALE,1.5e-2*COV_SCALE,1.5e-2*COV_SCALE]
CCOV_HIDDEN_MIN=[2.5e-3*COV_SCALE,7e-3*COV_SCALE,7e-3*COV_SCALE]
CCOV_LAMBD_STEP=10
# Layer Activation l_1 Loss
CL1_HIDDEN=CL1_HIDDENV#[8e-1,3e-1]
CL1_HIDDEN=[CL1_HIDDEN[0]/bet2/alph, CL1_HIDDEN[1]/bet/alph2,CL1_HIDDEN[2]/bet/alph2]

# Learning rate scaling for the hidden layers
LR_HIDDEN_SCALE=1.0

# Target power for hidden layers
TARGET_POW_HIDDEN=TARGET_POW_HIDDENV#[2.5e-1,2.5e-1]

TARGET_POW_HIDDEN=[TARGET_POW_HIDDEN[0]*bet4*alph2, TARGET_POW_HIDDEN[1]*bet2*alph4,TARGET_POW_HIDDEN[2]*bet4*alph6]
# Target power loss learning rate
TARGET_POW_HIDDEN_LR=TARGET_POW_HIDDEN_LRV#[4e-3,6e-3]
TARGET_POW_HIDDEN_LR=[TARGET_POW_HIDDEN_LR[0]/bet6/alph3,TARGET_POW_HIDDEN_LR[1]/bet3/alph6,TARGET_POW_HIDDEN_LR[2]/bet3/alph6]
# Batch size
batchsize=20

argsR_ini_hidden_list=[argsR_ini_hidden*bet4*alph2, argsR_ini_hidden*bet2*alph4,argsR_ini_hidden*bet2*alph4]
argsR_eps_weight_hidden_list=[argsR_eps_weight_hidden*alph2*bet4, argsR_eps_weight_hidden*alph4*bet2,argsR_eps_weight_hidden*alph4*bet2]

####################################
# Learning Rate Related Parameters #
####################################

lr_list = [1.0]
lr_decay_gamma_list = [0.9]
lr_decay_scheduler_step_list = [200000]


INCLUDE_MEAN=0

momentumv=momentumvv# 0.9999


# note that  lr_decay, lr_decay_step is not relevant as we are not using  scheduler (kept in this code for  legacy)
for lr, lr_decay, lr_decay_step in product(lr_list, lr_decay_gamma_list, lr_decay_scheduler_step_list):
    setting_number += 1
    # dictionary containing hyperparameters
    hyperparams_dict = {"lr" : lr, "lr_decay" : lr_decay,
                        "lr_decay_step" : lr_decay_step, "batchsize": batchsize,
                        "argsla_R" : argsla_R, 
                        "coeff_Wl1_reg" : coeff_Wl1_reg, "coeff_Wl2_reg0" : coeff_Wl2_reg0, 
                         "WEIGHT_SPARSITYV" : WEIGHT_SPARSITYV,"CMSE_OUT": CMSE_OUT,
                        "architecture" : architecture,
                        "CCOV_OUT" : CCOV_OUT,
                         "CL1_OUT": CL1_OUT,
                         "LR_OUT_SCALE": LR_OUT_SCALE,
                        "TARGET_POW_OUT": TARGET_POW_OUT,
                        "TARGET_POW_OUT_LR": TARGET_POW_OUT_LR,
                        "CMSE_HIDDEN": CMSE_HIDDEN,
                       "COV_SCALE": COV_SCALE,
                       "CCOV_HIDDEN_MAX": CCOV_HIDDEN_MAX,
                       "CCOV_HIDDEN_MIN": CCOV_HIDDEN_MIN,
                       "CCOV_LAMBD_STEP": CCOV_LAMBD_STEP,
                       "CL1_HIDDEN": CL1_HIDDEN,
                       "TARGET_POW_HIDDEN": TARGET_POW_HIDDEN,
                       "LR_HIDDEN_SCALE": LR_HIDDEN_SCALE,
                       "TARGET_POW_HIDDEN_LR": TARGET_POW_HIDDEN_LR
                       }

    # Loop for different seeds
    for seed_ in seed_list:
        # initialize random generator with the current seed
        np.random.seed(seed_)
        torch.manual_seed(seed_)
        PM_list=[]
        #PMX=torch.randn((architecture[-1],architecture[0]),dtype=torch.float64, device='cuda', requires_grad=False)
        for i in range(NL-1):
            PMi=torch.randn((architecture[i],architecture[-1]),dtype=torch.float64, device='cuda', requires_grad=False)/np.sqrt(architecture[i]*1.0)
            PM_list.append(PMi)
        
        # initialize layer objects
        nalloss=[]
        sigmaT=5e-4 # not used
        alphav=100/3.0
        # Generate layer objects for hidden layers
        for i in range(NL-1):
            # previous layer dimension
            inp_layerdim=architecture[i]
            # current layer dimension
            hidden_size=architecture[i+1]
            argsn1=nmseargsstruct(hidden_size,num_classes,inp_layerdim, argsla_R,argsla_mu,argsR_ini_hidden_list[i],argsR_eps_weight_hidden_list[i],alphav, sigmaT,argsla_R2,argsla_mu2,INCLUDE_MEAN,TARGET_POW_HIDDEN[i],Reh_initialization_gain_list[i], error_gain[i],INPUT_SCALE,OUTPUT_SCALE,g_exp,normalized_Rehk,g_const,out_activation)
            obj=AllLossNew(argsn1)
            nalloss.append(obj)

        # Generate layer objects for the output layer
        sigmaT=0.99 # not used 
        alphav=0.5 # not used 
        argsn2=nmseargsstruct(num_classes,num_classes, architecture[-2],argsla_R,argsla_mu,argsR_ini_out,argsR_eps_weight_out,alphav,sigmaT,argsla_R2,argsla_mu2,INCLUDE_MEAN,TARGET_POW_OUT,Reh_initialization_gain,error_gain[2],INPUT_SCALE,OUTPUT_SCALE,g_exp,normalized_Rehk,g_const,out_activation)
        obj=AllLossNew(argsn2)
        nalloss.append(obj)

        # Store initial values for Reh matrices
        # first layer
        Reh0=nalloss[0].Rehk.detach()
        # Frobenius norm of the  initial Rhe0
        Reh0n=torch.norm(Reh0,'fro')
        # second layer 
        Reh1=nalloss[1].Rehk
        # Frobenius norm of the  initial Rhe1
        Reh1n=torch.norm(Reh1,'fro')
        

        ########################################
        #        PERFORMANCE CHECK             #
        ########################################
        
        # performance tracking lists
        trn_acc_list = []
        tst_acc_list = []
        trn_loss_list=[]
        tst_loss_list=[]
        time_list=[]
        
        # Activation sparsity for the first layer and  the current batch (exponentially averaged)
        sp_ar0=0
        # Activation sparsity for the second layer and  the current batch (exponentially averaged)
        sp_ar1=0
        # Activation sparsity for the third layer and  the current batch (exponentially averaged)
        sp_ar2=0
        
        # Checking gradient alignment related parameters
        # number of times gradients calculated
        grad_count=0
        # number of times broadcast gradient is an ascent direction for each layer
        succ_count=np.zeros(NL)
        grad_norm_ratio=np.zeros(NL)
        inner_product=np.zeros(NL)

        # define model
        model = MLP(architecture, final_layer_activation = out_activation, weight_std=WEIGHT_SCALE,input_scale=INPUT_SCALE,output_scale= OUTPUT_SCALE).to(device)
        # SGD optimizer learning this essentially performs gradient update
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,weight_decay=1e-20)
        
        # list containing network parameters
        param_list = [param for param in model.parameters()]
        
        # list initialized for momentum calculations
        SD = [torch.zeros_like(param) for param in model.parameters()]

        # Sparsifying  weights (not really used as non sparse weights are %99)
        # indx: even-> synaptic weights odd->biases
        indx=0
        # mask for zeroing weights
        masks_list=[]
        for param in param_list:
            if (np.mod(indx,2)==0):
                curmask=torch.rand(param_list[indx].data.shape,dtype=torch.float64, device='cuda', requires_grad=False)
                curmask=(curmask>(1-WEIGHT_NON_SPARSITY))*1.0
            else:
                curmask=torch.ones(param_list[indx].data.shape,dtype=torch.float64, device='cuda', requires_grad=False)

            masks_list.append(curmask)
            param_list[indx].data*=masks_list[indx]*1/np.sqrt(WEIGHT_NON_SPARSITY)
            indx=indx+1;

        W0i=param_list[0].data
        W1i=param_list[2].data
        W0in=torch.norm(W0i,'fro')
        W1in=torch.norm(W1i,'fro')
   
        #############################################################################
        #                                  TRAINING LOOP                            #
        #############################################################################
        
        # learning rate parameters: counters related to how fast learning rate to be reduced
        cnt=0
        cnt2=0
        #  list of the estimates of the training accuracy (a new value added every batch)
        trn_acc_est=[]
        # training accuracy estimate based on the current batch (exponentially averaged)
        tr_ar=0
        INPUT_SCALE2=INPUT_SCALE*INPUT_SCALE
        INPUT_SCALE3=INPUT_SCALE2*INPUT_SCALE
        INPUT_SCALE4=INPUT_SCALE2*INPUT_SCALE2
        
        trn_acc, trn_loss = evaluateClassification(model, train_loader, device, True)
        tst_acc, tst_loss = evaluateClassification(model, test_loader, device, True)

        trn_acc_list.append(trn_acc)
        tst_acc_list.append(tst_acc)
        trn_loss_list.append(trn_loss)
        tst_loss_list.append(tst_loss)

        # Open a text file for writing
        with open(str(seed_)+'output_log_mnist.txt', 'a') as file:
            file.write(f"Epoch: {-1}, train mse:: {trn_loss}, test mse:: {tst_loss}\n")
            file.write(f"Epoch: {-1}, train accuracy:: {trn_acc*100}, test accuracy:: {tst_acc*100}\n")
        
        
        for epoch_ in range(n_epochs):  
            model.train()
            mycnt=0
            epoch_time = time.time()
            
            for idx, (x, y) in tqdm(enumerate(train_loader)):
                # update batch counter
                cnt=cnt+1
                # cnt2 is update every 10 batches
                if np.mod(cnt,10)==9:
                    cnt2=cnt2+1
                # learning rate for the current batch
                lrv=1/(cnt2*1.5/LRV_SCALE+1)
                lambdm=(1-1/(cnt2/momentumscale+1))
                momentumv=lambdm*momentumvv+(1-lambdm)*momentumvv2
                
                # scaling for decorrelation and entropy loss gradients
                scv=1.0*(cnt2/DLR_SCALE+1)
                # coefficient for the weight-l2-regularization
                coeff_Wl2_reg=coeff_Wl2_reg0/(cnt2/L2LR_SCALE+1)
                # input batch and corresponding levels to gpu
                x, y = x.to(device), y.to(device)
                # one-hot representation of the labesl
                y_one_hot = F.one_hot(y, num_classes=model.nc)*alph4
                # set gradients zero
                optimizer.zero_grad()
                x=x*INPUT_SCALE3
                # calculate model output and intermediate signals
                y_hat,hidden_list,prey,preh_list=model(x)

                # calculate error between labels and the model output
                err=y_hat-y_one_hot.to(torch.float64)
                # index of the peak model output component
                q=torch.argmax(y_hat,axis=1)
                # calculate the estimate of the training accuracy through exponential averaging
                tr_ar=0.99*tr_ar+0.01*torch.sum(1.0*(q==y)).item()/batchsize
                # update activation sparsity for the first layer 
                sp_ar0=0.99*sp_ar0+0.01*(torch.sum(hidden_list[0]==0)/hidden_list[0].numel()).item()
                # update activation sparsity for the first layer 
                sp_ar1=0.99*sp_ar1+0.01*(torch.sum(hidden_list[1]==0)/hidden_list[1].numel()).item()
                # update activation sparsity for the output
                sp_ar2=0.99*sp_ar2+0.01*(torch.sum(y_hat==0)/y_hat.numel()).item()
                param_list = [param for param in model.parameters()]
                
                """
                if np.mod(cnt,proj_update_period)==(proj_update_period-1):
                    torch.nn.init.xavier_uniform(nalloss[0].Rehk,gain = Reh_initialization_gain_list[0])
                    torch.nn.init.xavier_uniform(nalloss[1].Rehk,gain = Reh_initialization_gain_list[1])
                """
                # Vectorize the input
                xx= x.view(x.size(0),-1)

                # now the gradient calculations for each layer
                with torch.no_grad():
                    # Collect current network weights and parameters
                    param_list = [param for param in model.parameters()]
                    grad_count+=1

                    ###################################################
                    #        LOSS AND GRADIENT CALCULATIONS           #
                    ###################################################

                    # FINAL (OUTPUT) LAYER
                    # Learning Rate for the Output Layer
                    lrv2=lrv*LR_OUT_SCALE
                    # parameter list index calculations for final layer:
                    # Layer Index
                    ind=NL-1
                    # Index for Synaptic Weights
                    ind0=int(ind*2)
                    # Index for Biases
                    ind1=int(ind*2+1)

                    #Initialization of the backpropagated gradient parameter phi
                    identitymatrix = torch.eye(num_classes,dtype=torch.float64,device='cuda')
                    #Assign identity matrix to each batch index
                    Phik = identitymatrix.unsqueeze(0).repeat(batchsize, 1, 1)  
                    # Weight matrix for after out layer is identity (due to how phi recursive  update is structured)
                    Wk=torch.eye (num_classes,dtype=torch.float64,device='cuda')# param_list[ind0].data.to(torch.float64)

                    # Select the loss object for the final (output) layer
                    lossobj=nalloss[NL-1]
                    # Calculate Losses and Gradients for the output layer using the loss object
                    NMSEloss2,Cov_loss2, Outl1_loss2,  Reyk2,Ryk2,gradWmse2,gradbmse2, gradWcov2,gradbcov2,gradWl1out2,gradbl1out2,gradWmse22,gradbmse22,Phik,gradWpow2,gradbpow2=lossobj(y_hat,prey,err,hidden_list[NL-2],3.0,Phik,Wk)
                    onesw=torch.zeros(batchsize,architecture[-2])

                    if (GRAD_NORM>0):
                        gradWmse2=GRAD_SCALE*gradWmse2/(torch.norm(gradWmse2,'fro')+0.5)
                        gradbmse2=GRAD_SCALE*gradbmse2/(torch.norm(gradbmse2,'fro')+0.5)
                    
                    # calculate search direction for decorrelation  based on momentum
                    SD[ind0]=momentumv*SD[ind0]+(1-momentumv)*(gradWmse2)
                    SD[ind1]=momentumv*SD[ind1]+(1-momentumv)*gradbmse2.squeeze().data.to(torch.float32)
                    # calculate the  loss (without weight regularizers) gradient wrt synaptic weights
                    gradW2=scv*CMSE_OUT*(SD[ind0])
                    # calculate the  loss gradient wrt biases
                    gradb2=bias_lr_scale*alph2*(scv*CMSE_OUT*(torch.reshape(SD[ind1],gradbmse2.shape)))

                    #update gradient part  in the model parameters
                    #first using weight l1 regularization gradient
                    #param_list[ind0].grad = lrv2*coeff_Wl1_reg*torch.sign(param_list[ind0].data)
                    # then weight l2 regularization gradient
                    param_list[ind0].grad = coeff_Wl2_reg*(param_list[ind0].data)
                    #update synaptic weight gradients based on the loss gradient
                    param_list[ind0].grad += lrv2*gradW2
                    #update biases based on loss gradient
                    param_list[ind1].grad = lrv2*gradb2.squeeze().data.to(torch.float32)

                    # Learning Rate for the Hidden Layers
                    lrv1=lrv*LR_HIDDEN_SCALE

                    nmsehlosses=[]
                    # Gradient update for hidden layers, starting from the layer before output and backwards
                    for ind in range(NL-2, -1, -1):
                        #Select the loss object for the current hidden layer
                        lossobj=nalloss[ind]
                        # input for the current hidden layer (network input or the output of the previous layer)
                        if (ind==0):
                            inpvec=xx
                        else:
                            inpvec=hidden_list[ind-1]

                        # parameter list index for the layer weights
                        ind0=int(ind*2)
                        # parameter list index for the layer biases
                        ind1=int(ind*2+1)
                        # Select the next layer weights for the backpropagation component
                        Wk=param_list[ind0+2].data.to(torch.float64)
                        # Loss and gradient computations for the current hiddent layer
                        NMSEloss1,Cov_loss1,Outl1_loss1,Reyk1,Ryk1,gradWmse1,gradbmse1, gradWcov1,gradbcov1,gradWl1out1,gradbl1out1,gradWmse12,gradbmse12,Phik,gradWpow1,gradbpow1=lossobj(hidden_list[ind],preh_list[ind],err,inpvec,ind,Phik,Wk)
                        nmsehlosses.append(NMSEloss1)
                        
                        if (DBP_ON>0):
                            gradWmse1=0.5*gradWmse1+0.5*gradWmse12
                            gradbmse1=0.5*gradbmse1+0.5*gradbmse12
                        if (GRAD_NORM>0):
                            gradWmse1=GRAD_SCALE*gradWmse1/(torch.norm(gradWmse1,'fro')+0.5)
                            gradbmse1=GRAD_SCALE*gradbmse1/(torch.norm(gradbmse1,'fro')+0.5)
                       
                        # Calculate current entropy objective weight
                        CCOV_LAMBD=1/(np.sqrt(epoch_/CCOV_LAMBD_STEP)+1)
                        CCOV_HIDDEN=CCOV_LAMBD*CCOV_HIDDEN_MAX[ind]+(1.0-CCOV_LAMBD)*CCOV_HIDDEN_MIN[ind]
                        # combine gradients of different losses using loss weights
                        SD[ind0]=momentumv*SD[ind0]+(1-momentumv)*(gradWmse1)
                        SD[ind1]=momentumv*SD[ind1]+(1-momentumv)*gradbmse1.squeeze().data.to(torch.float32)
                        scale_b=alph
                        if ind==0:
                            scale_b=1.0
                        gradW1=scv*CMSE_HIDDEN[ind]*(SD[ind0])
                        gradb1=bias_lr_scale*scale_b*(scv*CMSE_HIDDEN[ind]*(torch.reshape(SD[ind1],gradbmse1.shape)))

                        # Update the current layer weight model parameter's gradient
                        # based on the weight  l1 regularization loss subgradient first
                        #param_list[ind0].grad = lrv1*coeff_Wl1_reg*torch.sign(param_list[ind0].data)
                        # based on the weight  l2 regularization loss subgradient 
                        param_list[ind0].grad = coeff_Wl2_reg*(param_list[ind0].data)
                        # and the loss gradient
                        param_list[ind0].grad += lrv1*gradW1
                        # update the  bias model parameter'sgradient compoenent  based on the bias gradients
                        param_list[ind1].grad = lrv1*gradb1.squeeze().data.to(torch.float32)

                #perform gradient update
                optimizer.step() 
                
                # apply sparsity mask to model parameters
                indx=0
                if (np.mod(idx,6)==5):
                    for param in param_list:
                        param_list[indx].data*=masks_list[indx]
                        #maxp=torch.max(torch.abs(param_list[indx]))
                        #indc=(torch.abs(param_list[indx])>(maxp/150/(indx*8+1)))
                        #param_list[indx].data*=indc
                        indx=indx+1
            
            time_list.append(time.time()-epoch_time)
            # At the end of the epoch 
            trn_acc, trn_loss = evaluateClassification(model, train_loader, device, True)
            tst_acc, tst_loss = evaluateClassification(model, test_loader, device, True)
            
            trn_acc_list.append(trn_acc)
            tst_acc_list.append(tst_acc)
            trn_loss_list.append(trn_loss)
            tst_loss_list.append(tst_loss)
            
            # Open a text file for writing
            with open(str(seed_)+'output_log_mnist.txt', 'a') as file:
                file.write(f"Epoch: {epoch_}, train mse:: {trn_loss}, test mse:: {tst_loss}, time :: {time_list[-1]}\n")
                file.write(f"Epoch: {epoch_}, train accuracy:: {trn_acc*100}, test accuracy:: {tst_acc*100}\n")
                
            
        with open(str(seed_)+"train_vars_mnist.pkl", 'wb') as f:  # Python 3: open(..., 'wb')
            pickle.dump([trn_acc_list, tst_acc_list, trn_loss_list, tst_loss_list, time_list], f)
        torch.save(model.state_dict(), str(seed_)+"mnistmodel.pth")