# %%
from peerannot.models.MV import MV
from peerannot.models.DS import Dawid_Skene as DS
from peerannot.models.GLAD import GLAD
from peerannot.models.WAUM import WAUM
from peerannot.models.DS_mc import DS_mc
from peerannot.models.GLAD_mc import GLAD_mc
from peerannot.models.ZenCrowd import ZC
from peerannot.models.DS_clust import Dawid_Skene_clust as DSWC
import numpy as np
from torchvision.transforms import ToTensor
import torch
import pymc3 as pm
import os
import json

# %% import data
current_dir = os.path.dirname(__file__)
with open(
    os.path.join(current_dir, "data", "extrahard_MC_500_5_4.json")
) as data_file:
    data = json.load(data_file)

print(data)
data = {int(key): val for key, val in data.items()}
z_true = np.load(pm.get_data("extrahard_MC_500_5_4_reference_classes.npy"))

# %%
mv = MV(answers=data)
y_mv = mv.get_answers()
acc_mv = np.mean(y_mv == z_true)
print(f"#### MV : {acc_mv:.3f}")
# %%

ds = DS(answers=data, n_classes=4)
ds.run_em(maxiter=100, epsilon=1e-6)
y_ds = ds.get_answers()
acc_ds = np.mean(y_ds == z_true)
print(f"#### DS : {acc_ds:.3f}")


# %%
ds_mc = DS_mc(data, 4)
ds_mc.run()
y_ds_mc = ds_mc.get_answers()
acc_ds_mc = np.mean(y_ds_mc == z_true)
print(f"#### DS MC: {acc_ds_mc:.3f}")

# %%

glad = GLAD(data, 4)
glad.run_em(epsilon=1e-6, maxiter=100)
y_glad = glad.get_answers()
acc_glad = np.mean(y_glad == z_true)
print(f"#### GLAD: {acc_glad:.3f}")

# %%

glad_mc = GLAD_mc(data, 4)
glad_mc.run()
y_glad_mc = glad_mc.get_answers()
acc_glad_mc = np.mean(y_glad_mc == z_true)
print(f"#### GLAD: {acc_glad_mc:.3f}")

# %%

for l in [1, 2]:
    dswc = DSWC(data, 4, l)
    dswc.run(epsilon=1e-6)
    y_dswc = dswc.get_answers()
    acc_dswc = np.mean(y_dswc == z_true)
    print(f"#### DS + WC (L={l}): {acc_dswc:.3f}")
# %%

zc = ZC(data)
zc.run(maxiter=100)
y_zc = zc.get_answers()
acc_zc = np.mean(y_zc == z_true)
print(f"#### ZC: {acc_zc:.3f}")

# %%
