# %%
from peerannot.models.MV import MV
from peerannot.models.DS import Dawid_Skene as DS
from peerannot.models.GLAD import GLAD
from peerannot.models.WAUM import WAUM
from peerannot.models.DS_mc import DS_mc
from peerannot.models.GLAD_mc import GLAD_mc
from peerannot.models.ZenCrowd import ZC
from peerannot.models.DS_clust import Dawid_Skene_clust as DSWC
import numpy as np
from torchvision.transforms import ToTensor
import torch
import pymc3 as pm
import os
import json
import pandas as pd

# %% import data

datasets = {
    "BlueBirds": 2,
    "AmazonSentimentBook": 2,
    "AmazonSentimentNegative": 2,
    "LonelinessSlrTech": 2,
    "LonelinessSlrIntervention": 2,
    "LonelinessSlrOlder": 2,
    "HITspamCrowdflower": 2,
    "HITspamMTurk": 2,
    "RTE": 2,
    "SentimentPopularity": 2,
    "TemporalOrdering": 2,
    "WebRelevance2010": 3,
    "AdultContent2": 5,
    "AdultContent3": 4,
    "WeatherSentiment": 5,
}

which = "RTE"

current_dir = os.path.dirname(__file__)
data = pd.read_csv(os.path.join(current_dir, "data", f"{which}.csv"))


# %%

labels = np.unique(data["response"])
converter_lab = {lab: i for i, lab in enumerate(labels)}
data["response"] = data["response"].map(converter_lab)
data["goldLabel"] = data["goldLabel"].map(converter_lab)

workers = np.unique(data["workerID"])
converter_worker = {worker: i for i, worker in enumerate(workers)}
data["workerID"] = data["workerID"].map(converter_worker)

tasks = np.unique(data["taskID"])
converter_task = {task: i for i, task in enumerate(tasks)}
data["taskID"] = data["taskID"].map(converter_task)

z_true = np.zeros(len(tasks))
visited_task = set()
answers = {task: {} for task in range(tasks.shape[0])}
for _, (worker, task, ans, gold, *_) in data.iterrows():
    task, worker, ans = int(task), int(worker), int(ans)
    if task not in visited_task:
        if not np.isnan(gold):
            z_true[task] = int(gold)
        visited_task.add(task)
    answers[task][worker] = int(ans)

answers = dict(sorted(answers.items()))
if len(z_true) == 0:
    print(f"Dataset {which} does not contain ground truth")
    compute_acc = False
else:
    compute_acc = True

# %%
mv = MV(answers=answers)
y_mv = mv.get_answers()
if compute_acc:
    acc_mv = np.mean(y_mv == z_true)
    print(f"#### MV : {acc_mv:.3f}")
# %%

ds = DS(answers=answers, n_classes=datasets[which])
ds.run_em(maxiter=100, epsilon=1e-6)
y_ds = ds.get_answers()
if compute_acc:
    acc_ds = np.mean(y_ds == z_true)
    print(f"#### DS : {acc_ds:.3f}")

# %%
ds_mc = DS_mc(answers, datasets[which])
ds_mc.run()
y_ds_mc = ds_mc.get_answers()
if compute_acc:
    acc_ds_mc = np.mean(y_ds_mc == z_true)
    print(f"#### DS MC: {acc_ds_mc:.3f}")

# %%

glad = GLAD(answers, datasets[which])
glad.run_em(epsilon=1e-10, maxiter=1000)
y_glad = glad.get_answers()
if compute_acc:
    acc_glad = np.mean(y_glad == z_true)
    print(f"#### GLAD: {acc_glad:.3f}")

# %%

glad_mc = GLAD_mc(answers, datasets[which])
glad_mc.run()
y_glad_mc = glad_mc.get_answers()
if compute_acc:
    acc_glad_mc = np.mean(y_glad_mc == z_true)
    print(f"#### GLAD: {acc_glad_mc:.3f}")

# %%
all_y_dswc = []
for l in [5]:
    dswc = DSWC(answers, datasets[which], l)
    dswc.run(epsilon=1e-6)
    y_dswc = dswc.get_answers()
    all_y_dswc.append(y_dswc)
    if compute_acc:
        acc_dswc = np.mean(y_dswc == z_true)
        print(f"#### DS + WC (L={l}): {acc_dswc:.3f}")
# %%

zc = ZC(answers)
zc.run(maxiter=100)
y_zc = zc.get_answers()
if compute_acc:
    acc_zc = np.mean(y_zc == z_true)
    print(f"#### ZC: {acc_zc:.3f}")

# %%
