import argparse
import json
import math
import os
import random
from contextlib import nullcontext
from pathlib import Path
from types import SimpleNamespace

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
from sklearn import metrics
from torch.backends import cudnn
from torchvision import datasets, models, transforms
from tqdm import tqdm

import wandb
from models.convnet import convnet
from models.resnet import resnet18
from models.linear import LinearModel
from models.simplenet import SimpleNet
from sc_methods.dg import deep_gambler_loss
from sc_methods.nntd import (
    calculate_nntd_sum_score,
    calculate_nntd_max_score,
    get_mean_a_t,
)

from sc_methods.sat import SelfAdativeTraining
from sc_methods.sn import SelectiveNetLoss
from utils.datasets_utils import load_data
from utils.eval_utils import (
    accuracy_coverage_tradeoff,
    calculate_sc_performance,
    compute_mean_conf_pred,
    compute_mi_auc_score,
    compute_mi_precision_score,
    save_cov_acc_tradeoff,
    sn_coverage_accuracy,
    save_scores,
    save_et_vt_scores,
    plot_decision_boundary,
    sn_accuracy_for_coverage,
    save_targets_preds_scores,
)
from utils.wandb_plotting_utils import (
    plot_cov_acc_tradeoff,
    plot_e_t_metric,
    plot_roc_confusion_matrix,
    plot_sc_mi_auc_performance,
    plot_sc_mi_precision_performance,
    plot_score_dist,
    plot_v_t_metric,
)


def init_wandb(args):
    # run_id = wandb.util.generate_id()
    run = wandb.init(
        project=args.wandb_name,
        name=args.identifier,
        config=vars(args),
        resume="allow",
        id=args.identifier,
    )
    return run


def determine_derived_params(args):
    if args.seed == -1:
        args.seed = random.randint(1, 10000)
    args.identifier = f"{args.dataset}_{args.sc_method}"
    eps = "inf"
    args.dp = False
    if args.epsilon != -1:
        args.dp = True
        eps = args.epsilon

    stop_at = ""
    ci = ""
    if args.early_stop_acc < 1:
        stop_at = f"_stop{args.early_stop_acc}"
    if args.class_imb < 1:
        ci = f"_ci{args.class_imb}"
    args.identifier = f"{args.identifier}_seed{args.seed}_epsilon{eps}{stop_at}{ci}"
    args.results_path = f"{args.base_results_path}{args.identifier}/"
    Path(args.results_path).mkdir(parents=True, exist_ok=True)

    return args


def seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        cudnn.benchmark = True


def train(
    args,
    net,
    train_loader,
    test_loader,
    opt,
    criterion,
    device,
    run,
    epoch,
    privacy_engine,
):
    loss = 0
    correct = 0
    test_preds = []
    for batch_idx, (data, target, indices) in enumerate(
        tqdm(train_loader, leave=False)
    ):
        net.train()
        data = data.to(device)
        if "utkface" in args.dataset:
            target = target[:, 1]
        target = target.to(device)
        opt.zero_grad()
        net.zero_grad()
        if args.sc_method == "sn":
            output, out_select, out_aux = net(data)
            selective_loss = criterion(output, out_select, target, device)
            selective_loss *= args.sn_alpha
            ce_loss = nn.CrossEntropyLoss()(out_aux, target)
            ce_loss *= 1.0 - args.sn_alpha
            l = selective_loss + ce_loss
        else:
            output = net(data)
            if (
                args.sc_method == "sr"
                or args.sc_method == "de"
                or args.sc_method == "nntd"
                or args.sc_method == "mcdo"
            ):
                l = criterion(output, target)
            elif args.sc_method == "sat":
                if epoch < args.sat_dg_pretrain:
                    output = output[:, :-1]
                    l = F.cross_entropy(output, target)
                else:
                    l = criterion(output, target, indices)
            elif args.sc_method == "dg":
                if epoch < args.sat_dg_pretrain:
                    output = output[:, :-1]
                    l = F.cross_entropy(output, target)
                else:
                    l = criterion(output, target, args.dg_reward)
        output = F.softmax(output, dim=1)
        entropy = -torch.sum(output * torch.log(output), dim=1)
        ent_reg = args.beta_entropy * torch.sum(entropy)
        l = l + ent_reg
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        loss += l.item()
        l.backward()
        opt.step()

        if batch_idx % args.batch_checkpoint_freq == 0:
            test_loss, test_accuracy, test_pred = test(
                args, net, test_loader, criterion, device, run, "test"
            )
            test_preds.append(test_pred)

    test_preds = torch.stack(test_preds, dim=1)
    loss /= len(train_loader.dataset)
    accuracy = correct / len(train_loader.dataset)
    if args.epsilon > -1:
        eps = privacy_engine.accountant.get_epsilon(delta=args.delta)
        run.log({"epsilon": eps})
    run.log({"train_loss": loss})
    run.log({"train_accuracy": accuracy})
    return loss, accuracy, test_loss, test_accuracy, test_preds


def optimizer(args, net):
    return optim.SGD(
        net.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
        momentum=args.momentum,
    )


def test(args, net, loader, criterion, device, run, mode="test"):
    net.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        predictions = []
        for data, target, indices in tqdm(loader, leave=False):
            data = data.to(device)
            if "utkface" in args.dataset:
                target = target[:, 1]
            target = target.to(device)
            if args.sc_method == "sn":
                output, out_select, out_aux = net(data)
                selective_loss = criterion(output, out_select, target, device)
                selective_loss *= args.sn_alpha
                ce_loss = nn.CrossEntropyLoss()(out_aux, target)
                ce_loss *= 1.0 - args.sn_alpha
                loss += (selective_loss + ce_loss).item()
                pred = output.argmax(dim=1, keepdim=True)
            else:
                output = net(data)
                if (
                    args.sc_method == "sr"
                    or args.sc_method == "de"
                    or args.sc_method == "nntd"
                    or args.sc_method == "mcdo"
                ):
                    loss += criterion(output, target).item()
                    pred = output.argmax(dim=1, keepdim=True)
                elif args.sc_method == "sat":
                    loss += F.cross_entropy(output[:, :-1], target).item()
                    pred = output[:, :-1].argmax(dim=1, keepdim=True)
                elif args.sc_method == "dg":
                    loss += criterion(output, target, args.dg_reward).item()
                    pred = output[:, :-1].argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            predictions.append(pred)
        predictions = torch.cat(predictions).flatten()
        loss /= len(loader.dataset)
        accuracy = correct / len(loader.dataset)
        run.log({f"{mode}/loss": loss})
        run.log({f"{mode}/accuracy": accuracy})
        return loss, accuracy, predictions


def get_sc_scores(args, net, loader, criterion, device):
    net.eval()

    if args.sc_method == "mcdo":

        def apply_dropout(m):
            if type(m) == nn.Dropout:
                m.train()

        net.apply(apply_dropout)

    loss = 0
    correct = 0
    with torch.no_grad():
        scores_targets = []
        true_targets = []
        full_targets = []
        predicted_targets = []
        softmaxes = []
        for data, target, indices in loader:
            data = data.to(device)
            if "utkface" in args.dataset:
                full_targets.append(target)
                target = target[:, 1]
            target = target.to(device)
            if args.sc_method == "sn":
                output, out_select, _ = net(data)
                output = F.softmax(output, dim=1)
                conf, pred = output.max(dim=1)
                softmaxes.append(output)
                if args.softmax_score:
                    scores_targets.append(1 - conf)
                else:
                    scores_targets.append(out_select)
            else:
                output = net(data)
                output = F.softmax(output, dim=1)
                if (
                    args.sc_method == "sr"
                    or args.sc_method == "de"
                    or args.sc_method == "nntd"
                    or args.sc_method == "mcdo"
                ):
                    conf, pred = output.max(dim=1)
                    scores_targets.append(1 - conf)
                    softmaxes.append(output)
                elif args.sc_method == "sat" or args.sc_method == "dg":
                    conf, pred = output[:, :-1].max(dim=1)
                    softmaxes.append(output[:, :-1])
                    if args.softmax_score:
                        scores_targets.append(1 - conf)
                    else:
                        scores_targets.append(output[:, -1])
            true_targets.append(target)
            predicted_targets.append(pred)
        true_targets = torch.cat(true_targets).flatten()
        predicted_targets = torch.cat(predicted_targets).flatten()
        scores_targets = torch.cat(scores_targets).flatten()
        softmaxes = torch.cat(softmaxes, dim=0)
        if "utkface" in args.dataset:
            full_targets = torch.cat(full_targets, dim=0)
            np.save(f"{args.results_path}full_targets.npy", full_targets.cpu().numpy())
        return true_targets, predicted_targets, softmaxes, scores_targets


def main(args):
    # Determine random seed in case it is not set
    args = determine_derived_params(args)

    # Initialize W&B calls
    run = init_wandb(args)
    args = SimpleNamespace(**wandb.config)

    # Call param derivation again since hyperparameter sweep from init_wandb might have overriden these
    args = determine_derived_params(args)

    # Print the full parameter dict
    print("Parameters")
    print(json.dumps(vars(args), indent=4, sort_keys=True))

    # Determine compute device
    if args.require_gpu and (not torch.cuda.is_available()):
        print("No GPU found, not running")
        exit(-1)
    device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"On {device}")

    # Set number of loops in case we are using Deep Ensembles or SelectiveNet
    if args.sc_method == "de":
        num_runs = args.ens_size
    elif args.sc_method == "sn":
        num_runs = len(args.sn_coverages)
    else:
        num_runs = 1

    ens_softmaxes_te = []
    ens_softmaxes_val = []
    ens_softmaxes_tr = []
    sn_coverages_val = []
    sn_accuracies_val = []
    sn_coverages_te = []
    sn_accuracies_te = []
    for r in range(num_runs):
        # Set seed for this random run
        seed(args.seed + r)
        if args.sc_method == "sn":
            seed(args.seed)
            args.sn_coverage = args.sn_coverages[r]
            print(f"SN coverage {args.sn_coverage}")

        # Load data
        (
            train_set,
            train_loader,
            train_loader_validation,
            val_set,
            validation_loader,
            test_set,
            test_loader,
            test_set_unnorm,
            num_classes,
            class_names,
        ) = load_data(args)

        num_effective_classes = (
            num_classes
            if args.sc_method == "sr"
            or args.sc_method == "de"
            or args.sc_method == "nntd"
            or args.sc_method == "mcdo"
            or args.sc_method == "sn"
            else num_classes + 1
        )

        # Initialize model
        if args.dataset == "breastcancer":
            net = SimpleNet(dims=[9, 9, num_effective_classes])
        elif args.dataset == "eicu":
            pass
        elif args.dataset == "mimic":
            pass
        elif "gauss" in args.dataset:
            net = LinearModel(2, num_effective_classes)
        elif args.dataset in ["cars", "food"]:
            net = models.resnet18(
                pretrained=args.pretrained, num_classes=num_effective_classes
            )
        else:
            net = resnet18(
                num_classes=num_effective_classes,
                device=device,
                is_selectivenet=True if args.sc_method == "sn" else False,
            )
            if "utkface" in args.dataset:
                net.conv1 = nn.Conv2d(
                    1, 64, kernel_size=3, stride=1, padding=1, bias=False
                )
            if args.dataset in ["mnist", "fashionmnist"]:
                net.conv1 = nn.Conv2d(
                    1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
                )

        if args.sc_method == "mcdo":
            net.fc = nn.Sequential(
                nn.Dropout(0.2),
                nn.Linear(
                    in_features=net.fc.in_features, out_features=net.fc.out_features
                ),
            )

        net = net.to(device)

        # Initialize optimizer
        opt = optimizer(args, net)

        privacy_engine = None
        if args.dp:
            # Set up Opacus wrapping
            net = ModuleValidator.fix(net)
            opt = optimizer(args, net)
            secure_rng = False
            privacy_engine = PrivacyEngine(secure_mode=secure_rng)

            args.effective_delta = 1 / len(train_loader.dataset)

            if args.sc_method == "de" or args.sc_method == "sn":
                args.effective_epsilon = args.epsilon / np.sqrt(num_runs)
                args.effective_delta = args.effective_delta / num_runs
            else:
                args.effective_epsilon = args.epsilon

            net, opt, train_loader = privacy_engine.make_private_with_epsilon(
                module=net,
                optimizer=opt,
                data_loader=train_loader,
                epochs=args.epochs,
                target_epsilon=args.effective_epsilon,
                target_delta=1 / len(train_loader.dataset),
                max_grad_norm=args.max_grad_norm,
            )

        print(net)

        # Add learning rate scheduler
        interval = args.schedule_interval
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            opt,
            milestones=[e + interval for e in range(0, args.epochs, interval)],
            gamma=args.schedule_gamma,
        )

        # Specify loss based on selective classification method
        if (
            args.sc_method == "sr"
            or args.sc_method == "de"
            or args.sc_method == "nntd"
            or args.sc_method == "mcdo"
        ):
            criterion = nn.CrossEntropyLoss()
        elif args.sc_method == "sat":
            criterion = SelfAdativeTraining(
                num_examples=len(train_set),
                num_classes=num_classes,
                mom=args.sat_momentum,
            )
        elif args.sc_method == "dg":
            criterion = deep_gambler_loss
        elif args.sc_method == "sn":
            criterion = SelectiveNetLoss(nn.CrossEntropyLoss(), args.sn_coverage)

        preds_tr = []
        preds_te = []
        preds_val = []

        intermediate_checkpoint_path = f"{args.results_path}checkpoint_run{r}.pt"

        # Handle preemtion
        if os.path.exists(intermediate_checkpoint_path):
            print("Job was preempted, loading checkpoint.")
            if args.dp:
                checkpoint = privacy_engine.load_checkpoint(
                    path=intermediate_checkpoint_path, module=net, optimizer=opt
                )
            else:
                checkpoint = torch.load(intermediate_checkpoint_path)
                net.load_state_dict(checkpoint["model_state_dict"])
                opt.load_state_dict(checkpoint["optimizer_state_dict"])
            net = net.to(device)
            args = checkpoint["args"]
            scheduler.load_state_dict(checkpoint["scheduler"])
            # preds_tr = [e.to(device) for e in checkpoint["preds_tr"]]
            preds_te = [e.to(device) for e in checkpoint["preds_te"]]
            preds_val = [e.to(device) for e in checkpoint["preds_val"]]
        else:
            print("Job is new.")
            preds_tr = []
            preds_te = []
            preds_val = []
            args.current_epoch = -1

        pbar = tqdm(range(args.current_epoch + 1, args.epochs))
        final_checkpoint_path = f"{args.results_path}model_epoch{args.epochs}_run{r}.pt"
        final_preds_te_path = (
            f"{args.results_path}preds_te_epoch{args.epochs}_run{r}.pt"
        )
        final_preds_tr_path = (
            f"{args.results_path}preds_tr_epoch{args.epochs}_run{r}.pt"
        )
        final_preds_val_path = (
            f"{args.results_path}preds_val_epoch{args.epochs}_run{r}.pt"
        )
        if not os.path.exists(final_checkpoint_path):
            print(
                f"No prior final model found. Training from epoch {args.current_epoch + 1} to epoch {args.epochs}."
            )
            # Train over number of epochs and evaluate on test after each epoch
            for epoch in pbar:
                current_epoch_checkpoint_path = (
                    f"{args.results_path}checkpoint_run{r}_epoch{epoch}.pt"
                )
                args.current_epoch = epoch
                train_loss, train_accuracy, test_loss, test_accuracy, test_pred = train(
                    args,
                    net,
                    train_loader,
                    test_loader,
                    opt,
                    criterion,
                    device,
                    run,
                    epoch,
                    privacy_engine,
                )
                # train_loss, train_accuracy, train_pred = test(
                #     args,
                #     net,
                #     train_loader_validation,
                #     criterion,
                #     device,
                #     run,
                #     "train",
                # )
                # test_loss, test_accuracy, test_pred = test(
                #     args, net, test_loader, criterion, device, run, "test"
                # )
                val_loss, val_accuracy, val_pred = test(
                    args, net, validation_loader, criterion, device, run, "val"
                )
                # preds_tr.append(train_pred)
                preds_te.append(test_pred)
                preds_val.append(val_pred)
                if args.log_checkpoints:
                    torch.save(
                        {
                            "args": args,
                            "model_state_dict": net.state_dict(),
                            "optimizer_state_dict": opt.state_dict(),
                        },
                        current_epoch_checkpoint_path,
                    )
                scheduler.step()
                if args.dp:
                    privacy_engine.save_checkpoint(
                        path=intermediate_checkpoint_path,
                        module=net,
                        optimizer=opt,
                        checkpoint_dict={
                            "args": args,
                            "scheduler": scheduler.state_dict(),
                            # "preds_tr": preds_tr,
                            "preds_val": preds_val,
                            "preds_te": preds_te,
                        },
                    )
                else:
                    torch.save(
                        {
                            "args": args,
                            "scheduler": scheduler.state_dict(),
                            "model_state_dict": net.state_dict(),
                            "optimizer_state_dict": opt.state_dict(),
                            # "preds_tr": preds_tr,
                            "preds_val": preds_val,
                            "preds_te": preds_te,
                        },
                        intermediate_checkpoint_path,
                    )
                pbar.set_postfix(
                    {
                        "tr_loss": train_loss,
                        "te_loss": test_loss,
                        # "val_loss": val_loss,
                        "tr_acc": train_accuracy,
                        "te_acc": test_accuracy,
                        "val_acc": val_accuracy,
                    }
                )
                run.log({"epoch": epoch})
                if args.early_stop_acc < 1:
                    if test_accuracy >= args.early_stop_acc:
                        break
            # preds_tr = torch.cat(preds_tr, dim=1)
            preds_te = torch.cat(preds_te, dim=1)
            preds_val = torch.cat(preds_val, dim=1)
            torch.save(
                preds_te,
                final_preds_te_path,
            )
            # torch.save(
            #     preds_tr,
            #     final_preds_tr_path,
            # )
            torch.save(
                preds_val,
                final_preds_val_path,
            )
            torch.save(
                {
                    "args": args,
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                },
                final_checkpoint_path,
            )
            os.remove(intermediate_checkpoint_path)
        else:
            print("Prior model found.")
            checkpoint = torch.load(final_checkpoint_path)
            net.load_state_dict(checkpoint["model_state_dict"])
            opt.load_state_dict(checkpoint["optimizer_state_dict"])
            # preds_tr = torch.load(final_preds_tr_path)
            preds_te = torch.load(final_preds_te_path)
            preds_val = torch.load(final_preds_val_path)

        if args.sc_method == "mcdo":
            num_random_runs_mcdo = args.ens_size
        else:
            num_random_runs_mcdo = 1

        mcdo_softmaxes_te = []
        mcdo_softmaxes_val = []
        # mcdo_softmaxes_tr = []

        for r in range(num_random_runs_mcdo):
            seed(args.seed + r)
            # (
            #     true_targets_tr,
            #     predicted_targets_tr,
            #     softmaxes_tr,
            #     scores_targets_tr,
            # ) = get_sc_scores(args, net, train_loader_validation, criterion, device)

            # Compute the score for selective classification for test set
            (
                true_targets_te,
                predicted_targets_te,
                softmaxes_te,
                scores_targets_te,
            ) = get_sc_scores(args, net, test_loader, criterion, device)

            (
                true_targets_val,
                predicted_targets_val,
                softmaxes_val,
                scores_targets_val,
            ) = get_sc_scores(args, net, validation_loader, criterion, device)

            mcdo_softmaxes_te.append(softmaxes_te)
            # mcdo_softmaxes_val.append(softmaxes_val)
            # mcdo_softmaxes_tr.append(softmaxes_tr)

        scores_targets_te, predicted_targets_te = compute_mean_conf_pred(
            mcdo_softmaxes_te
        )
        scores_targets_val, predicted_targets_val = compute_mean_conf_pred(
            mcdo_softmaxes_val
        )
        # scores_targets_tr, predicted_targets_tr = compute_mean_conf_pred(
        #     mcdo_softmaxes_tr
        # )

        # NNTD requires special handling
        if args.sc_method == "nntd":
            # e_t_tr, v_t_tr = get_mean_a_t(preds_tr)
            e_t_corr_te, v_t_corr_te = get_mean_a_t(
                preds_te[predicted_targets_te == true_targets_te]
            )
            e_t_incorr_te, v_t_incorr_te = get_mean_a_t(
                preds_te[predicted_targets_te != true_targets_te]
            )

            # save_et_vt_scores(e_t_tr.cpu(), v_t_tr.cpu(), args, "tr")
            save_et_vt_scores(e_t_corr_te.cpu(), v_t_corr_te.cpu(), args, "corr")
            save_et_vt_scores(e_t_incorr_te.cpu(), v_t_incorr_te.cpu(), args, "incorr")

            plot_e_t_metric(e_t_corr_te, e_t_incorr_te, run)
            plot_v_t_metric(v_t_corr_te, v_t_incorr_te, run)

            if args.nntd_max:
                scores_targets_val = calculate_nntd_max_score(preds_val, device, args)
                scores_targets_te = calculate_nntd_max_score(preds_te, device, args)
                # scores_targets_tr = calculate_nntd_max_score(preds_tr, device, args)
            else:
                scores_targets_val = calculate_nntd_sum_score(
                    preds_val,
                    args.nntd_eval_checkpoints_start,
                    args.nntd_eval_checkpoints_step,
                    device,
                    args,
                )
                scores_targets_te = calculate_nntd_sum_score(
                    preds_te,
                    args.nntd_eval_checkpoints_start,
                    args.nntd_eval_checkpoints_step,
                    device,
                    args,
                )
                # scores_targets_tr = calculate_nntd_sum_score(
                #     preds_tr,
                #     args.nntd_eval_checkpoints_start,
                #     args.nntd_eval_checkpoints_step,
                #     device,
                #     args,
                # )

            if args.nntd_eval_checkpoints_step == 1 and args.nntd_k == 3:
                if args.dataset == "food":
                    num_samp = 1000
                else:
                    num_samp = len(test_set_unnorm)
                print("Load all test points")
                accept_to_reject_ordering = (
                    torch.argsort(scores_targets_te[:num_samp]).cpu().numpy()
                )
                te_samp = []
                te_lab = []
                for i in tqdm(range(num_samp)):
                    img, target, index = test_set_unnorm[i]
                    te_samp.append(img)
                    te_lab.append(target)
                te_samp = np.stack(te_samp, axis=0)
                te_lab = np.array(te_lab)
                np.save(f"{args.results_path}labels.npy", te_lab)
                num_points = 20
                te_samp_accept = te_samp[accept_to_reject_ordering][:num_points]
                te_samp_reject = te_samp[accept_to_reject_ordering][::-1][:num_points]
                te_lab_accept = te_lab[accept_to_reject_ordering][:num_points]
                te_lab_reject = te_lab[accept_to_reject_ordering][::-1][:num_points]
                pred_accept = preds_te.cpu().numpy()[accept_to_reject_ordering][
                    :num_points
                ]
                pred_reject = preds_te.cpu().numpy()[accept_to_reject_ordering][::-1][
                    :num_points
                ]

                np.save(f"{args.results_path}te_samp_accept.npy", te_samp_accept)
                np.save(f"{args.results_path}te_samp_reject.npy", te_samp_reject)
                np.save(f"{args.results_path}te_lab_accept.npy", te_lab_accept)
                np.save(f"{args.results_path}te_lab_reject.npy", te_lab_reject)
                np.save(f"{args.results_path}pred_accept.npy", pred_accept)
                np.save(f"{args.results_path}pred_reject.npy", pred_reject)

                for i in range(20):
                    plt.plot(preds_te[i].cpu())
                    plt.title(f"{scores_targets_te[i].cpu()}")
                    plt.savefig(f"{args.results_path}preds_{i}.png")
                    plt.clf()

        ens_softmaxes_te.append(softmaxes_te)
        ens_softmaxes_val.append(softmaxes_val)
        # ens_softmaxes_tr.append(softmaxes_tr)

        if args.sc_method == "sn":
            sn_cov_val, sn_acc_val = sn_coverage_accuracy(
                true_targets_val, predicted_targets_val, scores_targets_val
            )
            sn_coverages_val.append(sn_cov_val)
            sn_accuracies_val.append(sn_acc_val.cpu())
            sn_cov_te, sn_acc_te = sn_accuracy_for_coverage(
                true_targets_te,
                predicted_targets_te,
                scores_targets_te,
                args.sn_coverage,
            )
            sn_coverages_te.append(sn_cov_te[0])
            sn_accuracies_te.append(sn_acc_te[0])

    # Aggregate scores across ensembles
    if args.sc_method == "de":
        scores_targets_te, predicted_targets_te = compute_mean_conf_pred(
            ens_softmaxes_te
        )
        scores_targets_val, predicted_targets_val = compute_mean_conf_pred(
            ens_softmaxes_val
        )
        # scores_targets_tr, predicted_targets_tr = compute_mean_conf_pred(
        #     ens_softmaxes_tr
        # )

    if args.sc_method == "sn":
        covs_at_fc_val = np.array(sn_coverages_val)
        accs_at_fc_val = np.array(sn_accuracies_val)
        covs_at_fc_te = np.array(sn_coverages_te)
        accs_at_fc_te = np.array(sn_accuracies_te)
    else:
        (
            _,
            _,
            covs_at_fc_val,
            accs_at_fc_val,
        ) = accuracy_coverage_tradeoff(
            true_targets_val, predicted_targets_val, scores_targets_val
        )
        _, _, covs_at_fc_te, accs_at_fc_te = accuracy_coverage_tradeoff(
            true_targets_te, predicted_targets_te, scores_targets_te
        )

    if args.sc_method == "nntd":
        if args.nntd_max:
            meth_app = "max"
        else:
            meth_app = f"sum_{args.nntd_k}_{args.nntd_eval_checkpoints_step}"

    else:
        meth_app = ""

    if args.class_imb < 1 or "utkface" in args.dataset:
        save_targets_preds_scores(
            true_targets_te, predicted_targets_te, scores_targets_te, args
        )

    save_scores(scores_targets_val.cpu(), args, f"val_{meth_app}")
    save_scores(scores_targets_te.cpu(), args, f"test_{meth_app}")

    save_scores(
        scores_targets_te[true_targets_te == predicted_targets_te].cpu(),
        args,
        f"test_corr_{meth_app}",
    )
    save_scores(
        scores_targets_te[true_targets_te != predicted_targets_te].cpu(),
        args,
        f"test_incorr_{meth_app}",
    )
    save_cov_acc_tradeoff(covs_at_fc_val, accs_at_fc_val, args, f"val_{meth_app}")
    save_cov_acc_tradeoff(covs_at_fc_te, accs_at_fc_te, args, f"test_{meth_app}")

    if "gauss" in args.dataset:
        plot_decision_boundary(
            args,
            device,
            train_set.tensors[0],
            train_set.tensors[1],
            net,
            steps=1000,
            cmap="RdBu",
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Selective Classification Under Differntial Privacy Constraints"
    )
    parser.add_argument(
        "--wandb-name",
        default="sc-dp",
        type=str,
        help="name of wandb project",
    )
    parser.add_argument(
        "--sweep",
        type=str,
        choices=[
            "nntd-k",
            "dg-reward",
        ],
    )
    parser.add_argument(
        "-d",
        "--dataset",
        default="cifar10",
        type=str,
        choices=[
            "cifar10",
            "cifar100",
            "svhn",
            "gtsrb",
            "mnist",
            "fashionmnist",
            "breastcancer",
            "mimic",
            "eicu",
            "food",
            "cars",
            "imagenet",
            "2d_gauss",
            "utkface_age",
            "utkface_race",
            "utkface_gender",
        ],
    )
    parser.add_argument(
        "--dataset-path",
        default="/mfsnic/u/stephan/datasets/",
        type=str,
        help="path to datasets",
    )
    parser.add_argument(
        "--base-results-path",
        default="/mfsnic/u/stephan/sc-dp/",
        type=str,
        help="path to results",
    )
    parser.add_argument(
        "--epochs",
        default=200,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--momentum", default=0.9, type=float, metavar="M", help="momentum"
    )
    parser.add_argument(
        "--sat-momentum", default=0.9, type=float, help="momentum for sat"
    )
    parser.add_argument(
        "--sat-dg-pretrain",
        default=100,
        type=int,
        metavar="N",
        help="num of cross-entropy pretraining runs",
    )
    parser.add_argument(
        "--train-batch", default=128, type=int, metavar="N", help="train batchsize"
    )
    parser.add_argument(
        "--test-batch", default=512, type=int, metavar="N", help="test batchsize"
    )
    parser.add_argument(
        "--lr",
        "--learning-rate",
        default=0.1,
        type=float,
        metavar="LR",
        help="initial learning rate",
    )
    parser.add_argument(
        "--weight-decay",
        "--wd",
        default=5e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
    )
    parser.add_argument(
        "--gpu-id", default="0", type=str, help="id(s) for CUDA_VISIBLE_DEVICES"
    )
    parser.add_argument("--seed", default=-1, type=int, help="seed")
    parser.add_argument(
        "--workers",
        default=4,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 4)",
    )
    parser.add_argument(
        "--epsilon", default=-1.0, type=float, metavar="M", help="target epsilon"
    )
    parser.add_argument("--delta", default=1e-5, type=float, metavar="M", help="delta")
    parser.add_argument(
        "--sigma", default=1.0, type=float, metavar="M", help="Noise multiplier"
    )
    parser.add_argument(
        "--max-grad-norm",
        default=10.0,
        type=float,
        metavar="M",
        help="Maximum per sample grad norm",
    )
    parser.add_argument(
        "--sc-method",
        default="sr",
        type=str,
        choices=["sr", "sat", "dg", "de", "sn", "nntd", "mcdo", "sn"],
    )
    parser.add_argument(
        "--dg-reward",
        default=2.2,
        type=float,
        metavar="M",
        help="deep gambler reward",
    )
    parser.add_argument(
        "--early-stop-acc",
        default=1.0,
        type=float,
        metavar="M",
        help="deep gambler reward",
    )
    parser.add_argument(
        "--sn-coverages",
        type=int,
        nargs="+",
        default=[0.1, 0.25, 0.5, 0.75, 1.0],
        help="selectivenet optimized coverages",
    )
    parser.add_argument(
        "--sn-alpha",
        default=0.5,
        type=float,
        metavar="M",
        help="tradeoff between selectivenet loss and cross entropy loss",
    )
    parser.add_argument(
        "--beta-entropy",
        default=0.01,
        type=float,
        metavar="M",
        help="beta for entropy regularization",
    )
    parser.add_argument(
        "--class-imb",
        default=1,
        type=float,
        metavar="M",
        help="class imbalance for first class",
    )
    parser.add_argument(
        "--ens-size",
        default=5,
        type=int,
        metavar="N",
        help="ensemble size",
    )
    parser.add_argument(
        "--schedule-interval",
        default=25,
        type=int,
        metavar="N",
        help="learning rate scheduling interval",
    )
    parser.add_argument(
        "--schedule-gamma",
        default=0.5,
        type=float,
        metavar="M",
        help="learning rate scheduling gamma",
    )
    parser.add_argument(
        "--nntd-k",
        default=3,
        type=int,
        metavar="M",
        help="nntd weighting exponent",
    )
    parser.add_argument(
        "--nntd-eval-checkpoints-step",
        default=1,
        type=int,
        metavar="N",
        help="Checkpoint step for evaluation",
    )
    parser.add_argument(
        "--nntd-eval-checkpoints-start",
        default=0,
        type=int,
        metavar="N",
        help="Checkpoint start for evaluation",
    )
    parser.add_argument(
        "--softmax-score",
        action="store_true",
        help="Report SC score from softmax and not from individual selection mechanism",
    )
    parser.add_argument(
        "--pretrained", action="store_true", help="Start from pretrained model"
    )
    parser.add_argument(
        "--require-gpu", action="store_true", help="Require a GPU, crash otherwise"
    )
    parser.add_argument(
        "--nntd-max",
        action="store_true",
        help="Use NNTD max score instead of sum score",
    )
    parser.add_argument(
        "--log-checkpoints",
        action="store_true",
        help="Store checkpoints after every epoch",
    )
    parser.add_argument(
        "--val-frac",
        default=0.1,
        type=float,
        metavar="M",
        help="Fraction of validation points to be taken from training set",
    )
    parser.add_argument(
        "--batch-checkpoint-freq",
        default=50,
        type=int,
        metavar="N",
        help="After how many batches should we create a checkpoint",
    )
    args = parser.parse_args()

    if args.sweep is None:
        main(args)
    elif args.sc_method == "nntd" and args.sweep == "nntd-k":
        sweep_configuration = {
            "name": f"{args.dataset}_{args.sc_method}_{args.sweep}",
            "method": "grid",
            "metric": {"goal": "minimize", "name": "val/sc_performance"},
            "parameters": {
                # 'nntd_k': {'min': 1, 'max': 20},
                "nntd_k": {"values": [1, 2, 3, 4, 5, 8, 10, 15, 20]},
            },
        }
        sweep_id = wandb.sweep(sweep=sweep_configuration, project="sc-dp")
        wandb.agent(sweep_id, function=lambda: main(args))
    elif args.sc_method == "dg" and args.sweep == "dg-reward":
        sweep_configuration = {
            "name": f"{args.dataset}_{args.sc_method}_{args.sweep}",
            "method": "grid",
            "metric": {"goal": "minimize", "name": "val/sc_performance"},
            "parameters": {
                # 'nntd_k': {'min': 1, 'max': 20},
                "dg_reward": {
                    "values": [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 7.5, 10, 20, 50, 100]
                },
            },
        }
        sweep_id = wandb.sweep(sweep=sweep_configuration, project="sc-dp")
        wandb.agent(sweep_id, function=lambda: main(args))
