import os
import sys
import time
import numpy as np
from tqdm import tqdm
import copy

from arguments import get_args
from Alg.opt import *
from Alg.PCDG import PCDG
from Alg import modelopera
import Replay.alg as ReplayAlg
from datautil.getdataloader import get_img_dataloader
from utils.util import set_random_seed, save_checkpoint, print_args, train_valid_target_eval_names, alg_loss_dict, log_print
from train import train, balance_finetune
from utils.visual import save_plot_acc_epochs, fit_tSNE, visual_tSNE
from Alg.pLabel import assign_pseudo_label
from network.util import freeze_classifier, freeze_proxy

if __name__ == '__main__':
    args = get_args()
    set_random_seed(args.seed)
    log_print('################################################', args.log_file)
    log_print('############### Attention: arguments steps_per_epoch should be changed with batch_size and dataset ! ####################', args.log_file)
    log_print('command args: {}'.format(sys.argv[1:]), args.log_file)
    log_print('sourceAlg: {}   targetAlg: {}   pseudo_LabelAlg: {}   Replay algorithm: {}'.format(args.sourceAlg, args.targetAlg, args.pLabelAlg, args.replay), args.log_file)
    log_print('model: {}  pretrained: {}'.format(args.net, not args.no_pretrained), args.log_file)
    log_print('arguments: {}\n'.format(args), args.log_file, p=False)

    if args.targetAlg == 'LDAuCID':
        torch.set_num_threads(50)

    # Get Data
    train_loaders, eval_loaders, eval_name_dict, task_sequence_name = get_img_dataloader(args)

    # Model
    Alg_model = PCDG(args).cuda()
    old_model = None   # used for knwoledge distillation algorithms
    Replay_algorithm_class = ReplayAlg.get_algorithm_class(args.replay)
    Replay_algorithm = Replay_algorithm_class(args)
    Alg_model.train()

    # initial statistics metrics
    target_domain_acc_list = []
    source_domain_acc_list = []
    all_val_acc_record = {}  # list of record list for each task. e.g.'task0': [initial acc, [acc along training of task0], [acc along training of task1]...]
    for tid in range(len(eval_name_dict['valid'])):
        all_val_acc_record['task{}'.format(tid)] = [[modelopera.accuracy(Alg_model, eval_loaders[eval_name_dict['valid'][tid]])]]
    if args.tsne:
        tSNE_dict = {'features':[], 'clabels':[], 'dlabels':[]}
        tSNE_dict = fit_tSNE(args, Alg_model, eval_loaders, tSNE_dict)


    # incremental train different domains
    for task_id, dataloader in enumerate(train_loaders):
        
        # construct replay exemplars
        replay_dataset = Replay_algorithm.update_dataloader()
        # log_print('current training data size: {}'.format(len(curr_dataloader.dataset)), args.log_file)

        if task_id == 1 and args.freeze:
            if args.targetAlg in args.PCL_net:
                freeze_proxy(Alg_model)
            elif args.targetAlg in args.ERM_net:
                freeze_classifier(Alg_model)

        # main training
        if args.targetLR is not None and task_id>0:
            args.lr = args.targetLR
        Alg_model, val_acc_record, pseudo_dataloader = train(args, Alg_model, old_model, task_id, dataloader, replay_dataset, eval_loaders, eval_name_dict)
        for tid in range(len(eval_name_dict['valid'])):
            all_val_acc_record['task{}'.format(tid)].append(val_acc_record['task{}'.format(tid)])
        
        # show inter result.
        for tid in range(task_id+1):
            log_print('after task {}: {}'.format(tid, [all_val_acc_record['task{}'.format(i)][tid+1][-1] for i in range(len(eval_name_dict['valid']))]), args.log_file)

        # finish task
        # Alg_model.after_train(dataloader, task_id)
        Alg_model.after_train(pseudo_dataloader, task_id)
        Replay_algorithm.update(Alg_model, task_id, pseudo_dataloader)
        
        # balance training
        # if args.balance_finetune and task_id > 0:
        #     Alg_model = balance_finetune(args, Alg_model, task_id, Replay_algorithm, eval_loaders, eval_name_dict)
        
        if args.tsne:
            tSNE_dict = fit_tSNE(args, Alg_model, eval_loaders, tSNE_dict)

        # see if the trained model can assign correct pseudo label own its own training data.
        # replay_dataset = Replay_algorithm.update_dataloader()
        # _, _ = assign_pseudo_label(args, dataloader, replay_dataset, task_id, Alg_model, 0, cur=True)

        # save model after finishing a task. It will be used for knowledge distill algorithms
        save_checkpoint(args.saved_model_name, Alg_model, args)
        old_model = copy.deepcopy(Alg_model)
        Alg_model.cuda()
        old_model.cuda().eval()
    
    # log_print('\nDG accuracy on new tasks: {}    average: {}'.format(target_domain_acc_list, np.mean(target_domain_acc_list)), args.log_file)
    save_plot_acc_epochs(args, all_val_acc_record, task_sequence_name)
    if args.tsne:
        visual_tSNE(args, tSNE_dict)

    log_print('\nDGaccuracy matrix: ', args.log_file)
    log_print('at start: {}'.format([all_val_acc_record['task{}'.format(tid)][0][0] for tid in range(len(eval_name_dict['valid']))]), args.log_file)
    for tid in range(len(eval_name_dict['valid'])):
        log_print('after task {}: {}'.format(tid, [all_val_acc_record['task{}'.format(i)][tid+1][-1] for i in range(len(eval_name_dict['valid']))]), args.log_file)

    log_print('', args.log_file)




