import argparse
import os, sys
import os.path as osp
import torchvision
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import network, loss
from torch.utils.data import DataLoader
from data_list import ImageList, ImageList_idx
import random, pdb, math, copy
from tqdm import tqdm
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
from sklearn.metrics.cluster import normalized_mutual_info_score
from utils_func import compute_accuracy, get_model
from plot_tsne import plot_tsne

def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer

def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    decay = (11 + gamma * iter_num / max_iter) ** (-power)
    # decay = (1 + gamma) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
        param_group['weight_decay'] = 1e-3
        param_group['momentum'] = 0.9
        param_group['nesterov'] = True
    return optimizer

def image_train(resize_size=256, crop_size=224, alexnet=False):
  if not alexnet:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
  else:
    normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
  return  transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

def image_test(resize_size=256, crop_size=224, alexnet=False):
  if not alexnet:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
  else:
    normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
  return  transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        normalize
    ])

def data_load(args): 
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_tar = open(args.t_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    tar_to_domain_id = {i:0 for i in range(len(txt_tar))}
    test_to_domain_id = {i:0 for i in range(len(txt_test))}
 
    dsets["target"] = ImageList_idx(txt_tar, idx_to_domain=tar_to_domain_id, transform=image_train(), subsample=args.sub_t)
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["test"] = ImageList_idx(txt_test, idx_to_domain=test_to_domain_id, transform=image_test(), subsample=args.sub_t)
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
    dsets["target_te"] = ImageList(txt_tar, idx_to_domain=tar_to_domain_id, transform=image_test(), subsample=args.sub_t)
    dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)

    return dset_loaders

def cal_acc(loader, netF, netB, netC, flag=False):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    all_output = nn.Softmax(dim=1)(all_output)
    _, predict = torch.max(all_output, 1)
    mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() / np.log(all_label.size()[0])
    nmi = normalized_mutual_info_score(all_label, predict)
    cluster_acc = compute_accuracy(predict, all_label, args.class_num, args.class_num)
    return mean_ent, predict, mean_ent, nmi, cluster_acc*100

def train_target(args):
    dset_loaders = data_load(args)

    netF = get_model(model_name=args.net, init_type=args.target_init).cuda()    
    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
  
    
    modelpath = args.output_dir + '/source_F.pt'   
    netF.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir + '/source_B.pt'   
    netB.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir + '/source_C.pt'    
    netC.load_state_dict(torch.load(modelpath))
    
    param_group = []
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': args.lr*0.1}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': args.lr}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': args.lr}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // 10
    iter_num = 0

    netF.eval()
    netB.eval()
    netC.eval()
    _, pry, mean_ent, nmi, cluster_acc = cal_acc(dset_loaders['test'], netF, netB, netC, False)
    log_str = 'Task: {}, Iter:{}/{}; NMI={:.2f}, Cluster Acc={:.2f}, Ent={:.3f}'.format(args.name, iter_num, max_iter, nmi, cluster_acc, mean_ent)

    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str+'\n')
    netF.train()
    netB.train()
    netC.train()

    old_pry = 0
    args.epochs = args.max_epoch
    args.iters_per_epoch = len(dset_loaders["target"]) 
    proto_criterion = loss.ProtoMultiDomainLoss(num_domains=1, num_classes=args.class_num, ot_reg=args.ot_reg, args=args)
    while iter_num < max_iter:
        optimizer.zero_grad()
        try:
            inputs_test, _, tar_idx, domain_id = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx, domain_id = iter_test.next()

        if inputs_test.size(0) == 1:
            continue

        inputs_test = inputs_test.cuda()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=0.75)

        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)

        softmax_out = nn.Softmax(dim=1)(outputs_test)
        entropy_loss = torch.mean(loss.Entropy(softmax_out))

        msoftmax = softmax_out.mean(dim=0)
        gentropy_loss = -torch.sum(msoftmax * torch.log(msoftmax + 1e-5))
        entropy_loss -= gentropy_loss
        pred_labels = torch.argmax(outputs_test, dim=1).detach()
         
        proto_loss = proto_criterion(netC.fc.weight, features_test, pred_labels, domain_id)

        total_loss = args.lambda_proto * proto_loss +  args.lambda_im * entropy_loss
        total_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            _, pry, mean_ent, nmi, cluster_acc = cal_acc(dset_loaders['test'], netF, netB, netC, False)
            log_str = 'Task: {}, Iter:{}/{}; NMI={:.2f}, Cluster Acc={:.2f}, Ent={:.3f}'.format(args.name, iter_num, max_iter, nmi, cluster_acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')
            netF.train()
            netB.train()
            netC.train()

            if torch.abs(pry - old_pry).sum() == 0:
                break
            else:
                old_pry = pry.clone()
                
    encoder = nn.Sequential(*[netF, netB])
    plot_tsne(netC, encoder, dset_loaders['test'], args) 
    return netF, netB, netC

def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='DINE')
    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--t', type=int, default=0, help="target")
    parser.add_argument('--max_epoch', type=int, default=30, help="max iterations")
    parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
    parser.add_argument('--worker', type=int, default=4, help="number of workers")
    parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'image-clef', 'office-home', 'office-caltech', 'PACS'])
    parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet50, resnext50")
    parser.add_argument('--net_src', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101")
    parser.add_argument('--seed', type=int, default=2021, help="random seed")

    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
    parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
    parser.add_argument('--output', type=str, default='san')
    parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
    parser.add_argument('--lambda_im', type=float, default=1.0)
    parser.add_argument('--lambda_proto', type=float, default=1.0)
    parser.add_argument('--target_init', type=str, default="sup", choices=["ssl", "sup"])
    parser.add_argument('--ot_reg', type=float, default=0.01)
    parser.add_argument('--sub_t', action='store_true',
                        help='whether to subsample target data')
    parser.add_argument('--beta', type=float, default=0.99)
 
    args = parser.parse_args()
    if args.dset == 'office-home':
        names = ['Art', 'Clipart', 'Product', 'RealWorld']
        args.class_num = 65 

    if args.dset == 'office':
        names = ['amazon', 'dslr', 'webcam']
        args.class_num = 31

    if args.dset == 'PACS':
        names = ['art_painting', 'cartoon', 'photo', 'sketch']
        args.class_num = 7

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    # torch.backends.cudnn.deterministic = True

    args.s = [x for x in range(len(names)) if x != args.t]
    folder = './data/'
    args.s_dset_path_list = []
    args.name_src = "".join([names[s][0].upper() for s in args.s])
 
    for i in range(len(names)):
        if i in args.s:
            continue
        args.t = i
        args.s_dset_path = folder + args.dset + '/' + names[i] + '_list.txt'
        args.s_dset_path_list.append(args.s_dset_path)
        args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
        args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'

        args.output_dir = osp.join(args.output, args.net_src + '_' + args.net, str(args.seed), args.da, args.dset, args.name_src+names[args.t][0].upper())
        args.name = args.name_src+names[args.t][0].upper()

        if not osp.exists(args.output_dir):
            os.system('mkdir -p ' + args.output_dir)
        if not osp.exists(args.output_dir):
            os.mkdir(args.output_dir)
        
        args.out_file = open(osp.join(args.output_dir, 'log_finetune.txt'), 'w')
        args.out_file.write(print_args(args)+'\n')
        args.out_file.flush()

        train_target(args)
