import argparse
import random
import torch
import glob
import tqdm
import os

from code.train import get_loaders
from code.evaluate import Ensemble, get_method, evaluate_loader


def get_curves(fnames, n=100, steps=50, do_curves=True):
    model = Ensemble(fnames)
    model.eval()

    args = torch.load(fnames[0])[1][0]["args"]

    loader_tr, _, _ = get_loaders(
        args["data_dir"], n,
        args["class_probs"], args["random_labels"],
        in_evaluation=True)

    if do_curves:
        curves = torch.zeros(n, steps)
        with torch.no_grad():
            for x, y in loader_tr:
                per = torch.randperm(len(x))
                x0, x1 = x, x[per]
                y1 = y.argmax(1)[per]
                for l_, lam in enumerate(torch.linspace(0, 1, steps)):
                    p = model(lam * x1 + (1 - lam) * x0).cpu()
                    p = [p[pi, yi].item() for pi, yi in zip(range(len(p)), y1)]
                    curves[:, l_] = torch.Tensor(p)
                break

        return curves
    else:
        with torch.no_grad():
            return {len(model): evaluate_loader(model, loader_tr,
                    alpha=args["alpha"], what="loss")}


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot curves")
    parser.add_argument("--output_dir", default="results/jobs_100/", type=str)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--num_trials", default=10, type=int)
    parser.add_argument("--num_members", default=10, type=int)
    parser.add_argument("--num_evals", default=100, type=int)
    parser.add_argument("--epoch", default="best", type=str)
    args = vars(parser.parse_args())

    torch.manual_seed(0)
    random.seed(0)

    raw_fnames = glob.glob(os.path.join(args["output_dir"], "*.train.out"))
    random.shuffle(raw_fnames)

    fnames = {}
    for fname in raw_fnames:
        method = get_method(fname)
        if method not in fnames:
            fnames[method] = []
        fnames[method].append(
            [fname[:-10] + ".train." + args["epoch"] + ".pt"])

    curves_erm = []
    for erm in tqdm.tqdm(range(args["num_trials"])):
        curves_erm.append(get_curves(fnames["erm"][erm]))

    curves_mix = []
    for mix in tqdm.tqdm(range(args["num_trials"])):
        curves_mix.append(get_curves(fnames["mixup"][mix]))

    curves_ens = []
    for nt in tqdm.tqdm(range(args["num_trials"])):
        ens_fnames = []
        for e in random.sample(fnames["erm"], args["num_members"]):
            ens_fnames += e
        curves_ens.append(get_curves(ens_fnames))

    torch.save((curves_erm, curves_mix, curves_ens), "curves.pt")

    loss_results = []
    for nm in range(1, 20 + 1):
        for nt in range(args["num_trials"]):
            ens_fnames = []
            for e in random.sample(fnames["erm"], nm):
                ens_fnames += e
            loss_results.append(get_curves(ens_fnames, do_curves=False))
            print(loss_results[-1])

    torch.save(loss_results, "losses.pt")
