import copy
import logging
import os
import os.path
import sys
import time
from utils.toolkit import  accuracy_domain

import torch
from utils import factory
from utils.data_manager import DataManager
from utils.toolkit import count_parameters
import shutil 

def train(args):
    seed_list = copy.deepcopy(args['seed'])
    device = copy.deepcopy(args['device'])
    for seed in seed_list:
        args['seed'] = seed
        args['device'] = device
        if(args["prefix"]=="prefix_one_prompt"):
            _prefix_prompt_train(args)
            return
        else:
            _train(args)
            #raise ValueError('Unknown net: {}.'.format(args["net_type"]))
    myseed = 42069  # set a random seed for reproducibility
    # deterministic=true 每次返回的卷积算法将是确定的，即默认算法。
    # 如果配合上设置 Torch 的随机种子为固定值的话，应该可以保证每次运行网络的时候相同输入的输出是固定的。

    torch.backends.cudnn.deterministic = True
    torch.manual_seed(myseed)# # sets the seed for generating random numbers
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(myseed)# # Sets the seed for generating random numbers on all GPUs.
       
def _prefix_prompt_train(args):
    logfilename = './logs/{}_{}_{}_{}_'.format(args['model_name'],args['query_type'],
                                                args['dataset'], args['init_cls'])+ time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
    # logfilename = './logs/{}_{}_{}_{}_{}_{}_{}_'.format(args['prefix'], args['seed'], args['model_name'], args['net_type'],
    #     args['dataset'], args['init_cls'], args['increment'])+ time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())                                            
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(filename)s] => %(message)s',
        handlers=[
            logging.FileHandler(filename=logfilename + '.log'),
            logging.StreamHandler(sys.stdout)
        ]
    )
    os.makedirs(logfilename)
    print(logfilename)
    _set_random()
    _set_device(args)
    print_args(args)

    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'], args)
    args['class_order'] = data_manager._class_order
    model = factory.get_model(args['model_name'], args)
    cnn_curve, nme_curve = {'top1': []}, {'top1': []}

    # configs选了5个任务,选了5个域,nb_tasks因此设置为5
    for task in range(data_manager.nb_tasks):
        logging.info('All params: {}'.format(count_parameters(model._network)))
        logging.info('Trainable params: {}'.format(count_parameters(model._network, True)))
        model.begin_incremental(data_manager)
        # 简单三步走,训练\评估\事后处理(为了域增量第二阶段)
        # 模型执行，模型会加载到GPU中，并启动多个线程操作
        model.incremental_train(data_manager)
        cnn_accy, nme_accy = model.eval_task()
        model.after_task()

        if nme_accy is not None:
            logging.info('CNN: {}'.format(cnn_accy['grouped']))
            logging.info('NME: {}'.format(nme_accy['grouped']))#直接根据平均向量距离选
            cnn_curve['top1'].append(cnn_accy['grouped']['total'])
            nme_curve['top1'].append(nme_accy['top1'])
            logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
            logging.info('NME top1 curve: {}'.format(nme_curve['top1']))
        else:
            logging.info('CNN: {}'.format(cnn_accy['grouped']))
            cnn_curve['top1'].append(cnn_accy['grouped']['total'])# 记录历史CNN top1 curve
            logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
        #保存的模型参数数量不变，这里的task_序号代表已经训练了多少任务的模型
    torch.save(model, os.path.join(logfilename, "task_{}.pth".format(int(task))))

  
def _evaluate(model,y_pred, y_true):
    ret = {}
    grouped = accuracy_domain(y_pred.T[0], y_true, model._known_classes, class_num=model.class_num)
    ret['grouped'] = grouped
    ret['top1'] = grouped['total']
    #ret['top{}'.format(self.topk)] = np.around((y_pred.T == np.tile(y_true, (self.topk, 1))).sum()*100/len(y_true), decimals=2)
    return ret

def _train(args):
    logfilename = './logs/{}_{}_{}_{}_{}_{}_{}_'.format(args['prefix'], args['seed'], args['model_name'],args['net_type'],
                                                args['dataset'], args['init_cls'], args['increment'])+ time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
    # logfilename = './logs/{}_{}_{}_{}_{}_{}_{}_'.format(args['prefix'], args['seed'], args['model_name'], args['net_type'],
    #     args['dataset'], args['init_cls'], args['increment'])+ time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())                                            
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(filename)s] => %(message)s',
        handlers=[
            logging.FileHandler(filename=logfilename + '.log'),
            logging.StreamHandler(sys.stdout)
        ]
    )
    os.makedirs(logfilename)
    print(logfilename)
    _set_random()
    _set_device(args)
    print_args(args)
    # 初始化数据加载器、模型实例、评估曲线
    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'], args)
    args['class_order'] = data_manager._class_order
    model = factory.get_model(args['model_name'], args)
    cnn_curve, nme_curve = {'top1': []}, {'top1': []}

    # configs选了5个任务,选了5个域,nb_tasks因此设置为5
    for task in range(data_manager.nb_tasks):
        logging.info('All params: {}'.format(count_parameters(model._network)))
        logging.info('Trainable params: {}'.format(count_parameters(model._network, True)))
        model.begin_incremental(data_manager)
        # 简单三步走,训练\评估\事后处理(为了域增量第二阶段)
        model.incremental_train(data_manager)# 模型执行，模型会加载到GPU中，并启动多个线程操作
        
        cnn_accy, nme_accy = model.eval_task()
        model.after_task()

        if nme_accy is not None:
            logging.info('CNN: {}'.format(cnn_accy['grouped']))
            logging.info('NME: {}'.format(nme_accy['grouped']))#直接根据平均向量距离选
            cnn_curve['top1'].append(cnn_accy['top1'])
            nme_curve['top1'].append(nme_accy['top1'])
            logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
            logging.info('NME top1 curve: {}'.format(nme_curve['top1']))
        else:
            logging.info('CNN: {}'.format(cnn_accy['grouped']))
            cnn_curve['top1'].append(cnn_accy['top1'])# 记录历史CNN top1 curve
            logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
        #保存的模型参数数量不变，这里的task_序号代表已经训练了多少任务的模型
    torch.save(model, os.path.join(logfilename, "task_{}.pth".format(int(task))))

def _set_device(args):
    device_type = args['device']
    gpus = []

    for device in device_type:
        if device_type == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:{}'.format(device))

        gpus.append(device)

    args['device'] = gpus


def _set_random():
    torch.manual_seed(1) #为CPU中设置种子，生成随机数
    torch.cuda.manual_seed(1) #为特定GPU设置种子，生成随机数
    torch.cuda.manual_seed_all(1) #为所有GPU设置种子，生成随机数
    # 设置随机种子是为了确保每次生成固定的随机数，这就使得每次实验结果显示一致了，有利于实验的比较和改进
    # 每次返回的卷积算法将是确定的，即默认算法
    # 保证每次运行网络的时候,对相同输入，模型输出是固定的。
    # benchmark 设置False，是为了保证不使用选择卷积算法的机制，使用固定的卷积算法。
    # 但是，就算是固定的卷积算法，由于其实现不同，也可能是不可控制的，即相同的值，同一个算法卷积出来有细微差别，
    # deterministic设置True保证使用确定性的卷积算法，二者配合起来，才能保证卷积操作的一致性
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def print_args(args):
    for key, value in args.items():
        logging.info('{}: {}'.format(key, value))
