import os
import sys
sys.path.append(os.getcwd())
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import random
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import time
import models_CNN
import CL
import utils_CNN
import numpy as np

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


def test_previous_tasks(args, model, criterion, device, current_task_idx, task_labels, test_dataset):
    total = 0
    total_TIL =0 
    cnt=0
    tasks_acc = []
    task_acc_TIL = []
    for i in range(current_task_idx+1):
        print('Task '+str(i))
        test_loader=utils_CNN.get_task_load_test(test_dataset[i],args.batch_size)
        val_acc, val_acc_TIL, cnf_matrix = evaluate(args,i, task_labels, model, criterion, device, test_loader, is_test_set=True)
        print(cnf_matrix)
        tasks_acc.append(val_acc)
        task_acc_TIL.append(val_acc_TIL)
        total += val_acc
        total_TIL+= val_acc_TIL
        cnt +=1
    average_over_tasks = total /cnt
    average_over_tasks_TIL = total_TIL/cnt
    print(f"average acc over {cnt} tasks = {average_over_tasks}")
    print(f"average acc TIL over {cnt} tasks = {average_over_tasks_TIL}")
    return tasks_acc, average_over_tasks, task_acc_TIL, average_over_tasks_TIL

def evaluate(args,current_task, task_labels, model, criterion, device, test_loader, is_test_set=False):
    model.eval()
    test_loss = 0
    correct = 0
    correct_TIL = 0
    n = 0
    y_true = []
    y_pred = []
    y_pred_TIL = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device,dtype=torch.int64)
            output,_ ,_,_,_,_ = model(data)
            multihead=torch.zeros_like(output)
            multihead[:,task_labels[current_task]]=1

            test_loss += criterion(output, target).item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            y_true += target.tolist()
            y_pred += pred.view_as(target).tolist()
            #print(pred)
            correct += pred.eq(target.view_as(pred)).sum().item()
            n += target.shape[0]

            # task incremental learning
            output = output*multihead
            pred_TIL = output.argmax(dim=1, keepdim=True)
            correct_TIL += pred_TIL.eq(target.view_as(pred_TIL)).sum().item()
            y_pred_TIL += pred_TIL.view_as(target).tolist()

    cnf_matrix = confusion_matrix(y_true, y_pred)
    cnf_matrix_TIL = confusion_matrix(y_true, y_pred_TIL)
    print(cnf_matrix)
    print(cnf_matrix_TIL)
    test_loss /= float(n)
    print('\n{}: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
        'Test evaluation' if is_test_set else 'Evaluation',
        test_loss, correct, n, 100. * correct / float(n)))
    print('\n{}: Average loss: {:.4f}, Accuracy ITL: {}/{} ({:.3f}%)\n'.format(
        'Test evaluation' if is_test_set else 'Evaluation',
        test_loss, correct_TIL, n, 100. * correct_TIL / float(n)))
    sys.stdout.flush()                
    return correct / float(n), correct_TIL/float(n), cnf_matrix

def get_layer_reprsentation(CL_obj, model, device, test_loader, current_task):
    model.eval()
    num_classes_per_tasks = len(CL_obj.task_labels[current_task])
    average_F1, average_F2, average_F3, average_F4, average_L1 = torch.zeros((num_classes_per_tasks,32)).to(device) , torch.zeros((num_classes_per_tasks,32)).to(device), torch.zeros((num_classes_per_tasks,64)).to(device), torch.zeros((num_classes_per_tasks,64)).to(device), torch.zeros((num_classes_per_tasks,512)).to(device)
    total_samples_per_class = torch.zeros((num_classes_per_tasks)).to(device)
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device,dtype=torch.int64)
            output, f1 ,f2,f3,f4,l1 = model(data)
            for i in range(num_classes_per_tasks):
                current_class_idxs = (target== CL_obj.task_labels[current_task][i])
                samples_count = (current_class_idxs==True).sum()
                if samples_count>0:
                    average_F1[i]+=f1[current_class_idxs].sum(axis=0).sum(dim=(1, 2))
                    average_F2[i]+=f2[current_class_idxs].sum(axis=0).sum(dim=(1, 2))
                    average_F3[i]+=f3[current_class_idxs].sum(axis=0).sum(dim=(1, 2))
                    average_F4[i]+=f4[current_class_idxs].sum(axis=0).sum(dim=(1, 2))
                    average_L1[i]+=l1[current_class_idxs].sum(axis=0)
                    total_samples_per_class[i]+=samples_count
        for i in range(num_classes_per_tasks):
            average_F1[i]/=total_samples_per_class[i]
            average_F2[i]/=total_samples_per_class[i]
            average_F3[i]/=total_samples_per_class[i]
            average_F4[i]/=total_samples_per_class[i]
            average_L1[i]/=total_samples_per_class[i]
    return  average_F1, average_F2, average_F3, average_F4, average_L1 

def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return np.eye(num_classes, dtype='uint8')[y.cpu()]

# evaluating automatic detection of class relation using the approach from you et al. 2020 Co-tuning for transfer learning
def class_relationship(CL_obj, model, val_loader, device):
    model.eval()
    current_task = CL_obj.current_task+1
    num_classes_per_tasks = len(CL_obj.task_labels[current_task])
    current_task_labels = CL_obj.task_labels[current_task]
    num_prev_classes = num_classes_per_tasks*current_task
    each_target_class_source_class = np.zeros((num_classes_per_tasks, num_prev_classes))
    num_samples_per_class = np.zeros((num_classes_per_tasks,))
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output,_,_,_,_,_ = model(data)
            pred = output[:,0:num_prev_classes].argmax(dim=1, keepdim=True) 
            pred = pred.view_as(target)
            pred = to_categorical(pred, num_prev_classes)
                
            for i in range(num_classes_per_tasks):
                if len(pred.shape) == 1:
                    each_target_class_source_class[i] += pred[target==current_task_labels[i]]
                    num_samples_per_class[i]+= 1
                else:
                    each_target_class_source_class[i] += np.sum(pred[target.cpu()==current_task_labels[i]],axis=0)
                    num_samples_per_class[i]+= pred[target.cpu()==current_task_labels[i]].shape[0]

        for i in range(num_classes_per_tasks):
                each_target_class_source_class[i]/=num_samples_per_class[i]
                print(current_task_labels[i])
                print(each_target_class_source_class[i])        
        return each_target_class_source_class
def main():
    parser = argparse.ArgumentParser(description='KAN algorithm')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=10, metavar='S',
                         help='random seed (default: 10)')
    parser.add_argument('--benchmark', type=str, default= 'CIFAR10', 
                        help= 'Options: CIFAR10, CIFAR100')
    parser.add_argument('--num_classes_per_task', type=int, default=2,
                        help='number of classes in each task (default: 2 for CIFAR)')
    parser.add_argument('--num_tasks', type=int, default=2,
                        help='number of tasks')
    parser.add_argument('--class_order', default='1,3,5,9',   
                        help='new order for classes, None if orginial order required') 
    parser.add_argument('--epochs', type=int, default=1, metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 100)')
    parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
                        help='input batch size for testing (default: 100)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--momentum', type=float, default=0, metavar='M',
                        help='SGD momentum (default: 0.0)')
    parser.add_argument('--l2', type=float, default=0)
    parser.add_argument('--log_interval', type=int, default=50, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--optimizer', type=str, default='sgd', help='The optimizer to use. Default: sgd. Options: sgd, adam.')
    parser.add_argument('--representation_relation', type=bool, default=True)
    parser.add_argument('--selection_method_for_related_class', type=str, default='mostrelated', 
                        help='Options: mostrelated, leastrelated, random.')
    parser.add_argument('--save_path', type=str, default='./')

    args = parser.parse_args()
    torch.manual_seed(args.seed)
    os.environ['PYTHONHASHSEED']=str(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    print(args)	

    isExist = os.path.exists(args.save_path)
    if not isExist:
        # Create a new directory because it does not exist 
        os.makedirs(args.save_path)
        print("The new directory is created!")
	
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(device)

    ############### Tasks construction ###############
    ## construct the CL tasks with the specified order
    num_tasks = args.num_tasks
    num_classes_per_task = args.num_classes_per_task
    if args.class_order=='None':
        class_order = list(range(0, num_tasks*num_classes_per_task))
    else:
        class_order = list(map(int, args.class_order.split(',')))
    task_labels = []
    for i in range(0, args.num_tasks):
        start_idx = i*num_classes_per_task
        end_idx   = i*num_classes_per_task + num_classes_per_task
        task_labels.append(class_order[start_idx:end_idx]) 
    target_task_labels = []
    for i in range(0, args.num_tasks*num_classes_per_task, num_classes_per_task):
        target_task_labels.append(list(range(i, i+num_classes_per_task))) 
    train_dataset,test_dataset = utils_CNN.task_construction(task_labels, target_task_labels, args.benchmark)
    
    ############### number of neurons for allocations ###############
    num_freezedNodes_per_layer=[0, 2, 2, 5, 5, 100, args.num_classes_per_task]
    selected_nodes_count = [3, 32, 32, 64, 64, 512, 2]
    additional_selected_nodes = [0, 0, 0, 10, 20, 0, 0] # selected from free per task
    no_neurons_reused_from_previous = [0, 0, 0, 15, 10, 0, 0] # c_candidate per class 
    print("num_freezedNodes_per_layer", num_freezedNodes_per_layer)
    print("selected_nodes_count", selected_nodes_count)
    print("additional_nodes",additional_selected_nodes)
    print("no_reuse",no_neurons_reused_from_previous)

    ############### CL model ###############
    input_channels = 3
    num_classes = num_tasks*args.num_classes_per_task
    model = models_CNN.CNN(input_channels,num_classes).to(device)
    cl = CL.CL(device, num_freezedNodes_per_layer, selected_nodes_count, target_task_labels,model, additional_selected_nodes, no_neurons_reused_from_previous)
    task_labels = target_task_labels

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(cl.model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.l2, nesterov=False)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(cl.model.parameters(),lr=.001, betas=(0.9, .999))

    ############### training ###############
    cl.set_init_network_weight()
    for task_idx in range(0,num_tasks):
        train_loader = utils_CNN.get_task_load_train(train_dataset[task_idx],args.batch_size)
        val_loader = utils_CNN.get_task_load_train(train_dataset[task_idx],args.batch_size)
        test_loader = utils_CNN.get_task_load_test(test_dataset[task_idx],args.batch_size)
        cl.reset_importance()
        test_acc_along_training = []
        train_acc_along_training = []
        test_acc_along_training_TIL = []
        train_acc_along_training_TIL = []
        for epoch in range(args.epochs):
            cl.model.train()
            t0=time.time()
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(device), target.to(device,dtype=torch.int64)
                # zero the parameter gradients   
                optimizer.zero_grad()
                # save weights
                cl.set_old_weight()
                outputs, _,_,_,_,_ = cl.model(data)
                # Computes loss
                loss = criterion(outputs, target)
                #compute gradient 
                loss.backward()
                cl.apply_mask_on_grad()
                optimizer.step()
                cl.calculate_importance()
                cl.recover_old_task_weight()
                if batch_idx % args.log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader)*args.batch_size,
                    100. * batch_idx / len(train_loader), loss.item()))
                    sys.stdout.flush()    

            print('Current learning rate: {0}. Time taken for epoch: {1:.2f} seconds.\n'.format(optimizer.param_groups[0]['lr'], time.time() - t0))
            if epoch < args.epochs-1:
                #### drop and grow cycled from the SpaceNet algorithm #####
                cl.remove()
                cl.add()
                val_acc, val_acc_TIL, _ = evaluate(args,task_idx, task_labels, cl.model, criterion, device, test_loader, is_test_set=True)
                test_acc_along_training.append(val_acc)
                test_acc_along_training_TIL.append(val_acc_TIL)
 
                val_acc,val_acc_TIL ,_ = evaluate(args,task_idx, task_labels, cl.model, criterion, device, val_loader, is_test_set=True)
                train_acc_along_training.append(val_acc)
                train_acc_along_training_TIL.append(val_acc_TIL)

        file_name_each_task = args.selection_method_for_related_class+"_test_acc_along_training_seed_"+ str(args.seed)+ "_task_"+ str(task_idx)
        np.savetxt(f"{args.save_path}{file_name_each_task}.txt", test_acc_along_training)
        file_name_each_task_train = args.selection_method_for_related_class+"_train_acc_along_training_seed_"+ str(args.seed)+ "_task_"+ str(task_idx)
        np.savetxt(f"{args.save_path}{file_name_each_task_train}.txt", train_acc_along_training)
        file_name_each_task_TIL = args.selection_method_for_related_class+"_test_acc_TIL_along_training_seed_"+ str(args.seed)+ "_task_"+ str(task_idx)
        np.savetxt(f"{args.save_path}{file_name_each_task_TIL}.txt", test_acc_along_training_TIL)
        file_name_each_task_train_TIL = args.selection_method_for_related_class+"_train_acc_TIL_along_training_seed_"+ str(args.seed)+ "_task_"+ str(task_idx)
        np.savetxt(f"{args.save_path}{file_name_each_task_train_TIL}.txt", train_acc_along_training_TIL)
        
        cl.set_classifer_to_all_learned_tasks()
        _, _, _, _ = test_previous_tasks(args, cl.model, criterion, device,task_idx, task_labels, test_dataset)

        if task_idx<num_tasks-1:
            train_loader_val = utils_CNN.get_task_load_train(train_dataset[task_idx+1],args.batch_size)
            new_old_class_relation = class_relationship(cl, cl.model, train_loader_val, device)   
            # calculate the activation of new classes on model f(t-1)
            if args.representation_relation==True:
                train_loader_val = utils_CNN.get_task_load_train(train_dataset[task_idx+1], args.batch_size)
                t2_F1, t2_F2, t2_F3, t2_F4, t2_L1 = get_layer_reprsentation(cl, model, device, train_loader_val, task_idx+1)
                t2_representations = [t2_F1, t2_F2, t2_F3, t2_F4, t2_L1]
            else:
                t2_representations = None
            cl.prepare_next_task(new_old_class_relation, args.selection_method_for_related_class, t2_representations)

    each_task_acc, average_acc, each_task_acc_TIL, average_acc_TIL = test_previous_tasks(args, cl.model, criterion, device,task_idx, task_labels, test_dataset)
    print("average_acc:", average_acc)
    print("average_acc_TIL:", average_acc_TIL)
    file_name_each_task = args.selection_method_for_related_class +"_final_acc_each_task_seed_"+ str(args.seed)
    file_name_average_acc = args.selection_method_for_related_class + "_average_acc_seed_" + str(args.seed)
    file_name_each_task_TIL = args.selection_method_for_related_class +"_final_acc_TIL_each_task_seed_"+ str(args.seed)
    file_name_average_acc_TIL = args.selection_method_for_related_class + "_average_acc_TIL_seed_" + str(args.seed)
    np.savetxt(f"{args.save_path}{file_name_each_task}.txt", each_task_acc)
    np.savetxt(f"{args.save_path}{file_name_average_acc}.txt", [average_acc])
    np.savetxt(f"{args.save_path}{file_name_each_task_TIL}.txt", each_task_acc_TIL)
    np.savetxt(f"{args.save_path}{file_name_average_acc_TIL}.txt", [average_acc_TIL])

if __name__ == '__main__':
   main()