# %%
from utils import (
    get_train_val_test,
    set_seed,
    expected_calibration_error,
)
from copy import deepcopy
import os
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch
import json
from torchvision.models import resnet18
import numpy as np
from tqdm import tqdm
from cifar10h_data import CrowdCIFAR
from peerannot.models.WAUM import WAUM

current = os.path.dirname(os.path.abspath(__file__))

name = "cifar10h_p_0_spam_60_nw_10"
path_res = os.path.join(current, "data", f"{name}")

SEED = 42
set_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

train_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)

train, val, test = get_train_val_test(
    os.path.join(current, "data"),
    train_transform,
    None,
)
Loss = torch.nn.CrossEntropyLoss()
with open(os.path.join(current, "data", f"{name}.json"), "r") as f:
    all_ans = json.load(f)
all_ans = {int(key): val for key, val in all_ans.items() if int(key) < 9500}
train.all_ans = all_ans

test_set_ = CrowdCIFAR(
    set=test,
    targets=test.c10h_c10_targets,
    transform=test_transform,
)
val_set_ = CrowdCIFAR(
    set=val,
    targets=val.c10h_c10_targets,
    transform=test_transform,
)
val_loader = DataLoader(
    val_set_,
    batch_size=64,
    shuffle=False,
    pin_memory=(torch.cuda.is_available()),
)
test_loader = DataLoader(
    test_set_,
    batch_size=64,
    shuffle=False,
    pin_memory=(torch.cuda.is_available()),
)
n_epochs = 150


def get_model_sgd():
    model = resnet18(num_classes=10)  # cifar is too small
    # model = AlexNet(num_classes=10)
    model.conv1 = nn.Conv2d(
        3,
        64,
        kernel_size=3,
        stride=1,
        padding=3,
        bias=False,
    )
    model.maxpool = nn.Identity()  # avoid hard downsampling
    model = model.to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=5e-4,
    )
    milestones = [50, 100]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.1
    )
    return model, optimizer, scheduler


def get_correct(pred, target):
    if len(target.shape) > 1:
        tmp = torch.argmax(target, dim=1)
    else:
        tmp = target
    return (torch.eq(pred.squeeze(), tmp).float().sum()).item()


def train_step(
    batch, epoch, epoch_loss, epoch_acc, n_epochs, model, optimizer, device
):
    model.train()
    with torch.enable_grad():
        optimizer.zero_grad()
        (
            input,
            target,
            _,
            sample_ids,
        ) = batch
        input = input.to(device)
        target = target.to(device)
        output = model(input)
        loss = Loss(output, target)
        loss.backward()
        optimizer.step()

        _, pred = output.data.topk(1, dim=1)
        batch_size = target.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, target)
        return epoch_loss, epoch_acc


def val_step(batch, epoch_loss, epoch_acc, epoch_ece, model, device):
    model.eval()
    with torch.no_grad():
        (
            input,
            target,
            ttarget,
            sample_ids,
        ) = batch
        input = input.to(device)
        ttarget = ttarget.to(device)
        output = model(input)
        loss = Loss(output, ttarget)
        _, pred = output.data.topk(1, dim=1)
        batch_size = ttarget.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, ttarget)
        epoch_ece += (
            expected_calibration_error(
                output.softmax(axis=1).data.cpu().numpy(),
                ttarget.cpu().numpy(),
            )
            * batch_size
        )
        return epoch_loss, epoch_acc, epoch_ece


def run_train(
    model,
    optimizer,
    scheduler,
    n_epochs,
    train_dl,
    val_dl,
    train_set,
    val_set,
    name="MV",
):
    train_metrics = {"accuracy": [], "loss": []}
    val_metrics = {"accuracy": [], "loss": [], "ECE": []}
    best_error = np.inf
    for epoch in tqdm(range(n_epochs), total=n_epochs):
        epoch_loss = 0
        epoch_acc = 0
        for batch_step, batch in tqdm(
            enumerate(train_dl), leave=False, total=len(train_dl)
        ):
            epoch_loss, epoch_acc = train_step(
                batch,
                epoch,
                epoch_loss,
                epoch_acc,
                n_epochs,
                model,
                optimizer,
                device,
            )
        train_metrics["accuracy"].append(epoch_acc * 100 / len(train_set))
        train_metrics["loss"].append(epoch_loss / len(train_set))
        scheduler.step()
        epoch_loss, epoch_acc, epoch_ece = 0.0, 0.0, 0.0
        for batch_step, batch in enumerate(val_dl):
            epoch_loss, epoch_acc, epoch_ece = val_step(
                batch, epoch_loss, epoch_acc, epoch_ece, model, device
            )
        val_metrics["accuracy"].append(epoch_acc * 100 / len(val_set))
        val_metrics["loss"].append(epoch_loss / len(val_set))
        val_metrics["ECE"].append(epoch_ece / len(val_set))
        if val_metrics["loss"][-1] < best_error:
            best_error = val_metrics["loss"][-1]
            os.makedirs(os.path.join(current, "best", name), exist_ok=True)
            torch.save(
                model.state_dict(), os.path.join(current, "best", f"{name}.pt")
            )
    return train_metrics, val_metrics


def run_test(test_loader, test_set, name="MV"):
    model = resnet18(num_classes=10)  # cifar is too small
    # model = AlexNet(num_classes=10)  # cifar is too small
    model.conv1 = nn.Conv2d(
        3,
        64,
        kernel_size=3,
        stride=1,
        padding=3,
        bias=False,
    )
    model.maxpool = nn.Identity()  # avoid hard downsampling
    model.load_state_dict(
        torch.load(os.path.join(current, "best", f"{name}.pt"))
    )
    model = model.to(device)
    model.eval()
    test_metrics = {"accuracy": [], "loss": [], "ECE": []}
    epoch_loss, epoch_acc = 0, 0
    epoch_ece = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, total=len(test_loader)):
            (
                input,
                target,
                ttarget,
                sample_ids,
            ) = batch
            input = input.to(device)
            ttarget = ttarget.to(device)
            output = model(input)
            loss = Loss(output, ttarget)
            _, pred = output.data.topk(1, dim=1)
            batch_size = ttarget.size(0)
            epoch_loss += batch_size * loss.item()
            epoch_acc += get_correct(pred, ttarget)
            epoch_ece += (
                expected_calibration_error(
                    output.softmax(axis=1).data.cpu().numpy(),
                    ttarget.cpu().numpy(),
                )
                * batch_size
            )
        test_metrics["loss"].append(epoch_loss / len(test_set))
        test_metrics["accuracy"].append(epoch_acc * 100 / len(test_set))
        test_metrics["ECE"].append(epoch_ece / len(test_set))
    return test_metrics


def multiple_monitorings(waum):
    monitor = {
        "0.5": {},
        "0.75": {},
        "0.8": {},
        "0.9": {},
        "0.99": {},
        "0.995": {},
    }
    train_set_waum = CrowdCIFAR(
        set=train, targets=[0] * len(train), transform=train_transform
    )
    for which, alpha in tqdm(
        enumerate([0.5, 0.75, 0.8, 0.9, 0.99, 0.995]), desc="alpha"
    ):
        waum.run(alpha=1 - alpha)
        y_waum = waum.get_probas()
        acc_waum = np.mean(
            np.argmax(y_waum, axis=1)
            == np.delete(
                train_set_waum.set.c10h_c10_targets, waum.too_hard, axis=0
            )
        )
        train_mdf = deepcopy(train)
        train_mdf.c10h_data = np.delete(
            train_mdf.c10h_data, waum.too_hard, axis=0
        )
        train_set_waum = CrowdCIFAR(
            set=train_mdf, targets=y_waum, transform=train_transform
        )

        # cut = waum.cut_lowests(alpha=1 - alpha)
        # print("We removed", len(cut))
        # all_ans_ds_waum = {
        #     key: val for key, val in ans_wo_spam.items() if int(key) not in cut
        # }
        # ds_waum = run_DS(
        #     name + "WAUM",
        #     SEED,
        #     all_ans_ds_waum,
        #     nw - len(who_spam),
        #     10,
        #     60,
        #     path_res,
        #     method,
        #     force=True,
        #     save=False,
        #     cut=cut,
        # )
        # _ = waum.get_final_labels(ds_waum.pi)
        # res = (waum.baseline / waum.baseline.sum(1).reshape(-1, 1))[:9500, :]
        # cut_0 = np.where(waum.baseline.sum(1) == 0)[0]
        # ok_test = [
        #     i
        #     for idx, i in enumerate(train_set_waum.set.c10h_c10_targets)
        #     if (idx not in cut) and (idx not in cut_0)
        # ]
        # mask = np.ones_like(res).astype(bool)
        # mask[cut] = 0
        # mask[cut_0] = 0
        # mask = np.where(mask)
        # res = res[mask[0], mask[1]].reshape(-1, 10)
        # acc_waum = np.argmax(res, 1) == np.array(ok_test).flatten()
        # print(
        #     f"WAUM: {alpha=} \n -- train targets: \n \t Top-1: {np.mean(acc_waum):.3f}%"
        # )
        # all_cut = np.hstack((cut, cut_0)).tolist()
        # train_mdf = deepcopy(train)
        # train_mdf.c10h_data = np.delete(train_mdf.c10h_data, all_cut, axis=0)
        # train_set_waum = CrowdCIFAR(
        #     set=train_mdf, targets=res, transform=train_transform
        # )
        for seed in range(5):
            set_seed(seed)
            monitor[str(alpha)][str(seed)] = []
            model, optimizer, scheduler = get_model_sgd()

            train_loader_waum = DataLoader(
                train_set_waum,
                batch_size=64,
                shuffle=True,
                pin_memory=(torch.cuda.is_available()),
                num_workers=0,
            )
            metrics_train_waum, metrics_val_waum = run_train(
                model,
                optimizer,
                scheduler,
                n_epochs,
                train_loader_waum,
                val_loader,
                train_set_waum,
                val_set_,
                name="WAUM",
            )
            metrics_test_waum = run_test(test_loader, test_set_, name="WAUM")
            monitor[str(alpha)][str(seed)].append(
                [
                    metrics_test_waum["accuracy"][0],
                    metrics_test_waum["loss"][0],
                    metrics_test_waum["ECE"][0],
                    acc_waum,
                ]
            )
            print(monitor[str(alpha)][str(seed)])
    return monitor


if __name__ == "__main__":
    import pprint

    train_set_waum = CrowdCIFAR(
        set=train, targets=[0] * len(train), transform=train_transform
    )
    model, optimizer, scheduler = get_model_sgd()
    waum = WAUM(
        train_set_waum,
        all_ans,
        10,
        model,
        torch.nn.CrossEntropyLoss(),
        optimizer,
        50,
    )

    # all_workers = []
    # for id_, task_ in all_ans.items():
    #     for worker, ans in task_.items():
    #         all_workers.append(worker)
    # nw = len(np.unique(all_workers))
    # print(f"{nw=}")

    # method = "DS"
    # ds = run_DS(name, SEED, all_ans, nw, 10, 60, path_res, method)

    # method = "DSrmSpam"
    # ds_rm, ans_wo_spam, who_spam = run_DSrmSpam(
    #     ds, name, SEED, all_ans, nw, 10, 60, path_res, method, force=True
    # )

    # method = "DSrmSpam"
    # res_DSrmSpam = os.path.join(
    #     current, "data", f"{name}", f"{name}_{method}_seed_{SEED}.csv"
    # )
    # targets = np.loadtxt(res_DSrmSpam, delimiter=",").astype(int)
    # train_set_waum = CrowdCIFAR(
    #     set=train, targets=targets, transform=train_transform
    # )

    # model, optimizer, scheduler = get_model_sgd()

    # waum = AUM.WAUM(
    #     ds_rm.pi,
    #     ans_wo_spam,
    #     DatasetWithIndex(train),
    #     10,
    #     Loss,
    #     model,
    #     50,
    #     optimizer,
    #     nw=nw - len(who_spam),
    # )
    # waum.get_aum(batch_idx=[0, 1, 2], weighted=True)
    # waum.get_psi5_waum()
    res_monitor = multiple_monitorings(waum)
    pprint.pprint(res_monitor)
    with open("alpha_importance_60_10.json", "w") as outfile:
        json.dump(res_monitor, outfile, indent=2)

# %%
