from __future__ import print_function
from functools import reduce
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import argparse
import numpy as np
from PreResNet import ResNet18
import dataloader_redmini as dataloader
from sklearn.mixture import GaussianMixture

import torch.distributions as dist
from utils_mutual import LogFT
from utils_partial import GenCompl

parser = argparse.ArgumentParser(description="PyTorch Red Mini-ImageNet Training")
parser.add_argument("--batch_size", default=64, type=int, help="train batchsize")
parser.add_argument("--lr", "--learning_rate", default=0.02, type=float, help="initial learning rate")
parser.add_argument("--noise_mode", default="sym")
parser.add_argument("--alpha", default=0.5, type=float, help="parameter for Beta")
parser.add_argument("--lambda_u", default=25, type=float, help="weight for unsupervised loss")
parser.add_argument("--p_threshold", default=0.5, type=float, help="clean probability threshold")
parser.add_argument("--T", default=0.5, type=float, help="sharpening temperature")
parser.add_argument("--num_epochs", default=200, type=int)
parser.add_argument("--r", default=0.5, type=float, help="noise ratio")
parser.add_argument("--id", default="")
parser.add_argument("--seed", default=123)
parser.add_argument("--gpuid", default=0, type=int)
parser.add_argument("--num_class", default=100, type=int)
parser.add_argument(
    "--data_path",
    default="/run/media/Data/cifar-10-batches-py",
    type=str,
    help="path to dataset",
)
parser.add_argument("--topk", default=3, type=int)
parser.add_argument("--warmup_epochs", default=10, type=int)
parser.add_argument("--desc", default="", type=str)
parser.add_argument("--bce_weight", default=10, type=int)
parser.add_argument("--beta", default=0.9, type=float)

# ablation
parser.add_argument("--use_coteaching", action="store_true")
args = parser.parse_args()

torch.cuda.set_device(args.gpuid)
# random.seed(args.seed)
# torch.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val
        self.count += n
        self.avg = self.sum / self.count


class LLD(nn.Module):
    def __init__(self, num_samples) -> None:
        super(LLD, self).__init__()
        self.beta = args.beta
        # self.latent = F.one_hot(x_nt, args.num_class).float().cuda()
        self.latent = (torch.ones(num_samples, args.num_class) / args.num_class).cuda()
        self.temp = args.T
        self.init_H = dist.Categorical(probs=self.latent).entropy()
        self.H = dist.Categorical(probs=self.latent).entropy()
        self.H_gain = torch.zeros(num_samples, args.num_class).cuda()

    def update_hist(self, probs, index):
        probs = torch.clamp(probs, 1e-4, 1.0 - 1e-4).detach()
        self.latent[index] = self.beta * self.latent[index] + (1 - self.beta) * (
            probs / probs.sum(1, keepdim=True)
        )

    def sample_latent(self, index=None, reverse=False, lam=None,mix_idx=None):
        if index is None:
            latent_distribution = (self.latent) ** (1 / args.T)
        else:
            if lam is not None:
                latent_distribution = (lam * self.latent[index] + (1-lam) * self.latent[index][mix_idx]) ** (1 / args.T)
            else:
                latent_distribution = (self.latent[index]) ** (1 / args.T)
            if reverse:
                latent_distribution = (1 - latent_distribution) / (1 - latent_distribution).sum(
                    1, keepdim=True
                )

        norm_ld = latent_distribution / latent_distribution.sum(1, keepdim=True)
        return dist.Categorical(probs=norm_ld)



def warmup(epoch, net, optimizer, dataloader, latent, tilde_latent):
    net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs, tildey, _ = net(inputs)
        log_outputs = outputs.log_softmax(1)
        latent.update_hist(outputs.softmax(1), path)
        tilde_latent.update_hist(tildey.softmax(1), path)
        onehot_labels = F.one_hot(labels, args.num_class).float()
        categorical = latent.sample_latent(path)
        pos1 = F.one_hot(categorical.sample(), args.num_class).float()
        pos_set1 = torch.logical_or(pos1, onehot_labels).float()
        rnd_neg1 = compl.full_space_partial(pos_set1, path)
        rnd_neg2 = compl.full_space_partial(pos_set1, path)
        prior1, log_prior1 = build_partial(onehot_labels, pos1, rnd_neg1)
        prior2, log_prior2 = build_partial(onehot_labels, pos1, rnd_neg2)

        ce = F.cross_entropy(tildey, labels)
        par = partial_loss(log_outputs, log_prior1, log_prior2)
        m_kl = m_kl_loss(tildey, log_prior1, log_prior2,log_outputs)
        l = ce + par + m_kl

        l.backward()

        optimizer.step()

        # metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl)
        l_ce.update(ce)
        l_par.update(par)
        l_kl.update(m_kl)
        sys.stdout.write("\r")
        sys.stdout.write(
            f"Red:{args.r} | Epoch [{epoch}/{args.num_epochs}] Iter[{batch_idx+1}/{num_iter}]\t CE: {l_ce.avg:.2f} PAR: {l_par.avg:.2f} KL: {l_kl.avg:.2f}\t COV: {m_cov.avg + m_neg_cov.avg:.2f} SIZE: {m_pri_size.avg:.2f}"
        )
        sys.stdout.flush()


def eval_train(net, dataloader):
    net.eval()
    losses = torch.zeros(len(eval_loader.dataset.train_imgs))
    with torch.no_grad():
        for batch_idx, (inputs, targets, index) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs, tildey, _ = net(inputs)
            loss = F.cross_entropy(tildey, targets, reduction="none")
            for b in range(inputs.size(0)):
                losses[index[b]] = loss[b]
    losses = (losses - losses.min()) / (losses.max() - losses.min())

    # if args.r == 0.9:  # average loss over last 5 epochs to improve convergence stability
    #     history = torch.stack(all_loss)
    #     input_loss = history[-5:].mean(0)
    #     input_loss = input_loss.reshape(-1, 1)
    # else:
    input_loss = losses.reshape(-1, 1)

    # # fit a two-component GMM to the loss
    gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
    gmm.fit(input_loss)
    prob = gmm.predict_proba(input_loss)
    prob = prob[:, gmm.means_.argmin()]
    return 1 - torch.from_numpy(prob).cuda()


def build_partial(onehot_labels, pos, rnd_neg=None):
    if rnd_neg is None:
        prior = (onehot_labels + pos).clamp(min=0.0, max=1.0)
    else:
        prior = (onehot_labels + pos + rnd_neg).clamp(min=0.0, max=1.0)
    tmp_prior = prior / prior.sum(1, keepdim=True)
    log_prior = tmp_prior.clamp(1e-9).log()
    return prior, log_prior


def train_partial(epoch, net, optimizer, dataloader, probs, latent, tilde_latent):
    # net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):
        net.train()
        clean, noisy = probs[path] > args.p_threshold, probs[path] <= args.p_threshold
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs, tildey, _ = net(inputs)
        log_outputs = outputs.log_softmax(1)
        onehot_labels = F.one_hot(labels, args.num_class).float()

        tmp = torch.ones_like(probs[path])
        categorical = latent.sample_latent(path)
        pos1 = F.one_hot(categorical.sample(), args.num_class).float()
        # fix1 = (pos1 + onehot_labels).clamp(max=1.0)
        # fix2 = (pos1 + onehot_labels).clamp(max=1.0)
        fix = torch.logical_or(pos1, onehot_labels).float()

        rnd_neg1 = compl.full_space_partial(fix, path, probs[path])
        rnd_neg2 = compl.full_space_partial(fix, path, probs[path])

        prior1 = (fix + rnd_neg1).clamp(max=1.0)
        prior2 = (fix + rnd_neg2).clamp(max=1.0)

        tmp_prior1 = prior1 / prior1.sum(1, keepdim=True)
        tmp_prior2 = prior2 / prior2.sum(1, keepdim=True)
        log_prior1 = tmp_prior1.clamp(1e-9).log()
        log_prior2 = tmp_prior2.clamp(1e-9).log()

        latent.update_hist(outputs.softmax(1), path)
        tilde_latent.update_hist(tildey.softmax(1), path)

        ce = F.cross_entropy(tildey, labels)
        par = partial_loss(log_outputs, log_prior1, log_prior2)
        m_kl = m_kl_loss(tildey, log_prior1, log_prior2,log_outputs)
        l = ce + par + m_kl

        l.backward()

        optimizer.step()

        l_ce.update(ce)
        l_par.update(par)
        l_kl.update(m_kl)

        sys.stdout.write("\r")
        sys.stdout.write(
            f"Red:{args.r} | Epoch [{epoch}/{args.num_epochs}] Iter[{batch_idx+1}/{num_iter}]\t ce: {l_ce.avg:.2f} par: {l_par.avg:.2f} kl: {l_kl.avg:.2f}\t cov: {m_cov.avg:.3f}|{m_neg_cov.avg:.3f} size: {m_pri_size.avg:.3f} space: {base.avg:.3f}|{m_minus.avg:.3f}"
        )
        sys.stdout.flush()




def metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl):
    m_pri_size.update((prior1 > 0).sum(1).float().mean().item())
    l_ce.update(ce)
    l_par.update(par)
    l_kl.update(m_kl)


def m_kl_loss(tildey, log_prior1, log_prior2, log_outputs):
    return (
        F.kl_div(
            (tildey.log_softmax(1) + log_prior2).log_softmax(1),
            log_outputs.detach(),
            reduction="batchmean",
            log_target=True,
        )
        + F.kl_div(
            (tildey.log_softmax(1) + log_prior1).log_softmax(1),
            log_outputs.detach(),
            reduction="batchmean",
            log_target=True,
        )
    ) / 2


def partial_loss(log_outputs, log_prior1, log_prior2):
    return (
        F.kl_div(
            log_outputs,
            (log_prior1 + log_outputs.log_softmax(0)).log_softmax(1),
            reduction="batchmean",
            log_target=True,
        )
        + F.kl_div(
            log_outputs,
            (log_prior2 + log_outputs.log_softmax(0)).log_softmax(1),
            reduction="batchmean",
            log_target=True,
        )
    ) / 2


def test(epoch, net1, verbose=True, net2=None):
    net1.eval()
    if net2 is not None:
        net2.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):  # type: ignore
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1, _, _ = net1(inputs)
            if net2 is not None:
                outputs2, _, _ = net2(inputs)
                outputs = outputs1 + outputs2
            else:
                outputs = outputs1

            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()
    acc = 100.0 * correct / total
    print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" % (epoch, acc))
    if verbose:
        test_log.write("Epoch:%d   Accuracy: %.2f%%\n" % (epoch, acc))

        test_log.flush()


def create_model():
    model = ResNet18(num_classes=args.num_class)
    model = model.cuda()
    return model


stats_log = open(
    "./checkpoint/%s_%.1f_%s_%s" % ("RedMini", args.r, args.noise_mode, args.desc) + "_stats.txt",
    "w",
)
test_log = open(
    "./checkpoint/%s_%.1f_%s_%s" % ("RedMini", args.r, args.noise_mode, args.desc) + "_acc.txt",
    "w",
)
logger = LogFT(stats_log)

warm_up = args.warmup_epochs

loader = dataloader.red_dataloader(
    r=args.r,
    noise_mode=args.noise_mode,
    batch_size=args.batch_size,
    num_workers=5,
    root_dir=args.data_path,
    log=stats_log,
    noise_file="%s/%.1f_%s.json" % (args.data_path, args.r, args.noise_mode),
)

print("| Building net")
net1 = create_model()
net2 = create_model()
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
print("test loader")
test_loader = loader.run("test")
print("warmup loader")
warmup_trainloader = loader.run("warmup")
print("eval_train loader")
eval_loader = loader.run("eval_train")
# x_nt = torch.tensor(eval_loader.dataset.train_labels).cuda()
# x_ct = torch.tensor(eval_loader.dataset.clean_label).cuda()
num_samples = len(eval_loader.dataset.train_imgs)
latent = LLD(num_samples)
latent2 = LLD(num_samples)
# latent = Dir_LLD(50000)
# q_ell_x = LLD(50000)
tilde_latent = LLD(num_samples)
tilde_latent2 = LLD(num_samples)
compl = GenCompl(num_samples, args.num_class)

m_cov = AverageMeter()
m_neg_cov = AverageMeter()
m_pri_size = AverageMeter()
m_add = AverageMeter()
m_minus = AverageMeter()
m_y_add = AverageMeter()
m_y_minus = AverageMeter()
base = AverageMeter()

l_ce = AverageMeter()
l_par = AverageMeter()
l_kl = AverageMeter()

cav_acc = AverageMeter()

for epoch in range(args.num_epochs + 1):
    lr = args.lr
    if epoch >= 100:
        lr /= 10
    for param_group in optimizer1.param_groups:
        param_group["lr"] = lr
    for param_group in optimizer2.param_groups:
        param_group["lr"] = lr
    if epoch < warm_up:
        print("Warmup Net1")
        warmup(epoch, net1, optimizer1, warmup_trainloader, latent, tilde_latent)
        print("Warmup Net2")
        warmup(epoch, net2, optimizer2, warmup_trainloader, latent2, tilde_latent2)

    else:
        probs = eval_train(net1, eval_loader)
        probs2 = eval_train(net2, eval_loader)

        train_partial(epoch, net1, optimizer1, warmup_trainloader, probs2.float(), latent, tilde_latent)
        train_partial(epoch, net2, optimizer2, warmup_trainloader, probs.float(), latent2, tilde_latent2)

    test(epoch, net1, verbose=True, net2=net2)
    stats_log.write(
        "Epoch:%d   Pos-cov:%.2f  Neg-cov:%.2f  Size:%.2f\n"
        % (epoch, m_cov.avg, m_neg_cov.avg, m_pri_size.avg)
    )
    stats_log.flush()
    m_cov.reset()
    m_neg_cov.reset()
    m_pri_size.reset()
    l_ce.reset()
    l_par.reset()
    l_kl.reset()
    cav_acc.reset()
    [i.reset() for i in [m_add, m_minus, m_y_add, m_y_minus, base]]
