import os
import sys
import time
import copy
import torch
import random
import logging
import platform
import numpy as np
import datetime
from utils import factory
from utils.data_manager import DataManager
from utils.toolkit import count_parameters,tsne
import copy
import matplotlib.pyplot as plt
import seaborn as sns

        
from utils.toolkit import compute_hit_matrix, information_metrics


def train(args):
    seed_list = copy.deepcopy(args["seed"])
    device = copy.deepcopy(args["device"])

    for seed in seed_list:
        args["seed"] = seed
        args["device"] = device
       
        if args['test']:
            _test(args)
        else:
            _train(args)
        

def _train(args):

    init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
        
    if args['dnm']:
        logs_name = "{}/{}/{}_{}_{}".format(args["log_name"],args["model_name"],args["dataset"], init_cls, args['increment'])
        
        if not os.path.exists(logs_name):
            os.makedirs(logs_name)

        logfilename = "{}/{}_{}_{}_{}_{}_{}_{}_{}_{}".format(
            logs_name,
            args["prefix"],
            args["seed"],
            args["convnet_type"],
            args["num_branch"],
            args["synapse_activation"],
            args["dendritic_activation"],
            args["soma"],
            args["sn"],
            args["dn"],
        )
    else:
        logs_name = "{}/{}/{}_{}_{}".format(args["log_name"],args["model_name"],args["dataset"], init_cls, args['increment'])
        
        if not os.path.exists(logs_name):
            os.makedirs(logs_name)

        logfilename = "{}/{}_{}_{}".format(
            logs_name,
            args["prefix"],
            args["seed"],
            args["convnet_type"],
        )


    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(filename)s %(lineno)d] => %(message)s",
        handlers=[
            logging.FileHandler(filename=logfilename + ".log"),
            logging.StreamHandler(sys.stdout),
        ],
    )

    # Save run environments
    logging.info('python version: {}'.format(sys.version))
    logging.info('pytorch version: {}'.format(torch.__version__))
    logging.info('OS info: {}'.format(platform.uname()))
    
    start_ = time.time()

    _set_random(args["seed"])
    _set_device(args)
    print_args(args)
    
    data_manager = DataManager(
        args["dataset"],
        args["shuffle"],
        args["seed"],
        args["init_cls"],
        args["increment"],
        args
    )
    

    model = factory.get_model(args["model_name"], args)
    
    cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []}
    cnn_matrix, nme_matrix = [], []

    for task in range(data_manager.nb_tasks):
        logging.info("All params: {}".format(count_parameters(model._network)))
        logging.info(
            "Trainable params: {}".format(count_parameters(model._network, True))
        )
        model.incremental_train(data_manager)
        cnn_accy, nme_accy = model.eval_task()
        model.after_task()

        if nme_accy is not None:
            logging.info("CNN: {}".format(cnn_accy["grouped"]))
            logging.info("NME: {}".format(nme_accy["grouped"]))

            cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
            cnn_keys_sorted = sorted(cnn_keys)
            cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
            cnn_matrix.append(cnn_values)

            nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key]
            nme_keys_sorted = sorted(nme_keys)
            nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted]
            nme_matrix.append(nme_values)


            cnn_curve["top1"].append(cnn_accy["top1"])
            cnn_curve["top5"].append(cnn_accy["top5"])

            nme_curve["top1"].append(nme_accy["top1"])
            nme_curve["top5"].append(nme_accy["top5"])

            logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
            logging.info("CNN top5 curve: {}".format(cnn_curve["top5"]))
            logging.info("NME top1 curve: {}".format(nme_curve["top1"]))
            logging.info("NME top5 curve: {}\n".format(nme_curve["top5"]))

            print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
            print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"]))

            logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
            logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"])))
        else:
            logging.info("No NME accuracy.")
            logging.info("CNN: {}".format(cnn_accy["grouped"]))

            cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
            cnn_keys_sorted = sorted(cnn_keys)
            cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
            cnn_matrix.append(cnn_values)

            cnn_curve["top1"].append(cnn_accy["top1"])
            cnn_curve["top5"].append(cnn_accy["top5"])

            logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
            logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"]))

            print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
            logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))

        if args['save_model']:
            save_dict = model._network.state_dict()
            if args['dnm']:
                ckpt_file_name = f"{logfilename}_{data_manager.dataset_name}_{task}_{data_manager.nb_tasks}_dnm.ckpt"
            else:
                ckpt_file_name = f"{logfilename}_{data_manager.dataset_name}_{task}_{data_manager.nb_tasks}.ckpt"
            
            
            torch.save(save_dict,os.path.join(ckpt_file_name) )
                
                
            
            
    logging.info('Total Time: {}'.format(datetime.timedelta(seconds=int(time.time()-start_))))

    

    if len(cnn_matrix)>0:
        np_acctable = np.zeros([task + 1, task + 1])
        for idxx, line in enumerate(cnn_matrix):
            idxy = len(line)
            np_acctable[idxx, :idxy] = np.array(line)
        np_acctable = np_acctable.T
        forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task])
        print('Accuracy Matrix (CNN):')
        print(np_acctable)
        logging.info('Accuracy Matrix (CNN):\n{}'.format(np_acctable))
        print('Forgetting (CNN):', forgetting)
        logging.info('Forgetting (CNN): {}'.format(forgetting))
    if len(nme_matrix)>0:
        np_acctable = np.zeros([task + 1, task + 1])
        for idxx, line in enumerate(nme_matrix):
            idxy = len(line)
            np_acctable[idxx, :idxy] = np.array(line)
        np_acctable = np_acctable.T
        forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task])
        print('Accuracy Matrix (NME):')
        logging.info('Accuracy Matrix (NME):\n{}'.format(np_acctable))
        print(np_acctable)
        print('Forgetting (NME):', forgetting)
        logging.info('Forgetting (NME): {}'.format(forgetting))


def _test(args):

    init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
        
    if args['dnm']:
        logs_name = "{}/{}/{}_{}_{}".format(args["log_name"],args["model_name"],args["dataset"], init_cls, args['increment'])
        
        if not os.path.exists(logs_name):
            os.makedirs(logs_name)

        logfilename = "{}/{}_{}_{}_{}_{}_{}_{}_{}_{}".format(
            logs_name,
            args["prefix"],
            args["seed"],
            args["convnet_type"],
            args["num_branch"],
            args["synapse_activation"],
            args["dendritic_activation"],
            args["soma"],
            args["sn"],
            args["dn"],
        )
    else:
        logs_name = "{}/{}/{}_{}_{}".format(args["log_name"],args["model_name"],args["dataset"], init_cls, args['increment'])
        
        if not os.path.exists(logs_name):
            os.makedirs(logs_name)

        logfilename = "{}/{}_{}_{}".format(
            logs_name,
            args["prefix"],
            args["seed"],
            args["convnet_type"],
        )


    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(filename)s %(lineno)d] => %(message)s",
        handlers=[
            logging.FileHandler(filename=logfilename + "_test.log"),
            logging.StreamHandler(sys.stdout),
        ],
    )

    # Save run environments
    logging.info('python version: {}'.format(sys.version))
    logging.info('pytorch version: {}'.format(torch.__version__))
    logging.info('OS info: {}'.format(platform.uname()))
    
    start_ = time.time()

    _set_random(args["seed"])
    _set_device(args)
    print_args(args)
    
    data_manager = DataManager(
        args["dataset"],
        args["shuffle"],
        args["seed"],
        args["init_cls"],
        args["increment"],
        args
    )
    

    model = factory.get_model(args["model_name"], args)
    cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []}
    cnn_matrix, nme_matrix = [], []
    nbtask = range(data_manager.nb_tasks)
    for task in nbtask:
        logging.info("All params: {}".format(count_parameters(model._network)))
        logging.info(
            "Trainable params: {}".format(count_parameters(model._network, True))
        )

        model.incremental_test(data_manager)
        model._network.to('cuda')
            
        if args['dnm']:
                ckpt_file_name = f"{args['model_name']}_{data_manager.dataset_name}_{task}_{data_manager.nb_tasks}_dnm.ckpt"
        else:
                ckpt_file_name = f"{args['model_name']}_{data_manager.dataset_name}_{task}_{data_manager.nb_tasks}.ckpt"
        path = os.path.join(logs_name,ckpt_file_name)

        # if args['dnm']:
        #         ckpt_file_name = f"{logfilename}_{data_manager.dataset_name}_{task}_{data_manager.nb_tasks}_dnm.ckpt"
        # else:
        #         ckpt_file_name = f"{logfilename}_{data_manager.dataset_name}_{task}_{data_manager.nb_tasks}.ckpt"
        # path = os.path.join(ckpt_file_name)

        print('load:',path)
        model._network.load_state_dict(torch.load(path, weights_only=True))

        
        cnn_accy, nme_accy = model.eval_task()

        #### generation of T-SNE
        # tsne_path = './tsne/'
        # tsne(args,model,task,tsne_path)



        # DNM analysis
        
        
        features_list = [] #np.zeros(nsam,ncla,2,512)
        logits_list = []
        sa_x_list = []
        true_labels_list = []
        temp_input = []
        
        
        for _, (_, inputs, targets) in enumerate(model.test_loader):           
            inputs = inputs.to(model._device)
            
            with torch.no_grad():
                outputs = model._network(inputs)
                      
            temp_input.append(inputs.detach().cpu().numpy())

            if args['dnm']:
                sa_x_list.append(outputs['sa_x'].cpu().numpy())

            logits_list.append(outputs['logits'].cpu().numpy())
            features_list.append(outputs['features'].cpu().numpy())
            true_labels_list.append(targets.detach().cpu().numpy())            

            # class_list[targets.cpu()]=True
            
        
        if args['dnm']:
            sa_x_arr = np.squeeze(np.array(sa_x_list))
            print(sa_x_arr.shape)

        features_arr = np.squeeze(np.array(features_list))
        true_labels_arr= np.squeeze(np.array(true_labels_list))
        logits_arr= np.squeeze(np.array(logits_list))
        print(features_arr.shape,true_labels_arr.shape,logits_arr.shape)

        if args['dnm']:
            np.save(f'{logfilename}_sa_x_{task}.npy',sa_x_arr)
            np.save(f'{logfilename}_features_{task}.npy',features_arr)
            np.save(f'{logfilename}_labels_{task}.npy',true_labels_arr)
            np.save(f'{logfilename}_logits_{task}.npy',logits_arr)
        else:
            np.save(f'{logfilename}_features_{task}.npy',features_arr)
            np.save(f'{logfilename}_labels_{task}.npy',true_labels_arr)
            np.save(f'{logfilename}_logits_{task}.npy',logits_arr)

        model.after_task()
        
        # if nme_accy is not None:
        #     logging.info("CNN: {}".format(cnn_accy["grouped"]))
        #     logging.info("NME: {}".format(nme_accy["grouped"]))

        #     cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
        #     cnn_keys_sorted = sorted(cnn_keys)
        #     cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
        #     cnn_matrix.append(cnn_values)

        #     nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key]
        #     nme_keys_sorted = sorted(nme_keys)
        #     nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted]
        #     nme_matrix.append(nme_values)


        #     cnn_curve["top1"].append(cnn_accy["top1"])
        #     cnn_curve["top5"].append(cnn_accy["top5"])

        #     nme_curve["top1"].append(nme_accy["top1"])
        #     nme_curve["top5"].append(nme_accy["top5"])

        #     logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
        #     logging.info("CNN top5 curve: {}".format(cnn_curve["top5"]))
        #     logging.info("NME top1 curve: {}".format(nme_curve["top1"]))
        #     logging.info("NME top5 curve: {}\n".format(nme_curve["top5"]))

        #     print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
        #     print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"]))

        #     logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
        #     logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"])))
        # else:
        #     logging.info("No NME accuracy.")
        #     logging.info("CNN: {}".format(cnn_accy["grouped"]))

        #     cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
        #     cnn_keys_sorted = sorted(cnn_keys)
        #     cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
        #     cnn_matrix.append(cnn_values)

        #     cnn_curve["top1"].append(cnn_accy["top1"])
        #     cnn_curve["top5"].append(cnn_accy["top5"])

        #     logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
        #     logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"]))

        #     print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
        #     logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
        

                
                

            
    logging.info('Total Time: {}'.format(datetime.timedelta(seconds=int(time.time()-start_))))

    

    if len(cnn_matrix)>0:
        np_acctable = np.zeros([task + 1, task + 1])
        for idxx, line in enumerate(cnn_matrix):
            idxy = len(line)
            np_acctable[idxx, :idxy] = np.array(line)
        np_acctable = np_acctable.T
        forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task])
        print('Accuracy Matrix (CNN):')
        print(np_acctable)
        logging.info('Accuracy Matrix (CNN):\n{}'.format(np_acctable))
        print('Forgetting (CNN):', forgetting)
        logging.info('Forgetting (CNN): {}'.format(forgetting))
    if len(nme_matrix)>0:
        np_acctable = np.zeros([task + 1, task + 1])
        for idxx, line in enumerate(nme_matrix):
            idxy = len(line)
            np_acctable[idxx, :idxy] = np.array(line)
        np_acctable = np_acctable.T
        forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task])
        print('Accuracy Matrix (NME):')
        logging.info('Accuracy Matrix (NME):\n{}'.format(np_acctable))
        print(np_acctable)
        print('Forgetting (NME):', forgetting)
        logging.info('Forgetting (NME): {}'.format(forgetting))

def _set_device(args):
    device_type = args["device"]
    gpus = []

    for device in device_type:
        if device == -1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda:{}".format(device))

        gpus.append(device)

    args["device"] = gpus


def _set_random(seed):
    random.seed(seed)
    np.random.random(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
def print_args(args):
    for key, value in args.items():
        logging.info("{}: {}".format(key, value))
