#coding=utf-8

import os,sys
import time
import numpy as np
import torch
# import torch.multiprocessing
# torch.multiprocessing.set_sharing_strategy('file_system')
import json
import collections

from alg.opt import *
from alg import alg,modelopera
from utils.util import set_random_seed,get_args,save_checkpoint,print_row,print_args,train_valid_target_eval_names,alg_loss_dict,MyEncoder,has_been_trained,print_environ
from datautil.getdataloader_single import get_act_dataloader

if __name__=='__main__':
    args=get_args()
    
    s=print_args(args,[])
    set_random_seed(args.seed)

    if has_been_trained(args.output):
        pass
    else:
        print_environ()
        print(s)
        if args.latent_domain_num<6:
            args.batch_size=32*args.latent_domain_num
        else:
            args.batch_size=16*args.latent_domain_num
    
        if args.task.startswith('cross'):
            train_loader,train_loader_noshuffle,valid_loader,target_loader,traindata,validdata,testdata=get_act_dataloader(args)

    
        algorithm_class = alg.get_algorithm_class(args.algorithm)
        algorithm = algorithm_class(args).cuda()
        algorithm.train()
        optd=get_optimizer(algorithm,args,nettype='TDBADV')
        schd=get_scheduler(optd,args)
        opt=get_optimizer(algorithm,args,nettype='TDBCLS')
        sch=get_scheduler(opt,args)
        opta=get_optimizer(algorithm,args,nettype='TDBALL')
        scha=get_scheduler(opt,args)

        for round in range(args.max_epoch):
            print('====round %d====='%round)
            print('====start obtain all features====')
            loss_list=['class']
            print_row(['epoch']+[item+'_loss' for item in loss_list],colwidth=15)

            for step in range(args.local_epoch):
                for data in train_loader:
                    # print(len(data[0]))
                    loss_result_dict=algorithm.update_a(data,opta,scha)
                print_row([step]+[loss_result_dict[item] for item in loss_list],colwidth=15)            

            print('====start domain splitting training====')
            loss_list=['total','dis','ent']
            print_row(['epoch']+[item+'_loss' for item in loss_list],colwidth=15)

            for step in range(args.local_epoch):
                for data in train_loader:
                    # print(len(data[0]))
                    loss_result_dict=algorithm.update_d(data,optd,schd)
                print_row([step]+[loss_result_dict[item] for item in loss_list],colwidth=15)

            algorithm.set_dlabel(train_loader)
            
            print('====start MixDANN class training====')

            loss_list=alg_loss_dict(args)
            eval_dict=train_valid_target_eval_names(args)
            print_key=['epoch']
            print_key.extend([item+'_loss' for item in loss_list])
            print_key.extend([item+'_acc' for item in eval_dict.keys()])
            print_key.append('total_cost_time')
            print_row(print_key,colwidth=15)
            best_valid_acc,target_acc=0,0

            last_results_keys = None
            sss=time.time()
            for step in range(args.local_epoch):
                step_start_time = time.time()
                for data in train_loader:
                    step_vals = algorithm.update(data,opt,sch)
                

                results = {
                    'epoch': step,
                }

                results['train_acc']=modelopera.accuracy(algorithm,train_loader_noshuffle,None) 

                acc = modelopera.accuracy(algorithm, valid_loader, None)
                results['valid_acc'] = acc

                acc = modelopera.accuracy(algorithm, target_loader, None)
                results['target_acc'] = acc
                    
                for key in loss_list:
                    results[key+'_loss']=step_vals[key]
                if results['valid_acc']>best_valid_acc:
                    best_valid_acc=results['valid_acc']
                    target_acc=results['target_acc']
                    save_checkpoint('modelbest.pkl', algorithm, args,opt,sch,optd,schd)
                    algorithm.cuda()
                results['total_cost_time']=time.time()-sss
                print_row([results[key] for key in print_key],colwidth=15)

                results.update({
                    'args': vars(args)
                })
                epochs_path = os.path.join(args.output, 'results.jsonl')
                with open(epochs_path, 'a') as f:
                    f.write(json.dumps(results, sort_keys=True,cls=MyEncoder) + "\n")

                algorithm_dict = algorithm.state_dict()
                start_step = step + 1

        save_checkpoint('model.pkl',algorithm,args,opt,sch,optd,schd)
        print('target acc:%.4f'%target_acc)
        with open(os.path.join(args.output, 'newdone'), 'w') as f:
            f.write('done\n')
            f.write('total cost time:%s\n'%(str(time.time()-sss)))
            f.write('target acc:%.4f\n' % (target_acc))
            f.write('valid acc:%.4f' % (best_valid_acc))