import torch
import sys

from ResNet import resnet_cifar34
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import dataloader_cifarN as dataloader
from utils_mutual import LogFT
import argparse
import torch.distributions as dist

from utils_partial import GenCompl
import random
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser(description="PyTorch CIFAR Training")
parser.add_argument("--batch_size", default=64, type=int, help="train batchsize")
parser.add_argument("--lr", "--learning_rate", default=0.02, type=float, help="initial learning rate")
parser.add_argument("--noise_mode", default="sym")
parser.add_argument("--T", default=0.5, type=float, help="sharpening temperature")
parser.add_argument("--num_epochs", default=120, type=int)
parser.add_argument("--r", default=0.5, type=float, help="noise ratio")
parser.add_argument("--p_threshold", default=0.6, type=float)
parser.add_argument("--seed", default=123)
parser.add_argument("--gpuid", default=0, type=int)
parser.add_argument("--num_class", default=10, type=int)
parser.add_argument(
    "--data_path",
    default="/run/media/Data/cifar-10-batches-py",
    type=str,
    help="path to dataset",
)
parser.add_argument("--dataset", default="cifar10", type=str)
parser.add_argument("--topk", default=50, type=int)
parser.add_argument("--warmup_epochs", default=10, type=int)
parser.add_argument("--desc", default="", type=str)
parser.add_argument("--beta", default=0.9, type=float)
parser.add_argument("--sample_trial", default=10, type=int)
parser.add_argument("--alpha", default=4, type=float)
parser.add_argument("--aug", default="w", type=str)
parser.add_argument("--target", default="noisy_label", type=str)
args = parser.parse_args()

torch.cuda.set_device(args.gpuid)
# random.seed(args.seed)
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)
torch.set_printoptions(precision=2, sci_mode=False)
np.set_printoptions(precision=2, suppress=True)
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val
        self.count += n
        self.avg = self.sum / self.count


class LLD(nn.Module):
    def __init__(self, num_samples) -> None:
        super(LLD, self).__init__()
        self.beta = args.beta
        # self.latent = F.one_hot(x_nt, args.num_class).float().cuda()
        self.latent = (torch.ones(num_samples, args.num_class) / args.num_class).cuda()
        self.temp = args.T
        self.init_H = dist.Categorical(probs=self.latent).entropy()
        self.H = dist.Categorical(probs=self.latent).entropy()
        self.H_gain = torch.zeros(num_samples, args.num_class).cuda()

    def update_hist(self, probs, index):
        probs = torch.clamp(probs, 1e-4, 1.0 - 1e-4).detach()
        self.latent[index] = self.beta * self.latent[index] + (1 - self.beta) * (
            probs / probs.sum(1, keepdim=True)
        )

    def sample_latent(self, index=None, reverse=False):
        if index is None:
            latent_distribution = (self.latent) ** (1 / args.T)
        else:
            latent_distribution = (self.latent[index]) ** (1 / args.T)
            if reverse:
                latent_distribution = (1 - latent_distribution) / (1 - latent_distribution).sum(
                    1, keepdim=True
                )

        norm_ld = latent_distribution / latent_distribution.sum(1, keepdim=True)
        return dist.Categorical(probs=norm_ld)


def warmup(epoch, net, optimizer, dataloader):
    net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs, tildey, _ = net(inputs)
        log_outputs = outputs.log_softmax(1)
        onehot_labels = F.one_hot(labels, args.num_class).float()
        categorical = latent.sample_latent(path)
        pos1 = F.one_hot(categorical.sample(), args.num_class).float()
        pos_set1 = torch.logical_or(pos1, onehot_labels).float()
        rnd_neg1 = compl.full_space_partial(pos_set1, path)
        rnd_neg2 = compl.full_space_partial(pos_set1, path)
        prior1, log_prior1 = build_partial(onehot_labels, pos1, rnd_neg1)
        prior2, log_prior2 = build_partial(onehot_labels, pos1, rnd_neg2)
        latent.update_hist(outputs.softmax(1), path)
        tilde_latent.update_hist(tildey.softmax(1), path)

        ce = F.cross_entropy(tildey, labels)
        par = partial_loss(log_outputs, log_prior1, log_prior2)
        m_kl = m_kl_loss(tildey, log_prior1, log_prior2, log_outputs)
        l = ce + par + m_kl

        l.backward()

        optimizer.step()

        metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl)
        sys.stdout.write("\r")
        sys.stdout.write(
            f"{args.dataset}:{args.r}-{args.noise_mode} | Epoch [{epoch}/{args.num_epochs}] Iter[{batch_idx+1}/{num_iter}]\t CE: {l_ce.avg:.2f} PAR: {l_par.avg:.2f} KL: {l_kl.avg:.2f}\t COV: {m_cov.avg + m_neg_cov.avg:.2f} SIZE: {m_pri_size.avg:.2f}"
        )
        sys.stdout.flush()


def build_partial(onehot_labels, pos, rnd_neg=None):
    if rnd_neg is None:
        prior = (onehot_labels + pos).clamp(min=0.0, max=1.0)
    else:
        prior = (onehot_labels + pos + rnd_neg).clamp(min=0.0, max=1.0)
    tmp_prior = prior / prior.sum(1, keepdim=True)
    log_prior = tmp_prior.clamp(1e-9).log()
    return prior, log_prior


def eval_train(net, dataloader):
    net.eval()
    losses = torch.zeros(50000)
    with torch.no_grad():
        for batch_idx, (inputs, targets, index) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs, tildey, _ = net(inputs)
            loss = F.cross_entropy(tildey, targets, reduction="none")
            for b in range(inputs.size(0)):
                losses[index[b]] = loss[b]
    losses = ((losses - losses.min()) / (losses.max() - losses.min())).unsqueeze(1)

    if args.r == 0.9:  # average loss over last 5 epochs to improve convergence stability
        history = torch.stack(all_loss)
        input_loss = history[-5:].mean(0)
        input_loss = input_loss.reshape(-1, 1)
    else:
        input_loss = losses.reshape(-1, 1)

    # fit a two-component GMM to the loss
    gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
    gmm.fit(input_loss)
    prob = gmm.predict_proba(input_loss)
    prob = prob[:, gmm.means_.argmin()]
    return 1 - torch.from_numpy(prob).cuda()


def train_partial(epoch, net, optimizer, dataloader, probs):
    net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):
        # net.train()
        clean, noisy = probs[path] > args.p_threshold, probs[path] <= args.p_threshold
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs, tildey, _ = net(inputs)
        log_outputs = outputs.log_softmax(1)
        onehot_labels = F.one_hot(labels, args.num_class).float()

        tmp = torch.ones_like(probs[path])
        categorical = latent.sample_latent(path)
        pos1 = F.one_hot(categorical.sample(), args.num_class).float()
        pos_set1 = (pos1 + onehot_labels).clamp(max=1.0)
        fix = (pos1 + onehot_labels).clamp(max=1.0)
        fix1 = (pos1 + onehot_labels).clamp(max=1.0)
        fix2 = (pos1 + onehot_labels).clamp(max=1.0)

        rnd_neg1 = compl.full_space_partial(fix1, path, probs[path])
        rnd_neg2 = compl.full_space_partial(fix2, path, probs[path])
        prior1 = (fix1 + rnd_neg1).clamp(max=1.0)
        prior2 = (fix2 + rnd_neg2).clamp(max=1.0)

        tmp_prior1 = prior1 / prior1.sum(1, keepdim=True)
        tmp_prior2 = prior2 / prior2.sum(1, keepdim=True)
        log_prior1 = tmp_prior1.clamp(1e-9).log()
        log_prior2 = tmp_prior2.clamp(1e-9).log()

        latent.update_hist(outputs.softmax(1), path)
        tilde_latent.update_hist(tildey.softmax(1), path)

        ce = F.cross_entropy(tildey, labels)
        par = partial_loss(log_outputs, log_prior1, log_prior2)
        m_kl = m_kl_loss(tildey, log_prior1, log_prior2, log_outputs)
        l = ce + par + m_kl

        l.backward()

        optimizer.step()

        metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl)
        sys.stdout.write("\r")
        sys.stdout.write(
            f"{args.dataset}:{args.r} | Epoch [{epoch}/{args.num_epochs}] Iter[{batch_idx+1}/{num_iter}]\t ce: {l_ce.avg:.2f} par: {l_par.avg:.2f} kl: {l_kl.avg:.2f}\t cov: {m_cov.avg:.3f}|{m_neg_cov.avg:.3f} size: {m_pri_size.avg:.3f}"
        )
        sys.stdout.flush()


def metric_update(inputs, path, pos_set1, prior1, ce, par, m_kl):
    m_cov.update(
        torch.logical_and(prior1 * pos_set1, F.one_hot(x_ct[path], args.num_class)).sum().item(),
        inputs.shape[0],
    )
    m_neg_cov.update(
        torch.logical_and(prior1 * (1 - pos_set1), F.one_hot(x_ct[path], args.num_class)).sum().item(),
        inputs.shape[0],
    )
    m_pri_size.update((prior1 > 0).sum(1).float().mean().item())
    l_ce.update(ce)
    l_par.update(par)
    l_kl.update(m_kl)

    # cav_acc.update((cav != x_ct[path]).sum().item() / inputs.shape[0])


def m_kl_loss(tildey, log_prior1, log_prior2, log_outputs):
    return (
        F.kl_div(
            (tildey.log_softmax(1) + log_prior1).log_softmax(1),
            log_outputs.detach(),
            reduction="batchmean",
            log_target=True,
        )
        + F.kl_div(
            (tildey.log_softmax(1) + log_prior2).log_softmax(1),
            log_outputs.detach(),
            reduction="batchmean",
            log_target=True,
        )
    ) / 2


def partial_loss(log_outputs, log_prior1, log_prior2):
    return (
        F.kl_div(
            log_outputs,
            (log_prior1 + log_outputs.log_softmax(0)).log_softmax(1),
            reduction="batchmean",
            log_target=True,
        )
        + F.kl_div(
            log_outputs,
            (log_prior2 + log_outputs.log_softmax(0)).log_softmax(1),
            reduction="batchmean",
            log_target=True,
        )
    ) / 2


def test(epoch, net1, verbose=True):
    net1.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):  # type: ignore
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1, _, _ = net1(inputs)
            outputs = outputs1

            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()
    acc = 100.0 * correct / total
    print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" % (epoch, acc))
    if verbose:
        test_log.write("Epoch:%d   Accuracy: %.2f%%\n" % (epoch, acc))

        test_log.flush()


def create_model():
    model = resnet_cifar34(args.num_class)
    model = model.cuda()
    return model


stats_log = open(
    "./checkpoint/%s_%.1f_%s_%s" % (args.dataset, args.r, args.noise_mode, args.desc) + "_stats.txt",
    "w",
)
test_log = open(
    "./checkpoint/%s_%.1f_%s_%s" % (args.dataset, args.r, args.noise_mode, args.desc) + "_acc.txt",
    "w",
)
logger = LogFT(stats_log)

warm_up = args.warmup_epochs

loader = dataloader.cifar_dataloader(
    args.dataset,
    batch_size=args.batch_size,
    num_workers=5,
    root_dir=args.data_path,
    log=stats_log,
)

net1 = create_model()
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
print("test loader")
test_loader = loader.run("test")
print("warmup loader")
warmup_trainloader = loader.run("warmup", target=args.target)
print("eval_train loader")
eval_loader = loader.run("eval_train", target=args.target)
x_nt = torch.tensor(eval_loader.dataset.noise_label).cuda()
x_ct = torch.tensor(eval_loader.dataset.clean_label).cuda()

num_samples = 50000
latent = LLD(num_samples)
tilde_latent = LLD(num_samples)
compl = GenCompl(num_samples, args.num_class)
m_cov = AverageMeter()
m_neg_cov = AverageMeter()
m_pri_size = AverageMeter()
m_add = AverageMeter()
m_minus = AverageMeter()
m_y_add = AverageMeter()
m_y_minus = AverageMeter()
base = AverageMeter()

l_ce = AverageMeter()
l_par = AverageMeter()
l_kl = AverageMeter()

cav_acc = AverageMeter()

for epoch in range(args.num_epochs + 1):
    cav_list = torch.tensor([]).cuda()
    lr = args.lr
    if epoch >= 80:
        lr /= 10
    for param_group in optimizer1.param_groups:
        param_group["lr"] = lr

    if epoch < args.warmup_epochs:
        warmup(epoch, net1, optimizer1, warmup_trainloader)
    else:
        probs = eval_train(net1, eval_loader)
        plt.scatter(
            range((x_ct == x_nt).nonzero().shape[0]),
            probs[x_ct == x_nt].detach().cpu().numpy(),
            color="blue",
            label="Clean Data",
            alpha=0.05,
        )
        plt.scatter(
            range((x_ct == x_nt).nonzero().shape[0], num_samples),
            probs[x_ct != x_nt].detach().cpu().numpy(),
            color="red",
            label="Noisy Data",
            alpha=0.05,
        )

        plt.legend()
        plt.savefig("distance.png")
        plt.clf()
        plt.scatter(
            range((x_ct == x_nt).nonzero().shape[0]),
            (compl.keep_list[x_ct == x_nt].sum(1).detach().cpu().numpy()),
            color="blue",
            label="Clean Data",
            alpha=0.05,
        )
        plt.scatter(
            range((x_ct == x_nt).nonzero().shape[0], num_samples),
            (compl.keep_list[x_ct != x_nt].detach().cpu().numpy().sum(1)),
            color="red",
            label="Noisy Data",
            alpha=0.05,
        )
        plt.legend()
        plt.savefig("weight.png")
        plt.clf()

        train_partial(epoch, net1, optimizer1, warmup_trainloader, probs.float())

    test(epoch, net1, verbose=True)
    stats_log.write(
        "Epoch:%d   Pos-cov:%.2f  Neg-cov:%.2f  Size:%.2f\n"
        % (epoch, m_cov.avg, m_neg_cov.avg, m_pri_size.avg)
    )
    stats_log.flush()
    m_cov.reset()
    m_neg_cov.reset()
    m_pri_size.reset()
    l_ce.reset()
    l_par.reset()
    l_kl.reset()
    cav_acc.reset()
    [i.reset() for i in [m_add, m_minus, m_y_add, m_y_minus, base]]
