from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

# import torchvision.models as models
from VGG import vgg19_bn
import random
import argparse
import numpy as np
import dataloader_animal10n as dataloader
from utils_mutual import LogFT
from functools import reduce
import torch.distributions as dist
from utils_partial import GenCompl
from sklearn.mixture import GaussianMixture

parser = argparse.ArgumentParser(description="PyTorch Animal10N Training")
parser.add_argument("--batch_size", default=64, type=int, help="train batchsize")
parser.add_argument("--lr", "--learning_rate", default=0.02, type=float, help="initial learning rate")
parser.add_argument("--alpha", default=4, type=float, help="parameter for Beta")
parser.add_argument("--lambda_u", default=25, type=float, help="weight for unsupervised loss")
parser.add_argument("--T", default=0.5, type=float, help="sharpening temperature")
parser.add_argument("--num_epochs", default=100, type=int)
parser.add_argument("--id", default="Aniaml10N")
parser.add_argument("--data_path", default="../../Clothing1M/data", type=str, help="path to dataset")
parser.add_argument("--seed", default=123)
parser.add_argument("--gpuid", default=0, type=int)
parser.add_argument("--num_class", default=10, type=int)
parser.add_argument("--generate_pt", action="store_true")
parser.add_argument("--topk", default=3, type=int)
parser.add_argument("--use_coteaching", action="store_true")
parser.add_argument("--beta", default=0.9, type=float)
args = parser.parse_args()

torch.cuda.set_device(args.gpuid)
# random.seed(args.seed)
# torch.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val
        self.count += n
        self.avg = self.sum / self.count


class LLD(nn.Module):
    def __init__(self, num_samples) -> None:
        super(LLD, self).__init__()
        self.beta = args.beta
        # self.latent = F.one_hot(x_nt, args.num_class).float().cuda()
        self.latent = (torch.ones(num_samples, args.num_class) / args.num_class).cuda()
        self.temp = args.T
        self.init_H = dist.Categorical(probs=self.latent).entropy()
        self.H = dist.Categorical(probs=self.latent).entropy()
        self.H_gain = torch.zeros(num_samples, args.num_class).cuda()

    def update_hist(self, probs, index):
        probs = torch.clamp(probs, 1e-4, 1.0 - 1e-4).detach()
        self.latent[index] = self.beta * self.latent[index] + (1 - self.beta) * (
            probs / probs.sum(1, keepdim=True)
        )

    def sample_latent(self, index=None, reverse=False):
        if index is None:
            latent_distribution = (self.latent) ** (1 / args.T)
        else:
            latent_distribution = (self.latent[index]) ** (1 / args.T)
            if reverse:
                latent_distribution = (1 - latent_distribution) / (1 - latent_distribution).sum(
                    1, keepdim=True
                )

        norm_ld = latent_distribution / latent_distribution.sum(1, keepdim=True)
        return dist.Categorical(probs=norm_ld)


def warmup(epoch, net, optimizer, dataloader, latent, tilde_latent):
    net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs, tildey, _ = net(inputs)
        log_outputs = outputs.log_softmax(1)
        onehot_labels = F.one_hot(labels, args.num_class).float()
        categorical = latent.sample_latent(path)
        pos1 = F.one_hot(categorical.sample(), args.num_class).float()
        pos_set1 = torch.logical_or(pos1, onehot_labels).float()
        rnd_neg1 = compl.full_space_partial(pos_set1, path)
        rnd_neg2 = compl.full_space_partial(pos_set1, path)
        prior1, log_prior1 = build_partial(onehot_labels, pos1, rnd_neg1)
        prior2, log_prior2 = build_partial(onehot_labels, pos1, rnd_neg2)
        latent.update_hist(outputs.softmax(1), path)
        tilde_latent.update_hist(tildey.softmax(1), path)

        ce = F.cross_entropy(tildey, labels)
        par = partial_loss(log_outputs, log_prior1, log_prior2)
        m_kl = m_kl_loss(tildey, log_prior1, log_prior2, log_outputs)
        l = ce + par + m_kl

        l.backward()

        optimizer.step()

        metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl)
        sys.stdout.write("\r")
        sys.stdout.write(
            f"Animal10N: | Epoch [{epoch}/{args.num_epochs}] Iter[{batch_idx+1}/{num_iter}]\t CE: {l_ce.avg:.2f} PAR: {l_par.avg:.2f} KL: {l_kl.avg:.2f}\t COV: {m_cov.avg + m_neg_cov.avg:.2f} SIZE: {m_pri_size.avg:.2f}"
        )
        sys.stdout.flush()


def build_partial(onehot_labels, pos, rnd_neg=None):
    if rnd_neg is None:
        prior = (onehot_labels + pos).clamp(min=0.0, max=1.0)
    else:
        prior = (onehot_labels + pos + rnd_neg).clamp(min=0.0, max=1.0)
    tmp_prior = prior / prior.sum(1, keepdim=True)
    log_prior = tmp_prior.clamp(1e-9).log()
    return prior, log_prior


def train_partial(epoch, net, optimizer, dataloader, probs, latent, tilde_latent):
    # net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):
        net.train()
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs, tildey, _ = net(inputs)
        log_outputs = outputs.log_softmax(1)
        onehot_labels = F.one_hot(labels, args.num_class).float()

        tmp = torch.ones_like(probs[path])
        categorical = latent.sample_latent(path)
        pos1 = F.one_hot(categorical.sample(), args.num_class).float()
        pos_set1 = (pos1 + onehot_labels).clamp(max=1.0)
        fix = (pos1 + onehot_labels).clamp(max=1.0)

        fix1 = (pos1 + onehot_labels).clamp(max=1.0)
        fix2 = (pos1 + onehot_labels).clamp(max=1.0)

        rnd_neg1 = compl.full_space_partial(fix1, path, probs[path])
        rnd_neg2 = compl.full_space_partial(fix2, path, probs[path])
        prior1 = (fix1 + rnd_neg1).clamp(max=1.0)
        prior2 = (fix2 + rnd_neg2).clamp(max=1.0)

        tmp_prior1 = prior1 / prior1.sum(1, keepdim=True)
        tmp_prior2 = prior2 / prior2.sum(1, keepdim=True)
        log_prior1 = tmp_prior1.clamp(1e-9).log()
        log_prior2 = tmp_prior2.clamp(1e-9).log()

        latent.update_hist(outputs.softmax(1), path)
        tilde_latent.update_hist(tildey.softmax(1), path)

        ce = F.cross_entropy(tildey, labels)
        # ce = torch.mean(-torch.sum(mixed_labels * F.log_softmax(tildey, dim=1), dim = -1))
        par = partial_loss(log_outputs, log_prior1, log_prior2)
        m_kl = m_kl_loss(tildey, log_prior1, log_prior2, log_outputs)
        l = ce + par + m_kl

        l.backward()

        optimizer.step()

        metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl)

        sys.stdout.write("\r")
        sys.stdout.write(
            f"Animal10N | Epoch [{epoch}/{args.num_epochs}] Iter[{batch_idx+1}/{num_iter}]\t ce: {l_ce.avg:.2f} par: {l_par.avg:.2f} kl: {l_kl.avg:.2f}\t cov: {m_cov.avg:.3f}|{m_neg_cov.avg:.3f} size: {m_pri_size.avg:.3f} space: {base.avg:.3f}|{m_minus.avg:.3f}"
        )
        sys.stdout.flush()


def metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl):
    m_pri_size.update((prior1 > 0).sum(1).float().mean().item())
    l_ce.update(ce)
    l_par.update(par)
    l_kl.update(m_kl)


def m_kl_loss(tildey, log_prior1, log_prior2, log_outputs):
    return (
        F.kl_div(
            (tildey.log_softmax(1) + log_prior2).log_softmax(1),
            log_outputs.detach(),
            reduction="batchmean",
            log_target=True,
        )
        + F.kl_div(
            (tildey.log_softmax(1) + log_prior1).log_softmax(1),
            log_outputs.detach(),
            reduction="batchmean",
            log_target=True,
        )
    ) / 2


def partial_loss(log_outputs, log_prior1, log_prior2):
    return (
        F.kl_div(
            log_outputs,
            (log_prior1 + log_outputs.log_softmax(0)).log_softmax(1),
            reduction="batchmean",
            log_target=True,
        )
        + F.kl_div(
            log_outputs,
            (log_prior2 + log_outputs.log_softmax(0)).log_softmax(1),
            reduction="batchmean",
            log_target=True,
        )
    ) / 2


def test(net1, test_loader, test_log, verbose=True, net2=None):
    net1.eval()
    net2.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1, _, _ = net1(inputs)
            outputs2, _, _ = net2(inputs)
            outputs = (outputs1 + outputs2) / 2
            _, predicted = torch.max(outputs, 1)

            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()
    acc = 100.0 * correct / total
    print("\n| Test Acc: %.2f%%\n" % (acc))
    if verbose:
        test_log.write("Epoch:%d   Accuracy:%.2f\n" % (epoch, acc))
        test_log.flush()
    return acc


def eval_train(net, dataloader):
    net.eval()
    losses = torch.zeros(len(eval_loader.dataset.train_data))
    with torch.no_grad():
        for batch_idx, (inputs, targets, index) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs, tildey, _ = net(inputs)
            loss = F.cross_entropy(tildey, targets, reduction="none")
            for b in range(inputs.size(0)):
                losses[index[b]] = loss[b]
    losses = (losses - losses.min()) / (losses.max() - losses.min())

    # if args.r == 0.9:  # average loss over last 5 epochs to improve convergence stability
    #     history = torch.stack(all_loss)
    #     input_loss = history[-5:].mean(0)
    #     input_loss = input_loss.reshape(-1, 1)
    # else:
    input_loss = losses.reshape(-1, 1)

    # # fit a two-component GMM to the loss
    gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
    gmm.fit(input_loss)
    prob = gmm.predict_proba(input_loss)
    prob = prob[:, gmm.means_.argmin()]
    return 1 - torch.from_numpy(prob).cuda()


def create_model():
    model = vgg19_bn()
    model.classifier2 = nn.Linear(4096, args.num_class)
    model.transition = nn.Linear(4096 + args.num_class, args.num_class)
    model = model.cuda()
    return model


log = open("./checkpoint/%s.txt" % args.id, "w")
log.flush()
logger = LogFT(log)

loader = dataloader.animal_dataloader(
    dataset=args.data_path, batch_size=args.batch_size, num_workers=5, saved=True
)

print("| Building net")
net1 = create_model()
net2 = create_model()
cudnn.benchmark = True

optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)

m_cov = AverageMeter()
m_neg_cov = AverageMeter()
m_pri_size = AverageMeter()
m_add = AverageMeter()
m_minus = AverageMeter()
m_y_add = AverageMeter()
m_y_minus = AverageMeter()
base = AverageMeter()

l_ce = AverageMeter()
l_par = AverageMeter()
l_kl = AverageMeter()
warmup_loader = loader.run("warmup")
test_loader = loader.run("test")
eval_loader = loader.run("eval_train")
num_samples = len(eval_loader.dataset.train_data)
latent = LLD(num_samples)
latent2 = LLD(num_samples)
tilde_latent = LLD(num_samples)
tilde_latent2 = LLD(num_samples)
compl = GenCompl(num_samples, args.num_class)
compl2 = GenCompl(num_samples, args.num_class)
for epoch in range(args.num_epochs + 1):
    lr = args.lr
    if epoch >= 50:
        lr /= 10
    for param_group in optimizer1.param_groups:
        param_group["lr"] = lr

    for param_group in optimizer2.param_groups:
        param_group["lr"] = lr

    if epoch < 10:  # warm up
        print("Warmup Net1")
        warmup(epoch, net1, optimizer1, warmup_loader, latent, tilde_latent)
        print("Warmup Net2")
        warmup(epoch, net2, optimizer2, warmup_loader, latent2, tilde_latent2)

    else:
        probs = eval_train(net1, eval_loader)
        probs2 = eval_train(net2, eval_loader)

        train_partial(epoch, net1, optimizer1, warmup_loader, probs2.float(), latent, tilde_latent)
        train_partial(epoch, net2, optimizer2, warmup_loader, probs.float(), latent2, tilde_latent2)

    acc = test(net1, test_loader, log, verbose=True, net2=net2)

# log.write("Test Accuracy:%.2f\n" % (acc))
# log.flush()
