from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.models as models
import random
import argparse
import numpy as np
import dataloader_clothing1M as dataloader
from utils_mutual import LogFT
from functools import reduce
import torch.distributions as dist
from utils_partial import GenCompl
from ResNet import CustomResNet
from sklearn.mixture import GaussianMixture

parser = argparse.ArgumentParser(description="PyTorch Clothing1M Training")
parser.add_argument("--batch_size", default=64, type=int, help="train batchsize")
parser.add_argument("--lr", "--learning_rate", default=0.002, type=float, help="initial learning rate")
parser.add_argument("--alpha", default=0.5, type=float, help="parameter for Beta")
parser.add_argument("--lambda_u", default=0, type=float, help="weight for unsupervised loss")
parser.add_argument("--p_threshold", default=0.5, type=float, help="clean probability threshold")
parser.add_argument("--T", default=0.5, type=float, help="sharpening temperature")
parser.add_argument("--num_epochs", default=40, type=int)
parser.add_argument("--id", default="clothing1m_mixup")
parser.add_argument("--data_path", default="../../Clothing1M/data", type=str, help="path to dataset")
parser.add_argument("--seed", default=123)
parser.add_argument("--gpuid", default=0, type=int)
parser.add_argument("--num_class", default=14, type=int)
parser.add_argument("--num_batches", default=3000, type=int)

parser.add_argument("--topk", default=3, type=int)
parser.add_argument("--bce_weight", default=1, type=int)
parser.add_argument("--use_coteaching", action="store_true")
parser.add_argument("--beta", default=0.9, type=float)
args = parser.parse_args()

torch.cuda.set_device(args.gpuid)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val
        self.count += n
        self.avg = self.sum / self.count


class LLD(nn.Module):
    def __init__(self, num_samples) -> None:
        super(LLD, self).__init__()
        self.beta = args.beta
        # self.latent = F.one_hot(x_nt, args.num_class).float().cuda()
        self.latent = (torch.ones(num_samples, args.num_class) / args.num_class).cuda()
        self.temp = args.T
        self.init_H = dist.Categorical(probs=self.latent).entropy()
        self.H = dist.Categorical(probs=self.latent).entropy()
        self.H_gain = torch.zeros(num_samples, args.num_class).cuda()

    def update_hist(self, probs, index):
        probs = torch.clamp(probs, 1e-4, 1.0 - 1e-4).detach()
        self.latent[index] = self.beta * self.latent[index] + (1 - self.beta) * (
            probs / probs.sum(1, keepdim=True)
        )

    def sample_latent(self, index=None, reverse=False, lam=None, mix_idx=None):
        if index is None:
            latent_distribution = (self.latent) ** (1 / args.T)
        else:
            if lam is not None:
                latent_distribution = (
                    lam * self.latent[index] + (1 - lam) * self.latent[index][mix_idx]
                ) ** (1 / args.T)
            else:
                latent_distribution = (self.latent[index]) ** (1 / args.T)
            if reverse:
                latent_distribution = (1 - latent_distribution) / (1 - latent_distribution).sum(
                    1, keepdim=True
                )

        norm_ld = latent_distribution / latent_distribution.sum(1, keepdim=True)
        return dist.Categorical(probs=norm_ld)


def warmup(net, optimizer, dataloader, latent, tilde_latent):
    net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    for batch_idx, (inputs, labels, path, _) in enumerate(dataloader):
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs, tildey, _ = net(inputs)
        log_outputs = outputs.log_softmax(1)
        onehot_labels = F.one_hot(labels, args.num_class).float()
        categorical = latent.sample_latent(path)
        pos1 = F.one_hot(categorical.sample(), args.num_class).float()
        pos_set1 = torch.logical_or(pos1, onehot_labels).float()
        rnd_neg1 = compl.full_space_partial(pos_set1, path)
        rnd_neg2 = compl.full_space_partial(pos_set1, path)
        prior1, log_prior1 = build_partial(onehot_labels, pos1, rnd_neg1)
        prior2, log_prior2 = build_partial(onehot_labels, pos1, rnd_neg2)
        latent.update_hist(outputs.softmax(1), path)
        tilde_latent.update_hist(tildey.softmax(1), path)

        ce = F.cross_entropy(tildey, labels)
        par = partial_loss(log_outputs, log_prior1, log_prior2)
        m_kl = m_kl_loss(tildey, log_prior1, log_prior2, log_outputs)
        l = ce + par + m_kl

        l.backward()

        optimizer.step()

        # metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl)
        l_ce.update(ce)
        l_par.update(par)
        l_kl.update(m_kl)
        sys.stdout.write("\r")
        sys.stdout.write(
            f"C1M | Epoch [{epoch}/{args.num_epochs}] Iter[{batch_idx+1}/{num_iter}]\t CE: {l_ce.avg:.2f} PAR: {l_par.avg:.2f} KL: {l_kl.avg:.2f}\t COV: {m_cov.avg + m_neg_cov.avg:.2f} SIZE: {m_pri_size.avg:.2f}"
        )
        sys.stdout.flush()


def build_partial(onehot_labels, pos, rnd_neg=None):
    if rnd_neg is None:
        prior = (onehot_labels + pos).clamp(min=0.0, max=1.0)
    else:
        prior = (onehot_labels + pos + rnd_neg).clamp(min=0.0, max=1.0)
    tmp_prior = prior / prior.sum(1, keepdim=True)
    log_prior = tmp_prior.clamp(1e-9).log()
    return prior, log_prior


def train_partial(epoch, net, optimizer, dataloader, probs, latent, tilde_latent):
    # net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    for batch_idx, (inputs, labels, path, id_raw) in enumerate(dataloader):
        net.train()
        clean, noisy = probs[path] > args.p_threshold, probs[path] <= args.p_threshold
        inputs, labels = inputs.cuda(), labels.cuda()
        lam = np.random.beta(args.alpha, args.alpha)
        lam = max(lam, 1 - lam)
        batch_size = inputs.size()[0]
        mix_index = torch.randperm(batch_size).cuda()
        inputs = lam * inputs + (1 - lam) * inputs[mix_index]
        optimizer.zero_grad()
        outputs, tildey, _ = net(inputs)
        log_outputs = outputs.log_softmax(1)
        onehot_labels = F.one_hot(labels, args.num_class).float()
        mixed_labels = lam * onehot_labels + (1 - lam) * onehot_labels[mix_index]

        tmp = torch.ones_like(probs[path])

        categorical = latent.sample_latent(id_raw, lam=lam, mix_idx=mix_index)
        pos1 = F.one_hot(categorical.sample(), args.num_class).float()
        pos_set1 = (pos1 + onehot_labels).clamp(max=1.0)
        fix = (pos1 + onehot_labels).clamp(max=1.0)

        fix1 = (pos1 + onehot_labels).clamp(max=1.0)
        fix2 = (pos1 + onehot_labels).clamp(max=1.0)

        rnd_neg1 = compl.full_space_partial(fix1, path, probs[path])
        rnd_neg2 = compl.full_space_partial(fix2, path, probs[path])
        prior1 = (fix1 + rnd_neg1).clamp(max=1.0)
        prior2 = (fix2 + rnd_neg2).clamp(max=1.0)
        tmp_prior1 = prior1 / prior1.sum(1, keepdim=True)
        tmp_prior2 = prior2 / prior2.sum(1, keepdim=True)
        log_prior1 = tmp_prior1.clamp(1e-9).log()
        log_prior2 = tmp_prior2.clamp(1e-9).log()

        latent.update_hist(outputs.softmax(1), id_raw)
        tilde_latent.update_hist(tildey.softmax(1), id_raw)

        # ce = F.cross_entropy(tildey, labels)
        ce = torch.mean(-torch.sum(mixed_labels * F.log_softmax(tildey, dim=1), dim=-1))
        par = partial_loss(log_outputs, log_prior1, log_prior2)
        m_kl = m_kl_loss(tildey, log_prior1, log_prior2, log_outputs)
        l = ce + par + m_kl

        l.backward()

        optimizer.step()

        l_ce.update(ce)
        l_par.update(par)
        l_kl.update(m_kl)

        sys.stdout.write("\r")
        sys.stdout.write(
            f"C1M | Epoch [{epoch}/{args.num_epochs}] Iter[{batch_idx+1}/{num_iter}]\t ce: {l_ce.avg:.2f} par: {l_par.avg:.2f} kl: {l_kl.avg:.2f}\t cov: {m_cov.avg:.3f}|{m_neg_cov.avg:.3f} size: {m_pri_size.avg:.3f} space: {base.avg:.3f}|{m_minus.avg:.3f}"
        )
        sys.stdout.flush()


def metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl):
    # m_cov.update(
    #     torch.logical_and(prior1 * pos_set1, F.one_hot(x_ct[path], args.num_class)).sum().item(),
    #     inputs.shape[0],
    # )
    # m_neg_cov.update(
    #     torch.logical_and(prior1 * (1 - pos_set1), F.one_hot(x_ct[path], args.num_class)).sum().item(),
    #     inputs.shape[0],
    # )
    m_pri_size.update((prior1 > 0).sum(1).float().mean().item())
    l_ce.update(ce)
    l_par.update(par)
    l_kl.update(m_kl)

    # cav_acc.update((cav != x_ct[path]).sum().item() / inputs.shape[0])


def m_kl_loss(tildey, log_prior1, log_prior2, log_outputs):
    return (
        F.kl_div(
            # log_prior1,
            (tildey.log_softmax(1) + log_prior2).log_softmax(1),
            log_outputs.detach(),
            reduction="batchmean",
            log_target=True,
        )
        + F.kl_div(
            # log_prior2,
            (tildey.log_softmax(1) + log_prior1).log_softmax(1),
            log_outputs.detach(),
            reduction="batchmean",
            log_target=True,
        )
    ) / 2


def partial_loss(log_outputs, log_prior1, log_prior2):
    return (
        F.kl_div(
            log_outputs,
            (log_prior1 + log_outputs.log_softmax(0)).log_softmax(1),
            reduction="batchmean",
            log_target=True,
        )
        + F.kl_div(
            log_outputs,
            (log_prior2 + log_outputs.log_softmax(0)).log_softmax(1),
            reduction="batchmean",
            log_target=True,
        )
    ) / 2


def val(net, val_loader, k, net2=None):
    net.eval()
    if net2 is not None:
        net2.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1, _, _ = net(inputs)
            if net2 is not None:
                outputs2, _, _ = net2(inputs)
                outputs = (outputs1 + outputs2) / 2
            else:
                outputs = outputs1
            _, predicted = torch.max(outputs, 1)

            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()
    acc = 100.0 * correct / total
    print("\n| Validation\t Net%d  Acc: %.2f%%" % (k, acc))
    if acc > best_acc[k - 1]:
        best_acc[k - 1] = acc
        print("| Saving Best Net%d ..." % k)
        save_point = "./checkpoint/%s_net%d.pth.tar" % (args.id, k)
        torch.save(net.state_dict(), save_point)
    return acc


def test(net1, test_loader, test_log, verbose=True, net2=None):
    net1.eval()
    if net2 is not None:
        net2.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1, _, _ = net1(inputs)
            if net2 is not None:
                outputs2, _, _ = net2(inputs)
                outputs = outputs1 + outputs2
            else:
                outputs = outputs1
            _, predicted = torch.max(outputs, 1)

            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()
    acc = 100.0 * correct / total
    print("\n| Test Acc: %.2f%%\n" % (acc))
    if verbose:
        test_log.write("Epoch:%d   Accuracy:%.2f\n" % (epoch, acc))
        test_log.flush()
    return acc


def eval_train(epoch, model):
    model.eval()
    num_samples = args.num_batches * args.batch_size
    losses = torch.zeros(num_samples)
    paths = []
    n = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets, path, _) in enumerate(eval_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            _, tildey, _ = model(inputs)
            loss = F.cross_entropy(tildey, targets, reduction="none")
            for b in range(inputs.size(0)):
                losses[n] = loss[b]
                paths.append(path[b].item())
                n += 1
            sys.stdout.write("\r")
            sys.stdout.write("| Evaluating loss Iter %3d\t" % (batch_idx))
            sys.stdout.flush()

    losses = (losses - losses.min()) / (losses.max() - losses.min())
    losses = losses.reshape(-1, 1)
    gmm = GaussianMixture(n_components=2, max_iter=10, reg_covar=5e-4, tol=1e-2)
    gmm.fit(losses)
    prob = gmm.predict_proba(losses)
    prob = prob[:, gmm.means_.argmin()]
    return 1 - torch.from_numpy(prob).cuda(), paths


def create_model():
    model = CustomResNet(args.num_class)
    model = model.cuda()
    return model


log = open("./checkpoint/%s.txt" % args.id, "w")
log.flush()
logger = LogFT(log)

loader = dataloader.clothing_dataloader(
    root=args.data_path, batch_size=args.batch_size, num_workers=5, num_batches=args.num_batches
)

print("| Building net")
net1 = create_model()
net2 = create_model()
cudnn.benchmark = True
m_cov = AverageMeter()
m_neg_cov = AverageMeter()
m_pri_size = AverageMeter()
m_add = AverageMeter()
m_minus = AverageMeter()
m_y_add = AverageMeter()
m_y_minus = AverageMeter()
base = AverageMeter()

l_ce = AverageMeter()
l_par = AverageMeter()
l_kl = AverageMeter()

cav_acc = AverageMeter()
eval_loader = loader.run("eval_train")
num_samples = eval_loader.dataset.num_raw_example
latent = LLD(num_samples)
latent2 = LLD(num_samples)
total_space = torch.ones(num_samples, args.num_class).cuda().float()
# latent = Dir_LLD(50000)
# q_ell_x = LLD(50000)
tilde_latent = LLD(num_samples)
tilde_latent2 = LLD(num_samples)
compl = GenCompl(num_samples, args.num_class)

# optimizer1 = optim.SGD([{'param':net1.parameters(), 'param':transition.parameters()}], lr=args.lr, momentum=0.9, weight_decay=1e-3)
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)


best_acc = [0, 0]
test_loader = loader.run("test")
for epoch in range(args.num_epochs + 1):
    lr = args.lr
    if epoch >= 7:
        lr /= 10
    for param_group in optimizer1.param_groups:
        param_group["lr"] = lr
    for param_group in optimizer2.param_groups:
        param_group["lr"] = lr

    if epoch < 1:  # warm up
        train_loader = loader.run("warmup")
        print("Warmup Net1")
        warmup(net1, optimizer1, train_loader, latent, tilde_latent)
        train_loader = loader.run("warmup")
        warmup(net2, optimizer2, train_loader, latent2, tilde_latent2)
    else:
        eval_loader = loader.run("eval_train")
        probs, path = eval_train(epoch, net1)
        eval_loader2 = loader.run("eval_train")
        probs2, path2 = eval_train(epoch, net2)

        train_loader = loader.run("train", paths=path, eval_loader=eval_loader)
        train_loader2 = loader.run("train", paths=path2, eval_loader=eval_loader2)
        train_partial(epoch, net1, optimizer1, train_loader, probs2.float(), latent, tilde_latent)
        train_partial(epoch, net2, optimizer2, train_loader, probs.float(), latent2, tilde_latent2)

    val_loader = loader.run("val")  # validation
    acc1 = val(net1, val_loader, 1)
    acc2 = val(net2, val_loader, 2)
    log.write("Validation Epoch:%d      Acc1:%.2f  Acc2:%.2f\n" % (epoch, acc1, acc2))
    log.flush()
    acc = test(net1, test_loader, log, verbose=True, net2=net2)
    log.write("Test Epoch:%d      Acc1:%.2f\n" % (epoch, acc))
    log.flush()


# test_loader = loader.run("test")
net1.load_state_dict(torch.load("./checkpoint/%s_net1.pth.tar" % args.id))
net2.load_state_dict(torch.load("./checkpoint/%s_net2.pth.tar" % args.id))
acc = test(net1, test_loader, log, verbose=True, net2=net2)

# log.write("Test Accuracy:%.2f\n" % (acc))
# log.flush()
