import sys
import logging
import copy
import torch
from utils import factory
from utils.data_manager import DataManager
from utils.toolkit import count_parameters, save_results_to_excel, convert_time, get_device_name
import os
import time
import json
import numpy as np


def train(args):
    torch.cuda.empty_cache()
    seed = copy.deepcopy(args["seed"])
    device = copy.deepcopy(args["device"])
    dataset = copy.deepcopy(args["dataset"])
    model_dir =  copy.deepcopy(args["model_dir"])

    args["model_dir"] = "{}/{}/seed_{}/".format(model_dir, args["model_name"], seed)

    merge_json = {**args, **args[dataset]}
    new_args = json.dumps(merge_json)
    new_args = json.loads(new_args)
    _train(new_args)


def _train(args):
    init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
    logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment'])

    if not os.path.exists(logs_name):
        os.makedirs(logs_name)

    logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format(
        args["model_name"],
        args["dataset"],
        init_cls,
        args["increment"],
        args["prefix"],
        args["seed"],
        args["convnet_type"],
    )

    if args["print_info"]:
        logging.basicConfig(
            
            level=logging.INFO,
            format="%(asctime)s [%(filename)s] => %(message)s",
            handlers=[
                logging.FileHandler(filename=logfilename + ".log"), 
                logging.StreamHandler(sys.stdout),
            ],
        )
    else:
        logging.basicConfig(
            
            level=logging.INFO,
            format="%(asctime)s [%(filename)s] => %(message)s",
            handlers=[
                logging.FileHandler(filename=logfilename + ".log"), 
            ],
        )
    
    _set_random(args["seed"])
    device_name_list = get_device_name(args["device"])
    _set_device(args)


    print_args(args)

    data_manager = DataManager(
        args["dataset"],
        args["shuffle"],
        args["seed"],
        args["init_cls"],
        args["increment"],
    )

    model = factory.get_model(args["model_name"], args)

    result_for_record = {
        'top1_acc': [],
        'top5_acc': [],
    }
    cnn_curve, nme_curve, maha_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []}, {"top1": [], "top5": []}
    grouped_top1_acc = []   
    start_time = time.time()
    for task in range(data_manager.nb_tasks):
        logging.info("All params: {}".format(count_parameters(model._network)))
        logging.info("Trainable params: {}".format(count_parameters(model._network, True)))
        
        model.incremental_train(data_manager)
        cnn_accy, nme_accy = model.eval_task() 
        """
        cnn_accy = {
            'grouped': {
                "total": 0.0,
                "old": 0.0,
                "new": 0.0,
                "0-9": 0.0,
                "10-19": 0.0,
                "20-29": 0.0,
                "30-39": 0.0,
                "40-49": 0.0,
                }
            'top1': 0.0,
            'top5': 0.0
        }
        """
        logging.info("ball_r:{}".format(str(model._balls._get_radius_orig())))
        model.after_task()
        
        if nme_accy is not None:
            logging.info("CNN: {}".format(cnn_accy["grouped"]))
            logging.info("NME: {}".format(nme_accy["grouped"]))

            cnn_curve["top1"].append(cnn_accy["top1"])
            nme_curve["top1"].append(nme_accy["top1"])

            cnn_curve["top1"].append(cnn_accy["top1"])
            nme_curve["top5"].append(nme_accy["top5"])
            

            logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
            logging.info("NME top1 curve: {}".format(nme_curve["top1"]))
            grouped_top1_acc.append(cnn_accy["grouped"])
        else:
            logging.info("No NME accuracy.")
            logging.info("CNN: {}".format(cnn_accy["grouped"]))
            cnn_curve["top1"].append(cnn_accy["top1"])
            cnn_curve["top5"].append(cnn_accy["top5"])

            logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
            logging.info("CNN top5 curve: {}".format(cnn_curve["top5"]))

            cnn_accy["grouped"].pop("total")
            cnn_accy["grouped"].pop("old")
            cnn_accy["grouped"].pop("new")
            grouped_top1_acc.append(cnn_accy["grouped"])
    
    used_time = convert_time(time.time() - start_time)
    formatted_grouped_top1_acc = ",\n".join([str(item) for item in grouped_top1_acc])
    result_for_record["top1_acc"].append(
        (args['model_name'], str(args), cnn_curve["top1"], np.mean(cnn_curve["top1"]), formatted_grouped_top1_acc)
    )

    if args["NCM"] == False:
        result_for_record["top5_acc"].append(
            (args['model_name'], str(args), cnn_curve["top5"], np.mean(cnn_curve["top5"]), '-')
        )


    save_results_to_excel(
        args["dataset"],
        args["model_name"] + args["suffix_res_file"],
        str(args["init_cls"])+"_"+str(args["increment"]),
        results= result_for_record,
        runing_time= used_time,
        seed= args["seed"],
        device=str(device_name_list), 
        note=str([args["note"]]), 
    )

def _set_device(args):
    device_type = args["device"] 
    gpus = []

    for device in device_type:
        if device_type == -1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda:{}".format(device))

        gpus.append(device)

    args["device"] = gpus


def _set_random(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def print_args(args):
    for key, value in args.items():
        logging.info("{}: {}".format(key, value))
