import sys
import argparse
import json
import random

from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn
import torch.utils.data

sys.path.append('.')
from image_uncertainty.utils.datasets import all_datasets
from image_uncertainty.utils.evaluate_ood import get_auroc_ood_dl
from image_uncertainty.cifar.cifar_datasets import get_training_dataloader, get_test_dataloader
from image_uncertainty.cifar.cifar_evaluate import load_model
from image_uncertainty.models.duq import (
    MultiLinearCentroids, LinearCentroids, benchmark, calc_gradient_penalty
)


def main(
    architecture,
    batch_size,
    length_scale,
    centroid_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
    output_dir,
    data_seed,
    epochs
):
    ds = all_datasets["CIFAR100"]()
    input_size, num_classes, dataset, test_dataset = ds
    print('classes', num_classes)

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    val_size = int(len(dataset) * 0.8)
    train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
    val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

    val_dataset.transform = (
        test_dataset.transform
    )  # Test time preprocessing for validation

    # Set the pass to pretrained resnet on classification
    weights = 'experiments/checkpoint/resnet50_spectral/lsun_44/model_0.pth'

    feature_extractor = load_model('resnet50_spectral', weights, True)
    feature_extractor.linear = nn.Identity()
    feature_extractor.eval()

    milestones = [25, 50, 75]
    if args.epochs == 100:
        epochs = 100
    elif args.epochs == 4:
        epochs = 4
        milestones = [1, 2, 3]
    elif args.epochs == 20:
        epochs = 20
        milestones = [5, 10, 15]
    else:
        epochs = args.epochs
    model_output_size = 2048

    centroid_size = 128
    if centroid_size is None:
        centroid_size = model_output_size

    if args.architecture == 'multilinear':
        klass = MultiLinearCentroids
    else:
        klass = LinearCentroids

    model = klass(
        num_classes=100,
        gamma=gamma,
        embedding_size=model_output_size,
        features=centroid_size,
        feature_extractor=feature_extractor,
        batch_size=batch_size,
        sigma=length_scale
    )

    model = model.cuda()

    optimizer = torch.optim.SGD(
        model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay
    )

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.2
    )



    train_loader, val_loader = get_training_dataloader(batch_size=batch_size, seed=data_seed)
    test_loader = get_test_dataloader(batch_size=batch_size, ood=False)
    ood_loader = get_test_dataloader(batch_size=batch_size, ood=True, ood_name='svhn')

    for e in range(epochs):
        for i, (x, y) in enumerate(tqdm(train_loader)):
            model.train()
            optimizer.zero_grad()
            x, y = x.cuda(), y.cuda()

            x.requires_grad_(True)

            y_pred = model(x)

            y = F.one_hot(y, num_classes).float()

            loss = F.binary_cross_entropy(y_pred, y, reduction="mean")

            if l_gradient_penalty > 0:
                gp = calc_gradient_penalty(model.z, y_pred)
                loss += l_gradient_penalty * gp

            loss.backward()
            optimizer.step()

            x.requires_grad_(False)
            with torch.no_grad():
                model.eval()
                model.update_embeddings(x, y)

            if (i+1) % 25 == 0:
                benchmark(val_loader, model, e, loss.item())
        accuracy, auroc = get_auroc_ood_dl(val_loader, ood_loader, model)
        print('OOD', accuracy, auroc)
        scheduler.step()

    for ood_name in ['svhn', 'lsun', 'smooth']:
        ood_loader = get_test_dataloader(batch_size=batch_size, ood=True, ood_name=ood_name)
        accuracy, auroc = get_auroc_ood_dl(test_loader, ood_loader, model)
        print(f'OOD {ood_name} acc: {accuracy:.3f}, {auroc:.3f}')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--architecture",
        default="multilinear",
        choices=["multilinear", "linear"],
        help="Pick an duq variant (default: multilinear)",
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=1024,
        help="Batch size to use for training (default: 128)",
    )

    parser.add_argument(
        '--epochs', type=int, default=4
    )

    parser.add_argument(
        "--centroid_size",
        type=int,
        default=None,
        help="Size to use for centroids (default: same as model output)",
    )

    parser.add_argument(
        "--learning_rate",
        type=float,
        default=0.05,
        help="Learning rate (default: 0.05)",
    )

    parser.add_argument(
        "--l_gradient_penalty",
        type=float,
        default=0.75,
        help="Weight for gradient penalty (default: 0.75)",
    )

    parser.add_argument(
        "--gamma",
        type=float,
        default=0.999,
        help="Decay factor for exponential average (default: 0.999)",
    )

    parser.add_argument(
        "--length_scale",
        type=float,
        default=0.1,
        help="Length scale of RBF kernel (default: 0.1)",
    )

    parser.add_argument(
        "--weight_decay", type=float, default=5e-4, help="Weight decay (default: 5e-4)"
    )

    parser.add_argument(
        "--output_dir", type=str, default="results", help="set output folder"
    )
    parser.add_argument(
        '--data_seed', type=int, default=42
    )

    # Below setting cannot be used for model selection,
    # because the validation set equals the test set.
    parser.add_argument(
        "--final_model",
        action="store_true",
        default=False,
        help="Use entire training set for final model",
    )

    args = parser.parse_args()
    kwargs = vars(args)
    print("input args:\n", json.dumps(kwargs, indent=4, separators=(",", ":")))

    main(**kwargs)
