import numpy as np
from tqdm import tqdm
import math
import random
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

from Alg import modelopera
from utils.util import save_checkpoint, log_print
from Replay.utils import construct_BF_dataloader
from DataAug.StyleTransfer import styleTransfer
from Alg.pLabel import assign_pseudo_label
from datautil.mydataloader import InfiniteDataLoader

def train(args, Alg_model, old_model, task_id, dataloader, replay_dataset, eval_loaders, eval_name_dict):
    '''
    train a task
    '''

    acc_record = {}
    all_val_acc_record = {}
    for tid in range(len(eval_name_dict['valid'])):
        all_val_acc_record['task{}'.format(tid)] = []
    best_valid_acc, target_acc = 0, 0

    max_epoch = args.max_epoch #if (task_id == 0) else int(args.max_epoch / args.epoch_reduce)
    Alg_model.get_optimizer(lr_decay=args.lr_decay1)
    Alg_model.optimizer = op_copy(Alg_model.optimizer)
    max_iter = max_epoch * args.steps_per_epoch
    iter_num = 0
    
    # Alg_model.get_scheduler()

    with tqdm(range(max_epoch)) as tepoch:
        tepoch.set_description(f"Task {task_id}")
        for epoch in tepoch:

            # progressly assign pseudo label
            if epoch % args.pseudo_fre == 0 and epoch<args.pseudo_max_epoch:
                pseudo_dataloader, plabel_sc = assign_pseudo_label(args, dataloader, replay_dataset, task_id, Alg_model, epoch)
                if args.targetAlg == 'LDAuCID' and task_id > 0:
                    curr_dataloader = pseudo_dataloader
                    replay_dataloader = InfiniteDataLoader(replay_dataset, weights=None, batch_size=args.batch_size, num_workers=args.N_WORKERS)
                else:
                    curr_dataloader = cat_pseudo_replay(args, pseudo_dataloader, replay_dataset)
                    replay_dataloader = None
            
            Alg_model.naug = 0 if task_id > 0 else args.batch_size*args.steps_per_epoch
            for iter_ in range(args.steps_per_epoch):     # make sure each tasks has the same training iters. 
                minibatches = [(data) for data in next(iter(curr_dataloader))]     # get data repeatly [imgs, class_label, domain_label].  But in original code, which use zip(*train_loaders), its shape is [domain1, domain2, domain3..], each domain's shape is [imgs, class_label, domain_label]
                iter_num+=1
                # lr_scheduler(Alg_model.optimizer, iter_num=iter_num, max_iter=max_iter)
                # train model
                # for minibatches in curr_dataloader:
                Alg_model.train()
                if task_id == 0:
                    step_vals = Alg_model.train_source(minibatches, task_id, epoch)
                else:
                    step_vals = Alg_model.adapt(minibatches, task_id, epoch, replay_dataloader, old_model)
                    
            # mix classifier and replay center
            # if args.sourceAlg in ['PCL2'] and args.mix_classifier and Alg_model.replay_center is not None:
            #     # Alg_model.classifier.data = args.classifier_mix_tau * Alg_model.classifier.data + (1-args.classifier_mix_tau) * Alg_model.replay_center
            #     Alg_model.classifier.data =  Alg_model.classifier.data + task_id * Alg_model.replay_center / (1+task_id)
            if args.sourceAlg in ['PCL2'] and args.mix_classifier and old_model is not None:
                # Alg_model.classifier.data = args.classifier_mix_tau * Alg_model.classifier.data + (1-args.classifier_mix_tau) * Alg_model.replay_center
                Alg_model.classifier.data =  args.classifier_mix_tau * Alg_model.classifier.data + (1-args.classifier_mix_tau) * old_model.classifier.data

            if not args.no_lr_sch:
                Alg_model.optimizer = lr_scheduler(Alg_model.optimizer, epoch, max_epoch)

            
            # only calculate accuracy of current domain
            for item in ['train', 'valid']:     
                acc_record[item] = np.mean(np.array([modelopera.accuracy(Alg_model, eval_loaders[eval_name_dict[item][task_id]])]))
            if plabel_sc is None:
                tepoch.set_postfix(**step_vals, **acc_record, naug=Alg_model.naug/(args.batch_size*args.steps_per_epoch))
            else: 
                tepoch.set_postfix(**step_vals, **acc_record, **plabel_sc, naug=Alg_model.naug/(args.batch_size*args.steps_per_epoch)) # show pseudo label accuracy
                # tepoch.set_postfix(**step_vals, **acc_record, naug=Alg_model.naug/(args.batch_size*args.steps_per_epoch))

            # record accuracy of validation data of all tasks along epochs.
            for tid in range(len(eval_name_dict['valid'])):
                all_val_acc_record['task{}'.format(tid)].append(modelopera.accuracy(Alg_model, eval_loaders[eval_name_dict['valid'][tid]]))
            if args.dataset == 'idomain_net' and epoch == max_epoch-1:
                print(all_val_acc_record)
            if acc_record['valid'] > best_valid_acc:
                best_valid_acc = acc_record['valid']
                # target_acc = acc_record['target']
                # save_checkpoint('model.pkl', Alg_model, args)
                # Alg_model.to(args.device)
    
    log_print('task{} training result on max_epoch{}: {} {}'.format(task_id, max_epoch, step_vals, acc_record), args.log_file, p=False)
        
    return Alg_model, all_val_acc_record, pseudo_dataloader


def cat_pseudo_replay(args, dataloader, replay_dataset):
    if replay_dataset is not None:
        dataset = torch.utils.data.ConcatDataset([dataloader.dataset, replay_dataset])    # when load this concated dataset, it will first fetch all the data from first dataset, then the second replay dataset. 
        dataloader = InfiniteDataLoader(dataset=dataset, weights=None, batch_size=args.batch_size, num_workers=args.N_WORKERS)
        # dataloader = DataLoader(dataset=dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.N_WORKERS)
    return dataloader

def balance_finetune():
    pass

def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer

def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
    return optimizer
