# %%
from peerannot.models.MV import MV
from peerannot.models.Soft import Soft
from peerannot.models.DS import Dawid_Skene as DS
from peerannot.models.GLAD import GLAD
from peerannot.models.WAUM import WAUM
import numpy as np
import torch
from torchvision.transforms import ToTensor
from sklearn.model_selection import train_test_split
from utils import (
    get_train_val_test,
    set_seed,
    DatasetWithIndex,
    expected_calibration_error,
)
import os
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch
import json
from torchvision.models import resnet18
from tqdm import tqdm
import pprint
from cifar10h_data import CrowdCIFAR
from copy import deepcopy
import sys

# %%
name = "cifar10h_p_0_spam_0_all_workers"
current = os.path.dirname(os.path.abspath(__file__))
path_data = os.path.join(current, "data")
device = "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 150
Loss = torch.nn.CrossEntropyLoss()
with open(os.path.join(current, "data", f"{name}.json"), "r") as f:
    all_ans = json.load(f)
all_ans = {int(key): val for key, val in all_ans.items() if int(key) < 9500}
all_ans = dict(sorted(all_ans.items()))
train_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
        # transforms.RandomHorizontalFlip(),
    ]
)
train_transform_noaug = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)
path_res = os.path.join(current, "data", f"{name}")

# %%


def setup_workspace(SEED=42):
    set_seed(SEED)

    train, val, test = get_train_val_test(
        os.path.join(current, "data"),
        train_transform,
        None,
    )
    train.all_ans = all_ans

    test_set_ = CrowdCIFAR(
        set=test,
        targets=test.c10h_c10_targets,
        transform=test_transform,
    )
    val_set_ = CrowdCIFAR(
        set=val,
        targets=val.c10h_c10_targets,
        transform=test_transform,
    )
    val_loader = DataLoader(
        val_set_,
        batch_size=64,
        shuffle=False,
        pin_memory=(torch.cuda.is_available()),
    )
    test_loader = DataLoader(
        test_set_,
        batch_size=64,
        shuffle=False,
        pin_memory=(torch.cuda.is_available()),
    )
    return (
        train,
        val,
        test,
        test_set_,
        val_set_,
        val_loader,
        test_loader,
    )


def get_model_sgd():
    model = resnet18(num_classes=10)  # cifar is too small
    # model = AlexNet(num_classes=10)
    model.conv1 = nn.Conv2d(
        3,
        64,
        kernel_size=3,
        stride=1,
        padding=3,
        bias=False,
    )
    model.maxpool = nn.Identity()  # avoid hard downsampling
    model = model.to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=5e-4,
    )
    milestones = [50, 100]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.1
    )
    return model, optimizer, scheduler


def get_correct(pred, target):
    if len(target.shape) > 1:
        tmp = torch.argmax(target, dim=1)
    else:
        tmp = target
    return (torch.eq(pred.squeeze(), tmp).float().sum()).item()


def train_step(
    batch, epoch, epoch_loss, epoch_acc, n_epochs, model, optimizer, device
):
    model.train()
    with torch.enable_grad():
        optimizer.zero_grad()
        (
            input,
            target,
            _,
            sample_ids,
        ) = batch
        input = input.to(device)
        target = target.to(device)
        output = model(input)
        loss = Loss(output, target)
        loss.backward()
        optimizer.step()

        _, pred = output.data.topk(1, dim=1)
        batch_size = target.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, target)
        return epoch_loss, epoch_acc


def val_step(batch, epoch_loss, epoch_acc, epoch_ece, model, device):
    model.eval()
    with torch.no_grad():
        (
            input,
            target,
            ttarget,
            sample_ids,
        ) = batch
        input = input.to(device)
        ttarget = ttarget.to(device)
        output = model(input)
        loss = Loss(output, ttarget)
        _, pred = output.data.topk(1, dim=1)
        batch_size = ttarget.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, ttarget)
        epoch_ece += (
            expected_calibration_error(
                output.softmax(axis=1).data.cpu().numpy(),
                ttarget.cpu().numpy(),
            )
            * batch_size
        )
        return epoch_loss, epoch_acc, epoch_ece


def run_train(
    model,
    optimizer,
    scheduler,
    n_epochs,
    train_dl,
    val_dl,
    train_set,
    val_set,
    name="MV",
):
    train_metrics = {"accuracy": [], "loss": []}
    val_metrics = {"accuracy": [], "loss": [], "ECE": []}
    best_error = np.inf
    for epoch in tqdm(range(n_epochs), total=n_epochs):
        epoch_loss = 0
        epoch_acc = 0
        for batch_step, batch in tqdm(
            enumerate(train_dl), leave=False, total=len(train_dl)
        ):
            epoch_loss, epoch_acc = train_step(
                batch,
                epoch,
                epoch_loss,
                epoch_acc,
                n_epochs,
                model,
                optimizer,
                device,
            )
        train_metrics["accuracy"].append(epoch_acc * 100 / len(train_set))
        train_metrics["loss"].append(epoch_loss / len(train_set))
        scheduler.step()
        epoch_loss, epoch_acc, epoch_ece = 0.0, 0.0, 0.0
        for batch_step, batch in enumerate(val_dl):
            epoch_loss, epoch_acc, epoch_ece = val_step(
                batch, epoch_loss, epoch_acc, epoch_ece, model, device
            )
        val_metrics["accuracy"].append(epoch_acc * 100 / len(val_set))
        val_metrics["loss"].append(epoch_loss / len(val_set))
        val_metrics["ECE"].append(epoch_ece / len(val_set))
        if val_metrics["loss"][-1] < best_error:
            best_error = val_metrics["loss"][-1]
            os.makedirs(
                os.path.join(current, "best_interact", name), exist_ok=True
            )
            torch.save(
                model.state_dict(),
                os.path.join(current, "best_interact", f"{name}.pt"),
            )
    return train_metrics, val_metrics


def run_test(test_loader, test_set, name="MV"):
    model = resnet18(num_classes=10)
    # model = AlexNet(num_classes=10)  # cifar is too small
    model.conv1 = nn.Conv2d(
        3,
        64,
        kernel_size=3,
        stride=1,
        padding=3,
        bias=False,
    )
    model.maxpool = nn.Identity()  # avoid hard downsampling
    model.load_state_dict(
        torch.load(os.path.join(current, "best_interact", f"{name}.pt"))
    )
    model = model.to(device)
    model.eval()
    test_metrics = {"accuracy": [], "loss": [], "ECE": []}
    epoch_loss, epoch_acc = 0, 0
    epoch_ece = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, total=len(test_loader)):
            (
                input,
                target,
                ttarget,
                sample_ids,
            ) = batch
            input = input.to(device)
            ttarget = ttarget.to(device)
            output = model(input)
            loss = Loss(output, ttarget)
            _, pred = output.data.topk(1, dim=1)
            batch_size = ttarget.size(0)
            epoch_loss += batch_size * loss.item()
            epoch_acc += get_correct(pred, ttarget)
            epoch_ece += (
                expected_calibration_error(
                    output.softmax(axis=1).data.cpu().numpy(),
                    ttarget.cpu().numpy(),
                )
                * batch_size
            )
        test_metrics["loss"].append(epoch_loss / len(test_set))
        test_metrics["accuracy"].append(epoch_acc * 100 / len(test_set))
        test_metrics["ECE"].append(epoch_ece / len(test_set))
    return test_metrics


# %%


def get_labels(seed, train):
    # mv = MV(all_ans)
    # y_mv = mv.get_answers()
    # train_set_mv = CrowdCIFAR(
    #     set=train, targets=y_mv, transform=train_transform
    # )
    # acc_mv = np.mean(y_mv == train_set_mv.set.c10h_c10_targets)
    # print("MV", acc_mv)

    # soft = Soft(answers=all_ans, n_classes=10)
    # y_soft = soft.get_probas()
    # train_set_soft = CrowdCIFAR(
    #     set=train, targets=y_soft, transform=train_transform
    # )
    # acc_soft = np.mean(
    #     np.argmax(y_soft, axis=1) == train_set_soft.set.c10h_c10_targets
    # )
    # print(
    #     "Soft",
    #     acc_soft,
    # )
    # ds = DS(answers=all_ans, n_classes=10)
    # ds.run_em(epsilon=1e-6, maxiter=50)
    # y_ds = ds.get_probas()
    # train_set_ds = CrowdCIFAR(
    #     set=train, targets=y_ds, transform=train_transform
    # )
    # acc_ds = np.mean(
    #     np.argmax(y_ds, axis=1) == train_set_ds.set.c10h_c10_targets
    # )
    # print(
    #     "DS",
    #     acc_ds,
    # )

    train_set_waum = CrowdCIFAR(
        set=train, targets=[0] * len(train), transform=train_transform_noaug
    )
    model, optimizer, scheduler = get_model_sgd()
    waum = WAUM(
        train_set_waum,
        all_ans,
        10,
        model,
        torch.nn.CrossEntropyLoss(),
        optimizer,
        50,
    )
    waum.run(alpha=0.01)
    y_waum = waum.get_probas()
    acc_waum = np.mean(
        np.argmax(y_waum, axis=1)
        == np.delete(
            train_set_waum.set.c10h_c10_targets, waum.too_hard, axis=0
        )
    )
    print(
        "WAUM",
        acc_waum,
        "cut",
        len(waum.too_hard),
    )
    train_mdf = deepcopy(train)
    train_mdf.c10h_data = np.delete(train_mdf.c10h_data, waum.too_hard, axis=0)
    train_set_waum = CrowdCIFAR(
        set=train_mdf, targets=y_waum, transform=train_transform
    )

    # glad = GLAD(
    #     n_classes=10,
    #     answers=all_ans,
    # )
    # glad.run_em(epsilon=1e-6, maxiter=50)
    # y_glad = glad.get_probas()
    # train_set_glad = CrowdCIFAR(
    #     set=train, targets=y_glad, transform=train_transform
    # )
    # acc_glad = np.mean(
    #     np.argmax(y_glad, axis=1) == train_set_glad.set.c10h_c10_targets
    # )
    # print(
    #     "GLAD",
    #     acc_glad,
    # )

    return (
        [
            # (train_set_mv, "MV"),
            # (train_set_soft, "Soft"),
            # (train_set_ds, "DS"),
            (train_set_waum, "WAUM"),
            # (train_set_glad, "GLAD"),
        ],
        [acc_waum],
        [waum]
        # [acc_mv, acc_soft, acc_ds, acc_waum, acc_glad],
        # [mv, soft, ds, waum, glad],
    )


def run_model(
    seed, train_w_names, val_loader, val_set_, test_loader, test_set_
):
    train, val, test = [], [], []
    for train_set, name in train_w_names:
        model, optimizer, scheduler = get_model_sgd()
        train_loader = DataLoader(
            train_set,
            batch_size=64,
            shuffle=True,
            pin_memory=(torch.cuda.is_available()),
            num_workers=0,
        )
        metrics_train, metrics_val = run_train(
            model,
            optimizer,
            scheduler,
            n_epochs,
            train_loader,
            val_loader,
            train_set,
            val_set_,
            name=name,
        )
        metrics_test = run_test(test_loader, test_set_, name=name)
        train.append(metrics_train)
        val.append(metrics_val)
        test.append(metrics_test)
    return train, val, test


# %%


def run_all(SEEDS):
    methods = ["WAUM"]  # ["MV", "Soft", "DS", "WAUM", "GLAD"]
    monitor = {
        key: {"accuracy": [], "loss": [], "ECE": [], "train_recovery": []}
        for key in methods
    }
    for seed in SEEDS:
        (
            train,
            val,
            test,
            test_set_,
            val_set_,
            val_loader,
            test_loader,
        ) = setup_workspace(seed)

        (
            train_w_names,
            accuracies,
            [
                # mv,
                # soft,
                # ds,
                waum,
                # glad,
            ],
        ) = get_labels(seed, train)
        train, _, test = run_model(
            seed, train_w_names, val_loader, val_set_, test_loader, test_set_
        )

        # monitor["MV"]["train_recovery"].append(np.mean(accuracies[0]))
        # monitor["MV"]["accuracy"].append(test[0]["accuracy"][0])
        # monitor["MV"]["loss"].append(test["loss"][0])
        # monitor["MV"]["ECE"].append(test["ECE"][0])
        # monitor["soft"]["train_recovery"].append(np.mean(accuracies[1]))
        # monitor["soft"]["accuracy"].append(test[1]["accuracy"][0])
        # monitor["soft"]["loss"].append(test[1]["loss"][0])
        # monitor["soft"]["ECE"].append(test[1]["ECE"][0])
        # monitor["DS"]["train_recovery"].append(np.mean(accuracies[2]))
        # monitor["DS"]["accuracy"].append(test[2]["accuracy"][0])
        # monitor["DS"]["loss"].append(test[2]["loss"][0])
        # monitor["DS"]["ECE"].append(test[2]["ECE"][0])
        monitor["WAUM"]["train_recovery"].append(np.mean(accuracies[0]))
        monitor["WAUM"]["accuracy"].append(test[0]["accuracy"][0])
        monitor["WAUM"]["loss"].append(test[0]["loss"][0])
        monitor["WAUM"]["ECE"].append(test[0]["ECE"][0])
        # monitor["GLAD"]["train_recovery"].append(np.mean(accuracies[4]))
        # monitor["GLAD"]["accuracy"].append(test[4]["accuracy"][0])
        # monitor["GLAD"]["loss"].append(test[4]["loss"][0])
        # monitor["GLAD"]["ECE"].append(test[4]["ECE"][0])
        print(monitor)
    return monitor


# %%
if __name__ == "__main__":
    all_res = run_all([i for i in range(5)])
    pprint.pprint(all_res)
    save = False
    if save:
        with open("spam_0_peerannot_nw_all.json", "w") as outfile:
            # with open("json_expe_0_spam_60_glad_aug.json", "w") as outfile:
            json.dump(all_res, outfile, indent=2)

# %%
