import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import torch.nn.functional as F
import warnings
import logging
import math
import shutil
import copy
import argparse
import pickle
from gradients import *
from utils import *


warnings.filterwarnings("ignore")

def get_args(): 
    parser = argparse.ArgumentParser(description='Modify default values of the script.')
    parser.add_argument('--n_epochs', type=int, default=101, help='n_epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
    parser.add_argument('--logger_name', type=int, default=1, help='logger_name')
    parser.add_argument('--scale_factor', type=float, default=6, help='scale_factor')
    parser.add_argument('--init_dist', type=str, default="gaussian", help='init_dist')
    
    parser.add_argument('--CMSE_OUT', type=float, default=10, help='CMSE_OUT')
    parser.add_argument('--CMSE_OUT2', type=float, default=0.1, help='CMSE_OUT2')
    parser.add_argument('--CMSE_HIDDEN', type=float, default=0.1, help='CMSE_HIDDEN')
    parser.add_argument('--CCOV_OUT', type=float, default=0, help='CCOV_OUT')
    parser.add_argument('--CCOV_OUT2', type=float, default=1e-7, help='CCOV_OUT2')
    parser.add_argument('--CCOV_HIDDEN', type=float, default=1e-9, help='CCOV_HIDDEN')
    parser.add_argument('--CL1_OUT', type=float, default=1e-10, help='Layer Activation L_1 Loss')
    parser.add_argument('--CL1_HIDDEN', type=float, default=1e-11, help='Layer Activation L_1 Loss')
    parser.add_argument('--weight_decay', type=float, default=1e-8, help='weight_decay')
    
    parser.add_argument('--Reh_gain_lin', type=float, default=0.01, help='Reh_gain_lin')
    parser.add_argument('--Reh_gain', type=float, default=0.01, help='Reh_gain')
    parser.add_argument('--Reh_lambda', type=float, default=0.99999, help='Reh lambda')
    parser.add_argument('--Reh_lambda2', type=float, default=0.99999, help='Reh lambda Entropy')
    parser.add_argument('--Reh_lambda_drop', type=float, default=0.02, help='Reh lambda drop rate')
    parser.add_argument('--Reh_lambda_drop_every', type=int, default=1, help='Reh lambda drop every')
    parser.add_argument('--Reh_ini', type=float, default=1e-8, help='scale')
    parser.add_argument('--lr', type=float, default=1e-4, help='lr')
    parser.add_argument('--lr_drop_every', type=float, default=1, help='lr_drop_every')
    parser.add_argument('--lr_drop_rate', type=float, default=0.97, help='lr_drop_rate')
    parser.add_argument('--method', type=str, default="ebd", help='training mtd')
    args = parser.parse_args()
    return args

args = get_args()
logger_name = args.logger_name
scale_factor=args.scale_factor
init_dist=args.init_dist
lr=args.lr
Reh_lambda_drop_every=args.Reh_lambda_drop_every
Reh_lambda_drop=args.Reh_lambda_drop
Reh_lambda=args.Reh_lambda
Reh_lambda2=args.Reh_lambda2
Reh_gain=args.Reh_gain
Reh_gain_lin=args.Reh_gain_lin
CMSE_OUT=args.CMSE_OUT
CMSE_OUT2=args.CMSE_OUT2
CMSE_HIDDEN=args.CMSE_HIDDEN
CCOV_OUT=args.CCOV_OUT
CCOV_OUT2=args.CCOV_OUT2
CCOV_HIDDEN=args.CCOV_HIDDEN
CL1_OUT=args.CL1_OUT
CL1_HIDDEN=args.CL1_HIDDEN
n_epochs=args.n_epochs
batch_size=args.batch_size
lr_drop_every=args.lr_drop_every
lr_drop_rate=args.lr_drop_rate
Reh_ini=args.Reh_ini
weight_decay=args.weight_decay
method=args.method

######################################################################################
# Loggers
task_name = "exp_mnist_v1_"+str(logger_name)+"_"+str(method) # change for each experiment    
logger_folder = "loggers_conv"
logger_folder_name = logger_folder+"/"+task_name
save_dir = logger_folder_name+"/experiment"
#########################################

if not os.path.exists(logger_folder):
    os.mkdir(logger_folder)
if not os.path.exists(logger_folder_name):
    os.mkdir(logger_folder_name)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s | | %(levelname)s | | %(message)s')

logger_file_name = os.path.join(logger_folder_name, "experiment")
file_handler = logging.FileHandler(logger_file_name,'w')
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.info('Code started \n')
shutil.copyfile("cnn_mnist.py", logger_folder_name+"/run.py")
######################################################################################


class MNISTModelCNN(nn.Module):
    def __init__(self):
        self.kernel_size = 3
        self.pad=1
        self.P0=64
        self.P1=32
        self.b_linear_size = 1568
        
        super(MNISTModelCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, self.P0, kernel_size=(self.kernel_size,self.kernel_size), stride=1, padding=self.pad)
        self.act1 = nn.ReLU()
        self.pool1 = nn.AvgPool2d(kernel_size=(2, 2))
        
        self.conv2 = nn.Conv2d(self.P0, self.P1, kernel_size=(self.kernel_size,self.kernel_size), stride=1, padding=self.pad)
        self.act2 = nn.ReLU()
        self.pool2 = nn.AvgPool2d(kernel_size=(2, 2))
        
        self.flat = nn.Flatten()
        self.fc3 = nn.Linear(self.b_linear_size, 1024)
        self.act4 = nn.ReLU()
        self.fc4 = nn.Linear(1024, 10)
        self.act5 = nn.ReLU()
        self.nc = 10
        
        # Custom Kaiming initialization with smaller std
        self.initialize_weights(dist=init_dist, scale_factor=scale_factor)

    def initialize_weights(self, dist="gaussian", scale_factor=math.sqrt(1/6)):
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                fan_in = nn.init._calculate_correct_fan(layer.weight, mode='fan_in')
                print(fan_in)
                gain = math.sqrt(2.0) * math.sqrt(1/scale_factor)  # scaling to the gain
                with torch.no_grad():  # Ensure no gradients are tracked
                    if dist == "uniform":
                        std = gain / math.sqrt(fan_in)
                        bound = math.sqrt(3.0) * std  # Calculate the bound for uniform
                        layer.weight.uniform_(-bound, bound)
                        if layer.bias is not None:
                            layer.bias.fill_(0)
                    elif dist == "gaussian":
                        std = gain / math.sqrt(fan_in)  # Standard deviation
                        layer.weight.data = torch.randn_like(layer.weight) * std
                        if layer.bias is not None:
                            layer.bias.fill_(0)
        
    def forward(self, x):
        preh_list=[]
        hidden_list=[]
        
        out1 = self.conv1(x)
        preh_list.append(out1)
        out2 = self.act1(out1)
        hidden_list.append(out2)
        out22 = self.pool1(out2)
        
        out3 = self.conv2(out22)
        preh_list.append(out3)
        out4 = self.act2(out3)
        hidden_list.append(out4)
        out44 = self.pool2(out4)
        
        out8 = self.flat(out44)
        
        hidden_list.append(out8)
        out9 = self.fc3(out8)
        preh_list.append(out9)
        out10 = self.act4(out9)
        hidden_list.append(out10)
        out11 = self.fc4(out10)
        preh_list.append(out11)
        
        #x = self.act5(out11)
        x = out11
        return x,preh_list,hidden_list
    
    # initialize layer losses
    def return_nalloss(self, method):
        nalloss=[]
        # Fixed Parameters
        num_classes=10
        inp_dim=28*28
        inp_layerdim=[28,28]
        hidden_size=[28,28]
        argsn1=nmseargsstructCNN_LC(hidden_size,num_classes,inp_layerdim,Reh_lambda,Reh_ini,Reh_ini,3,self.P0,self.kernel_size,batch_size,1,Reh_gain, method)
        obj=LossConvolutive(argsn1)
        nalloss.append(obj)

        inp_layerdim=[14,14]
        hidden_size=[14,14]
        argsn1=nmseargsstructCNN_LC(hidden_size,num_classes,inp_layerdim,Reh_lambda,Reh_ini,Reh_ini,self.P0,self.P1,self.kernel_size,batch_size,1,Reh_gain, method)
        obj=LossConvolutive(argsn1)
        nalloss.append(obj)                  

        argsn2=nmseargsstruct(1024,num_classes,self.b_linear_size,Reh_lambda,Reh_ini,Reh_ini,Reh_gain,Reh_lambda2, method)
        obj=LossFullyConnected(argsn2)
        nalloss.append(obj)

        argsn2=nmseargsstruct(10,num_classes,1024,Reh_lambda,Reh_ini,Reh_ini,Reh_gain,Reh_lambda2, method)
        obj=LossFullyConnected(argsn2)
        nalloss.append(obj)
        return nalloss
    
np.random.seed(logger_name)
torch.manual_seed(logger_name)

trainloader, testloader = get_loaders(batch_size, "mnist")
num_classes = 10
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
loss_fn = nn.MSELoss()
model = MNISTModelCNN().to(device).type(torch.float64)
nalloss = model.return_nalloss(method)

trn_acc_list = []
tst_acc_list = []
trn_loss_list=[]
tst_loss_list=[]
time_list=[]
nmse3_list=[]
nmse2_list=[]
nmse1_list=[]
nmse0_list=[]
nmse_list=[[],[],[],[]]
angs=[[],[],[],[]]

out_dim = 10
    
print("Training loop") 
logger.info(f"Training loop")

logger_write = 150
lrist=lr
lri=lrist
tr_ar=0
optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999), lr=lri, weight_decay=weight_decay)  

for epoch in range(n_epochs):    
    if epoch>1 and np.mod(epoch, lr_drop_every)==0:
        lri=lri*lr_drop_rate
    model.train()
    epoch_time = time.time()

    if epoch>1 and np.mod(epoch, Reh_lambda_drop_every)==0:
        for n in nalloss:
            n.la_R = n.la_R + Reh_lambda_drop*(1-n.la_R)
    
    if epoch>1 and np.mod(epoch, Reh_lambda_drop_every)==0:
        for n in nalloss[2:]:
            n.la_R2 = n.la_R2 + Reh_lambda_drop*(1-n.la_R2)

    t = 0
    mse = 0
    covloss = 0
    optimizer.param_groups[0]['lr'] = lri
    lrv=optimizer.param_groups[0]['lr']
    print("epoch: ", epoch) 
    for inputs, labels in (trainloader):
        if(t>=len(trainloader)-1):
            break

        # forward, backward, and then weight update
        inputs = inputs.to(device).type(torch.float64)
        y_pred,preh_list,hidden_list = model(inputs)
        y_one_hot = F.one_hot(labels, num_classes=10).clone().detach().type(torch.float64)
        optimizer.zero_grad()
        err = y_pred - y_one_hot.to(device)
        loss = loss_fn(y_pred, y_one_hot.to(device))
        pred = torch.argmax(y_pred, dim=1).squeeze()        
        tr_ar = 0.99*tr_ar + 0.01*(labels == pred.to("cpu")).sum().item()/(batch_size)

        if(t%logger_write == 0):
            print(f"Iter: {t:02} | Time {int(time.time()-epoch_time)}s | lr: {lrv:.7f} | mse: {mse/(t+1):.3f}, cov: {covloss/(t+1):.3f}, trar: {tr_ar:.3f}")
            logger.info(f"Iter: {t:02} | Time {int(time.time()-epoch_time)}s | lr: {lrv:.7f} | mse: {mse/(t+1):.3f}, cov: {covloss/(t+1):.3f}, trar: {tr_ar:.3f}")
        t=t+1
        # Collect current network weights and parameters
        param_list = [param for param in model.parameters()]
        lrv2=lrv
        
        if(method == "bp"):
            loss.backward()
            mse_lin, mse_lin2, mse0, mse1 = 0,0,0,0
            angc0, angc1, angl2, angl1 = torch.tensor(0),torch.tensor(0),torch.tensor(0),torch.tensor(0)
            RW0, RW1, Ryk_lin, Ryk_lin2 = torch.zeros((1,1,1)),torch.zeros((1,1,1)),torch.zeros((1,1,1)),torch.zeros((1,1,1))
            
        elif(method == "ebd"):  
            with torch.no_grad():
                # Calculate Losses and Gradients for the output layer using the loss object
                mse_lin, covloss_lin, Reyk_lin, Ryk_lin, dWmse_lin, _, dWcov_lin, _, dWl1out_lin, _,angl1 = nalloss[-1](y_pred,preh_list[-1],err,hidden_list[-1],2.0)
                mse_lin2, covloss_lin2, Reyk_lin2, Ryk_lin2, dWmse_lin2, _, dWcov_lin2, _, dWl1out_lin2, _,angl2 = nalloss[-2](hidden_list[-1],preh_list[-2],err,hidden_list[-2],1.0)
                mse0,covloss0,dW0,db0,dW_cov0,db_cov0,dWL10,dbL10,RW0,angc0 = nalloss[0](param_list[0],hidden_list[0],preh_list[0],err,inputs,None)
                mse1,covloss1,dW1,db1,dW_cov1,db_cov1,dWL11,dbL11,RW1,angc1 = nalloss[1](param_list[2],hidden_list[1],preh_list[1],err,F.avg_pool2d(hidden_list[0],2),None)
                mse = mse + mse1.item() + mse0.item() + mse_lin.item() + mse_lin2.item()
                covloss = covloss + covloss1.item() + covloss0.item() + covloss_lin.item() + covloss_lin2.item()

                #update synaptic weights based on loss gradient
                param_list[0].grad = CMSE_HIDDEN*dW0.permute(1,0,2,3) + CCOV_HIDDEN*dW_cov0 + CL1_HIDDEN*dWL10.permute(1,0,2,3)
                param_list[2].grad = CMSE_HIDDEN*dW1.permute(1,0,2,3) + CCOV_HIDDEN*dW_cov1 + CL1_HIDDEN*dWL11.permute(1,0,2,3)
                param_list[-4].grad = CMSE_OUT2*(dWmse_lin2) + CCOV_OUT2*dWcov_lin2 + CL1_OUT*dWl1out_lin2
                param_list[-2].grad = CMSE_OUT*(dWmse_lin) + CCOV_OUT*dWcov_lin
                
        elif(method == "dfa1"):
            with torch.no_grad():
                # Calculate Losses and Gradients for the output layer using the loss object
                mse_lin, _, Reyk_lin, Ryk_lin, gradWmse_lin, _, _, _, _, _,angl1 = nalloss[-1](y_pred,preh_list[-1],err,hidden_list[-1],2.0)
                mse_lin2, _, Reyk_lin2, Ryk_lin2, gradWmse_lin2, _, _, _, _, _,angl2 = nalloss[-2](hidden_list[-1],preh_list[-2],err,hidden_list[-2],1.0)
                mse0,_,dW0,db0,_,_,_,_,RW0,angc0 = nalloss[0](param_list[0],hidden_list[0],preh_list[0],err,inputs,None)
                mse1,_,dW1,db1,_,_,_,_,RW1,angc1 = nalloss[1](param_list[2],hidden_list[1],preh_list[1],err,F.avg_pool2d(hidden_list[0],2),None)

                #update synaptic weights based on loss gradient
                param_list[0].grad = CMSE_HIDDEN*dW0.permute(1,0,2,3)
                param_list[2].grad = CMSE_HIDDEN*dW1.permute(1,0,2,3)
                param_list[-4].grad = CMSE_OUT2*gradWmse_lin2
                param_list[-2].grad = CMSE_OUT*gradWmse_lin
        
        elif(method == "dfa2"):
            with torch.no_grad():
                # Calculate Losses and Gradients for the output layer using the loss object
                mse_lin, _, Reyk_lin, Ryk_lin, gradWmse_lin, _, dWcov_lin, _, _, _,angl1 = nalloss[-1](y_pred,preh_list[-1],err,hidden_list[-1],2.0)
                mse_lin2, _, Reyk_lin2, Ryk_lin2, gradWmse_lin2, _, dWcov_lin2, _, _, _,angl2 = nalloss[-2](hidden_list[-1],preh_list[-2],err,hidden_list[-2],1.0)
                mse0,_,dW0,db0,dW_cov0,_,_,_,RW0,angc0 = nalloss[0](param_list[0],hidden_list[0],preh_list[0],err,inputs,None)
                mse1,_,dW1,db1,dW_cov1,_,_,_,RW1,angc1 = nalloss[1](param_list[2],hidden_list[1],preh_list[1],err,F.avg_pool2d(hidden_list[0],2),None)

                #update synaptic weights based on loss gradient
                param_list[0].grad = CMSE_HIDDEN*dW0.permute(1,0,2,3) + CCOV_HIDDEN*dW_cov0
                param_list[2].grad = CMSE_HIDDEN*dW1.permute(1,0,2,3) + CCOV_HIDDEN*dW_cov1
                param_list[-4].grad = CMSE_OUT2*gradWmse_lin2 + CCOV_OUT2*dWcov_lin2
                param_list[-2].grad = CMSE_OUT*gradWmse_lin + CCOV_OUT*dWcov_lin
                
        optimizer.step() 

    time_list.append(time.time()-epoch_time)
    logger.info(f"training time: {int(time.time()-epoch_time)} sec")
    # END OF THE EPOCH
    nmse_list[0].append(mse0**2/batch_size)
    nmse_list[1].append(mse1**2/batch_size)
    nmse_list[2].append(mse_lin2**2/batch_size)
    nmse_list[3].append(mse_lin**2/batch_size)

    angs[0].append(angc0.cpu().item())
    angs[1].append(angc1.cpu().item())
    angs[2].append(angl2.cpu().item())
    angs[3].append(angl1.cpu().item())

    trn_acc, trn_loss = evaluateClassification(model, trainloader, device, True)
    tst_acc, tst_loss = evaluateClassification(model, testloader, device, True)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    trn_loss_list.append(trn_loss)
    tst_loss_list.append(tst_loss)

    logger.info(f"time: {int(time.time()-epoch_time)} sec")
    logger.info(f"Epoch: {epoch}, train mse:: {trn_loss}, test mse:: {tst_loss}")
    logger.info(f"Epoch: {epoch}, train accuracy:: {trn_acc*100}, test accuracy:: {tst_acc*100}")
    print("time: ", int(time.time()-epoch_time), " sec")
    print("Epoch %d: model accuracy %.2f%%" % (epoch, tst_acc*100))

    # PLOT FIGURES
    # Clear the current output and plot updated results
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    # Unpack the axes array to individual axes
    ax1, ax2 = axes[0]
    ax3, ax4 = axes[1]

    # Plot for Training and Test Accuracy
    ax1.plot(trn_acc_list, label="Training Accuracy")
    ax1.plot(tst_acc_list, label="Test Accuracy")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Accuracy")
    ax1.legend()
    ftrain = "{:.4f}".format(trn_acc_list[-1])
    ftest = "{:.4f}".format(tst_acc_list[-1])
    ax1.set_title(f"Train:{ftrain} Test:{ftest}")
    ax1.grid(True)

    # Plot for nmse1_list
    nmse0_list_np = [item.cpu().numpy() for item in nmse_list[0] if torch.is_tensor(item)]
    nmse1_list_np = [item.cpu().numpy() for item in nmse_list[1] if torch.is_tensor(item)]
    nmse2_list_np = [item.cpu().numpy() for item in nmse_list[2] if torch.is_tensor(item)]
    nmse3_list_np = [item.cpu().numpy() for item in nmse_list[3] if torch.is_tensor(item)]

    ax2.plot(nmse0_list_np, label="NMSE0 Value")
    ax2.plot(nmse1_list_np, label="NMSE1 Value")
    ax2.plot(nmse2_list_np, label="NMSE2 Value")
    ax2.plot(nmse3_list_np, label="NMSE3 Value")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("NMSE Valus")
    ax2.legend()
    ax2.set_title(f"Layerwise NMSE Losses")
    ax2.grid(True)
    
    # Display Eigenvalues for the output layer 
    ax3.semilogy(np.sort(np.linalg.eig(RW0.cpu().numpy())[0]))
    ax3.semilogy(np.sort(np.linalg.eig(RW1.cpu().numpy())[0]))
    ax3.semilogy(np.sort(np.linalg.eig(Ryk_lin.cpu().numpy())[0]))
    ax3.semilogy(np.sort(np.linalg.eig(Ryk_lin2.cpu().numpy())[0]))
    ax3.legend(['Layer 0', 'Layer 1', 'Layer 2', 'Layer 3'])
    ax3.set_xlabel("Output Eigenvalue Index")
    ax3.set_title("1st Hidden Cov. Eigenvalue")
    ax3.grid(True)

    ax4.set_xlabel("Epoch")
    ax4.set_ylabel("Angle Value")
    ax4.plot(angs[0])
    ax4.plot(angs[1])
    ax4.plot(angs[2])
    ax4.plot(angs[3])
    ax4.legend(['Layer 0', 'Layer 1', 'Layer 2', 'Layer 3']) #
    ax4.set_title("Reh Cosine Angles")

    plt.tight_layout()
    plt.savefig(logger_folder_name+"/epoch_"+str(epoch)+"_.png")
    print("time end epoch: ", int(time.time()-epoch_time), " sec")
    torch.cuda.empty_cache()

# Saving the objects:
with open(logger_folder_name+'/train_vars.pkl', 'wb') as f: 
    pickle.dump([trn_acc_list, tst_acc_list, trn_loss_list, tst_loss_list, nmse_list, angs, time_list], f)
    
torch.save(model.state_dict(), logger_folder_name+"/mnist10model.pth")