import os
import sys
import json
import torch
import hashlib
import submitit
import argparse
import torchvision

import numpy as np
from PIL import Image


def mean(x):
    return sum(x) / len(x)


def parse_args():
    parser = argparse.ArgumentParser(description='CIFAR10 training')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--debug', action="store_true")
    parser.add_argument('--data_dir', default="data/", type=str)
    parser.add_argument('--output_dir', default="results/jobs_01/", type=str)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--method', default="erm", type=str)
    parser.add_argument('--alpha', default=1, type=float)
    parser.add_argument('--num_epochs', default=500, type=int)
    parser.add_argument('--eval_every', default=10, type=int)
    parser.add_argument('--random_labels', default=0, type=int)
    parser.add_argument(
        '--class_probs', default=[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        nargs="+", type=float)
    parser.add_argument('--n_seeds', default=100, type=int)
    return dict(vars(parser.parse_args()))


class Tee:
    def __init__(self, fname, stream, mode="a+"):
        self.stream = stream
        self.file = open(fname, mode)

    def write(self, message):
        self.stream.write(message)
        self.file.write(message)
        self.flush()

    def flush(self):
        self.stream.flush()
        self.file.flush()


class CIFAR10NEG:
    def __init__(self, root, train=None, download=None, transform=None):
        data = np.load(os.path.join(root, "cifar10_neg.npz"))
        self.inputs = data["data"]
        self.targets = data["labels"]
        self.transform = transform

    def __getitem__(self, index):
        x = self.transform(Image.fromarray(self.inputs[index]))
        y = self.targets[index]
        return x, y

    def __len__(self):
        return len(self.targets)


class MyCIFAR10:
    def __init__(self, root, train, transform,
                 class_probs=None, neg=False, random_labels=0, one_hots=True):
        if neg:
            Dataset = CIFAR10NEG
        else:
            Dataset = torchvision.datasets.CIFAR10

        self.dataset = Dataset(
          root=root, train=train, download=False, transform=transform)

        if class_probs is None:
            class_probs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        binary_probs = [float(cp > 0) for cp in class_probs]
        self.num_classes = int(sum(binary_probs))

        # gather as much data as possible for test
        if not train:
            class_probs = binary_probs

        self.indices = []
        for i, target in enumerate(self.dataset.targets):
            if torch.zeros(1).bernoulli_(class_probs[target]).item():
                self.indices.append(i)

        self.transform_label = {}
        counter = 0
        for i, bp in enumerate(binary_probs):
            if bp:
                self.transform_label[i] = counter
                counter += 1

        self.random_labels = None
        if random_labels > 0:
            self.num_classes = random_labels
            self.random_labels = torch.zeros(
                len(self.dataset)).random_(random_labels).long().tolist()

        self.one_hots = one_hots

    def __getitem__(self, i):
        x, y = self.dataset[self.indices[i]]
        y = self.transform_label[y]
        if self.random_labels is not None:
            y = self.random_labels[self.indices[i]]
        if self.one_hots:
            y = torch.nn.functional.one_hot(
                torch.LongTensor([y]), self.num_classes).view(-1).float()
        return x, y

    def __len__(self):
        return len(self.indices)


def accuracy(network, loader, device="cuda"):
    correct = {k: 0 for k in range(loader.dataset.num_classes)}
    total = {k: 0 for k in range(loader.dataset.num_classes)}
    for x, y in loader:
        t = y.argmax(1)
        hits = network(x.to(device)).detach().cpu().argmax(1).eq(t)
        for hi, ti in zip(hits, t):
            correct[ti.item()] += hi.item()
            total[ti.item()] += 1

    if total[loader.dataset.num_classes - 1] == 0:
        del correct[loader.dataset.num_classes - 1]

    return {k: correct[k] / total[k] for k in correct.keys()}


class SoftCrossEntropyLoss(torch.nn.Module):
    def __init__(self, reduction="mean"):
        super(SoftCrossEntropyLoss, self).__init__()
        self.lsm = torch.nn.LogSoftmax(dim=1)
        self.reduction = reduction

    def forward(self, inputs, targets, logits=True):
        inputs = inputs[:, :targets.size(1)]
        if logits:
            losses = (targets * self.lsm(inputs)).sum(1).mul(-1)
        else:
            losses = (targets * inputs.log()).sum(1).mul(-1)

        if self.reduction == "mean":
            return losses.mean()
        else:
            return losses.sum()


def get_network(num_classes):
    network = torchvision.models.resnet.resnet50(num_classes=num_classes)
    network.conv1 = torch.nn.Conv2d(
        3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    return network


def get_loaders(data_dir, batch_size, class_probs=None, random_labels=0,
                in_evaluation=False):
    transform_train = torchvision.transforms.Compose([
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.4914, 0.4822, 0.4465),
            (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            (0.4914, 0.4822, 0.4465),
            (0.2023, 0.1994, 0.2010)),
    ])

    data_tr = MyCIFAR10(data_dir, True,
                        transform_test if in_evaluation else transform_train,
                        class_probs, False, random_labels)
    data_te = MyCIFAR10(data_dir, False, transform_test,
                        class_probs, False, random_labels)
    data_ne = MyCIFAR10(data_dir, False, transform_test,
                        class_probs, True, random_labels)

    loader_tr = torch.utils.data.DataLoader(
        data_tr, batch_size=batch_size, shuffle=not in_evaluation,
        num_workers=2)

    loader_te = torch.utils.data.DataLoader(
        data_te, batch_size=batch_size, shuffle=False, num_workers=2)

    loader_ne = torch.utils.data.DataLoader(
        data_ne, batch_size=batch_size, shuffle=False, num_workers=2)

    return loader_tr, loader_te, loader_ne


def mix_data(x, y, method, alpha):
    def mix_(x1, y1, x2, y2, lam):
        xm = lam * x1 + (1 - lam) * x2
        ym = lam * y1 + (1 - lam) * y2
        return xm, ym

    lam = torch.distributions.beta.Beta(alpha, alpha).sample().item()
    per = torch.randperm(len(x))
    x1, y1, x2, y2 = x, y, x[per], y[per]

    if method == "erm":
        xm, ym = x1, y1
    elif method == "mixup":
        xm, ym = mix_(x1, y1, x2, y2, lam)
    elif method == "extra1":
        if torch.zeros(1).bernoulli_(0.5).item():
            xm, ym = mix_(x1, y1, x2, y2, lam)
        else:
            um = torch.zeros_like(y).fill_(1. / y.size(1))
            xm, ym = mix_(x1, y1, 2 * x1 - x2, um, lam)
    elif method == "extra2":
        y1 = torch.cat((y1, torch.zeros(y1.size(0), 1)), -1)
        y2 = torch.cat((y2, torch.zeros(y2.size(0), 1)), -1)

        if torch.zeros(1).bernoulli_(0.5).item():
            xm, ym = mix_(x1, y1, x2, y2, lam)
        else:
            um = torch.zeros_like(y1)
            um[:, -1] = 1
            xm, ym = mix_(x1, y1, 2 * x1 - x2, um, lam)
    else:
        raise NotImplementedError

    return xm.detach(), ym.detach()


def train_model(args):
    os.makedirs(args["output_dir"], exist_ok=True)

    run_hash = os.path.join(args["output_dir"], hashlib.md5(
        json.dumps(args, sort_keys=True).encode('utf-8')).hexdigest())

    sys.stdout = Tee(run_hash + ".train.out", sys.stdout)
    sys.stderr = Tee(run_hash + ".train.err", sys.stderr)

    torch.manual_seed(args["seed"])

    loader_tr, loader_te, loader_ne = get_loaders(
        args["data_dir"], args["batch_size"],
        args["class_probs"], args["random_labels"])

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    network = get_network(
        loader_tr.dataset.num_classes + int(args["method"] == "extra2"))
    network.to(device)

    loss = SoftCrossEntropyLoss()
    loss.to(device)

    optimizer = torch.optim.SGD(
        network.parameters(), lr=args["lr"], momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 150, 0.1)

    acc_te_best = 0
    history = []
    for epoch in range(args["num_epochs"]):
        for x, y in loader_tr:
            xm, ym = mix_data(x, y, args["method"], args["alpha"])
            loss_value = loss(network(xm.to(device)), ym.to(device))
            optimizer.zero_grad()
            loss_value.backward()
            optimizer.step()

        scheduler.step()

        if (epoch == 0) or ((epoch + 1) % args["eval_every"] == 0):
            network.eval()
            acc_tr = mean(accuracy(network, loader_tr, device).values())
            acc_te = mean(accuracy(network, loader_te, device).values())
            acc_ne = mean(accuracy(network, loader_ne, device).values())
            stats = {"args": args, "epoch": epoch,
                     "acc_tr": acc_tr, "acc_te": acc_te, "acc_ne": acc_ne}
            print(json.dumps(stats))
            history.append(stats)
            network.train()

            if acc_te > acc_te_best:
                torch.save((network.state_dict(), history),
                           "{}.train.best.pt".format(run_hash))
                acc_te_best = acc_te

    torch.save((network.state_dict(), history),
               "{}.train.last.pt".format(run_hash))


def run_jobs(function, commands):
    executor = submitit.SlurmExecutor(folder="submitit/")
    executor.update_parameters(
        time=3 * 24 * 60,
        gpus_per_node=1,
        array_parallelism=512,
        cpus_per_task=8,
        partition="partition")
    executor.map_array(function, commands)


def run_sweep(args):
    commands = []
    for method in ("erm", "mixup", "extra1", "extra2"):
        args["method"] = method
        for seed in range(args["n_seeds"]):
            args["seed"] = seed
            commands.append(dict(args))

    run_jobs(train_model, commands)


if __name__ == "__main__":
    args = parse_args()
    if args["debug"]:
        args["output_dir"] = "results/debug/"
        args["num_epochs"] = 2
        args["eval_every"] = 1
        args["seed"] = 0
        train_model(args)
    else:
        run_sweep(args)
