import sys
import os
from utils import set_seed, get_train_val_test
import numpy as np
import pickle
import time

current = os.path.dirname(__file__)
sys.path.append(os.path.join(current, "..", "aggregations"))
from GLAD import GLAD  # noqa
from DS import Dawid_Skene  # noqa
from MV import MV  # noqa
from spam_removal import SpamRm  # noqa
from soft import Soft  # noqa


def setup(
    nw,
    villain,
    names,
    trains,
    vals,
    tests,
    SEED,
    chosen_nw=None,
):
    name = f"simulated_nw_{nw}_villain_{villain}"
    names.append(name)
    path_res = os.path.join(current, "data", f"{name}")
    if not os.path.exists(path_res):
        os.makedirs(path_res)
    train, val, test = get_train_val_test(SEED, nw=nw, villain=villain)
    trains.append(train)
    vals.append(val)
    tests.append(test)
    for class_, class_name in zip(
        [train, val, test], ["train", "val", "test"]
    ):
        with open(
            os.path.join(
                current,
                "data",
                f"{name}",
                f"{name}_{class_name}_seed_{SEED}.pkl",
            ),
            "wb",
        ) as inp:
            pickle.dump(class_, inp, pickle.HIGHEST_PROTOCOL)
    true_targets = [task[1] for task in train.tasks.values()]
    np.savetxt(
        os.path.join(path_res, f"targets_train_{name}_{SEED}.csv"),
        np.array(true_targets).astype(int),
        delimiter=",",
    )
    answers = train.crowd_data
    all_ans = {}
    for task, answer in answers.items():
        all_ans[str(task)] = {}
        for worker, value in answer.items():
            all_ans[str(task)][str(worker)] = value[0]
    return name, all_ans, path_res, train


def run_DS(name, SEED, all_ans, nw, n_classes, path_res, method):
    # run DS model
    file_ds = os.path.exists(
        os.path.join(path_res, f"{name}_DS_seed_{SEED}.csv")
    )
    file_ds_rm = os.path.exists(
        os.path.join(path_res, f"{name}_DSrmSpam_seed_{SEED}.csv")
    )
    if ("DS" in method) and ((not file_ds) or (not file_ds_rm)):
        print("###### Running DS")
        t0 = time.perf_counter()
        ds = Dawid_Skene(
            answers=all_ans,
            n_workers=nw,
            n_classes=n_classes,
        )
        _ = ds.run_em(maxiter=30, epsilon=1e-10)
        pred_ds = ds.get_predictions()
        np.savetxt(
            os.path.join(path_res, f"{name}_DS_seed_{SEED}.csv"),
            pred_ds.astype(int),
            delimiter=",",
        )
        print(f"~~ Ended DS in {time.perf_counter() - t0}")
        return ds
    return None


def run_DSrmSpam(ds, name, SEED, all_ans, nw, n_classes, path_res, method):
    # run DS + rm spam
    if "DSrmSpam" in method and not os.path.exists(
        os.path.join(path_res, f"{name}_DSrmSpam_seed_{SEED}.csv")
    ):
        print("###### Running DS+rm spam")
        t0 = time.perf_counter()
        sp_rm = SpamRm(ds.pi, nw, n_classes)
        sp_rm.spam_score()
        who_spam = sp_rm.get_spammers()
        dic_ = {}
        for key, value in all_ans.items():
            dic_[str(key)] = {}
            for worker, ans in value.items():
                if int(worker) not in who_spam:
                    dic_[str(key)][str(worker)] = ans
        ans_wo_spam_orig = dic_
        dic_ = {}
        for task in ans_wo_spam_orig:
            dic_[task] = {}
            for worker, ans in ans_wo_spam_orig[task].items():
                wh = np.where(int(worker) > who_spam)[0]
                worker = int(worker)
                worker -= len(wh)
                dic_[task][str(worker)] = ans
        ans_wo_spam = dic_
        ds = Dawid_Skene(
            answers=ans_wo_spam,
            n_workers=nw - len(who_spam),
            n_classes=n_classes,
        )
        _ = ds.run_em(maxiter=30, epsilon=1e-7)
        pred_ds = ds.get_predictions()
        np.savetxt(
            os.path.join(path_res, f"{name}_DSrmSpam_seed_{SEED}.csv"),
            pred_ds.astype(int),
            delimiter=",",
        )
        print(f"~~ Ended DS+rm spam in {time.perf_counter() - t0}")


def run_MV(path_res, name, SEED, all_ans, method):
    if "MV" in method and not os.path.exists(
        os.path.join(path_res, f"{name}_MV_seed_{SEED}.csv")
    ):
        print("###### Running MV")
        t0 = time.perf_counter()
        mv = MV(
            answers=all_ans,
        )
        mv_ans = mv.get_answers()
        np.savetxt(
            os.path.join(path_res, f"{name}_MV_seed_{SEED}.csv"),
            np.array(mv_ans).astype(int),
            delimiter=",",
        )
        print(f"~~ Ended MV in {time.perf_counter() - t0}")


def run_soft(path_res, name, SEED, all_ans, method, n_classes):
    if "soft" in method and not os.path.exists(
        os.path.join(path_res, f"{name}_soft_seed_{SEED}.csv")
    ):
        print("###### Running Soft")
        t0 = time.perf_counter()
        soft = Soft(answers=all_ans, n_classes=n_classes)
        ans_soft = soft.get_answers()
        np.savetxt(
            os.path.join(path_res, f"{name}_soft_seed_{SEED}.csv"),
            ans_soft,
            delimiter=",",
        )
        print(f"~~ Ended Soft in {time.perf_counter() - t0}")


def run_GLAD(path_res, name, SEED, all_ans, method, n_classes, nw):
    if "GLAD" in method and not os.path.exists(
        os.path.join(path_res, f"{name}_GLAD_seed_{SEED}.csv")
    ):
        # run GLAD model
        print("###### Running GLAD")
        t0 = time.perf_counter()
        glad = GLAD(
            n_classes=n_classes,
            n_workers=nw,
            n_task=len(all_ans),
            answers=all_ans,
            n_iter=10,
        )
        glad.run()
        probas_GLAD = glad.get_probas()
        np.savetxt(
            os.path.join(path_res, f"{name}_GLAD_seed_{SEED}.csv"),
            probas_GLAD,
            delimiter=",",
        )
        print(f"~~ Ended GLAD in {time.perf_counter() - t0}")


def get_aggregated_votes(nw=30, SEED=42, method=["MV"]):
    set_seed(SEED)
    n_classes = 8
    names = []
    trains, vals, tests = [], [], []
    for villain in [0.2, 0.9]:
        name, all_ans, path_res, train = setup(
            nw,
            villain,
            names,
            trains,
            vals,
            tests,
            SEED,
        )
        ds = run_DS(name, SEED, all_ans, nw, n_classes, path_res, method)
        # run DS + rm spam
        run_DSrmSpam(ds, name, SEED, all_ans, nw, n_classes, path_res, method)
        # run MV model
        run_MV(path_res, name, SEED, all_ans, method)
        # run soft model
        run_soft(path_res, name, SEED, all_ans, method, n_classes)
        # run GLAD
        run_GLAD(path_res, name, SEED, all_ans, method, n_classes, nw)

    return trains, vals, tests, names, SEED
