import argparse
import os, sys
import os.path as osp
import torchvision
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import random, pdb, math, copy
from tqdm import tqdm
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
from ofa.model_zoo import proxylessnas_net
from ofa.utils import replace_conv2d_with_my_conv2d, replace_bn_with_gn
from ofa.utils import init_models
from .utils import KL_Loss

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "./../")))
from model import LiteResidualModule, build_network_from_config
from utils import set_module_grad_status, enable_bn_update, enable_bias_update
import model.network as network
import utils.loss as loss


class DatasetIndex(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        data, target = self.dataset[index]

        return data, target, index

    def __len__(self):
        return len(self.dataset)

def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer


def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
        param_group['weight_decay'] = 1e-3
        param_group['momentum'] = 0.9
        param_group['nesterov'] = True
    return optimizer


def image_train(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])


def image_test(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        normalize
    ])


def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size

    txt_path = os.path.join(args.root_path, args.tar)

    dsets["target"] = datasets.ImageFolder(txt_path, transform=image_train())
    dsets["target"] = DatasetIndex(dsets["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker,
                                        drop_last=False)

    dsets["test"] = datasets.ImageFolder(txt_path, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker,
                                      drop_last=False)


    return dset_loaders

def cal_acc(loader, netF, netB, netC, flag=False, args=None):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            features = netB(netF(inputs))
            outputs = netC(features)
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                all_feature = features.clone().detach().float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
                all_feature = torch.cat((all_feature, features.clone().detach().float()), 0)
    _, predict = torch.max(all_output, 1)
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
    if hasattr(args, 'class_savename'): # does result saving
        # arr = np.load(args.class_savename)
        # arr = np.concatenate((arr, np.array([torch.squeeze(predict).float().cpu().numpy(), all_label.cpu().numpy()]).T), axis=1)
        np.save(args.class_savename, np.array([torch.squeeze(predict).float().cpu().numpy(), all_label.cpu().numpy()]).T)
    if flag:
        matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
        acc = matrix.diagonal() / matrix.sum(axis=1) * 100
        aacc = acc.mean()
        aa = [str(np.round(i, 2)) for i in acc]
        acc = ' '.join(aa)
        return aacc, acc
    else:
        return accuracy * 100, mean_ent
def train_target(args, netF, netB, netC, optimizer, train_data_loader, test_data_loader, global_model):
    # ## set base network
    # if args.net[0:3] == 'res':
    #     netF = network.ResBase(res_name=args.net).cuda()
    # elif args.net[0:3] == 'vgg':
    #     netF = network.VGGBase(vgg_name=args.net).cuda()
    # elif args.net[0:3] == 'pro':
    #     net = proxylessnas_net(args.net, pretrained=True)
    #     netF = network.ProxylessBase(net).cuda()
    #
    # netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features,
    #                                bottleneck_dim=args.bottleneck).cuda()
    # netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda()
    #
    # modelpath = args.output_dir_src + '/source_F.pt'
    # netF.load_state_dict(torch.load(modelpath))
    # modelpath = args.output_dir_src + '/source_B.pt'
    # netB.load_state_dict(torch.load(modelpath))
    # modelpath = args.output_dir_src + '/source_C.pt'
    # netC.load_state_dict(torch.load(modelpath))
    # netC.eval()
    # for k, v in netC.named_parameters():
    #     v.requires_grad = False
    #
    # LiteResidualModule.insert_lite_residual(
    #     netF, args.lite_residual_downsample, 'bilinear', args.lite_residual_expand, args.lite_residual_ks,
    #     'relu', args.lite_residual_groups)
    # # replace bn layers with gn layers
    # # replace_bn_with_gn(netF, gn_channel_per_group=8)
    # netF.cuda()
    #
    # args.enable_feature_extractor_update = False
    # # args.enable_feature_extractor_update = True
    # args.enable_bn_update = True
    # args.enable_bias_update = True
    # args.enable_lite_residual = True
    # args.teacher_model = None
    #
    # set_module_grad_status(netF, args.enable_feature_extractor_update)
    # if args.enable_bn_update:
    #     enable_bn_update(netF)
    # if args.enable_bias_update:
    #     enable_bias_update(netF)
    # if args.enable_lite_residual:
    #     for m in netF.modules():
    #         if isinstance(m, LiteResidualModule):
    #             set_module_grad_status(m.lite_residual, True)
    #             if args.enable_bias_update or args.enable_bn_update:
    #                 m.lite_residual.final_bn.bias.requires_grad = False
    #             if args.random_init_lite_residual:
    #                 init_models(m.lite_residual)
    #                 m.lite_residual.final_bn.weight.data.zero_()
    #
    # param_group = []
    # for k, v in netF.named_parameters():
    #     if not k.__contains__('classifier'):
    #         param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
    #     else:
    #         v.requires_grad = False
    #
    # for k, v in netB.named_parameters():
    #     if args.lr_decay2 > 0:
    #         param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
    #     else:
    #         v.requires_grad = False
    #
    # optimizer = optim.SGD(param_group)
    # optimizer = op_copy(optimizer)
    netF.train()
    netB.train()

    max_iter = args.DA_max_epoch * len(train_data_loader)
    interval_iter = len(train_data_loader)*args.interval_const if hasattr(args, 'interval_const') else len(train_data_loader)
    iter_num = 0

    while iter_num < max_iter:
        try:
            inputs_test, labels, tar_idx = iter_test.next()
        except:
            iter_test = iter(train_data_loader)
            inputs_test, labels, tar_idx = iter_test.__next__()

        # print(tar_idx)

        if inputs_test.size(0) == 1:
            continue

        if iter_num % interval_iter == 0 and args.cls_par > 0:
            netF.eval()
            netB.eval()
            mem_label = obtain_label(test_data_loader, netF, netB, netC, args)
            # print(mem_label)
            mem_label = torch.from_numpy(mem_label).cuda()
            netF.train()
            netB.train()

        inputs_test = inputs_test.cuda()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)
        # print(outputs_test)

        if args.cls_par > 0:
            pred = mem_label[tar_idx]
            # pred = tar_idx.cuda()
            # print(pred)
            classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
            # print(classifier_loss)
            classifier_loss *= args.cls_par
            if iter_num < interval_iter and args.dset == "VISDA-C":
                classifier_loss *= 0
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss.Entropy(softmax_out))
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        if args.is_feedback:
            global_model.eval()
            global_logits = global_model(inputs_test)
            criterion_KL = KL_Loss(args.temperature)
            kd_loss = criterion_KL(outputs_test, global_logits)

            classifier_loss += args.kd_ratio * kd_loss

        # print(classifier_loss)

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        # if iter_num % interval_iter == 0 or iter_num == max_iter:
        #     netF.eval()
        #     netB.eval()
        #     if args.dset == 'VISDA-C':
        #         acc_s_te, acc_list = cal_acc(test_data_loader, netF, netB, netC, True)
        #         log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter,
        #                                                                     acc_s_te) + '\n' + acc_list
        #     else:
        #         acc_s_te, _ = cal_acc(test_data_loader, netF, netB, netC, False)
        #         log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te)
        #
        #     args.out_file.write(log_str + '\n')
        #     args.out_file.flush()
        #     print(log_str + '\n')
        #     netF.train()
        #     netB.train()

    # if args.issave:
    #     torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
    #     torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
    #     torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))
    return netF, netB, netC


def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s


def obtain_label(loader, netF, netB, netC, args):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for _ in range(len(loader)):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            feas = netB(netF(inputs))
            outputs = netC(feas)
            if start_test:
                all_fea = feas.float().cpu()
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)

    all_output = nn.Softmax(dim=1)(all_output)
    # print(all_output.size())
    ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
    unknown_weight = 1 - ent / np.log(args.class_num)
    _, predict = torch.max(all_output, 1)

    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    if args.distance == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()
    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)
    initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count > args.threshold)
    labelset = labelset[0]
    # print(labelset)

    dd = cdist(all_fea, initc[labelset], args.distance)
    pred_label = dd.argmin(axis=1)
    pred_label = labelset[pred_label]

    for round in range(1):
        aff = np.eye(K)[pred_label]
        initc = aff.transpose().dot(all_fea)
        initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
        dd = cdist(all_fea, initc[labelset], args.distance)
        pred_label = dd.argmin(axis=1)
        pred_label = labelset[pred_label]

    acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
    log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)

    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

    return pred_label.astype('int')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='SHOT')
    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--s', type=int, default=0, help="source")
    parser.add_argument('--t', type=int, default=1, help="target")
    parser.add_argument('--root_path', type=str, default='/zxc/DA_test/data/')
    parser.add_argument('--src', type=str, default='amazon')
    parser.add_argument('--tar', type=str, default='webcam')
    parser.add_argument('--max_epoch', type=int, default=30, help="max iterations")
    parser.add_argument('--interval', type=int, default=30)
    parser.add_argument('--batch_size', type=int, default=8, help="batch_size")
    parser.add_argument('--worker', type=int, default=4, help="number of workers")
    parser.add_argument('--dset', type=str, default='office',
                        choices=['VISDA-C', 'office', 'office-home', 'office-caltech'])
    parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--net', type=str, default='proxyless_mobile', help="alexnet, vgg16, resnet50, res101, proxyless_mobile")
    parser.add_argument('--seed', type=int, default=2020, help="random seed")

    parser.add_argument('--gent', type=bool, default=True)
    parser.add_argument('--ent', type=bool, default=True)
    parser.add_argument('--threshold', type=int, default=0)
    parser.add_argument('--cls_par', type=float, default=0.1)
    parser.add_argument('--ent_par', type=float, default=1.0)
    parser.add_argument('--lr_decay1', type=float, default=0.1)
    parser.add_argument('--lr_decay2', type=float, default=1.0)

    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--epsilon', type=float, default=1e-5)
    parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
    parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
    parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
    parser.add_argument('--output', type=str, default='ckps/target/')
    parser.add_argument('--output_src', type=str, default='ckps/source/')
    parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
    parser.add_argument('--issave', type=bool, default=True)
    """ lite residual module configs """
    parser.add_argument('--lite_residual_downsample', type=int, default=2)
    parser.add_argument('--lite_residual_expand', type=int, default=1)
    parser.add_argument('--lite_residual_groups', type=int, default=2)
    parser.add_argument('--lite_residual_ks', type=int, default=5)
    parser.add_argument('--random_init_lite_residual', action='store_true', default=True)
    args = parser.parse_args()

    if args.dset == 'office-home':
        names = ['Art', 'Clipart', 'Product', 'RealWorld']
        args.class_num = 65
    if args.dset == 'office':
        names = ['amazon', 'dslr', 'webcam']
        args.class_num = 31
    if args.dset == 'VISDA-C':
        names = ['train', 'validation']
        args.class_num = 12
    if args.dset == 'office-caltech':
        names = ['amazon', 'caltech', 'dslr', 'webcam']
        args.class_num = 10

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    # torch.backends.cudnn.deterministic = True

    for i in range(len(names)):
        if i == args.s:
            continue
        args.t = i

        folder = './data/'
        args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
        args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
        args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'

        if args.dset == 'office-home':
            if args.da == 'pda':
                args.class_num = 65
                args.src_classes = [i for i in range(65)]
                args.tar_classes = [i for i in range(25)]

        args.output_dir_src = osp.join(args.output_src, args.da, args.dset, args.net, names[args.s][0].upper())
        args.output_dir = osp.join(args.output, args.da, args.dset, args.net, names[args.s][0].upper() + names[args.t][0].upper())
        args.name = names[args.s][0].upper() + names[args.t][0].upper()

        if not osp.exists(args.output_dir):
            os.system('mkdir -p ' + args.output_dir)
        if not osp.exists(args.output_dir):
            os.mkdir(args.output_dir)

        args.savename = 'par_' + str(args.cls_par)
        if args.da == 'pda':
            args.gent = ''
            args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold)
        args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w')
        args.out_file.write(print_args(args) + '\n')
        args.out_file.flush()
        train_target(args)