import argparse
import os, sys
import os.path as osp
import torchvision
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import network, loss
from torch.utils.data import DataLoader
from data_list import ImageList, ImageList_idx
import random, pdb, math, copy
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F

from randaugment import rand_augment_transform
from utils import *

def Entropy(input_):
    bs = input_.size(0)
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=1)
    return entropy


def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group["lr0"] = param_group["lr"]
    return optimizer


def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group["lr"] = param_group["lr0"] * decay
        param_group["weight_decay"] = 1e-3
        param_group["momentum"] = 0.9
        param_group["nesterov"] = True
    return optimizer


def image_train(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    else:
        normalize = Normalize(meanfile="./ilsvrc_2012_mean.npy")
    return transforms.Compose(
        [
            transforms.Resize((resize_size, resize_size)),
            transforms.RandomCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    )


def image_target(resize_size=256, crop_size=224):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    rgb_mean = (0.485, 0.456, 0.406)
    ra_params = dict(translate_const=int(224 * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in rgb_mean]),)
    return transforms.Compose(
        [
            transforms.Resize((resize_size, resize_size)),
            transforms.RandomCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    ), transforms.Compose(
        [
            transforms.Resize((resize_size, resize_size)),
            transforms.RandomCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.0)
            ]),
            rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10), ra_params),
            transforms.ToTensor(),
            normalize,
        ]
    )


def image_test(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    else:
        normalize = Normalize(meanfile="./ilsvrc_2012_mean.npy")
    return transforms.Compose(
        [
            transforms.Resize((resize_size, resize_size)),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            normalize,
        ]
    )


def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = open(args.s_dset_path).readlines()
    txt_tar = open(args.t_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    dsize = len(txt_src)
    tr_size = int(0.9 * dsize)
    # print(dsize, tr_size, dsize - tr_size)
    _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
    tr_txt = txt_src

    dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
    dset_loaders["source_tr"] = DataLoader(
        dsets["source_tr"],
        batch_size=train_bs,
        shuffle=True,
        num_workers=args.worker,
        drop_last=False,
    )
    dsets["source_te"] = ImageList(te_txt, transform=image_test())
    dset_loaders["source_te"] = DataLoader(
        dsets["source_te"],
        batch_size=train_bs,
        shuffle=True,
        num_workers=args.worker,
        drop_last=False,
    )
    dsets["target"] = ImageList_idx(txt_tar, transform=image_target())
    dset_loaders["target"] = DataLoader(
        dsets["target"],
        batch_size=train_bs,
        shuffle=True,
        num_workers=args.worker,
        drop_last=False,
    )
    dsets["test"] = ImageList_idx(txt_test, transform=image_test())
    dset_loaders["test"] = DataLoader(
        dsets["test"],
        batch_size=train_bs * 3,
        shuffle=False,
        num_workers=args.worker,
        drop_last=False,
    )

    return dset_loaders


def hyper_decay(x, beta=-2, alpha=1):
    weight = (1 + 10 * x) ** (-beta) * alpha
    return weight


def train_target(args):
    dset_loaders = data_load(args)
    ## set base network
    netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bottleneck(
        type=args.classifier,
        feature_dim=netF.in_features,
        bottleneck_dim=args.bottleneck,
    ).cuda()
    netC = network.feat_classifier(
        type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck
    ).cuda()

    modelpath = args.output_dir_src + "/source_F.pt"
    netF.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + "/source_B.pt"
    netB.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + "/source_C.pt"
    netC.load_state_dict(torch.load(modelpath))
    
    netR = network.ClusterNet(class_num=args.class_num, backbone=args.ref_backbone).cuda()

    param_group = []
    param_group_c = []
    for k, v in netF.named_parameters():
        # if k.find('bn')!=-1:
        if True:
            param_group += [{"params": v, "lr": args.lr * 0.1}]  # 0.1

    for k, v in netB.named_parameters():
        if True:
            param_group += [{"params": v, "lr": args.lr * 1}]  # 1
    for k, v in netC.named_parameters():
        param_group_c += [{"params": v, "lr": args.lr * 1}]  # 1

    param_group_c += [{"params": netR.cluster_head.parameters(), "lr": args.cluster_lr}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    optimizer_c = optim.SGD(param_group_c)
    optimizer_c = op_copy(optimizer_c)

    # building feature bank and score bank
    loader = dset_loaders["target"]
    num_sample = len(loader.dataset)
    fea_bank = torch.randn(num_sample, 256)
    score_bank = torch.randn(num_sample, 12).cuda()

    netF.eval()
    netB.eval()
    netC.eval()
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = next(iter_test)
            inputs = data[0][0]
            indx = data[-1]
            inputs = inputs.cuda()
            output = netB(netF(inputs))
            output_norm = F.normalize(output)
            outputs = netC(output)
            outputs = nn.Softmax(dim=-1)(outputs)

            fea_bank[indx] = output_norm.detach().clone().cpu()
            score_bank[indx] = outputs.detach().clone()  # .cpu()

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_num = 0
    acc_log = 0
    accc_log = ""

    netF.train()
    netB.train()
    netC.train()
    netR.train()

    cluster_criterion = DistillLoss(
        5 * len(dset_loaders["target"]),
        max_iter,
        2,
        0.07,
        0.04,
    )

    real_max_iter = max_iter

    while iter_num < real_max_iter:
        try:
            inputs_test, _, tar_idx = next(iter_test)
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = next(iter_test)

        if inputs_test[0].size(0) == 1:
            continue

        inputs_target, inputs_target_u = inputs_test[0].cuda(), inputs_test[1].cuda()
        if True:
            alpha = (1 + 10 * iter_num / max_iter) ** (-args.beta) * args.alpha
        else:
            alpha = args.alpha

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
        lr_scheduler(optimizer_c, iter_num=iter_num, max_iter=max_iter)

        features_test = netB(netF(torch.cat((inputs_target, inputs_target_u), dim=0)))
        features_f, features_f_u = torch.chunk(features_test, chunks=2, dim=0)
        outputs_test = netC(features_test)
        output, output_u = torch.chunk(outputs_test, chunks=2, dim=0)
        softmax_test = nn.Softmax(dim=1)(outputs_test)
        softmax_out, softmax_out_u = torch.chunk(softmax_test, chunks=2, dim=0)

        out_r = netR(torch.cat((inputs_target, inputs_target_u), dim=0))
        features_r, features_r_u = torch.chunk(out_r['x_proj'], chunks=2, dim=0)
        cluster_logits_r, cluster_logits_r_u = torch.chunk(out_r['cluster_logits'], chunks=2, dim=0)

        with torch.no_grad():
            output_f_norm = F.normalize(features_f)
            output_f_ = output_f_norm.cpu().detach().clone()

            fea_bank[tar_idx] = output_f_.detach().clone().cpu()
            score_bank[tar_idx] = softmax_out.detach().clone()

            distance = output_f_ @ fea_bank.T
            _, idx_near = torch.topk(distance, dim=-1, largest=True, k=args.K + 1)
            idx_near = idx_near[:, 1:]  # batch x K
            score_near = score_bank[idx_near]  # batch x K x C

        # nn
        softmax_out_un = softmax_out.unsqueeze(1).expand(
            -1, args.K, -1
        )  # batch x K x C

        aad_loss = torch.mean(
            (F.kl_div(softmax_out_un, score_near, reduction="none").sum(-1)).sum(1)
        ) # Equal to dot product

        mask = torch.ones((features_f.shape[0], features_f.shape[0]))
        diag_num = torch.diag(mask)
        mask_diag = torch.diag_embed(diag_num)
        mask = mask - mask_diag
        copy = softmax_out.T  # .detach().clone()#

        dot_neg = softmax_out @ copy  # batch x batch

        dot_neg = (dot_neg * mask.cuda()).sum(-1)  # batch
        neg_pred = torch.mean(dot_neg)
        aad_loss += neg_pred * alpha

        cluster_scores = F.softmax(cluster_logits_r / 0.1, dim=1)
        cluster_scores_u = F.softmax(cluster_logits_r_u / 0.1, dim=1)

        # self-training loss
        max_prob_net, pseudo_labels = torch.max(softmax_out, dim=-1)
        st_loss_net = (F.cross_entropy(
            output_u, pseudo_labels, reduction='none'
        ) * max_prob_net.ge(0.97).float().detach()).mean()

        max_prob_cluster, cluster_idx = torch.max(cluster_scores, dim=-1)
        st_loss_cluster = (F.cross_entropy(
            cluster_logits_r_u, cluster_idx, reduction='none'
        ) * max_prob_cluster.ge(0.97).float().detach()).mean()

        st_loss = st_loss_cluster + st_loss_net

        ## CLUSTER HEAD TRAINING
        # unsupervised clustering loss - self distillation
        c_out = out_r['cluster_logits']
        c_out2 = out_r['cluster_logits'].detach()
        unsup_cluster_loss = cluster_criterion(c_out, c_out2, iter_num - 1)
        avg_probs = (c_out / 0.1).softmax(dim=1).mean(dim=0)
        me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
        unsup_cluster_loss += me_max_loss * 2.0 # memax_weight; mean-entropy-maximisation regulariser

        # unsupervised representation learning
        contrastive_logits, contrastive_labels = info_nce_logits(features=out_r['x_proj'])
        unsup_rep_loss = F.cross_entropy(contrastive_logits, contrastive_labels)
        
        clustering_loss = unsup_cluster_loss + unsup_rep_loss

        #### ICON
        if args.con_mode in ['prob', 'sim']:
            cluster_sim = get_ulb_sim_matrix_v2(
                args.con_mode, cluster_logits_r,
                sim_ratio=args.sim_ratio, diff_ratio=args.diff_ratio,
                sim_threshold=args.sim_threshold
            )
            cls_sim = get_ulb_sim_matrix_v2(
                args.con_mode, output,
                sim_ratio=args.sim_ratio, diff_ratio=args.diff_ratio,
                sim_threshold=args.sim_threshold, list='class'
            )
        elif args.con_mode in ['rank-k', 'stats']:
            cluster_sim = get_mat(features_r)
            cls_sim = get_mat(output)
        else:
            raise NotImplementedError()

        # consistency of classifier with u clusters
        icon_loss_u = SupConLoss(temperature=1.0, base_temperature=1.0)(
            torch.cat((softmax_out.unsqueeze(1), softmax_out_u.unsqueeze(1)), dim=1),
            mask=cluster_sim
        )

        icon_loss_v = SupConLoss(temperature=1.0, base_temperature=1.0)(
            torch.cat((cluster_scores.unsqueeze(1), cluster_scores_u.unsqueeze(1)), dim=1),
            mask=cls_sim
        )
        icon_loss = icon_loss_u + icon_loss_v

        # invariant loss
        irm_con_u = F.kl_div(softmax_out, cluster_scores, reduction='none').mean()
        irm_con_v = torch.mean(
            (F.kl_div(softmax_out_un, score_near, reduction="none").sum(-1)).sum(1)
        )

        inv_loss = torch.var(torch.stack([irm_con_u, irm_con_v]))

        pstr = ''
        pstr += f'aad_loss: {aad_loss.item():.4f}; '
        pstr += f'st_loss_net: {st_loss_net.item():.4f}; '
        pstr += f'st_loss_cluster: {st_loss_cluster.item():.4f}; '
        pstr += f'unsup_cluster_loss: {unsup_cluster_loss.item():.4f}; '
        pstr += f'unsup_rep_loss: {unsup_rep_loss.item():.4f}; '
        pstr += f'icon_loss: {icon_loss.item():.4f}; '
        pstr += f'inv_loss: {inv_loss.item():.4f}'

        loss = (aad_loss * args.w_aad) + \
               (clustering_loss * args.w_cluster) + \
               (st_loss * args.w_st) + \
               (icon_loss * args.w_con) + (inv_loss * args.w_inv)

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()))
            breakpoint()

        optimizer.zero_grad()
        optimizer_c.zero_grad()
        loss.backward()
        optimizer.step()
        optimizer_c.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            netR.eval()

            (acc, accc), cluster_acc = cal_acc_both(
                dset_loaders["test"], 'test',
                netF, netB, netC, None,
                is_visda=True,
            )

            # (acc_e, accc_e), _ = cal_acc_both(dset_loaders["easy"], 'test', netF, netB, netC, None, is_visda=True)
            # (acc_h, accc_h), _ = cal_acc_both(dset_loaders["hard"], 'test', netF, netB, netC, None, is_visda=True)

            log_str = (
                "Task: {}, Iter:{}/{};  Acc: {:.2f}".format(
                    args.name, iter_num, max_iter, acc
                )
                + "\n"
                + "T: "
                + accc
            )


            args.out_file.write(log_str + "\n")
            args.out_file.flush()
            print(pstr)
            print(log_str + "\n")
            netF.train()
            netB.train()
            netC.train()
            netR.train()
            if acc>acc_log:
                acc_log = acc
                accc_log = accc
                torch.save(
                    netF.state_dict(),
                    osp.join(args.output_dir, "target_F_" + '2021_'+str(args.tag) + ".pt"))
                torch.save(
                    netB.state_dict(),
                    osp.join(args.output_dir,
                                "target_B_" + '2021_' + str(args.tag) + ".pt"))
                torch.save(
                    netC.state_dict(),
                    osp.join(args.output_dir,
                                "target_C_" + '2021_' + str(args.tag) + ".pt"))
    
    print(acc_log)
    print(accc_log)

    return netF, netB, netC


def test_target(args):
    dset_loaders = data_load(args)
    ## set base network
    netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bottleneck(
        type=args.classifier,
        feature_dim=netF.in_features,
        bottleneck_dim=args.bottleneck,
    ).cuda()
    netC = network.feat_classifier(
        type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck
    ).cuda()

    netF.load_state_dict(torch.load(args.load_path))
    netB.load_state_dict(torch.load(args.load_path))
    netC.load_state_dict(torch.load(args.load_path))

    netF.eval()
    netB.eval()
    netC.eval()
    (acc, accc), cluster_acc = cal_acc_both(
        dset_loaders["test"], 'test',
        netF, netB, netC, None,
        is_visda=True,
    )
    log_str = (
        "Task: {}, Acc on target: {:.2f}".format(
            args.name, acc
        )
        + "\n"
        + "T: "
        + accc
    )

    print(log_str)


def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LPA")
    parser.add_argument(
        "--gpu_id", type=str, nargs="?", default="0", help="device id to run"
    )
    parser.add_argument("--s", type=int, default=0, help="source")
    parser.add_argument("--t", type=int, default=1, help="target")
    parser.add_argument("--max_epoch", type=int, default=15, help="max iterations")
    parser.add_argument("--interval", type=int, default=15)
    parser.add_argument("--batch_size", type=int, default=64, help="batch_size")
    parser.add_argument("--worker", type=int, default=4, help="number of workers")
    parser.add_argument("--dset", type=str, default="VISDA-C")
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument("--net", type=str, default="resnet101")
    parser.add_argument("--seed", type=int, default=2021, help="random seed")

    parser.add_argument("--bottleneck", type=int, default=256)
    parser.add_argument("--K", type=int, default=5)
    parser.add_argument("--epsilon", type=float, default=1e-5)
    parser.add_argument("--layer", type=str, default="wn", choices=["linear", "wn"])
    parser.add_argument("--classifier", type=str, default="bn", choices=["ori", "bn"])
    parser.add_argument("--output", type=str, default="weight/target/")
    parser.add_argument("--output_src", type=str, default="weight/source/")
    parser.add_argument("--tag", type=str, default="LPA")
    parser.add_argument("--da", type=str, default="uda")
    parser.add_argument("--issave", type=bool, default=True)
    parser.add_argument("--cc", default=False, action="store_true")
    parser.add_argument("--alpha", type=float, default=1.0)
    parser.add_argument("--beta", type=float, default=5.0)
    parser.add_argument("--alpha_decay", default=True)
    parser.add_argument("--nuclear", default=False, action="store_true")
    parser.add_argument("--var", default=False, action="store_true")

    parser.add_argument("--cluster_lr", type=float, default=0.001)
    parser.add_argument("--con_mode", type=str, default="prob", choices=["prob", "sim", "stats", "rank-k"])
    parser.add_argument("--sim_ratio", type=float, default=0.07)
    parser.add_argument("--diff_ratio", type=float, default=0.80)
    parser.add_argument("--sim_threshold", type=float, default=0.0)
    parser.add_argument("--w_aad", type=float, default=1.0)
    parser.add_argument("--w_cluster", type=float, default=1.0)
    parser.add_argument("--w_con", type=float, default=1.0)
    parser.add_argument("--w_st", type=float, default=1.0)
    parser.add_argument("--w_inv", type=float, default=0.)
    parser.add_argument("--dset_path", type=str, default="data/")
    parser.add_argument("--test", default=False, action="store_true")
    parser.add_argument("--ref_backbone", type=str, default="swinb", choices=["swinb", "resnet50", "resnet101"])

    args = parser.parse_args()

    if args.dset == "office-home":
        names = ["Art", "Clipart", "Product", "RealWorld"]
        args.class_num = 65
    if args.dset == "VISDA-C":
        names = ["train", "validation"]
        args.class_num = 12

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    torch.backends.cudnn.deterministic = True

    for i in range(len(names)):
        if i == args.s:
            continue
        args.t = i

        folder = args.dset_path
        args.s_dset_path = folder + args.dset + "/" + names[args.s] + "_list.txt"
        args.t_dset_path = folder + args.dset + "/" + names[args.t] + "_list.txt"
        args.test_dset_path = folder + args.dset + "/" + names[args.t] + "_list.txt"

        args.output_dir_src = osp.join(
            args.output_src, args.da, args.dset, names[args.s][0].upper(), "seed" + str(SEED)
        )
        args.output_dir = osp.join(
            args.output,
            args.da,
            args.dset,
            names[args.s][0].upper() + names[args.t][0].upper(),
            "seed" + str(SEED)
        )
        args.name = names[args.s][0].upper() + names[args.t][0].upper()

        if args.test:
            test_target(args)
        else:
            if not osp.exists(args.output_dir):
                os.system("mkdir -p " + args.output_dir)
            if not osp.exists(args.output_dir):
                os.mkdir(args.output_dir)

            args.out_file = open(
                osp.join(args.output_dir, "log_{}.txt".format(args.tag)), "w"
            )
            args.out_file.write(print_args(args) + "\n")
            args.out_file.flush()
            train_target(args)