from aggregate_votes import run_DSrmSpam, run_DS, run_soft, run_MV, run_GLAD
from utils import (
    get_train_val_test,
    set_seed,
    DatasetWithIndex,
    expected_calibration_error,
)
import os
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch
import json
from torchvision.models import resnet18
from alexnet import AlexNet
import numpy as np
from tqdm import tqdm
import pprint
from cifar10h_data import CrowdCIFAR
from copy import deepcopy
import sys

current = os.path.dirname(os.path.abspath(__file__))
path_data = os.path.join(current, "data")
sys.path.append(os.path.join(current, "..", "aggregations"))
import AUM  # noqa
import DS  # noqa

device = "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 150
name = "cifar10h_p_0_spam_60_nw_10"
Loss = torch.nn.CrossEntropyLoss()
with open(os.path.join(current, "data", f"{name}.json"), "r") as f:
    all_ans = json.load(f)
all_ans = {key: val for key, val in all_ans.items() if int(key) < 9500}
train_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
        transforms.RandomHorizontalFlip(),
    ]
)

train_transform_noaug = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)


test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)
path_res = os.path.join(current, "data", f"{name}")


def setup_workspace(SEED=42):
    set_seed(SEED)

    train, val, test = get_train_val_test(
        os.path.join(current, "data"),
        train_transform,
        None,
    )
    train.all_ans = all_ans

    test_set_ = CrowdCIFAR(
        set=test,
        targets=test.c10h_c10_targets,
        transform=test_transform,
    )
    val_set_ = CrowdCIFAR(
        set=val,
        targets=val.c10h_c10_targets,
        transform=test_transform,
    )
    val_loader = DataLoader(
        val_set_,
        batch_size=64,
        shuffle=False,
        pin_memory=(torch.cuda.is_available()),
    )
    test_loader = DataLoader(
        test_set_,
        batch_size=64,
        shuffle=False,
        pin_memory=(torch.cuda.is_available()),
    )
    return (
        train,
        val,
        test,
        test_set_,
        val_set_,
        val_loader,
        test_loader,
    )


def get_model_sgd():
    model = resnet18(num_classes=10)  # cifar is too small
    # model = AlexNet(num_classes=10)
    model.conv1 = nn.Conv2d(
        3,
        64,
        kernel_size=3,
        stride=1,
        padding=3,
        bias=False,
    )
    model.maxpool = nn.Identity()  # avoid hard downsampling
    model = model.to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=5e-4,
    )
    milestones = [50, 75]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.1
    )
    return model, optimizer, scheduler


model, optimizer, scheduler = get_model_sgd()


def get_correct(pred, target):
    if len(target.shape) > 1:
        tmp = torch.argmax(target, dim=1)
    else:
        tmp = target
    return (torch.eq(pred.squeeze(), tmp).float().sum()).item()


def train_step(
    batch, epoch, epoch_loss, epoch_acc, n_epochs, model, optimizer, device
):
    model.train()
    with torch.enable_grad():
        optimizer.zero_grad()
        (
            input,
            target,
            _,
            sample_ids,
        ) = batch
        input = input.to(device)
        target = target.to(device)
        output = model(input)
        loss = Loss(output, target)
        loss.backward()
        optimizer.step()

        _, pred = output.data.topk(1, dim=1)
        batch_size = target.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, target)
        return epoch_loss, epoch_acc


def val_step(batch, epoch_loss, epoch_acc, epoch_ece, model, device):
    model.eval()
    with torch.no_grad():
        (
            input,
            target,
            ttarget,
            sample_ids,
        ) = batch
        input = input.to(device)
        ttarget = ttarget.to(device)
        output = model(input)
        loss = Loss(output, ttarget)
        _, pred = output.data.topk(1, dim=1)
        batch_size = ttarget.size(0)
        epoch_loss += batch_size * loss.item()
        epoch_acc += get_correct(pred, ttarget)
        epoch_ece += (
            expected_calibration_error(
                output.softmax(axis=1).data.cpu().numpy(),
                ttarget.cpu().numpy(),
            )
            * batch_size
        )
        return epoch_loss, epoch_acc, epoch_ece


def run_train(
    model,
    optimizer,
    scheduler,
    n_epochs,
    train_dl,
    val_dl,
    train_set,
    val_set,
    name="MV",
):
    train_metrics = {"accuracy": [], "loss": []}
    val_metrics = {"accuracy": [], "loss": [], "ECE": []}
    best_error = np.inf
    for epoch in tqdm(range(n_epochs), total=n_epochs):
        epoch_loss = 0
        epoch_acc = 0
        for batch_step, batch in tqdm(
            enumerate(train_dl), leave=False, total=len(train_dl)
        ):
            epoch_loss, epoch_acc = train_step(
                batch,
                epoch,
                epoch_loss,
                epoch_acc,
                n_epochs,
                model,
                optimizer,
                device,
            )
        train_metrics["accuracy"].append(epoch_acc * 100 / len(train_set))
        train_metrics["loss"].append(epoch_loss / len(train_set))
        scheduler.step()
        epoch_loss, epoch_acc, epoch_ece = 0.0, 0.0, 0.0
        for batch_step, batch in enumerate(val_dl):
            epoch_loss, epoch_acc, epoch_ece = val_step(
                batch, epoch_loss, epoch_acc, epoch_ece, model, device
            )
        val_metrics["accuracy"].append(epoch_acc * 100 / len(val_set))
        val_metrics["loss"].append(epoch_loss / len(val_set))
        val_metrics["ECE"].append(epoch_ece / len(val_set))
        if val_metrics["loss"][-1] < best_error:
            best_error = val_metrics["loss"][-1]
            os.makedirs(os.path.join(current, "best", name), exist_ok=True)
            torch.save(
                model.state_dict(),
                os.path.join(current, "best", f"{name}.pt"),
            )
    return train_metrics, val_metrics


def run_test(test_loader, test_set, name="MV"):
    model = resnet18(num_classes=10)
    # model = AlexNet(num_classes=10)  # cifar is too small
    model.conv1 = nn.Conv2d(
        3,
        64,
        kernel_size=3,
        stride=1,
        padding=3,
        bias=False,
    )
    model.maxpool = nn.Identity()  # avoid hard downsampling
    model.load_state_dict(
        torch.load(os.path.join(current, "best", f"{name}.pt"))
    )
    model = model.to(device)
    model.eval()
    test_metrics = {"accuracy": [], "loss": [], "ECE": []}
    epoch_loss, epoch_acc = 0, 0
    epoch_ece = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, total=len(test_loader)):
            (
                input,
                target,
                ttarget,
                sample_ids,
            ) = batch
            input = input.to(device)
            ttarget = ttarget.to(device)
            output = model(input)
            loss = Loss(output, ttarget)
            _, pred = output.data.topk(1, dim=1)
            batch_size = ttarget.size(0)
            epoch_loss += batch_size * loss.item()
            epoch_acc += get_correct(pred, ttarget)
            epoch_ece += (
                expected_calibration_error(
                    output.softmax(axis=1).data.cpu().numpy(),
                    ttarget.cpu().numpy(),
                )
                * batch_size
            )
        test_metrics["loss"].append(epoch_loss / len(test_set))
        test_metrics["accuracy"].append(epoch_acc * 100 / len(test_set))
        test_metrics["ECE"].append(epoch_ece / len(test_set))
    return test_metrics


def get_res_mv(seed, train, val_loader, val_set_, test_loader, test_set_):
    method = "MV"
    run_MV(path_res, name, seed, all_ans, method)
    res_mv = os.path.join(
        path_data, f"{name}", f"{name}_{method}_seed_{seed}.csv"
    )
    targets = np.loadtxt(res_mv, delimiter=",").astype(int)
    train_set_mv = CrowdCIFAR(
        set=train, targets=targets, transform=train_transform
    )
    acc_mv = train_set_mv.targets[:9500] == train_set_mv.set.c10h_c10_targets
    train_loader_mv = DataLoader(
        train_set_mv,
        batch_size=64,
        shuffle=True,
        pin_memory=(torch.cuda.is_available()),
        num_workers=0,
    )
    model, optimizer, scheduler = get_model_sgd()
    metrics_train, metrics_val = run_train(
        model,
        optimizer,
        scheduler,
        n_epochs,
        train_loader_mv,
        val_loader,
        train_set_mv,
        val_set_,
    )
    metrics_test = run_test(test_loader, test_set_)
    return acc_mv, metrics_train, metrics_val, metrics_test


def get_res_soft(seed, train, val_loader, val_set_, test_loader, test_set_):
    method = "soft"
    run_soft(path_res, name, seed, all_ans, method, 10)
    res_soft = os.path.join(
        path_data, f"{name}", f"{name}_{method}_seed_{seed}.csv"
    )
    targets = np.loadtxt(res_soft, delimiter=",").astype(float)
    train_set_soft = CrowdCIFAR(
        set=train, targets=targets, transform=train_transform
    )
    acc_soft = (
        np.argmax(train_set_soft.targets[:9500], axis=1)
        == train_set_soft.set.c10h_c10_targets
    )
    model, optimizer, scheduler = get_model_sgd()
    train_loader_soft = DataLoader(
        train_set_soft,
        batch_size=64,
        shuffle=True,
        pin_memory=(torch.cuda.is_available()),
        num_workers=0,
    )
    metrics_train_soft, metrics_val_soft = run_train(
        model,
        optimizer,
        scheduler,
        n_epochs,
        train_loader_soft,
        val_loader,
        train_set_soft,
        val_set_,
        name="soft",
    )
    metrics_test_soft = run_test(test_loader, test_set_, name="soft")
    return acc_soft, metrics_train_soft, metrics_val_soft, metrics_test_soft


def get_res_DS(seed, train, val_loader, val_set_, test_loader, test_set_):
    all_workers = []
    for id_, task_ in all_ans.items():
        for worker, ans in task_.items():
            all_workers.append(worker)
    nw = len(np.unique(all_workers))
    print(f"{nw=}")
    method = "DS"
    ds = run_DS(name, seed, all_ans, nw, 10, 60, path_res, method)
    res_DS = os.path.join(
        path_data, f"{name}", f"{name}_{method}_seed_{seed}.csv"
    )
    targets_ds = np.loadtxt(res_DS, delimiter=",").astype(int)
    train_set_ds = CrowdCIFAR(
        set=train, targets=targets_ds, transform=train_transform
    )
    acc_ds = train_set_ds.targets[:9500] == train_set_ds.set.c10h_c10_targets
    model, optimizer, scheduler = get_model_sgd()
    train_loader_ds = DataLoader(
        train_set_ds,
        batch_size=64,
        shuffle=True,
        pin_memory=(torch.cuda.is_available()),
        num_workers=0,
    )
    metrics_train_ds, metrics_val_ds = run_train(
        model,
        optimizer,
        scheduler,
        n_epochs,
        train_loader_ds,
        val_loader,
        train_set_ds,
        val_set_,
        name="DS",
    )
    metrics_test_ds = run_test(test_loader, test_set_, name="DS")
    return acc_ds, metrics_train_ds, metrics_val_ds, metrics_test_ds, ds, nw


def get_res_glad(
    seed, train, val_loader, val_set_, test_loader, test_set_, nw
):
    method = "GLAD"
    run_GLAD(path_res, name, seed, all_ans, method, 10, nw, spam=60)
    res_glad = os.path.join(
        path_data,
        os.path.basename(path_res),
        f"{name}_{method}_seed_{seed}.csv",
    )
    targets_glad = np.loadtxt(res_glad, delimiter=",").astype(float)
    train_set_glad = CrowdCIFAR(
        set=train, targets=targets_glad, transform=train_transform
    )
    acc_glad = (
        np.argmax(train_set_glad.targets, axis=1)
        == train_set_glad.set.c10h_c10_targets
    )
    model, optimizer, scheduler = get_model_sgd()
    train_loader_glad = DataLoader(
        train_set_glad,
        batch_size=64,
        shuffle=True,
        pin_memory=(device == "cuda"),
        num_workers=0,
    )
    metrics_train_glad = run_train(
        model,
        optimizer,
        scheduler,
        n_epochs,
        train_loader_glad,
        val_loader,
        train_set_glad,
        val_set_,
        name="GLAD",
    )
    metrics_test_glad = run_test(test_loader, test_set_, name="GLAD")
    return acc_glad, metrics_train_glad, metrics_test_glad


def get_res_ds_rm_spam(
    seed, train, val_loader, val_set_, test_loader, test_set_, ds, nw
):
    method = "DSrmSpam"
    ds_rm, ans_wo_spam, who_spam = run_DSrmSpam(
        ds, name, seed, all_ans, nw, 10, 60, path_res, method, force=True
    )
    res_DSrmSpam = os.path.join(
        path_data, f"{name}", f"{name}_{method}_seed_{seed}.csv"
    )
    targets_dsrm = np.loadtxt(res_DSrmSpam, delimiter=",").astype(int)
    train_set_dsrm = CrowdCIFAR(
        set=train, targets=targets_dsrm, transform=train_transform
    )
    acc_dsrm = (
        train_set_dsrm.targets[:9500] == train_set_dsrm.set.c10h_c10_targets
    )
    model, optimizer, scheduler = get_model_sgd()
    train_loader_dsrm = DataLoader(
        train_set_dsrm,
        batch_size=64,
        shuffle=True,
        pin_memory=(torch.cuda.is_available()),
        num_workers=0,
    )
    metrics_train_dsrm, metrics_val_dsrm = run_train(
        model,
        optimizer,
        scheduler,
        n_epochs,
        train_loader_dsrm,
        val_loader,
        train_set_dsrm,
        val_set_,
        name="DSrm",
    )
    metrics_test_dsrm = run_test(test_loader, test_set_, name="DSrm")
    return (
        acc_dsrm,
        metrics_train_dsrm,
        metrics_val_dsrm,
        metrics_test_dsrm,
        ds_rm,
        ans_wo_spam,
        nw,
        who_spam,
    )


def get_res_waum(
    seed,
    train,
    val_loader,
    val_set_,
    test_loader,
    test_set_,
    ds_rm,
    ans_wo_spam,
    nw,
    who_spam,
):
    method = "WAUM"
    train_set_waum = CrowdCIFAR(
        set=train, targets=[None] * len(train), transform=train_transform_noaug
    )
    model, optimizer, scheduler = get_model_sgd()

    waum = AUM.WAUM(
        ds_rm.pi,
        ans_wo_spam,
        DatasetWithIndex(train),
        10,
        Loss,
        model,
        50,
        optimizer,
        nw=nw - len(who_spam),
    )
    waum.get_aum(batch_idx=[0, 1, 2], weighted=True)
    waum.get_psi5_waum()
    cut = waum.cut_lowests(alpha=0.01)
    all_ans_ds_waum = {
        key: val for key, val in ans_wo_spam.items() if int(key) not in cut
    }
    ds_waum = run_DS(
        name + "WAUM",
        seed,
        all_ans_ds_waum,
        nw - len(who_spam),
        10,
        0,
        path_res,
        method,
        force=True,
        save=False,
        cut=cut,
    )
    _ = waum.get_final_labels(ds_waum.pi)
    res = (waum.baseline / waum.baseline.sum(1).reshape(-1, 1))[:9500, :]
    cut_0 = np.where(waum.baseline.sum(1) == 0)[0]
    ok_test = [
        i
        for idx, i in enumerate(train_set_waum.set.c10h_c10_targets)
        if (idx not in cut) and (idx not in cut_0)
    ]
    mask = np.ones_like(res).astype(bool)
    mask[cut] = 0
    mask[cut_0] = 0
    mask = np.where(mask)
    res = res[mask[0], mask[1]].reshape(-1, 10)
    acc_waum = np.argmax(res, 1) == np.array(ok_test).flatten()
    all_cut = np.hstack((cut, cut_0)).tolist()
    train_mdf = deepcopy(train)
    train_mdf.c10h_data = np.delete(train_mdf.c10h_data, all_cut, axis=0)
    train_set_waum = CrowdCIFAR(
        set=train_mdf, targets=res, transform=train_transform
    )
    model, optimizer, scheduler = get_model_sgd()

    train_loader_waum = DataLoader(
        train_set_waum,
        batch_size=64,
        shuffle=True,
        pin_memory=(torch.cuda.is_available()),
        num_workers=0,
    )
    metrics_train_waum, metrics_val_waum = run_train(
        model,
        optimizer,
        scheduler,
        n_epochs,
        train_loader_waum,
        val_loader,
        train_set_waum,
        val_set_,
        name="WAUM",
    )
    metrics_test_waum = run_test(test_loader, test_set_, name="WAUM")
    return (
        acc_waum,
        metrics_train_waum,
        metrics_val_waum,
        metrics_test_waum,
        len(all_cut),
    )


def run_all(SEEDS):
    methods = ["MV", "soft", "DS", "GLAD", "DSrmSpam", "WAUM"]
    monitor = {
        key: {"accuracy": [], "loss": [], "ECE": [], "train_recovery": []}
        for key in methods
    }
    monitor["WAUM"]["cut"] = []
    for seed in SEEDS:
        (
            train,
            val,
            test,
            test_set_,
            val_set_,
            val_loader,
            test_loader,
        ) = setup_workspace(seed)

        # MV
        acc_mv, metrics_train, metrics_val, metrics_test = get_res_mv(
            seed, train, val_loader, val_set_, test_loader, test_set_
        )
        monitor["MV"]["train_recovery"].append(np.mean(acc_mv))
        monitor["MV"]["accuracy"].append(metrics_test["accuracy"][0])
        monitor["MV"]["loss"].append(metrics_test["loss"][0])
        monitor["MV"]["ECE"].append(metrics_test["ECE"][0])

        # soft
        acc_soft, metrics_train, metrics_val, metrics_test_soft = get_res_soft(
            seed, train, val_loader, val_set_, test_loader, test_set_
        )
        monitor["soft"]["train_recovery"].append(np.mean(acc_soft))
        monitor["soft"]["accuracy"].append(metrics_test_soft["accuracy"][0])
        monitor["soft"]["loss"].append(metrics_test_soft["loss"][0])
        monitor["soft"]["ECE"].append(metrics_test_soft["ECE"][0])

        # DS
        acc_ds, _, _, metrics_test_ds, ds, nw = get_res_DS(
            seed, train, val_loader, val_set_, test_loader, test_set_
        )
        monitor["DS"]["train_recovery"].append(np.mean(acc_ds))
        monitor["DS"]["accuracy"].append(metrics_test_ds["accuracy"][0])
        monitor["DS"]["loss"].append(metrics_test_ds["loss"][0])
        monitor["DS"]["ECE"].append(metrics_test_ds["ECE"][0])

        acc_glad, _, metrics_test_glad = get_res_glad(
            seed, train, val_loader, val_set_, test_loader, test_set_, nw
        )
        monitor["GLAD"]["train_recovery"].append(np.mean(acc_glad))
        monitor["GLAD"]["accuracy"].append(metrics_test_glad["accuracy"][0])
        monitor["GLAD"]["loss"].append(metrics_test_glad["loss"][0])
        monitor["GLAD"]["ECE"].append(metrics_test_glad["ECE"][0])

        # DSrmSpam
        (
            acc_dsrm,
            _,
            _,
            metrics_test_dsrm,
            ds_rm,
            ans_wo_spam,
            nw,
            who_spam,
        ) = get_res_ds_rm_spam(
            seed, train, val_loader, val_set_, test_loader, test_set_, ds, nw
        )
        monitor["DSrmSpam"]["train_recovery"].append(np.mean(acc_dsrm))
        monitor["DSrmSpam"]["accuracy"].append(
            metrics_test_dsrm["accuracy"][0]
        )
        monitor["DSrmSpam"]["loss"].append(metrics_test_dsrm["loss"][0])
        monitor["DSrmSpam"]["ECE"].append(metrics_test_dsrm["ECE"][0])

        # WAUM
        acc_waum, _, _, metrics_test_waum, n_cut = get_res_waum(
            seed,
            train,
            val_loader,
            val_set_,
            test_loader,
            test_set_,
            ds_rm,
            ans_wo_spam,
            nw,
            who_spam,
        )
        monitor["WAUM"]["train_recovery"].append(np.mean(acc_waum))
        monitor["WAUM"]["accuracy"].append(metrics_test_waum["accuracy"][0])
        monitor["WAUM"]["loss"].append(metrics_test_waum["loss"][0])
        monitor["WAUM"]["ECE"].append(metrics_test_waum["ECE"][0])
        monitor["WAUM"]["cut"].append(n_cut)
    return monitor


if __name__ == "__main__":
    all_res = run_all([i for i in range(2)])
    pprint.pprint(all_res)
    with open("json_expe_0_spam_60_glad_aug.json", "w") as outfile:
        json.dump(all_res, outfile, indent=2)
