import sys
import os
from utils import set_seed, get_train_val_test
import numpy as np
import pickle
import time
import json
import torch.nn as nn
import torchvision
import torch

current = os.path.dirname(__file__)
sys.path.append(os.path.join(current, "..", "aggregations"))
from GLAD import GLAD  # noqa
from DS import Dawid_Skene  # noqa
from MV import MV  # noqa
from spam_removal import SpamRm  # noqa
from AUM import WAUM  # noqa
from soft import Soft  # noqa


def run_DS(
    name,
    SEED,
    all_ans,
    nw,
    n_classes,
    spam,
    path_res,
    method,
    force=False,
    save=True,
    cut=[],
):
    file_ds = os.path.exists(
        os.path.join(path_res, f"{name}_DS_seed_{SEED}.csv")
    )
    file_ds_rm = os.path.exists(
        os.path.join(path_res, f"{name}_DSrmSpam_seed_{SEED}.csv")
    )
    file_WAUM = os.path.exists(
        os.path.join(path_res, f"{name}_WAUM_seed_{SEED}.csv")
    )
    if (
        ("DS" in method)
        and ((not file_ds) or (not file_ds_rm) or (not file_WAUM))
    ) or force:
        print("###### Running DS")
        t0 = time.perf_counter()
        ds = Dawid_Skene(
            answers=all_ans,
            n_workers=nw,
            n_classes=n_classes,
            n_spam=spam,
            cut=cut,
        )
        _ = ds.run_em(maxiter=30, epsilon=1e-10)
        pred_ds = ds.get_predictions()
        if save:
            np.savetxt(
                os.path.join(path_res, f"{name}_DS_seed_{SEED}.csv"),
                pred_ds.astype(int),
                delimiter=",",
            )
        print(f"~~ Ended DS in {time.perf_counter() - t0}")
        return ds
    return None


def run_DSrmSpam(
    ds, name, SEED, all_ans, nw, n_classes, spam, path_res, method, force=False
):
    if (
        "DSrmSpam" in method
        and not os.path.exists(
            os.path.join(path_res, f"{name}_DSrmSpam_seed_{SEED}.csv")
        )
    ) or force:
        print("###### Running DS+rm spam")
        t0 = time.perf_counter()
        sp_rm = SpamRm(ds.pi, nw, n_classes)
        sp_rm.spam_score()
        who_spam = sp_rm.get_spammers(k=2)
        dic_ = {}
        for key, value in all_ans.items():
            dic_[str(key)] = {}
            for worker, ans in value.items():
                if worker.startswith("spam"):
                    num = worker.split("spam")[1]
                    worker_last = nw - spam - 1
                    worker = worker_last + int(num)
                if int(worker) not in who_spam:
                    dic_[str(key)][str(worker)] = ans
        ans_wo_spam_orig = dic_
        dic_ = {}
        for task in ans_wo_spam_orig:
            dic_[task] = {}
            for worker, ans in ans_wo_spam_orig[task].items():
                wh = np.where(int(worker) > who_spam)[0]
                worker = int(worker)
                worker -= len(wh)
                dic_[task][str(worker)] = ans
        ans_wo_spam = dic_
        ds = Dawid_Skene(
            answers=ans_wo_spam,
            n_workers=nw - len(who_spam),
            n_classes=n_classes,
        )
        _ = ds.run_em(maxiter=30, epsilon=1e-7)
        pred_ds = ds.get_predictions()
        np.savetxt(
            os.path.join(path_res, f"{name}_DSrmSpam_seed_{SEED}.csv"),
            pred_ds.astype(int),
            delimiter=",",
        )
        print(f"~~ Ended DS+rm spam in {time.perf_counter() - t0}")
        return ds, ans_wo_spam, who_spam
    return None, None, None


def run_MV(path_res, name, SEED, all_ans, method):
    if "MV" in method and not os.path.exists(
        os.path.join(path_res, f"{name}_MV_seed_{SEED}.csv")
    ):
        print("###### Running MV")
        t0 = time.perf_counter()
        mv = MV(
            answers=all_ans,
        )
        mv_ans = mv.get_answers()
        np.savetxt(
            os.path.join(path_res, f"{name}_MV_seed_{SEED}.csv"),
            np.array(mv_ans).astype(int),
            delimiter=",",
        )
        print(f"~~ Ended MV in {time.perf_counter() - t0}")


def run_GLAD(
    path_res, name, SEED, all_ans, method, n_classes, nw, spam, force=False
):
    if (
        "GLAD" in method
        and not os.path.exists(
            os.path.join(path_res, f"{name}_GLAD_seed_{SEED}.csv")
        )
    ) or force:
        # run GLAD model
        print("###### Running GLAD")
        t0 = time.perf_counter()
        glad = GLAD(
            n_classes=n_classes,
            n_workers=nw,
            n_task=len(all_ans),
            answers=all_ans,
            n_iter=10,
            n_spam=spam,
        )
        glad.run()
        probas_GLAD = glad.get_probas()
        np.savetxt(
            os.path.join(path_res, f"{name}_GLAD_seed_{SEED}.csv"),
            probas_GLAD,
            delimiter=",",
        )
        print(f"~~ Ended GLAD in {time.perf_counter() - t0}")


def run_soft(path_res, name, SEED, all_ans, method, n_classes):
    if "soft" in method and not os.path.exists(
        os.path.join(path_res, f"{name}_soft_seed_{SEED}.csv")
    ):
        print("###### Running Soft")
        t0 = time.perf_counter()
        soft = Soft(answers=all_ans, n_classes=n_classes)
        ans_soft = soft.get_answers()
        np.savetxt(
            os.path.join(path_res, f"{name}_soft_seed_{SEED}.csv"),
            ans_soft,
            delimiter=",",
        )
        print(f"~~ Ended Soft in {time.perf_counter() - t0}")


def run_WAUM(
    pi, path_res, name, SEED, all_ans, method, n_classes, train_set, n_epoch=30
):
    if "WAUM" in method and not os.path.exists(
        os.path.join(path_res, f"{name}_WAUM_seed_{SEED}.csv")
    ):
        print("#### Running WAUM")
        t0 = time.perf_counter()
        criterion = nn.CrossEntropyLoss()
        model = torchvision.models.resnet18(
            pretrained=True
        )  # cifar is too small
        model.fc = nn.Linear(512, 10)
        model.conv1 = nn.Conv2d(
            3,
            64,
            kernel_size=3,
            stride=1,
            padding=3,
            bias=False,
        )
        model.maxpool = nn.Identity()  # avoid hard downsampling
        parameters = [p for p in model.parameters() if p.requires_grad]

        optimizer = torch.optim.SGD(
            parameters,
            lr=0.1,
            momentum=1e-4,
            weight_decay=1e-5,
        )

        waum = WAUM(
            pi,
            all_ans,
            train_set,
            n_classes,
            criterion,
            model,
            n_epoch,
            optimizer,
        )
        waum.get_aum(batch_idx=[0, 1, 2], weighted=True)
        waum.get_psi5_waum()
        waum.cut_lowests()
        ans_soft = waum.get_final_labels()
        np.savetxt(
            os.path.join(path_res, f"{name}_WAUM_seed_{SEED}.csv"),
            ans_soft,
            delimiter=",",
        )
        print(f"~~ Ended WAUM in {time.perf_counter() - t0}")
        return waum
    return None


def setup(
    names,
    p,
    spam,
    train_transform,
    target_transform,
    trains,
    vals,
    tests,
    SEED,
    chosen_nw=None,
):
    if chosen_nw is None:
        name = f"cifar10h_p_{p}_spam_{spam}_all_workers"
    else:
        name = f"cifar10h_p_{p}_spam_{spam}_nw_{chosen_nw}"
    names.append(name)
    path_res = os.path.join(current, "data", f"{name}")
    if not os.path.exists(path_res):
        os.makedirs(path_res)
    train, val, test = get_train_val_test(
        os.path.join(current, "data"),
        train_transform,
        target_transform,
    )
    trains.append(train)
    vals.append(val)
    tests.append(test)
    for class_, class_name in zip(
        [train, val, test], ["train", "val", "test"]
    ):
        with open(
            os.path.join(
                current,
                "data",
                f"{name}",
                f"{name}_{class_name}_seed_{SEED}.pkl",
            ),
            "wb",
        ) as inp:
            pickle.dump(class_, inp, pickle.HIGHEST_PROTOCOL)
    true_targets = train.c10h_c10_targets
    np.savetxt(
        os.path.join(path_res, f"targets_train_{name}_{SEED}.csv"),
        np.array(true_targets).astype(int),
        delimiter=",",
    )
    with open(os.path.join(current, "data", f"{name}.json"), "r") as f:
        all_ans = json.load(f)
    train.all_ans = all_ans
    all_workers = []
    for id_, task_ in all_ans.items():
        for worker, ans in task_.items():
            all_workers.append(worker)
    nw = len(np.unique(all_workers))
    print("nw=", nw)
    return name, all_ans, nw, path_res, train


def get_aggregated_votes(
    SEED=42, train_transform=None, method=["MV"], target_transform=None
):
    set_seed(SEED)
    n_classes = 10
    names = []
    trains, vals, tests = [], [], []
    for p in [0]:
        for spam in [0, 20, 60]:
            name, all_ans, nw, path_res, train = setup(
                names,
                p,
                spam,
                train_transform,
                target_transform,
                trains,
                vals,
                tests,
                SEED,
            )

            # run DS model
            ds = run_DS(
                name, SEED, all_ans, nw, n_classes, spam, path_res, method
            )
            # run DS + rm spam
            run_DSrmSpam(
                ds, name, SEED, all_ans, nw, n_classes, spam, path_res, method
            )
            # run MV model
            run_MV(path_res, name, SEED, all_ans, method)
            # run soft model
            run_soft(path_res, name, SEED, all_ans, method, n_classes)
            # run GLAD
            run_GLAD(
                path_res, name, SEED, all_ans, method, n_classes, nw, spam
            )
            run_WAUM(
                ds.pi,
                path_res,
                name,
                SEED,
                all_ans,
                method,
                n_classes,
                train,
            )

        # with nw=10
        spam = 60
        set_seed(SEED)
        name, all_ans, nw, path_res, train = setup(
            names,
            p,
            spam,
            train_transform,
            target_transform,
            trains,
            vals,
            tests,
            SEED,
            chosen_nw=10,
        )
        # run DS model
        ds = run_DS(name, SEED, all_ans, nw, n_classes, spam, path_res, method)
        # run DS + rm spam
        run_DSrmSpam(
            ds, name, SEED, all_ans, nw, n_classes, spam, path_res, method
        )
        # run MV model
        run_MV(path_res, name, SEED, all_ans, method)
        # run soft model
        run_soft(path_res, name, SEED, all_ans, method, n_classes)
        # run GLAD
        run_GLAD(path_res, name, SEED, all_ans, method, n_classes, nw, spam)
        # run WAUM
        run_WAUM(
            ds.pi,
            path_res,
            name,
            SEED,
            all_ans,
            method,
            n_classes,
            train,
        )

    return trains, vals, tests, names, SEED
