from code.train import SoftCrossEntropyLoss, get_loaders, get_network
from code.train import mean, accuracy, run_jobs, Tee
from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
import torchvision
import argparse
import hashlib
import random
import torch
import json
import glob
import sys
import os


def find_prototypes(single_model_fname, k=5):
    state_dict, history = torch.load(single_model_fname)
    args = history[0]["args"]

    loader_tr, loader_te, loader_ne = get_loaders(
        args["data_dir"],
        args["batch_size"],
        args["class_probs"],
        args["random_labels"],
        in_evaluation=True)

    network = get_network(loader_tr.dataset.num_classes)
    network.load_state_dict(state_dict)
    network.fc = torch.nn.Identity()
    network.eval()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    network.to(device)

    all_f = []
    for x, _ in loader_tr:
        all_f.append(network(x.to(device)).cpu().detach())
    all_f = torch.cat(all_f)

    km = KMeans(n_clusters=k).fit(all_f)

    all_c = torch.Tensor(km.cluster_centers_)

    return torch.cdist(all_f, all_c).min(0).indices.tolist()


class Member(torch.nn.Module):
    def __init__(self, fname):
        super().__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        state_dict, history = torch.load(fname)
        args = history[0]["args"]
        self.method = args["method"]

        num_classes = sum([p > 0 for p in args["class_probs"]])
        self.network = get_network(
            num_classes + int(self.method == "extra2"))
        self.network.load_state_dict(state_dict)
        self.loss = SoftCrossEntropyLoss(reduction="sum")

    def predict_(self, x):
        return self.network(x).softmax(dim=1)

    def grad(self, x, y, loss=True):
        x.requires_grad = True
        x = x.to(self.device)
        y = y.to(self.device)
        self.network.to(self.device)

        if loss:
            term = self.loss(self.predict_(x), y, logits=False)
        else:
            term = self.network(x).sum()

        grad = torch.autograd.grad(term, x, create_graph=True)[0]
        self.network.to("cpu")
        return grad.detach().cpu()

    def forward(self, x):
        self.network.to(self.device)
        prediction = self.predict_(x.to(self.device))
        self.network.to("cpu")
        return prediction.detach().cpu()


class Ensemble(torch.nn.Module):
    def __init__(self, fnames, init_models=True):
        super().__init__()
        self.members = torch.nn.ModuleList()
        if init_models:
            for fname in fnames:
                self.members.append(Member(fname))

    def method(self):
        return self.members[0].method

    def __len__(self):
        return len(self.members)

    def __getitem__(self, i):
        return self.members[i]

    def loss_grad(self, x, y):
        return mean([member.grad(x, y, loss=True) for member in self])

    def jacobian(self, x, y):
        return mean([member.grad(x, y, loss=False) for member in self])

    def member_predictions(self, x):
        return [member(x) for member in self.members]

    def forward(self, x):
        return mean(self.member_predictions(x))


class MemberTTM(Member):
    def __init__(self, fname, x, y, x_adv, y_adv, lam, aug=False):
        super().__init__(fname)
        mean = (0.4914, 0.4822, 0.4465)
        stdv = (0.2023, 0.1994, 0.2010)
        self.aug = torchvision.transforms.Lambda(lambda x: x)
        if aug:
            self.aug = torchvision.transforms.Compose([
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip()
            ])

        self.x = x.clone().unsqueeze(0)
        self.y = y.clone().unsqueeze(0)
        self.p = self.network(self.x).detach().cpu()

        self.x_adv = x_adv.clone().unsqueeze(0)
        self.y_adv = y_adv.clone().unsqueeze(0)
        self.p_adv = self.network(self.x_adv).detach().cpu()

        self.use_attacker = False
        self.use_y = False
        self.lam = lam

    def predict_(self, x):
        if self.use_attacker:
            ref_x = self.x_adv
            ref_l = self.y_adv if self.use_y else self.p_adv
        else:
            ref_x = self.x
            ref_l = self.y if self.use_y else self.y

        pm = self.network(self.lam * self.aug(x) + (1 - self.lam) *
                          ref_x.to(self.device))
        return ((pm - (1 - self.lam) *
                ref_l.to(self.device)) / self.lam).softmax(dim=1)


class EnsembleTTM(Ensemble):
    def __init__(self, fnames, loader, k=10, lam=0.8, aug=False):
        super().__init__(fnames, init_models=False)
        indices = find_prototypes(fnames[0], k)
        indices_r = torch.randperm(len(loader.dataset))
        for i, j in zip(indices, indices_r):
            self.members.append(
                MemberTTM(
                    fnames[0],
                    *loader.dataset[i],
                    *loader.dataset[j],
                    lam,
                    aug))


class MemberTTERM(Member):
    def __init__(self, fname):
        super().__init__(fname)
        mean = (0.4914, 0.4822, 0.4465)
        stdv = (0.2023, 0.1994, 0.2010)
        self.aug = torchvision.transforms.Compose([
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip()
        ])

    def predict_(self, x):
        return self.network(self.aug(x)).softmax(dim=1)


class EnsembleTTERM(Ensemble):
    def __init__(self, fnames, k=10):
        super().__init__(fnames, init_models=False)
        for _ in range(k):
            self.members.append(MemberTTERM(fnames[0]))


def evaluate_ece(model, loader, n_bins=15):
    predictions = []
    targets = []
    for x, y in loader:
        predictions.append(model(x))
        targets.append(y)

    predictions = torch.cat(predictions)
    targets = torch.cat(targets)

    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    softmaxes = predictions
    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(targets.argmax(1))

    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()
                                ) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_conf_in_bin = confidences[in_bin].mean()
            ece += (avg_conf_in_bin - accuracy_in_bin).abs() * prop_in_bin

    return ece.item()


def evaluate_loader(model, loader, alpha=None, eps=0.1, what="loss"):
    loss = SoftCrossEntropyLoss(reduction="sum")
    value = 0
    total = 0

    for x, y in loader:
        if alpha is not None:
            lam = torch.distributions.beta.Beta(alpha, alpha).sample()
            per = torch.randperm(len(x))
            xm = lam * x + (1 - lam) * x[per]
            ym = lam * y + (1 - lam) * y[per]
        else:
            xm, ym = x, y

        if what == "jacobian":
            with torch.enable_grad():
                value += model.jacobian(xm, ym).norm().log().item()
        elif what == "loss":
            value += loss(model(xm), ym, logits=False).item()
        elif what == "inter_pred":
            value += (
                model(xm) -
                (lam * model(x) + (1 - lam) * model(x[per]))).norm().item()
        elif what == "inter_label":
            preds = model(xm)
            if model.method() == "extra2":
                preds = preds[:, :-1]
            value += (
                preds -
                (lam * y + (1 - lam) * y[per])).norm().item()
        elif what == "adversarial":
            with torch.enable_grad():
                x_adv = (x + eps * model.loss_grad(x, y).sign())
                t = y.argmax(1)
                value += model(x_adv).argmax(1).eq(t).float().sum().item()
        else:
            raise NotImplementedError

        total += len(x)

    return value / total


def evaluate_ood_detection(model, loader_in, loader_out):
    def largest(sm):
        return sm.max(1).values.view(-1).mul(-1)

    def last(sm):
        return sm[:, -1]

    function = largest
    if model.method() == "extra2":
        function = last

    y_true = []
    y_pred = []

    for x, _ in loader_in:
        y_true.append(torch.zeros(len(x)))
        y_pred.append(function(model(x)).view(-1))

    for x, _ in loader_out:
        y_true.append(torch.ones(len(x)))
        y_pred.append(function(model(x)).view(-1))

    return roc_auc_score(
        torch.cat(y_true).view(-1), torch.cat(y_pred).view(-1))


def evaluate_diversity(model, loader):
    diversity = torch.zeros(len(model), len(model))
    predictions = []
    for _ in range(len(model)):
        predictions.append([])

    for x, _ in loader:
        these_predictions = model.member_predictions(x)
        for p in range(len(these_predictions)):
            predictions[p].append(these_predictions[p])

    for p, prediction in enumerate(predictions):
        predictions[p] = torch.cat(predictions[p])

    for p1 in range(len(predictions)):
        for p2 in range(len(predictions)):
            diversity[p1, p2] = torch.nn.functional.cosine_similarity(
                predictions[p1], predictions[p2]).mean()

    return (diversity - torch.eye(diversity.size(1))).norm().item()


def evaluate_model(fnames, eval_dir, ensemble_type="normal", k=10, lam=0.8):
    kwargs = "{}_{}_{}".format(ensemble_type, k, lam)
    filname = hashlib.md5(
        "_".join(fnames + [kwargs]).encode("utf-8")).hexdigest()

    outfile = os.path.join(eval_dir, filname)
    if os.path.exists(outfile + ".eval.json"):
        return None

    args = torch.load(fnames[0])[1][0]["args"]

    sys.stdout = Tee(outfile + ".eval.out", sys.stdout)
    sys.stderr = Tee(outfile + ".eval.err", sys.stderr)

    loader_tr, loader_te, loader_ne = get_loaders(
        args["data_dir"], args["batch_size"],
        args["class_probs"], args["random_labels"],
        in_evaluation=True)

    class_probs_ood = [1 - p for p in args["class_probs"]]

    loader_tr_out, loader_te_out, loader_ne_out = get_loaders(
        args["data_dir"], args["batch_size"],
        class_probs_ood, args["random_labels"],
        in_evaluation=True)

    if ensemble_type == "ttm":
        model = EnsembleTTM(fnames, loader_tr, k=k, lam=lam, aug=False)
        args["method"] = "ttm"
    elif ensemble_type == "ttm_aug":
        model = EnsembleTTM(fnames, loader_tr, k=k, lam=lam, aug=True)
        args["method"] = "ttm_aug"
    elif ensemble_type == "tterm":
        model = EnsembleTTERM(fnames, k=k)
        args["method"] = "tterm"
    else:
        model = Ensemble(fnames)
    model.eval()

    method_fmt = {
        "erm": "ERM",
        "mixup": "mixup",
        "extra1": "Extra-mixup-v1",
        "extra2": "Extra-mixup-v2",
        "ttm": "TestTime-mixup",
        "ttm_aug": "TestTime-mixup-aug",
        "tterm": "TestTime-ERM"
    }[args["method"]] + (" (ensemble)" if len(fnames) > 1 else "")

    with torch.no_grad():
        acc_tr = accuracy(model, loader_tr)
        acc_te = accuracy(model, loader_te)
        acc_ne = accuracy(model, loader_ne)

        result = {
            "args": args,
            "fnames": fnames,
            "method": method_fmt,
            #
            "stats": {
                "method": method_fmt,
                "k": k,
                "lam": lam,
                #
                "avg acc (tr)": mean(acc_tr.values()),
                "avg acc (te)": mean(acc_te.values()),
                "avg acc (ne)": mean(acc_ne.values()),
                #
                "worst acc (tr)": min(acc_tr.values()),
                "worst acc (te)": min(acc_te.values()),
                "worst acc (ne)": min(acc_ne.values()),
                #
                "adv acc (tr)": evaluate_loader(model, loader_tr, what="adversarial"),
                "adv acc (te)": evaluate_loader(model, loader_te, what="adversarial"),
                "adv acc (ne)": evaluate_loader(model, loader_ne, what="adversarial"),
                #
                "ood acc (tr)": evaluate_ood_detection(model, loader_tr, loader_tr_out),
                "ood acc (te)": evaluate_ood_detection(model, loader_te, loader_te_out),
                "ood acc (ne)": evaluate_ood_detection(model, loader_ne, loader_ne_out),
                #
                "diversity (tr)": evaluate_diversity(model, loader_tr),
                "diversity (te)": evaluate_diversity(model, loader_te),
                "diversity (ne)": evaluate_diversity(model, loader_ne),
                #
                "OOD diversity (tr)": evaluate_diversity(model, loader_tr_out),
                "OOD diversity (te)": evaluate_diversity(model, loader_te_out),
                "OOD diversity (ne)": evaluate_diversity(model, loader_ne_out),
                #
                "data loss (tr)": evaluate_loader(model, loader_tr, what="loss"),
                "data loss (te)": evaluate_loader(model, loader_te, what="loss"),
                "data loss (ne)": evaluate_loader(model, loader_ne, what="loss"),
                #
                "mixup loss (tr)": evaluate_loader(model, loader_tr, alpha=args["alpha"], what="loss"),
                "mixup loss (te)": evaluate_loader(model, loader_te, alpha=args["alpha"], what="loss"),
                "mixup loss (ne)": evaluate_loader(model, loader_ne, alpha=args["alpha"], what="loss"),
                #
                "ECE (tr)": evaluate_ece(model, loader_tr),
                "ECE (te)": evaluate_ece(model, loader_te),
                "ECE (ne)": evaluate_ece(model, loader_ne)
            }
        }

    print(json.dumps(result))
    with open(outfile + ".eval.json", "w") as f:
        f.write(json.dumps(result))

    return result


def get_method(fname):
    with open(fname, "r") as f:
        line = f.readline()
    return json.loads(line)["args"]["method"]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unbalanced CIFAR10 training")
    parser.add_argument("--output_dir", default="results/jobs_100/", type=str)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--num_ensembles", default=10, type=int)
    parser.add_argument("--k_choices",
                        default=[2, 5, 10, 20],
                        type=int)
    parser.add_argument("--lam_choices",
                        default=[.1, .2, .3, .4, .5, .6, .7, .8, .9],
                        nargs="+",
                        type=float)
    parser.add_argument("--epoch", default="best", type=str)
    args = vars(parser.parse_args())

    torch.manual_seed(0)
    random.seed(0)

    raw_fnames = glob.glob(os.path.join(args["output_dir"], "*.train.out"))
    random.shuffle(raw_fnames)

    fnames = {}
    for fname in raw_fnames:
        method = get_method(fname)
        if method not in fnames:
            fnames[method] = []
        fnames[method].append(
            [fname[:-10] + ".train." + args["epoch"] + ".pt"])

    def get_eval_foo(ensemble_type="normal", k=1, lam=0.8):
        def foo(fnames):
            return evaluate_model(
                fnames,
                args["output_dir"],
                ensemble_type=ensemble_type,
                k=k,
                lam=lam)
        return foo

    commands = []
    for method in fnames:
        method_fnames = fnames[method]
        for fname in fnames[method]:
            commands.append(fname)

    run_jobs(get_eval_foo(), commands)

    for k in args["k_choices"]:
        commands_ens = []
        for ens in range(args["num_ensembles"]):
            ens_fnames = []
            for ens_member in random.sample(fnames["erm"], k):
                ens_fnames.append(ens_member[0])
            commands_ens.append(ens_fnames)

        run_jobs(get_eval_foo(k=k), commands_ens)

    commands_ttm = []
    for f in range(args["num_ensembles"]):
        commands_ttm.append(fnames["mixup"][f])

    for k in args["k_choices"]:
        for lam in args["lam_choices"]:
            run_jobs(get_eval_foo("ttm", k, lam), commands_ttm)
            run_jobs(get_eval_foo("ttm_aug", k, lam), commands_ttm)

    commands_tterm = []
    for f in range(args["num_ensembles"]):
        commands_tterm.append(fnames["erm"][f])

    for k in args["k_choices"]:
        run_jobs(get_eval_foo("tterm", k), commands_tterm)
