import argparse
import os
import os.path as osp
from distutils.util import strtobool

import numpy as np
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import network
from models.build_model import build_model
import random

from util import *
from trainer import *

def train(args):
    writer = SummaryWriter(args.output_dir)

    # set base network
    if args.net[0:3] == 'res':
        netF_list = [network.ResBase(res_name=args.net).cuda() for i in range(len(args.src))]
    elif args.net[0:3] == 'vit':
        netF_list = [network.ViTBase(vit_name=args.net).cuda() for i in range(len(args.src))]

    netC_list = [network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF_list[i].in_features).cuda() for i in range(len(args.src))]

    #load source models
    for i in range(len(args.src)):
        modelpath = args.output_dir_src[i] + '/source_F.pt'
        print(modelpath)
        netF_list[i].load_state_dict(torch.load(modelpath))
        netF_list[i].eval()

        modelpath = args.output_dir_src[i] + '/source_C.pt'
        print(modelpath)
        netC_list[i].load_state_dict(torch.load(modelpath))
        netC_list[i].eval()

        for k, v in netF_list[i].named_parameters():
            v.requires_grad = False
        for k, v in netC_list[i].named_parameters():
            v.requires_grad = False

    # load data
    print("Load images and source predictions....")
    dset_loaders = data_load(args, netF_list, netC_list)

    del netF_list
    del netC_list

    # set target network
    if args.tar_net[0:3] == 'res':
        target_netF = network.ResBase(res_name=args.tar_net).cuda()

    target_netB = network.feat_bottleneck(type=args.classifier, feature_dim=target_netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    #output graph
    target_netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
    
    #set pseudo network 
    prn = build_model(device=torch.device('cuda:0'), d_model= args.class_num, \
                        n_layer = args.pseudo_layer,  dr_rate=args.lr_decay1, \
                        residual_block=args.residual_block)

    # Add network graph to tensorboard
    iter_test_random = iter(dset_loaders["target"])
    inputs_test, logits_test, preds_test = iter_test_random.next()
    output = prn(preds_test.cuda())

    writer.add_graph(prn, output.cuda())


    # Estimate quality of initial source predictions
    acc_s_te = cal_acc_pseudo(args, dset_loaders['test'], False, False, None)
    before_warmup_acc = acc_s_te

    try:
        for  i in range(len(acc_s_te)):
            name = '{} to {}'.format(args.src[i], args.tar[0])
            log_str = 'Task: {}, Source prediction evaluation (Before Warmup) [Test] Accuracy = {:.2f}% '.format(name, acc_s_te[i]  )
            print(log_str)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            writer.add_scalar("Pseudo labeler/test accuracy-{}".format(args.src[i]), acc_s_te[i],-1)
    except:
        log_str = 'Task: {}, Source prediction evaluation (Before Warmup) [Test] Accuracy = {:.2f}% '.format(args.name,  acc_s_te )
        print(log_str)
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        writer.add_scalar("Pseudo labeler/test accuracy", acc_s_te,-1)
    


    #target model parameters
    param_group = []
    for k, v in target_netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in target_netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr }]
        else:
            v.requires_grad = False
    for k, v in target_netC.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr }]
        else:
            v.requires_grad = False

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer, args)

    #pseudo model parameters
    pseudo_param_group = []
    for k, v in prn.named_parameters():
        if args.lr_decay2 > 0:
            pseudo_param_group += [{'params': v, 'lr': args.pseudo_lr }]
        else:
            v.requires_grad = False
    pseudo_optimizer = optim.SGD(pseudo_param_group)
    pseudo_optimizer = op_copy(pseudo_optimizer, args)
    
    # pseudo net warm up step
    after_warmup_acc = warmup_prn(args, writer, pseudo_optimizer, dset_loaders, prn, target_netF, target_netB, target_netC, before_warmup_acc)


    # target training
    print("Train target model...")
    train_target(args, writer, optimizer, pseudo_optimizer, dset_loaders, prn, target_netF, target_netB, target_netC, after_warmup_acc)

    if args.issave:   
        torch.save(target_netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + "_final.pt"))
        torch.save(target_netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + "_final.pt"))
        torch.save(target_netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + "_final.pt"))
        torch.save(prn.state_dict(), osp.join(args.output_dir, "pseudo_" + args.savename + "_final.pt"))

    return target_netF, target_netB, target_netC, prn

def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='SHOT')
    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--worker', type=int, default=4, help="number of workers")
    parser.add_argument('--seed', type=int, default=2020, help="random seed")

    parser.add_argument('--s', type=int, default=0, help="source")
    parser.add_argument('--t', type=int, default=0, help="target")
    parser.add_argument('--dset', type=str, default='office', choices=[ 'office', 'office-home', 'office-caltech', 'domainnet'])
    parser.add_argument('--net', type=str, default='resnet101', choices = ["resnet101", "vit16"], help="source network")
    parser.add_argument('--tar_net', type=str, default='resnet101', choices=['resnet101'], help="target network")

    parser.add_argument('--max_epoch', type=int, default=2, help="max iterations")
    parser.add_argument('--warmup_epoch', type=int, default=0, help="warmup iterations")
    parser.add_argument('--refine_epoch', type=int, default=1, help="refine iterations")
    parser.add_argument('--interval', type=int, default=50)
    parser.add_argument('--pseudo_train_epoch_iter', type=int, default=1)

    parser.add_argument('--batch_size', type=int, default=16, help="batch_size")
    parser.add_argument('--lr', type=float, default=1e-3, help="learning rate")
    parser.add_argument('--pseudo_lr', type=float, default=1e-3, help="pseudo labeler learning rate")
    parser.add_argument('--warmup_lr', type=float, default=1e-1, help="pseudo labeler learning rate")
    parser.add_argument('--lr_decay1', type=float, default=0.1)
    parser.add_argument('--lr_decay2', type=float, default=0.1)
    parser.add_argument('--gamma', type=float, default=1.0)

    parser.add_argument('--pseudo_layer', type=int, default=1, help="batch_size")
    parser.add_argument('--cross_attention', type=lambda x:bool(strtobool(x)), nargs='?', const=True, default=True)
    parser.add_argument('--residual_block', type=lambda x:bool(strtobool(x)), nargs='?', const=True, default=True)

    parser.add_argument('--lambda_u', type=float, default=1.0)
    parser.add_argument('--lambda_d', type=float, default=1.0)
    parser.add_argument('--lambda_s', type=float, default=1.0)

    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--epsilon', type=float, default=1e-5)
    parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
    parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])

    parser.add_argument('--output', type=str, default='./ckps/tar')
    parser.add_argument('--output_src', type=str, default='.ckps/src/')
    parser.add_argument('--data_dir', type=str, default='./data/')

    parser.add_argument('--issave', type=bool, default=True)
    
    args = parser.parse_args()

    if args.dset == 'office-home':
        names = ['Art', 'Clipart', 'Product', 'Real_World']
        args.class_num = 65
        args.src_num = 3

    if args.dset == 'office':
        names = ['amazon', 'dslr', 'webcam']
        args.class_num = 31
        args.src_num = 2

    if args.dset == 'office-caltech':
        names = ['amazon', 'caltech', 'dslr', 'webcam']
        args.class_num = 10
        args.src_num = 3

    if args.dset == 'domainnet':
        names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
        args.class_num = 345
        args.src_num = 5

    args.src = []
    args.tar = []
    for i in range(len(names)):
        if i == args.t:
            args.tar.append(names[i])
        else:
            args.src.append(names[i])
    
    # pdb.set_trace()
    print(len(args.src))


    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    # torch.backends.cudnn.deterministic = True

    for i in range(len(names)):
        if i != args.t:
            continue

        args.t_dset_path = os.path.join(args.data_dir, args.dset , names[args.t] + '_list.txt')
        args.test_dset_path =  os.path.join(args.data_dir, args.dset, names[args.t] + '_list.txt')
        print(args.t_dset_path)

        # multi source
        args.output_dir_src = []
        for i in range(len(args.src)):
            args.output_dir_src.append(osp.join(args.output_src, args.dset, args.src[i][0].upper()))
        print(args.output_dir_src)

        print_name = 'pseudo_layer_' +str(args.pseudo_layer) +\
                     '_pseudo_refine_'+ str(args.pseudo_train_epoch_iter) +\
                     '_warmup_' +str(args.warmup_epoch) + '_refine_' +str(args.refine_epoch) +\
                     'lambda_u_' +str(args.lambda_u) +'lambda_d_' +str(args.lambda_d)+'lambda_s_' +str(args.lambda_s)
                        
        args.output_dir = osp.join(args.output, args.dset, names[args.t][0].upper(), str(args.seed), print_name )
        print(args.output_dir)

        name = ''
        for i in range(len(args.src)):
            name+=args.src[i][0].upper()
        args.name = name+ ' to ' +names[args.t][0].upper()

        if not osp.exists(args.output_dir):
            os.system('mkdir -p ' + args.output_dir)
        if not osp.exists(args.output_dir):
            os.makedirs(args.output_dir)

        args.out_file = open(osp.join(args.output_dir, 'log.txt'), 'w')
        args.out_file.write(print_args(args)+'\n')
        args.out_file.flush()

        train(args)