# %%
from peerannot.models.MV import MV
from peerannot.models.Soft import Soft
from peerannot.models.DS import Dawid_Skene as DS
from peerannot.models.GLAD import GLAD
from peerannot.models.WAUM import WAUM
import numpy as np
import torch
from torchvision.transforms import ToTensor
from sklearn.model_selection import train_test_split
from utils import (
    get_train_val_test,
    set_seed,
    DatasetWithIndex,
    expected_calibration_error,
)
import os
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch
import json
from torchvision.models import resnet18
from tqdm import tqdm
import pprint
from cifar10h_data import CrowdCIFAR
from copy import deepcopy

# %%
name = "cifar10h_p_0_spam_0_nw_10"
current = os.path.dirname(os.path.abspath(__file__))
path_data = os.path.join(current, "data")
device = "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 150
Loss = torch.nn.CrossEntropyLoss()
with open(os.path.join(current, "data", f"{name}.json"), "r") as f:
    all_ans = json.load(f)
all_ans = {int(key): val for key, val in all_ans.items() if int(key) < 9500}
all_ans = dict(sorted(all_ans.items()))
size = 32
best_folder = f"best_{name}_{size}"
train_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((size, size)),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
        # transforms.RandomHorizontalFlip(),
    ]
)
train_transform_noaug = transforms.Compose(
    [
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)
test_transform = transforms.Compose(
    [
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)
path_res = os.path.join(current, "data", f"{name}")

# %%


def setup_workspace(SEED=42):
    set_seed(SEED)

    train, val, test = get_train_val_test(
        os.path.join(current, "data"),
        train_transform,
        None,
    )
    train.all_ans = all_ans

    test_set_ = CrowdCIFAR(
        set=test,
        targets=test.c10h_c10_targets,
        transform=test_transform,
    )
    val_set_ = CrowdCIFAR(
        set=val,
        targets=val.c10h_c10_targets,
        transform=test_transform,
    )
    val_loader = DataLoader(
        val_set_,
        batch_size=64,
        shuffle=False,
        pin_memory=(torch.cuda.is_available()),
    )
    test_loader = DataLoader(
        test_set_,
        batch_size=64,
        shuffle=False,
        pin_memory=(torch.cuda.is_available()),
    )
    return (
        train,
        val,
        test,
        test_set_,
        val_set_,
        val_loader,
        test_loader,
    )


def get_model_sgd():
    model = resnet18(num_classes=10)  # cifar is too small
    # model = AlexNet(num_classes=10)
    model.conv1 = nn.Conv2d(
        3,
        64,
        kernel_size=3,
        stride=1,
        padding=3,
        bias=False,
    )
    model.maxpool = nn.Identity()  # avoid hard downsampling
    model = model.to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=5e-4,
    )
    milestones = [50, 100]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.1
    )
    return model, optimizer, scheduler


def get_correct(pred, target):
    if len(target.shape) > 1:
        tmp = torch.argmax(target, dim=1)
    else:
        tmp = target
    return (torch.eq(pred.squeeze(), tmp).float().sum()).item()


def train_step(
    batch, epoch, epoch_loss, epoch_acc, n_epochs, model, optimizer, device
):
    model.train()
    with torch.enable_grad():
        optimizer.zero_grad()
        (
            input,
            target,
            _,
            sample_ids,
        ) = batch
        input = input.to(device)
        target = target.to(device)
        output = model(input)
        loss = Loss(output, target)
        loss.backward()
        optimizer.step()

        _, pred = output.data.topk(1, dim=1)
        batch_size = target.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, target)
        return epoch_loss, epoch_acc


def val_step(batch, epoch_loss, epoch_acc, epoch_ece, model, device):
    model.eval()
    with torch.no_grad():
        (
            input,
            target,
            ttarget,
            sample_ids,
        ) = batch
        input = input.to(device)
        ttarget = ttarget.to(device)
        output = model(input)
        loss = Loss(output, ttarget)
        _, pred = output.data.topk(1, dim=1)
        batch_size = ttarget.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, ttarget)
        epoch_ece += (
            expected_calibration_error(
                output.softmax(axis=1).data.cpu().numpy(),
                ttarget.cpu().numpy(),
            )
            * batch_size
        )
        return epoch_loss, epoch_acc, epoch_ece


def run_train(
    model,
    optimizer,
    scheduler,
    n_epochs,
    train_dl,
    val_dl,
    train_set,
    val_set,
    name="MV",
):
    train_metrics = {"accuracy": [], "loss": []}
    val_metrics = {"accuracy": [], "loss": [], "ECE": []}
    best_error = np.inf
    for epoch in tqdm(range(n_epochs), total=n_epochs):
        epoch_loss = 0
        epoch_acc = 0
        for batch_step, batch in tqdm(
            enumerate(train_dl), leave=False, total=len(train_dl)
        ):
            epoch_loss, epoch_acc = train_step(
                batch,
                epoch,
                epoch_loss,
                epoch_acc,
                n_epochs,
                model,
                optimizer,
                device,
            )
        train_metrics["accuracy"].append(epoch_acc * 100 / len(train_set))
        train_metrics["loss"].append(epoch_loss / len(train_set))
        scheduler.step()
        epoch_loss, epoch_acc, epoch_ece = 0.0, 0.0, 0.0
        for batch_step, batch in enumerate(val_dl):
            epoch_loss, epoch_acc, epoch_ece = val_step(
                batch, epoch_loss, epoch_acc, epoch_ece, model, device
            )
        val_metrics["accuracy"].append(epoch_acc * 100 / len(val_set))
        val_metrics["loss"].append(epoch_loss / len(val_set))
        val_metrics["ECE"].append(epoch_ece / len(val_set))
        if val_metrics["loss"][-1] < best_error:
            best_error = val_metrics["loss"][-1]
            os.makedirs(
                os.path.join(current, best_folder, name), exist_ok=True
            )
            torch.save(
                model.state_dict(),
                os.path.join(current, best_folder, f"{name}.pt"),
            )
    return train_metrics, val_metrics


def run_test(test_loader, test_set, name="MV"):
    model = resnet18(num_classes=10)
    # model = AlexNet(num_classes=10)  # cifar is too small
    model.conv1 = nn.Conv2d(
        3,
        64,
        kernel_size=3,
        stride=1,
        padding=3,
        bias=False,
    )
    model.maxpool = nn.Identity()  # avoid hard downsampling
    model.load_state_dict(
        torch.load(os.path.join(current, best_folder, f"{name}.pt"))
    )
    model = model.to(device)
    model.eval()
    test_metrics = {"accuracy": [], "loss": [], "ECE": []}
    epoch_loss, epoch_acc = 0, 0
    epoch_ece = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, total=len(test_loader)):
            (
                input,
                target,
                ttarget,
                sample_ids,
            ) = batch
            input = input.to(device)
            ttarget = ttarget.to(device)
            output = model(input)
            loss = Loss(output, ttarget)
            _, pred = output.data.topk(1, dim=1)
            batch_size = ttarget.size(0)
            epoch_loss += batch_size * loss.item()
            epoch_acc += get_correct(pred, ttarget)
            epoch_ece += (
                expected_calibration_error(
                    output.softmax(axis=1).data.cpu().numpy(),
                    ttarget.cpu().numpy(),
                )
                * batch_size
            )
        test_metrics["loss"].append(epoch_loss / len(test_set))
        test_metrics["accuracy"].append(epoch_acc * 100 / len(test_set))
        test_metrics["ECE"].append(epoch_ece / len(test_set))
    return test_metrics


# %%


def get_labels(seed, train, alpha):
    mv = MV(all_ans)
    y_mv = mv.get_answers()
    train_set_mv = CrowdCIFAR(
        set=train, targets=y_mv, transform=train_transform
    )
    acc_mv = np.mean(y_mv == train_set_mv.set.c10h_c10_targets)
    print("MV", acc_mv)

    soft = Soft(answers=all_ans, n_classes=10)
    y_soft = soft.get_probas()
    train_set_soft = CrowdCIFAR(
        set=train, targets=y_soft, transform=train_transform
    )
    acc_soft = np.mean(
        np.argmax(y_soft, axis=1) == train_set_soft.set.c10h_c10_targets
    )
    print(
        "Soft",
        acc_soft,
    )
    ds = DS(answers=all_ans, n_classes=10)
    ds.run_em(epsilon=1e-6, maxiter=50)
    y_ds = ds.get_probas()
    train_set_ds = CrowdCIFAR(
        set=train, targets=y_ds, transform=train_transform
    )
    acc_ds = np.mean(
        np.argmax(y_ds, axis=1) == train_set_ds.set.c10h_c10_targets
    )
    print(
        "DS",
        acc_ds,
    )

    train_set_waum = CrowdCIFAR(
        set=train, targets=[0] * len(train), transform=train_transform_noaug
    )
    model, optimizer, scheduler = get_model_sgd()
    waum = WAUM(
        train_set_waum,
        all_ans,
        10,
        model,
        torch.nn.CrossEntropyLoss(),
        optimizer,
        50,
    )
    waum.run(alpha=alpha)
    y_waum = waum.get_probas()
    acc_waum = np.mean(
        np.argmax(y_waum, axis=1)
        == np.delete(
            train_set_waum.set.c10h_c10_targets, waum.too_hard, axis=0
        )
    )
    print(
        "WAUM",
        acc_waum,
        "cut",
        len(waum.too_hard),
    )
    train_mdf = deepcopy(train)
    train_mdf.c10h_data = np.delete(train_mdf.c10h_data, waum.too_hard, axis=0)
    train_set_waum = CrowdCIFAR(
        set=train_mdf, targets=y_waum, transform=train_transform
    )

    glad = GLAD(
        n_classes=10,
        answers=all_ans,
    )
    glad.run_em(epsilon=1e-6, maxiter=50)
    y_glad = glad.get_probas()
    train_set_glad = CrowdCIFAR(
        set=train, targets=y_glad, transform=train_transform
    )
    acc_glad = np.mean(
        np.argmax(y_glad, axis=1) == train_set_glad.set.c10h_c10_targets
    )
    print(
        "GLAD",
        acc_glad,
    )

    return (
        [
            (train_set_mv, "MV"),
            (train_set_soft, "Soft"),
            (train_set_ds, "DS"),
            (train_set_waum, "WAUM"),
            (train_set_glad, "GLAD"),
        ],
        [acc_mv, acc_soft, acc_ds, acc_waum, acc_glad],
        [mv, soft, ds, waum, glad],
    )


def run_model(
    seed, train_w_names, val_loader, val_set_, test_loader, test_set_
):
    train, val, test = [], [], []
    for train_set, name in train_w_names:
        model, optimizer, scheduler = get_model_sgd()
        train_loader = DataLoader(
            train_set,
            batch_size=64,
            shuffle=True,
            pin_memory=(torch.cuda.is_available()),
            num_workers=0,
        )
        metrics_train, metrics_val = run_train(
            model,
            optimizer,
            scheduler,
            n_epochs,
            train_loader,
            val_loader,
            train_set,
            val_set_,
            name=name,
        )
        metrics_test = run_test(test_loader, test_set_, name=name)
        train.append(metrics_train)
        val.append(metrics_val)
        test.append(metrics_test)
    return train, val, test


# %%


def run_all(SEEDS, alpha=0.01):
    methods = ["MV", "soft", "DS", "WAUM", "GLAD"]
    monitor = {
        key: {"accuracy": [], "loss": [], "ECE": [], "train_recovery": []}
        for key in methods
    }
    for seed in SEEDS:
        (
            train,
            val,
            test,
            test_set_,
            val_set_,
            val_loader,
            test_loader,
        ) = setup_workspace(seed)

        (
            train_w_names,
            accuracies,
            [
                mv,
                soft,
                ds,
                waum,
                glad,
            ],
        ) = get_labels(seed, train, alpha)
        train, _, test = run_model(
            seed, train_w_names, val_loader, val_set_, test_loader, test_set_
        )

        monitor["MV"]["train_recovery"].append(np.mean(accuracies[0]))
        monitor["MV"]["accuracy"].append(test[0]["accuracy"][0])
        monitor["MV"]["loss"].append(test[0]["loss"][0])
        monitor["MV"]["ECE"].append(test[0]["ECE"][0])
        monitor["soft"]["train_recovery"].append(np.mean(accuracies[1]))
        monitor["soft"]["accuracy"].append(test[1]["accuracy"][0])
        monitor["soft"]["loss"].append(test[1]["loss"][0])
        monitor["soft"]["ECE"].append(test[1]["ECE"][0])
        monitor["DS"]["train_recovery"].append(np.mean(accuracies[2]))
        monitor["DS"]["accuracy"].append(test[2]["accuracy"][0])
        monitor["DS"]["loss"].append(test[2]["loss"][0])
        monitor["DS"]["ECE"].append(test[2]["ECE"][0])
        monitor["WAUM"]["train_recovery"].append(np.mean(accuracies[3]))
        monitor["WAUM"]["accuracy"].append(test[3]["accuracy"][0])
        monitor["WAUM"]["loss"].append(test[3]["loss"][0])
        monitor["WAUM"]["ECE"].append(test[3]["ECE"][0])
        monitor["GLAD"]["train_recovery"].append(np.mean(accuracies[4]))
        monitor["GLAD"]["accuracy"].append(test[4]["accuracy"][0])
        monitor["GLAD"]["loss"].append(test[4]["loss"][0])
        monitor["GLAD"]["ECE"].append(test[4]["ECE"][0])
        print(monitor)
    return monitor


# %%
if __name__ == "__main__":
    from script_chose_workers import keep_n_workers, save_res
    save = True
    for n in [2, 5, 10, 25, 50]:
        name = f"cifar10h_p_0_spam_0_nw_{n}"

        new, _, n = keep_n_workers(n)
        save_res(new, n)
        with open(os.path.join(current, "data", f"{name}.json"), "r") as f:
            all_ans = json.load(f)
        all_ans = {
            int(key): val for key, val in all_ans.items() if int(key) < 9500
        }
        all_ans = dict(sorted(all_ans.items()))
        best_folder = f"best_{name}_{size}"
        path_res = os.path.join(current, "data", f"{name}")
        all_res = run_all([i for i in range(5)])
        pprint.pprint(all_res)
        if save:
            with open(
                f"spam_0_peerannot_nw_10_size_32_n_{n}.json", "w"
            ) as outfile:
                json.dump(all_res, outfile, indent=2)

# %%
