import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import torch.nn.functional as F
import warnings
import logging
import math
import shutil
import copy
import argparse
import pickle
from utils import *

warnings.filterwarnings("ignore")

def get_args(): 
    parser = argparse.ArgumentParser(description='Modify default values of the script.')
    parser.add_argument('--n_epochs', type=int, default=51, help='n_epochs')
    parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
    parser.add_argument('--logger_name', type=int, default=1, help='logger_name')
    parser.add_argument('--scale_factor', type=float, default=1, help='scale_factor')
    parser.add_argument('--init_dist', type=str, default="gaussian", help='init_dist')
    parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay')
    parser.add_argument('--lr', type=float, default=1e-4, help='lr')
    parser.add_argument('--lr_drop_every', type=float, default=1, help='lr_drop_every')
    parser.add_argument('--lr_drop_rate', type=float, default=1, help='lr_drop_rate')
    args = parser.parse_args()
    return args

args = get_args()
logger_name = args.logger_name
scale_factor=args.scale_factor
init_dist=args.init_dist
lr=args.lr
n_epochs=args.n_epochs
batch_size=args.batch_size
lr_drop_every=args.lr_drop_every
lr_drop_rate=args.lr_drop_rate
weight_decay=args.weight_decay

#################################################################################################
# Loggers
task_name = "exp_mnist_v1_"+str(logger_name) # change for each experiment    
logger_folder = "loggers_mlp"
logger_folder_name = logger_folder+"/"+task_name
save_dir = logger_folder_name+"/experiment"

if not os.path.exists(logger_folder):
    os.mkdir(logger_folder)
if not os.path.exists(logger_folder_name):
    os.mkdir(logger_folder_name)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s | | %(levelname)s | | %(message)s')

logger_file_name = os.path.join(logger_folder_name, "experiment")
file_handler = logging.FileHandler(logger_file_name,'w')
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.info('Code started \n')
shutil.copyfile("mlp_mnist.py", logger_folder_name+"/run.py")
#################################################################################################


class MNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1*28*28, 512, bias=False)
        self.fc2 = nn.Linear(512, 512, bias=False)
        self.fc3 = nn.Linear(512, 10, bias=False)
        self.nc = 10
        
        # Custom Kaiming initialization with smaller std
        self.initialize_weights(dist=init_dist, scale_factor=scale_factor)

    def initialize_weights(self, dist="gaussian", scale_factor=math.sqrt(1/6)):
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                fan_in = nn.init._calculate_correct_fan(layer.weight, mode='fan_in')
                print(fan_in)
                gain = math.sqrt(2.0) * math.sqrt(1/scale_factor)  # scaling to the gain
                with torch.no_grad():  # Ensure no gradients are tracked
                    if dist == "uniform":
                        std = gain / math.sqrt(fan_in)
                        bound = math.sqrt(3.0) * std  # Calculate the bound for uniform
                        layer.weight.uniform_(-bound, bound)
                        if layer.bias is not None:
                            layer.bias.fill_(0)
                    elif dist == "gaussian":
                        std = gain / math.sqrt(fan_in)  # Standard deviation
                        layer.weight.data = torch.randn_like(layer.weight) * std
                        if layer.bias is not None:
                            layer.bias.fill_(0)
        
    def forward(self, x):
        preh_list=[]
        hidden_list=[]
        
        x = x.reshape(x.size(0), -1)
        preh_list = []
        hidden_list = []

        x = self.fc1(x)
        preh_list.append(x)
        x = F.relu(x)
        hidden_list.append(x)

        x = self.fc2(x)
        preh_list.append(x)
        x = F.relu(x)
        hidden_list.append(x)

        x = self.fc3(x)
        preh_list.append(x)

        return x,preh_list,hidden_list
    
    
np.random.seed(logger_name)
torch.manual_seed(logger_name)

trainloader, testloader = get_loaders(batch_size, "mnist")
num_classes = 10
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
loss_fn = nn.MSELoss()
model = MNISTModel().to(device).type(torch.float64)

trn_acc_list = []
tst_acc_list = []
trn_cov_list = []
tst_cov_list = []
trn_loss_list=[]
tst_loss_list=[]


print("Training loop") 
logger.info(f"Training loop")

logger_write = 50
lrist=lr
lri=lrist
tr_ar=0
optimizer = optim.SGD(model.parameters(), lr=lri, weight_decay=weight_decay)  

# END OF THE EPOCH
model.eval()

trn_acc, trn_corr, trn_loss = evaluateClassificationCorrelation(model, trainloader, device, True, 2)
tst_acc, tst_corr, tst_loss = evaluateClassificationCorrelation(model, testloader, device, True, 2)
trn_acc_list.append(trn_acc)
tst_acc_list.append(tst_acc)
trn_cov_list.append(trn_corr)
tst_cov_list.append(tst_corr)
trn_loss_list.append(trn_loss)
tst_loss_list.append(tst_loss)

    
for epoch in range(n_epochs):    
    if epoch>1 and np.mod(epoch, lr_drop_every)==0:
        lri=lri*lr_drop_rate
    model.train()
    epoch_time = time.time()

    t = 0
    mse = 0
    covloss = 0
    optimizer.param_groups[0]['lr'] = lri
    lrv=optimizer.param_groups[0]['lr']
    print("epoch: ", epoch) 
    for inputs, labels in (trainloader):
        if(t>=len(trainloader)-1):
            break

        # forward, backward, and then weight update
        inputs = inputs.to(device).type(torch.float64)
        y_pred,preh_list,hidden_list = model(inputs)
        y_one_hot = F.one_hot(labels, num_classes=10).clone().detach().type(torch.float64)
        optimizer.zero_grad()
        err = y_pred - y_one_hot.to(device)
        loss = loss_fn(y_pred, y_one_hot.to(device))
        pred = torch.argmax(y_pred, dim=1).squeeze()        
        tr_ar = 0.99*tr_ar + 0.01*(labels == pred.to("cpu")).sum().item()/(batch_size)

        if(t%logger_write == 0):
            print(f"Iter: {t:02} | Time {int(time.time()-epoch_time)}s | lr: {lrv:.7f} | mse: {mse/(t+1):.3f}, cov: {covloss/(t+1):.3f}, trar: {tr_ar:.3f}")
            logger.info(f"Iter: {t:02} | Time {int(time.time()-epoch_time)}s | lr: {lrv:.7f} | mse: {mse/(t+1):.3f}, cov: {covloss/(t+1):.3f}, trar: {tr_ar:.3f}")
        t=t+1
        # Collect current network weights and parameters
        param_list = [param for param in model.parameters()]
        lrv2=lrv
        
        loss.backward()
        optimizer.step() 

    # END OF THE EPOCH
    model.eval()

    trn_acc, trn_corr, trn_loss = evaluateClassificationCorrelation(model, trainloader, device, True, 2)
    tst_acc, tst_corr, tst_loss = evaluateClassificationCorrelation(model, testloader, device, True, 2)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    trn_cov_list.append(trn_corr)
    tst_cov_list.append(tst_corr)
    trn_loss_list.append(trn_loss)
    tst_loss_list.append(tst_loss)

    logger.info(f"time: {int(time.time()-epoch_time)} sec")
    logger.info(f"Epoch: {epoch}, train mse:: {trn_loss}, test mse:: {tst_loss*100}")
    logger.info(f"Epoch: {epoch}, train accuracy:: {trn_acc*100}, test accuracy:: {tst_acc*100}")
    print("time: ", int(time.time()-epoch_time), " sec")
    print("Epoch %d: model accuracy %.2f%%" % (epoch, tst_acc*100))
    
    # PLOT FIGURES
    # Clear the current output and plot updated results
    fig, axes = plt.subplots(1, 3, figsize=(12, 8))
    # Unpack the axes array to individual axes
    ax1, ax2, ax3 = axes

    # Plot for Training and Test Accuracy
    ax1.plot(trn_acc_list, label="Training Accuracy")
    ax1.plot(tst_acc_list, label="Test Accuracy")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Accuracy")
    ax1.legend()
    ftrain = "{:.4f}".format(trn_acc_list[-1])
    ftest = "{:.4f}".format(tst_acc_list[-1])
    ax1.set_title(f"Train:{ftrain} Test:{ftest}")
    ax1.grid(True)
    
    ax2.plot(np.array([torch.abs(i[0]).mean() for i in trn_cov_list]), label="Train COR0 Value")
    ax2.plot(np.array([torch.abs(i[1]).mean() for i in trn_cov_list]), label="Train COR1 Value")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Train Correlation Valus")
    ax2.legend()
    ax2.set_title(f"Avg. Abs. Correlation")
    ax2.grid(True)
    
    ax3.plot(np.array([torch.abs(i[0]).mean() for i in tst_cov_list]), label="Test COR0 Value")
    ax3.plot(np.array([torch.abs(i[1]).mean() for i in tst_cov_list]), label="Test COR1 Value")
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Test Correlation Valus")
    ax3.legend()
    ax3.set_title(f"Avg. Abs. Correlation")
    ax3.grid(True)

    plt.tight_layout()
    plt.savefig(logger_folder_name+"/epoch_"+str(epoch)+"_.png")
    print("time end epoch: ", int(time.time()-epoch_time), " sec")
    torch.cuda.empty_cache()

# Saving the objects:
with open(logger_folder_name+"/train_vars.pkl", 'wb') as f:  # Python 3: open(..., 'wb')
    pickle.dump([trn_acc_list, tst_acc_list, trn_cov_list, tst_cov_list, trn_loss_list, tst_loss_list], f)
    
torch.save(model.state_dict(), logger_folder_name+"/mnistmodel.pth")
