import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dataset import DatasetGenerator, WebVisionDatasetLoader
from models import CNN, ResNet18, ResNet34, change_running_stats
from torchvision.models import resnet50
import criterions
import random

import nni

from torch.utils.tensorboard import SummaryWriter
from utils import Timer, init_seeds, Timestats
from torch_scatter import scatter_sum

torch.set_float32_matmul_precision('high')

parser = argparse.ArgumentParser(description='Robust loss for learning with noisy labels')
parser.add_argument('--data_root', type=str, default="./data-bin/", help='the data root')
parser.add_argument('--output_dir', type=str, default=None, help='the output_dir')

parser.add_argument('--dataset', type=str, default="CIFAR100",
                    choices=["MNIST", "CIFAR100", "CIFAR10", "WEBVISION10", "WEBVISION50", "WEBVISION100", "WEBVISION200", "WEBVISION400", "WEBVISION1000", "none"],
                    metavar='DATA', help='Dataset name (default: CIFAR10)')

parser.add_argument('--param', type=str, default="none",
                    choices=["MNIST", "CIFAR100", "CIFAR10", "WEBVISION10", "WEBVISION50", "WEBVISION100", "WEBVISION200", "WEBVISION400", "WEBVISION1000", "none"],
                    metavar='DATA', help='Dataset name (default: CIFAR10)')

parser.add_argument('--noise_type', type=str, default='symmetric', choices=["symmetric", "asymmetric", "human"])
parser.add_argument('--noise_rate', type=float, default=0.4, help='the noise rate')

parser.add_argument('--shift', default="none",
                    help='Hyperparameter for shift-fix of underfitting.')
parser.add_argument('--scale', default="none",
                    help='Hyperparameter for scale-fix of underfitting.')
parser.add_argument('--ema',  default="none",
                    help='Hyperparameter for ema-fix of underfitting.')

parser.add_argument('--reg', type=str, default='CST', choices=["LS", "NLS", "RMSE", "CR", "CST"])
parser.add_argument('--alpha', type=float, default=0.1)

parser.add_argument('--loss', type=str, default='CE')
parser.add_argument('--q', type=float, default=None, help='Override hyperparamter of robust loss functions.')
parser.add_argument('--a', type=float, default=None, help='Override hyperparamter of robust loss functions.')

# other learning settings
parser.add_argument('--num_workers', type=int, default=10, help='number of workers for loading data')
parser.add_argument('--grad_bound', type=float, default=1., help='the gradient norm bound')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--track', action="store_true", default=False,
                    help='Whether capturing metrics of training at each epoch')
parser.add_argument('--epochs', type=int, default=-1)
parser.add_argument('--lr', type=float, default=-1)
parser.add_argument('--wd', type=float, default=-1)
parser.add_argument('--momentum', type=float, default=-1)
parser.add_argument('--schedule', type=str, default="cos", choices=["constant", "cos"])
parser.add_argument('--optim', type=str, default="sgd", choices=["sgd", "adam"])

args = parser.parse_args()
args.loss = args.loss.upper()
args.dataset = args.dataset.upper()
print(args)
if args.output_dir is not None:
    os.makedirs(args.output_dir, exist_ok=True)

init_seeds(args.seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

momentum = 0.9
batch_size = 128
weight_decay = 5e-4
lr = 0.1

if args.dataset == 'MNIST':
    in_channels = 1
    num_classes = 10
    epochs = 50
elif args.dataset == 'CIFAR10':
    in_channels = 3
    num_classes = 10
    epochs = 120
elif args.dataset == 'CIFAR100':
    in_channels = 3
    num_classes = 100
    epochs = 200
elif 'WEBVISION' in args.dataset:
    in_channels = 3
    num_classes = int(args.dataset.replace("WEBVISION", ""))
    epochs = 250
else:
    raise ValueError('Invalid value {}'.format(args.dataset))

# override with command line args
if args.epochs > 0:
    print("update epochs from {} to {}".format(epochs, args.epochs))
    epochs = args.epochs
if args.lr > 0:
    print("update lr from {} to {}".format(lr, args.lr))
    lr = args.lr

if args.wd >= 0:
    print("update wd from {} to {}".format(weight_decay, args.wd))
    weight_decay = args.wd

if args.dataset in ['CIFAR100', 'CIFAR10', 'MNIST']:
    data_loader = DatasetGenerator(train_batch_size=batch_size,
                                   eval_batch_size=batch_size,
                                   data_path=os.path.join(args.data_root, args.dataset),
                                   num_of_workers=args.num_workers,
                                   seed=args.seed,
                                   noise_type=args.noise_type,
                                   dataset=args.dataset,
                                   noise_rate=args.noise_rate)
elif "WEBVISION" in args.dataset:
    data_loader = WebVisionDatasetLoader(num_class=num_classes,
                                         train_batch_size=batch_size,
                                         eval_batch_size=batch_size,
                                         data_path=os.path.join(args.data_root, "WEBVISION"),
                                         num_of_workers=args.num_workers)
else:
    raise ValueError

data_loader = data_loader.getDataLoader()
train_loader = data_loader['train']
test_loader = data_loader['test']
train_eval_loader = data_loader.get("train_eval", train_loader)

if args.dataset in ['MNIST', 'CIFAR10']:
    model = CNN(dataset=args.dataset).to(device)
elif args.dataset == 'CIFAR100':
    model = ResNet34(num_classes=num_classes).to(device)
elif "WEBVISION" in args.dataset:
    model = resnet50(num_classes=num_classes).to(device)
else:
    raise ValueError

if args.param.lower() == "none":
    args.param = args.dataset
params = criterions.params.get(args.param, dict()).get(args.loss, dict())
params.update(criterions.params.get(args.param, dict()).get(args.reg, dict()))

print("Original:", params)
if args.a is not None:
    params.update({"a": args.a})
if args.q is not None:
    params.update({"q": args.q})
if args.alpha is not None:
    params.update({"alpha": args.alpha})
nni_params = nni.get_next_parameter()
params.update(nni_params)
print("Updated:", params)
loss_func = getattr(criterions, args.loss)(**params).cuda()
if args.shift != "none":
    loss_func = criterions.ShiftedWeightedLoss(loss_func, num_classes, nni_params.get("q", float(args.shift)))
elif args.scale != "none":
    loss_func = criterions.ScaledWeightedLoss(loss_func, num_classes, nni_params.get("q", float(args.scale)))

regularizer = getattr(criterions, args.reg)(dataloader=train_loader, **params).cuda()

if args.optim == "sgd":
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
elif args.optim == "adam":
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay)
else:
    raise ValueError()

if args.schedule == "cos":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.0)
elif args.schedule == "constant":
    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, total_iters=1000)
else:
    raise ValueError()


def fetch_cuda(*largs):
    return (v.cuda() if isinstance(v, torch.Tensor) else v for v in largs)


def evaluate(loader):
    model.eval()
    correct_y = 0.
    correct_g = 0.
    total = 0.
    with torch.no_grad():
        for x, y, g, n, i in loader:
            x, y, g, n, i = fetch_cuda(x, y, g, n, i)
            total += y.size(0)
            pred = torch.argmax(model(x), dim=1)
            correct_y += pred.eq(y).sum().item()
            correct_g += pred.eq(g).sum().item()
    return correct_y / total * 100, correct_g / total * 100


def extract(loader):
    model.train()
    # extracting statistics should not affect the running stats of batchnorm
    change_running_stats(model, track=False)
    delta_y = []
    delta_g = []
    nmask = []
    index = []
    with torch.no_grad():
        for x, y, g, n, i in loader:
            x, y, g, n, i = fetch_cuda(x, y, g, n, i)
            logits = model(x)
            index.append(i)
            nmask.append(n.bool())
            delta_y.append(criterions.get_delta(logits, y))
            delta_g.append(criterions.get_delta(logits, g))
        index = torch.cat(index, dim=0)
        _, sort_id = torch.sort(index)
        nmask = torch.cat(nmask, dim=0)[sort_id]
        delta_y = torch.cat(delta_y, dim=0)[sort_id]
        delta_g = torch.cat(delta_g, dim=0)[sort_id]
    change_running_stats(model, track=True)
    return {
        "clean":   delta_y[~nmask].cpu(),
        "noise_y": delta_y[nmask].cpu(),
        "noise_g": delta_g[nmask].cpu(),
    }


timer = Timer()
writer = None
metrics = []
if args.output_dir is not None:
    writer = SummaryWriter(args.output_dir)

try:
    if args.track:
        data = extract(train_loader)
        metrics.append(({"delta": data, "epoch": 0}))
    for epoch in range(1, epochs + 1):
        metric = dict()
        model.train()
        total_loss = 0.
        total_reg = 0.
        total_delta_norm = 0.
        total_count = 0.
        lr = scheduler.get_last_lr()[0]
        for x, y, g, n, i in train_loader:
            x, y, g, n, i = fetch_cuda(x, y, g, n, i)
            logits = model(x)
            losses = loss_func(logits, y)
            loss = losses.mean()
            reg = regularizer(logits)
            (loss + reg).backward()

            gnorms = [torch.nn.utils.clip_grad_norm_(p, args.grad_bound)
                      for p in model.parameters() if p.grad is not None]
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()
            total_reg += reg.item()
            delta_grad = torch.autograd.functional.jacobian(lambda s: loss_func(s, y).mean(), logits)
            delta_grad_norm = delta_grad.norm(p=1, dim=-1)
            total_delta_norm += scatter_sum(delta_grad_norm, n)
            total_count += scatter_sum(torch.ones_like(delta_grad_norm), n)

        signal = (total_delta_norm / torch.sum(total_delta_norm))[0].item() * 100
        norm = (torch.sum(total_delta_norm) / torch.sum(total_count)).item()
        pnorms = [torch.norm(p.detach(), 2).item() for p in model.parameters() if p.grad is not None]
        metric["pnorms"] = pnorms
        metric["gnorms"] = gnorms
        metric["signal"] = signal
        metric["lr"] = lr
        metric["norm"] = norm
        metric["elr"] = lr * norm
        metric["epoch"] = epoch
        metric["acc"] = dict()

        if args.track:
            metric["delta"] = extract(train_loader)
            metric["acc"]["train"], metric["acc"]["train_g"] = evaluate(train_eval_loader)

        metric["acc"]["test"], _ = evaluate(test_loader)
        elapse = timer.tick()
        metrics.append(metric)
        nni.report_intermediate_result(metric["acc"]["test"])
        if writer is not None:
            writer.add_scalar("signal", signal, epoch)
            writer.add_scalar("lr", lr, epoch)
            writer.add_scalar("elr", metric["elr"], epoch)
            writer.add_scalar("norm", metric["norm"], epoch)
            writer.add_scalars("acc", metric["acc"], epoch)
            writer.flush()

        logs = [f'epoch {epoch}',
                f'loss={total_loss:.4f}',
                f'reg={total_reg:.4f}',
                f'lr={lr:g}',
                f'signal={signal:.2f}',
                f'norm={norm:g}',
                *(f'{k}_acc={v:.2f}' for k, v in metric["acc"].items()),
                f'elapse={elapse}']
        print(", ".join(logs))
        scheduler.step()
finally:
    if args.output_dir is not None and metrics:
        torch.save(metrics, os.path.join(args.output_dir, "metrics.pt"))
    nni.report_final_result(metric["acc"]["test"])
