import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dataset import DatasetGenerator, WebVisionDatasetLoader
from models import CNN, ResNet34
from torchvision.models import resnet50, resnet34
import loss_funcs
import random
from tensorboardX import SummaryWriter
from utils import Timer, init_seeds
from torch_scatter import scatter_sum
from torch.cuda.amp import GradScaler, autocast

parser = argparse.ArgumentParser(description='Robust loss for learning with noisy labels')
parser.add_argument('--data_root', type=str, default="./data-bin/", help='the data root')
parser.add_argument('--output_dir', type=str, default="./output/", help='the output_dir')

parser.add_argument('--dataset', type=str, default="CIFAR100",
                    choices=["CIFAR100", "CIFAR10", "MNIST", "WEBVISION50", "WEBVISION200", "WEBVISION400"],
                    metavar='DATA', help='Dataset name (default: CIFAR10)')
parser.add_argument('--noise_type', type=str, default='symmetric', choices=["symmetric", "asymmetric", "human"])
parser.add_argument('--noise_rate', type=float, default=0.4, help='the noise rate')
parser.add_argument('--fix', type=str, default="none", choices=["none", "shift", "scale"],
                    help='Fix of underfitting.')
parser.add_argument('--tau', type=float, default=1,
                    help='Hyperparameter for underfitting fix.')

parser.add_argument('--loss', type=str, default='CE')
parser.add_argument('--q', type=float, default=-1, help='Override hyperparamter of robust loss functions.')
parser.add_argument('--a', type=float, default=-1, help='Override hyperparamter of robust loss functions.')

# other learning settings
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for loading data')
parser.add_argument('--grad_bound', type=float, default=5., help='the gradient norm bound')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--extract', action="store_true", default=False,
                    help='Whether extracting Delta_y of clean and noisy samples of the training set at each epoch')
parser.add_argument('--epochs', type=int, default=-1)
parser.add_argument('--lr', type=float, default=-1)
parser.add_argument('--momentum', type=float, default=-1)
parser.add_argument('--schedule', type=str, default="cos", choices=["constant", "cos"])
parser.add_argument('--fp16', action="store_true", default=False)

args = parser.parse_args()
args.loss = args.loss.upper()
args.dataset = args.dataset.upper()
print(args)
os.makedirs(args.output_dir, exist_ok=True)

init_seeds(args.seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

momentum = 0.9
if args.dataset == 'MNIST':
    batch_size = 128
    in_channels = 1
    num_classes = 10
    weight_decay = 1e-3
    lr = 0.01
    epochs = 50
    grad_acc = 1
elif args.dataset == 'CIFAR10':
    batch_size = 128
    in_channels = 3
    num_classes = 10
    weight_decay = 1e-4
    lr = 0.01
    epochs = 120
    grad_acc = 1
elif args.dataset == 'CIFAR100':
    batch_size = 128
    in_channels = 3
    num_classes = 100
    weight_decay = 1e-5
    lr = 0.1
    epochs = 200
    grad_acc = 1
elif 'WEBVISION' in args.dataset:
    in_channels = 3
    num_classes = int(args.dataset.replace("WEBVISION", ""))
    weight_decay = 3e-5
    lr = 0.2
    epochs = 250
    batch_size = 64
    grad_acc = 2
else:
    raise ValueError('Invalid value {}'.format(args.dataset))

# override with command line args
if args.epochs > 0:
    print("update epochs from {} to {}".format(epochs, args.epochs))
    epochs = args.epochs
if args.lr > 0:
    print("update lr from {} to {}".format(lr, args.lr))
    lr = args.lr
if args.momentum > 0:
    print("update momentum from {} to {}".format(momentum, args.momentum))
    momentum = args.momentum

if args.dataset in ['CIFAR100', 'CIFAR10', 'MNIST']:
    data_loader = DatasetGenerator(train_batch_size=batch_size,
                                   data_path=os.path.join(args.data_root, args.dataset),
                                   num_of_workers=args.num_workers,
                                   seed=args.seed,
                                   noise_type=args.noise_type,
                                   dataset=args.dataset,
                                   noise_rate=args.noise_rate)
elif "WEBVISION" in args.dataset:
    data_loader = WebVisionDatasetLoader(num_class=num_classes,
                                         train_batch_size=batch_size,
                                         eval_batch_size=batch_size,
                                         data_path=os.path.join(args.data_root, "WEBVISION"),
                                         num_of_workers=args.num_workers)
else:
    raise ValueError

data_loader = data_loader.getDataLoader()
train_loader = data_loader['train_dataset']
test_loader = data_loader['test_dataset']

if args.dataset in ['CIFAR10', 'MNIST']:
    model = CNN(type=args.dataset).to(device)
elif args.dataset == 'CIFAR100':
    model = ResNet34(num_classes=num_classes).to(device)
elif "WEBVISION" in args.dataset:
    model = resnet50(num_classes=num_classes).to(device)
else:
    raise ValueError

loss_func = getattr(loss_funcs, args.loss)(a=args.a, q=args.q, fix=args.fix, tau=args.tau)
loss_func.update(num_classes)

optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

if args.schedule == "cos":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.0)
else:
    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, total_iters=1000)


def evaluate(loader):
    model.eval()
    correct = 0.
    total = 0.
    with torch.no_grad():
        for x, y, _ in loader:
            x, y = x.to(device), y.to(device)
            with autocast(enabled=args.fp16):
                pred = torch.argmax(model(x), dim=1)
            total += y.size(0)
            correct += torch.sum(pred.eq(y)).item()
    acc = correct / total
    return acc


def extract(loader):
    model.train()
    clean_deltas = []
    noise_deltas = []
    for num_batch, (x, y, n) in enumerate(loader, start=1):
        x, y = x.to(device), y.to(device)
        with autocast(enabled=args.fp16):
            pred = model(x)
        _, _, delta = loss_func(pred, y, mode=args.fix)
        nmask = n.bool()
        clean_deltas.append(delta[~nmask].detach())
        noise_deltas.append(delta[nmask].detach())
    clean_deltas = torch.cat(clean_deltas, dim=0)
    noise_deltas = torch.cat(noise_deltas, dim=0)
    return {"clean": clean_deltas.cpu(), "noise": noise_deltas.cpu()}


step = 0
weight_avg = 0
elr_avg = 0
timer = Timer()
writer = SummaryWriter(os.path.join(args.output_dir, "result"))
total_clean_count = 0.
total_clean_weight = 0.
total_noise_count = 0.
total_noise_weight = 0.
scaler = GradScaler(enabled=args.fp16)
for epoch in range(1, epochs + 1):
    if args.extract:
        data = extract(train_loader)
        torch.save(data, os.path.join(args.output_dir, "extract_{}.pt".format(epoch)))
        for key in ["clean", "noise"]:
            if len(data[key]) > 0:
                writer.add_histogram(key, data[key], epoch)
    model.train()
    total_loss = 0.
    lr = scheduler.get_last_lr()[0]
    for batch_x, batch_y, batch_n in train_loader:
        step += 1
        batch_x, batch_y, batch_n = batch_x.to(device), batch_y.to(device), batch_n.to(device)
        with autocast(enabled=args.fp16):
            out = model(batch_x)
            losses, w, _ = loss_func(out, batch_y)
            loss = losses.mean() / grad_acc

        if step % grad_acc == 0:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_bound)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        if step % 50 == 0:
            weight = w.mean().item()
            elr = weight * lr
            weight_avg += (weight - weight_avg) / step
            elr_avg += (elr - elr_avg) / step
            writer.add_scalar("weight", weight, step)
            writer.add_scalar("weight_avg", weight_avg, step)
            writer.add_scalar("elr", elr, step)
            writer.add_scalar("elr_avg", elr_avg, step)

        wsum = scatter_sum(w, batch_n)
        csum = scatter_sum(torch.ones_like(batch_n), batch_n)
        total_clean_count += csum[0].item() * lr
        total_clean_weight += wsum[0].item() * lr
        if len(csum) > 1:
            total_noise_count += csum[1].item() * lr
            total_noise_weight += wsum[1].item() * lr
        total_loss += loss.item()
    test_acc = evaluate(test_loader)
    writer.add_scalar("loss", total_loss, step)
    writer.add_scalar("acc", test_acc, step)
    writer.add_scalar("lr", lr, step)
    clean_mean = total_clean_weight / total_clean_count if total_clean_count > 0 else 0
    noise_mean = total_noise_weight / total_noise_count if total_noise_count > 0 else 0
    snr = clean_mean / noise_mean if noise_mean > 0 else 0
    writer.add_scalar("clean", clean_mean, step)
    writer.add_scalar("noise", noise_mean, step)
    writer.add_scalar("snr", snr, step)
    writer.flush()
    elapse = timer.tick()
    print('Epoch {}: loss={:.4f}, lr={:g}, '
          'test_acc={:.2f}, duration={}'.format(epoch, total_loss, lr, test_acc * 100, elapse))
    scheduler.step()
