from simulator import Simulate_multi
import random
import numpy as np
import torch
import argparse
import torch.nn as nn


def parse_args():
    parser = argparse.ArgumentParser()

    # Dataset
    parser.add_argument(
        "--data-dir",
        type=str,
        default="./data/",
        help="where to download dataset",
    )

    # Output/logging file
    parser.add_argument(
        "--log-interval",
        type=int,
        default=20,
        help="how many steps between logging to the console",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./output",
        help="where to save out the model, must be an existing directory.",
    )

    parser.add_argument(
        "--detailed-aum",
        action="store_true",
        help="if set, the AUM calculations will be done in non-compressed mode",
    )

    # Optimizer params
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=0.1,
        help="optimizer learning rate",
    )
    parser.add_argument(
        "--momentum", type=float, default=0.9, help="momentum for optimizer"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=1e-4, help="decay for optimizer"
    )

    # Training Regime params
    parser.add_argument(
        "--num-epochs",
        type=int,
        default=150,
        help="number of epochs to train over",
    )
    parser.add_argument(
        "--train-batch-size",
        type=int,
        default=64,
        help="size of training batch",
    )
    parser.add_argument("--method", action="append")
    # Validation Regime params
    parser.add_argument(
        "--val-batch-size", type=int, default=64, help="size of val batch"
    )
    parser.add_argument("--nw", type=int, default=30, help="Number of workers")
    parser.add_argument("--architecture", type=str, default="logreg")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()
    return args


def set_seed(seed: int):
    """
    Sets random, numpy, torch, and torch.cuda seeds
    """

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class AverageMeter(object):
    """
    Computes and stores the average and current value
    Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def get_train_val_test(seed, nw=30, villain=0.2, nt=500):
    citizen = 1 - villain
    df_train = Simulate_multi(K=8, nt=nt, random_state=seed + 3, dmax=0.5)
    df_train()
    df_train.make_crowd(
        nw=nw,
        n_villain=villain,
        n_citizen=citizen,
        n_wizard=0,
        n_warrior=0,
    )
    df_train.crowd_answers(rep_max=1)
    df_valid = Simulate_multi(K=8, nt=30, random_state=seed + 3, dmax=0.5)
    df_valid()
    nw = 30
    df_valid.make_crowd(
        nw=nw,
        n_villain=villain,
        n_citizen=citizen,
        n_wizard=0,
        n_warrior=0,
    )
    df_valid.crowd_answers(rep_max=1)
    df_test = Simulate_multi(K=8, nt=100, random_state=seed + 3, dmax=0.5)
    df_test()
    nw = 30
    df_test.make_crowd(
        nw=nw,
        n_villain=villain,
        n_citizen=citizen,
        n_wizard=0,
        n_warrior=0,
    )
    df_valid.crowd_answers(rep_max=1)
    return df_train, df_valid, df_test


def expected_calibration_error(y_pred, y_true, num_bins=15):
    pred_y = np.argmax(y_pred, axis=-1)
    correct = pred_y == y_true
    prob_y = np.max(y_pred, axis=-1)

    b = np.linspace(start=0, stop=1.0, num=num_bins)
    bins = np.digitize(prob_y, bins=b, right=True)

    o = 0
    for b in range(num_bins):
        mask = bins == b
        if np.any(mask):
            o += np.abs(np.sum(correct[mask] - prob_y[mask]))

    return o / y_pred.shape[0]


class Toy_dataset_red(torch.utils.data.Dataset):
    def __init__(
        self,
        tasks,
        ans,
        workers,
        true_idx,
        transform=None,
        target_transform=None,
    ):
        self.transform = transform
        self.tasks = tasks
        self.ans = ans
        self.true_idx = true_idx
        self.workers = workers
        self.target_transform = target_transform

    def __len__(self):
        return self.tasks.shape[0]

    def __getitem__(self, idx):
        image = self.tasks[idx].reshape(-1, 1, 1)
        label = self.ans[idx]
        true_idx = self.true_idx[idx]
        worker = self.workers[idx]
        if self.transform:
            image = self.transform(image).type(torch.FloatTensor)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label, worker, true_idx, idx


class Toy_dataset(torch.utils.data.Dataset):
    def __init__(self, tasks, truth, transform=None, target_transform=None):
        self.transform = transform
        self.tasks = tasks
        self.truth = truth
        self.target_transform = target_transform

    def __len__(self):
        return self.tasks.shape[0]

    def __getitem__(self, idx):
        image = self.tasks[idx].reshape(-1, 1, 1)
        label = self.truth[idx]
        if self.transform:
            image = self.transform(image).type(torch.FloatTensor)
        if self.target_transform:
            label = self.target_transform(label)
        return image, "lab", label, idx


class model(nn.Module):
    def __init__(self, input_dim=2, output_dim=2):
        super(model, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.linear = torch.nn.Linear(input_dim, 30)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(30, 20)
        self.fc2 = nn.Linear(20, output_dim)

    def forward(self, x):
        out = self.linear(x.reshape(-1, self.input_dim))
        out = self.relu(out)
        out = self.fc1(out)
        out = self.relu(out)
        outputs = self.fc2(out)
        return outputs
