import random
import numpy as np
import torch
import argparse
from cifar10h_data import CIFAR10H
from torch.utils.data import Dataset


def parse_args():
    parser = argparse.ArgumentParser()

    # Dataset
    parser.add_argument(
        "--data-dir",
        type=str,
        default="./data/",
        help="where to download dataset",
    )

    # Output/logging file
    parser.add_argument(
        "--log-interval",
        type=int,
        default=20,
        help="how many steps between logging to the console",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./output",
        help="where to save out the model, must be an existing directory.",
    )

    parser.add_argument(
        "--detailed-aum",
        action="store_true",
        help="if set, the AUM calculations will be done in non-compressed mode",
    )

    # Optimizer params
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=0.1,
        help="optimizer learning rate",
    )
    parser.add_argument(
        "--momentum", type=float, default=0.9, help="momentum for optimizer"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=1e-4, help="decay for optimizer"
    )

    # Training Regime params
    parser.add_argument(
        "--num-epochs",
        type=int,
        default=150,
        help="number of epochs to train over",
    )
    parser.add_argument(
        "--train-batch-size",
        type=int,
        default=64,
        help="size of training batch",
    )
    parser.add_argument("--method", action="append")
    # Validation Regime params
    parser.add_argument(
        "--val-batch-size", type=int, default=64, help="size of val batch"
    )
    parser.add_argument("--architecture", type=str, default="logreg")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()
    return args


def set_seed(seed: int):
    """
    Sets random, numpy, torch, and torch.cuda seeds
    """

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class AverageMeter(object):
    """
    Computes and stores the average and current value
    Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def get_train_val_test(data_dir, train_transform=None, target_transform=None):

    df_train = CIFAR10H(
        data_dir,
        which_set="train",
        train=True,
        transform=train_transform,
        target_transform=target_transform,
        download=True,
    )

    df_val = CIFAR10H(
        data_dir,
        which_set="val",
        train=True,
        transform=train_transform,
        target_transform=target_transform,
        download=True,
    )
    df_test = CIFAR10H(
        data_dir,
        which_set="test",
        train=False,
        transform=train_transform,
        target_transform=target_transform,
        download=True,
    )
    return df_train, df_val, df_test


def expected_calibration_error(y_pred, y_true, num_bins=15):
    pred_y = np.argmax(y_pred, axis=-1)
    correct = pred_y == y_true
    prob_y = np.max(y_pred, axis=-1)

    b = np.linspace(start=0, stop=1.0, num=num_bins)
    bins = np.digitize(prob_y, bins=b, right=True)

    o = 0
    for b in range(num_bins):
        mask = bins == b
        if np.any(mask):
            o += np.abs(np.sum(correct[mask] - prob_y[mask]))

    return o / y_pred.shape[0]


class DatasetWithIndex(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, index):
        return (*self.base_dataset[index], index)
