# %%
import numpy as np
import os
import pandas as pd
import torch.nn as nn
import torchvision
from peerannot.models.MV import MV
from peerannot.models.Soft import Soft
from peerannot.models.DS import Dawid_Skene as DS
from peerannot.models.GLAD import GLAD
from peerannot.models.WAUM import WAUM
import torch.utils.data as data
from torchvision.datasets import ImageFolder
import torchvision.datasets as datasets
from torchvision.models import resnet18, vgg16_bn
from torch.utils.data import DataLoader
import torch
from copy import deepcopy
import pprint
import torchvision.transforms as T
from tqdm import tqdm

Loss = torch.nn.CrossEntropyLoss()
DIR = os.path.dirname(os.path.abspath(__file__))
best_folder = f"best"
device = "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 100
preprocess = T.Compose(
    [
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# %%


class CrowdLabelMe(data.Dataset):
    def __init__(self, base_dataset, targets, **kwargs):
        self.base_dataset = base_dataset
        self.targets = targets

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, index):
        xi, yi = self.base_dataset[index]
        return (xi, self.targets[index], yi, index)


class LabelMe(ImageFolder):
    def __init__(self, root, split, **kwargs):
        self.root = root
        self.dataroot = root
        self.split = split
        super().__init__(self.split_folder, **kwargs)
        self.reorder_tasks()

    @property
    def split_folder(self):
        return os.path.join(self.root, self.split)

    def reorder_tasks(self):
        task_files = np.loadtxt(
            os.path.join(self.dataroot, f"filenames_{self.split}.txt"),
            dtype=str,
        )
        task_labels = np.loadtxt(
            os.path.join(self.dataroot, f"labels_{self.split}_names.txt"),
            dtype=str,
        )
        samples = []
        targets = []
        for taskf, targetname in zip(task_files, task_labels):
            target = self.class_to_idx[targetname]
            targets.append(target)
            samples.append(
                (os.path.join(self.root, targetname, taskf), target)
            )
        self.samples = samples
        self.imgs = samples
        self.targets = targets


def prepare_data():
    base_dir = os.path.join(DIR, "data")
    labels = os.path.join(base_dir, "answers.txt")
    labels = np.loadtxt(labels)
    y_train_truth = np.loadtxt(
        os.path.join(base_dir, "labels_train.txt")
    ).astype(int)

    convert_labels = {0: 2, 1: 3, 2: 7, 3: 6, 4: 1, 5: 0, 6: 4, 7: 5}
    votes = {task: {} for task in range(labels.shape[0])}
    for id_, task in enumerate(labels):
        where = np.where(task != -1)[0]
        for worker in where:
            votes[id_][worker] = convert_labels[int(task[worker])]
    y_train_truth = y_train_truth
    n_workers = labels.shape[1]
    train = LabelMe(root=base_dir, split="train", transform=preprocess)
    val = LabelMe(root=base_dir, split="valid", transform=preprocess)
    test = LabelMe(root=base_dir, split="test", transform=preprocess)
    return (
        votes,
        y_train_truth,
        n_workers,
        train,
        val,
        test,
    )


def setup_workspace():

    (all_ans, y_train_truth, n_workers, train, val, test) = prepare_data()
    train.all_ans = all_ans

    test_set_ = CrowdLabelMe(
        base_dataset=test,
        targets=test.targets,
    )
    val_set_ = CrowdLabelMe(
        base_dataset=val,
        targets=val.targets,
    )
    val_loader = DataLoader(
        val_set_,
        batch_size=64,
        shuffle=False,
        pin_memory=(torch.cuda.is_available()),
    )
    test_loader = DataLoader(
        test_set_,
        batch_size=64,
        shuffle=False,
        pin_memory=(torch.cuda.is_available()),
    )
    return (
        train,
        val,
        test,
        test_set_,
        val_set_,
        val_loader,
        test_loader,
        all_ans,
    )


# %%


def get_correct(pred, target):
    if len(target.shape) > 1:
        tmp = torch.argmax(target, dim=1)
    else:
        tmp = target
    return (torch.eq(pred.squeeze(), tmp).float().sum()).item()


def expected_calibration_error(y_pred, y_true, num_bins=15):
    pred_y = np.argmax(y_pred, axis=-1)
    correct = pred_y == y_true
    prob_y = np.max(y_pred, axis=-1)

    b = np.linspace(start=0, stop=1.0, num=num_bins)
    bins = np.digitize(prob_y, bins=b, right=True)

    o = 0
    for b in range(num_bins):
        mask = bins == b
        if np.any(mask):
            o += np.abs(np.sum(correct[mask] - prob_y[mask]))

    return o / y_pred.shape[0]


def get_labels(seed, train, alpha, all_ans):
    # mv = MV(all_ans, n_classes=8)
    # y_mv = mv.get_answers()
    # train_set_mv = CrowdLabelMe(base_dataset=train, targets=y_mv)
    # acc_mv = np.mean(y_mv == train_set_mv.base_dataset.targets)
    # print("MV", acc_mv)

    # soft = Soft(answers=all_ans, n_classes=8)
    # y_soft = soft.get_probas()
    # train_set_soft = CrowdLabelMe(base_dataset=train, targets=y_soft)
    # acc_soft = np.mean(
    #     np.argmax(y_soft, axis=1) == train_set_soft.base_dataset.targets
    # )
    # print(
    #     "Soft",
    #     acc_soft,
    # )
    # ds = DS(answers=all_ans, n_classes=8)
    # ds.run_em(epsilon=1e-6, maxiter=50)
    # y_ds = ds.get_probas()
    # train_set_ds = CrowdLabelMe(
    #     base_dataset=train,
    #     targets=y_ds,
    # )
    # acc_ds = np.mean(
    #     np.argmax(y_ds, axis=1) == train_set_ds.base_dataset.targets
    # )
    # print(
    #     "DS",
    #     acc_ds,
    # )

    train_set_waum = CrowdLabelMe(base_dataset=train, targets=[0] * len(train))
    model, optimizer, scheduler = get_model_sgd()
    waum = WAUM(
        train_set_waum,
        all_ans,
        8,
        model,
        torch.nn.CrossEntropyLoss(),
        optimizer,
        50,
        verbose=True,
    )
    waum.run(alpha=alpha)
    with open(
        os.path.join(DIR, "outputs", f"waum_labelme_{alpha}_vgg.json"), "w"
    ) as waumlabelme, open(
        os.path.join(DIR, "outputs", f"workertrust_labelme_{alpha}_vgg.json"),
        "w",
    ) as workertrust, open(
        os.path.join(DIR, "outputs", f"aumperworker_labelme_{alpha}_vgg.json"),
        "w",
    ) as aumperworker:
        json.dump(
            {int(k): float(v) for k, v in waum.waum.items()}, waumlabelme
        )
        json.dump(
            {
                int(k): {
                    int(t): float(v)
                    for t, v in waum.score_per_worker[k].items()
                }
                for k in waum.score_per_worker.keys()
            },
            workertrust,
        )
        json.dump(
            {
                int(k): {
                    int(t): float(v) for t, v in waum.aum_per_worker[k].items()
                }
                for k in waum.aum_per_worker.keys()
            },
            aumperworker,
        )
    y_waum = waum.get_probas()
    acc_waum = np.mean(
        np.argmax(y_waum, axis=1)
        == np.delete(
            train_set_waum.base_dataset.targets, waum.too_hard, axis=0
        )
    )
    print(
        "WAUM",
        acc_waum,
        "cut",
        len(waum.too_hard),
    )
    train_mdf = deepcopy(train)
    train_mdf.samples = [
        samp
        for i, samp in enumerate(train_mdf.samples)
        if i not in waum.too_hard
    ]
    train_mdf.imgs = train_mdf.samples
    train_set_waum = CrowdLabelMe(base_dataset=train_mdf, targets=y_waum)

    # glad = GLAD(
    #     n_classes=8,
    #     answers=all_ans,
    # )
    # glad.run_em(epsilon=1e-6, maxiter=50)
    # y_glad = glad.get_probas()
    # train_set_glad = CrowdLabelMe(
    #     base_dataset=train,
    #     targets=y_glad,
    # )
    # acc_glad = np.mean(
    #     np.argmax(y_glad, axis=1) == train_set_glad.base_dataset.targets
    # )
    # print(
    #     "GLAD",
    #     acc_glad,
    # )

    return (
        [
            # (train_set_mv, "MV"),
            # (train_set_soft, "Soft"),
            # (train_set_ds, "DS"),
            (train_set_waum, "WAUM"),
            # (train_set_glad, "GLAD"),
        ],
        [acc_waum],
        [waum]
        # [acc_mv, acc_soft, acc_ds, acc_waum, acc_glad],
        # [mv, soft, ds, waum, glad],
    )


# %%
def train_step(
    batch, epoch, epoch_loss, epoch_acc, n_epochs, model, optimizer, device
):
    model.train()
    with torch.enable_grad():
        optimizer.zero_grad()
        (
            input,
            target,
            _,
            sample_ids,
        ) = batch
        input = input.to(device)
        target = target.to(device)
        output = model(input)
        loss = Loss(output, target)
        loss.backward()
        optimizer.step()

        _, pred = output.data.topk(1, dim=1)
        batch_size = target.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, target)
        return epoch_loss, epoch_acc


def val_step(batch, epoch_loss, epoch_acc, epoch_ece, model, device):
    model.eval()
    with torch.no_grad():
        (
            input,
            target,
            ttarget,
            sample_ids,
        ) = batch
        input = input.to(device)
        ttarget = ttarget.to(device)
        output = model(input)
        loss = Loss(output, ttarget)
        _, pred = output.data.topk(1, dim=1)
        batch_size = ttarget.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, ttarget)
        epoch_ece += (
            expected_calibration_error(
                output.softmax(axis=1).data.cpu().numpy(),
                ttarget.cpu().numpy(),
            )
            * batch_size
        )
        return epoch_loss, epoch_acc, epoch_ece


def run_test(test_loader, test_set, name="MV"):
    # model = resnet18(num_classes=8)
    model = vgg16_bn(pretrained=True)
    model.classifier[6] = nn.Linear(4096, 8)
    model.load_state_dict(
        torch.load(os.path.join(DIR, best_folder, f"{name}_vgg.pt"))
    )
    model = model.to(device)
    model.eval()
    test_metrics = {"accuracy": [], "loss": [], "ECE": []}
    epoch_loss, epoch_acc = 0, 0
    epoch_ece = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, total=len(test_loader)):
            (
                input,
                target,
                ttarget,
                sample_ids,
            ) = batch
            input = input.to(device)
            ttarget = ttarget.to(device)
            output = model(input)
            loss = Loss(output, ttarget)
            _, pred = output.data.topk(1, dim=1)
            batch_size = ttarget.size(0)
            epoch_loss += batch_size * loss.item()
            epoch_acc += get_correct(pred, ttarget)
            epoch_ece += (
                expected_calibration_error(
                    output.softmax(axis=1).data.cpu().numpy(),
                    ttarget.cpu().numpy(),
                )
                * batch_size
            )
        test_metrics["loss"].append(epoch_loss / len(test_set))
        test_metrics["accuracy"].append(epoch_acc * 100 / len(test_set))
        test_metrics["ECE"].append(epoch_ece / len(test_set))
    return test_metrics


def run_train(
    model,
    optimizer,
    scheduler,
    n_epochs,
    train_dl,
    val_dl,
    train_set,
    val_set,
    name="MV",
):
    train_metrics = {"accuracy": [], "loss": []}
    val_metrics = {"accuracy": [], "loss": [], "ECE": []}
    best_error = np.inf
    for epoch in tqdm(range(n_epochs), total=n_epochs):
        epoch_loss = 0
        epoch_acc = 0
        for batch_step, batch in tqdm(
            enumerate(train_dl), leave=False, total=len(train_dl)
        ):
            epoch_loss, epoch_acc = train_step(
                batch,
                epoch,
                epoch_loss,
                epoch_acc,
                n_epochs,
                model,
                optimizer,
                device,
            )
        train_metrics["accuracy"].append(epoch_acc * 100 / len(train_set))
        train_metrics["loss"].append(epoch_loss / len(train_set))
        scheduler.step()
        epoch_loss, epoch_acc, epoch_ece = 0.0, 0.0, 0.0
        for batch_step, batch in enumerate(val_dl):
            epoch_loss, epoch_acc, epoch_ece = val_step(
                batch, epoch_loss, epoch_acc, epoch_ece, model, device
            )
        val_metrics["accuracy"].append(epoch_acc * 100 / len(val_set))
        val_metrics["loss"].append(epoch_loss / len(val_set))
        val_metrics["ECE"].append(epoch_ece / len(val_set))
        if val_metrics["loss"][-1] < best_error:
            best_error = val_metrics["loss"][-1]
            os.makedirs(os.path.join(DIR, best_folder, name), exist_ok=True)
            torch.save(
                model.state_dict(),
                os.path.join(DIR, best_folder, f"{name}_vgg.pt"),
            )
    return train_metrics, val_metrics


def get_model_sgd():
    # model = resnet18(num_classes=8)
    model = vgg16_bn(pretrained=True)
    model.classifier[6] = nn.Linear(4096, 8)
    model = model.to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=5e-4,
    )
    milestones = [50, 75]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.1
    )
    return model, optimizer, scheduler


# %%
def run_model(
    seed, train_w_names, val_loader, val_set_, test_loader, test_set_
):
    train, val, test = [], [], []
    for train_set, name in train_w_names:
        model, optimizer, scheduler = get_model_sgd()
        train_loader = DataLoader(
            train_set,
            batch_size=64,
            shuffle=True,
            pin_memory=(torch.cuda.is_available()),
            num_workers=0,
        )
        metrics_train, metrics_val = run_train(
            model,
            optimizer,
            scheduler,
            n_epochs,
            train_loader,
            val_loader,
            train_set,
            val_set_,
            name=name,
        )
        metrics_test = run_test(test_loader, test_set_, name=name)
        train.append(metrics_train)
        val.append(metrics_val)
        test.append(metrics_test)
    return train, val, test


# %%


def run_all(SEEDS, alpha=0.1):
    methods = ["WAUM"]
    monitor = {
        key: {"accuracy": [], "loss": [], "ECE": [], "train_recovery": []}
        for key in methods
    }
    for seed in SEEDS:
        (
            train,
            val,
            test,
            test_set_,
            val_set_,
            val_loader,
            test_loader,
            votes,
        ) = setup_workspace()

        (
            train_w_names,
            accuracies,
            [
                # mv,
                # soft,
                # ds,
                waum,
                # glad,
            ],
        ) = get_labels(seed, train, alpha, votes)
        train, _, test = run_model(
            seed, train_w_names, val_loader, val_set_, test_loader, test_set_
        )

        # monitor["MV"]["train_recovery"].append(np.mean(accuracies[0]))
        # monitor["MV"]["accuracy"].append(test[0]["accuracy"][0])
        # monitor["MV"]["loss"].append(test[0]["loss"][0])
        # monitor["MV"]["ECE"].append(test[0]["ECE"][0])

        # monitor["soft"]["train_recovery"].append(np.mean(accuracies[1]))
        # monitor["soft"]["accuracy"].append(test[1]["accuracy"][0])
        # monitor["soft"]["loss"].append(test[1]["loss"][0])
        # monitor["soft"]["ECE"].append(test[1]["ECE"][0])

        # monitor["DS"]["train_recovery"].append(np.mean(accuracies[2]))
        # monitor["DS"]["accuracy"].append(test[2]["accuracy"][0])
        # monitor["DS"]["loss"].append(test[2]["loss"][0])
        # monitor["DS"]["ECE"].append(test[2]["ECE"][0])

        monitor["WAUM"]["train_recovery"].append(np.mean(accuracies[0]))
        monitor["WAUM"]["accuracy"].append(test[0]["accuracy"][0])
        monitor["WAUM"]["loss"].append(test[0]["loss"][0])
        monitor["WAUM"]["ECE"].append(test[0]["ECE"][0])

        # monitor["GLAD"]["train_recovery"].append(np.mean(accuracies[4]))
        # monitor["GLAD"]["accuracy"].append(test[4]["accuracy"][0])
        # monitor["GLAD"]["loss"].append(test[4]["loss"][0])
        # monitor["GLAD"]["ECE"].append(test[4]["ECE"][0])
        print(monitor)
    return monitor


# %%
import json

if __name__ == "__main__":
    save = True
    alphas = [0.01, 0.05]
    for alpha in alphas:
        all_res = run_all([i for i in range(2)], alpha)
        pprint.pprint(all_res)
        if save:
            with open(f"labelme_results_{alpha}_vgg.json", "w") as outfile:
                json.dump(all_res, outfile, indent=2)

# %%
