import os
import argparse
import torch
import torch.nn as nn
from tqdm import tqdm
from utils_cifar10 import get_dataset, get_network, get_tasks_dataset, TensorDataset
import copy
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


def main(args):
    channel, im_size, num_classes, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path, args=args)
    
    save_dir = os.path.join(args.result_path, args.framework, args.model, str(args.data_per_task),str(args.lr))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    train_datasets, joint_datasets = get_tasks_dataset(dst_train, args.num_task, args.data_per_task, args.framework, args.num_classes)
    
    criterion = nn.CrossEntropyLoss().to(args.device)
    
    trainloader =[]
    for i in range(args.num_task):
        sets = TensorDataset(copy.deepcopy(train_datasets[i][0]), copy.deepcopy(train_datasets[i][1]))
        trainloader.append(torch.utils.data.DataLoader(sets, batch_size=args.data_per_task)) #GD for CL
    sets = TensorDataset(copy.deepcopy(joint_datasets[0]), copy.deepcopy(joint_datasets[1]))    
    joint_trainloader = torch.utils.data.DataLoader(sets, batch_size=args.data_per_task*args.num_task)
    
    net = get_network(args.model, channel, args.num_classes, im_size).to(args.device)
    a,b = net._get_layers()
    a,b = copy.deepcopy(a),copy.deepcopy(b)
    joint_net= get_network(args.model, channel, args.num_classes, im_size).to(args.device)
    joint_net._set_layers(a,b)
    net.train()
    joint_net.train()
    net_optimizer = torch.optim.SGD(net.parameters(), lr=args.lr)
    joint_net_optimizer = torch.optim.SGD(joint_net.parameters(), lr=args.lr)
    net_optimizer.zero_grad()
    joint_net_optimizer.zero_grad()
    
    losses={x:[] for x in range(args.num_task)}
    joint_losses=[]
    differences_output = []
    differences_parameter = []
    
    for param in joint_net.parameters():
        param.requires_grad = True
    for param in net.parameters():
        param.requires_grad = False
    for epo in tqdm(range(int(args.JT))):
        joint_net_optimizer.zero_grad()
        loss_t = 0
        for i_batch, datum in enumerate(joint_trainloader):
            img = datum[0].float().to(args.device)
            lab = datum[1].long().to(args.device)
            output = joint_net(img)
            loss = criterion(output, lab)
            loss_t += loss.detach().item()
            loss.backward(retain_graph=True)
        joint_net_optimizer.step()
        joint_losses.append(loss_t)
        
        if epo%100==0 or epo==args.JT-1:
            fig, axes = plt.subplots(1,1,figsize=(10,7))
            ax = axes
            ax.plot(joint_losses,linewidth=3, label = "joint task loss")
            ax.set_yscale('log')
            plt.xlabel("Number of Iterations")
            plt.ylabel("Loss")
            ax.legend()
            fig.savefig(os.path.join(save_dir, f"K{args.K}J{args.J}JT{args.JT}_joint_loss.pdf"))
            plt.close(fig)
    
    
    for param in joint_net.parameters():
        param.requires_grad = False
    for param in net.parameters():
        param.requires_grad = True
    
    for cycle in range(args.J):
        trained_losses=[0 for _ in range(args.num_task)]
        for i in range(args.num_task):
            for epo in tqdm(range(args.K)):
                net_optimizer.zero_grad()
                for i_batch, datum in enumerate(trainloader[i]):
                    img = datum[0].float().to(args.device)
                    lab = datum[1].long().to(args.device)
                    output = net(img)
                    loss = criterion(output, lab)
                    loss.backward(retain_graph=True)
                net_optimizer.step()
                
                
            with torch.no_grad():
                for j in range(args.num_task):
                    loss_t = 0
                    for i_batch, datum in enumerate(trainloader[j]):
                        img = datum[0].float().to(args.device)
                        lab = datum[1].long().to(args.device)
                        output = net(img)
                        loss = criterion(output, lab)
                        loss_t += loss.detach().item()
                    losses[j].append(loss_t)
                    if i==j:
                        trained_losses.append(loss_t)
                differences_output.append(output_compare(net, joint_net, testloader))
                if args.model=='Linear':
                    differences_parameter.append(param_compare(net, joint_net))
                
        if cycle%30==0 or cycle==args.J-1:           
            print(f"cycle {cycle} done")     
            c = ['g', 'orange', 'b', 'pink', 'purple', 'black']
            fig, axes = plt.subplots(1,1,figsize=(10,7))
            ax = axes
            for i in range(args.num_task):
                ax.plot(losses[i],linewidth=3,c=c[i],label = f"task{i} loss")
            ax.set_yscale('log')
            plt.xlabel("Number of cycles*M")
            plt.ylabel("Loss")
            ax.legend()
            fig.savefig(os.path.join(save_dir, f"K{args.K}J{args.J}JT{args.JT}_loss.pdf"))
            plt.close(fig)
                    
            if args.model=='Linear':
                fig, axes = plt.subplots(1,1,figsize=(10,7))
                ax = axes
                ax.plot(differences_parameter,linewidth=3, label = "parameter difference")
                ax.set_yscale('linear')
                plt.xlabel("Number of cycles*M")
                plt.ylabel("l2-distance")
                ax.legend()
                fig.savefig(os.path.join(save_dir, f"K{args.K}J{cycle}JT{args.JT}_parameters.pdf"))
                plt.close(fig)

                fig, axes = plt.subplots(1,1,figsize=(10,7))
                ax = axes
                ax.plot(differences_output,linewidth=3, label = "output difference")
                ax.set_yscale('linear')
                plt.xlabel("Number of cycles*M")
                plt.ylabel("l2-distance")
                ax.legend()
                fig.savefig(os.path.join(save_dir, f"K{args.K}J{cycle}JT{args.JT}_diff_output.pdf"))
                plt.close(fig)

            
def output_compare(model1, model2, testloader):
    with torch.no_grad():
        model1.eval()
        model2.eval()
        dist=0
        for i_batch, datum in enumerate(testloader):
            img = datum[0].float().to(args.device)
            output1 = model1(img)
            output2 = model2(img)
            dist+=torch.norm(output1-output2)**2
        dist = torch.sqrt(dist)/10000  #10000=number of test data
        model1.train()
        model2.train()
    return dist.cpu()

def param_compare(model1, model2):
    with torch.no_grad():
        model1.eval()
        model2.eval()
        _, classifier1=model1._get_layers()
        _, classifier2=model2._get_layers()
        weight_d = classifier1.state_dict()['weight'] - classifier2.state_dict()['weight']
        bias_d = classifier1.state_dict()['bias'] - classifier2.state_dict()['bias']
        diff = torch.sqrt(torch.norm(weight_d)**2 + torch.norm(bias_d)**2)
        model1.train()
        model2.train()
    return diff.cpu()
    
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--device', type=str, default='cpu', help='device: cpu / cuda:#')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--data_path', type=str, default='data', help='path of dataset to be saved')
    parser.add_argument('--result_path', type=str, default='./result', help='path of result to be saved')
    parser.add_argument('--model', type=str, default='Linear', help='ConvNet / ShallowCNN / ReLU / Linear')
    parser.add_argument('--framework', type=str, default='TIL_naive', help='TIL_naive / TIL_unif / CIL')
    parser.add_argument('--num_classes', type=int, default=2, help='number of classes')
    parser.add_argument('--num_task', type=int, default=3, help='M: 3 or 5')
    parser.add_argument('--data_per_task', type=int, default=512, help='number of data per task')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--K', type=int, default=1200, help='epoch for each task training')
    parser.add_argument('--J', type=int, default=500, help='total cycle')
    parser.add_argument('--JT', type=int, default=1200, help='epoch for joint task')
    
    args = parser.parse_args()
    main(args)
    