# %%
import json
import numpy as np
import sys
import os
from tqdm.notebook import tqdm

current = os.path.dirname(__file__)

from peerannot.models.GLAD import GLAD
from peerannot.models.DS import Dawid_Skene as DS


def generate_adverse(
    ratio_diff=0.5, p_random=0, p_bad=0.25, nw=50, nt=1000, seed=42
):
    rng = np.random.default_rng(seed)
    labels = rng.choice([0, 1], size=nt)
    p_hard = (1 - p_random) / (ratio_diff + 1)
    difficulty = rng.choice(
        ["easy", "hard", "random"],
        p=[ratio_diff * p_hard, p_hard, p_random],
        size=nt,
    )
    quality = rng.choice(["good", "bad"], size=nw, p=[1 - p_bad, p_bad])
    answers = {}
    for i in range(nt):
        answers[str(i)] = {}
        n_ans = rng.choice(range(1, 11))
        who = rng.choice(range(nw), size=n_ans, replace=False)
        for j in range(n_ans):
            if difficulty[i] == "easy":
                p_switch = 0
            elif difficulty[i] == "hard":
                if quality[who[j]] == "bad":
                    p_switch = 0.45
                else:
                    p_switch = 0.25
            else:  # random
                p_switch = 0.5
            answers[str(i)][str(who[j])] = rng.choice(
                [labels[i], 1 - labels[i]], p=[1 - p_switch, p_switch]
            )
    return labels, difficulty, quality, answers


nw = 50
nt = 1000

# %%
accu_wo = []
accu_w = []
for rep in tqdm(range(1)):
    labels, difficulty, quality, answers = generate_adverse(
        2, 0.2, 0.1, nw, nt, seed=rep
    )
    answers = {int(key): val for key, val in answers.items()}

    glad = GLAD(
        n_classes=10,
        answers=answers,
    )
    glad.run_em(maxiter=50, epsilon=1e-6)
    probas = glad.get_probas()
    accu_glad = (np.argmax(probas, axis=1) == labels).sum() / nt
    accu_wo.append(accu_glad)

    answers_2 = {}
    tt = 0
    for i, task in enumerate(answers):
        if difficulty[i] != "random":
            answers_2[tt] = {}
            for j, k in answers[task].items():
                answers_2[tt][j] = k
            tt += 1
    nt_2 = nt - np.where(difficulty == "random")[0].shape[0]
    labels_2 = np.array(labels)[np.where(difficulty != "random")]
    glad2 = GLAD(
        n_classes=10,
        answers=answers_2,
    )
    glad2.run_em(maxiter=50, epsilon=1e-6)
    probas = glad2.get_probas()
    accu_glad = (np.argmax(probas, axis=1) == labels_2).sum() / nt_2
    accu_w.append(accu_glad)

print(
    np.mean(accu_wo) - 1.96 * np.std(accu_wo),
    np.mean(accu_wo) + 1.96 * np.std(accu_wo),
)

print(
    np.mean(accu_w) - 1.96 * np.std(accu_w),
    np.mean(accu_w) + 1.96 * np.std(accu_w),
)

# %% impact on DS model

accu_wo_ds = []
accu_w_ds = []
for rep in tqdm(range(40)):
    labels, difficulty, quality, answers = generate_adverse(
        0.25, 0.3, 0.1, nw, nt, seed=rep
    )
    answers = {int(key): val for key, val in answers.items()}

    ds = DS(
        answers=answers,
        n_classes=2,
    )
    _ = ds.run_em(maxiter=50, epsilon=1e-6)
    pred_ds = ds.get_probas()
    pi_ds = ds.pi
    accu_wo_ds.append(np.mean(np.argmax(pred_ds, axis=1) == labels))
    answers_2 = {}
    tt = 0
    for i, task in enumerate(answers):
        if difficulty[i] != "random":
            answers_2[tt] = {}
            for j, k in answers[task].items():
                answers_2[tt][j] = k
            tt += 1
    nt_2 = nt - np.where(difficulty == "random")[0].shape[0]
    labels_2 = np.array(labels)[np.where(difficulty != "random")]
    ds2 = DS(
        answers=answers_2,
        n_classes=2,
    )
    _ = ds2.run_em(maxiter=50, epsilon=1e-6)
    pred_ds2 = ds2.get_probas()
    pi_ds2 = ds2.pi
    accu_w_ds.append(np.mean(np.argmax(pred_ds2, axis=1) == labels_2))
print(
    np.mean(accu_wo_ds) - 1.96 * np.std(accu_wo_ds),
    np.mean(accu_wo_ds) + 1.96 * np.std(accu_wo_ds),
)

print(
    np.mean(accu_w_ds) - 1.96 * np.std(accu_w_ds),
    np.mean(accu_w_ds) + 1.96 * np.std(accu_w_ds),
)


# %%
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap

# %%
who = 2  # 2, 4
sns.color_palette("Spectral", as_cmap=True)
fig, (ax1, ax2) = plt.subplots(1, 2)
sns.heatmap(
    pi_ds[who],
    ax=ax1,
    annot=True,
    fmt=".3f",
    square=True,
    vmin=0,
    vmax=1,
)
sns.heatmap(
    pi_ds2[who],
    ax=ax2,
    annot=True,
    fmt=".3f",
    square=True,
    vmin=0,
    vmax=1,
)
plt.tight_layout()
plt.show()
# %%
sns.set()

plt.figure()
plt.scatter(glad.alpha, glad2.alpha)
plt.axline((1, 1), slope=1, color="black", linestyle="dotted")
plt.ylim([0, 2])
plt.xlabel(r"$\hat\alpha_j$ without removal")
plt.ylabel(r"$\hat\alpha_j$ with removal")

plt.figure()
plt.scatter(
    np.exp(glad.beta[np.where(difficulty != "random")]), np.exp(glad2.beta)
)
plt.axline((1, 1), slope=1, color="black", linestyle="dotted")
plt.xlabel(r"$\hat\beta_i$ without removal")
plt.ylabel(r"$\hat\beta_i$ with removal")

plt.figure()
plt.scatter(
    [pi[0, 0] for pi in ds.pi],
    [pi2[0, 0] for pi2 in ds2.pi],
    label=r"$\hat\pi^{(j)}_{00}$",
)
plt.scatter(
    [pi[1, 1] for pi in ds.pi],
    [pi2[1, 1] for pi2 in ds2.pi],
    label=r"$\hat\pi^{(j)}_{11}$",
)
plt.axline((1, 1), slope=1, color="black", linestyle="dotted")
plt.legend()
plt.xlabel(r"Without removal")
plt.ylabel(r"With removal")
plt.show()

# %%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("ticks")
df = {"beta": glad.beta, "difficulty": difficulty}
df = pd.DataFrame(df)
df = df.sort_values(by="beta")
answ_1 = {
    task: value for task, value in answers.items() if len(value.keys()) == 1
}
df["one vote"] = [False] * len(df.beta)
df.loc[map(int, list(answ_1.keys())), "one vote"] = True

plt.figure()
sns.violinplot(x="difficulty", y="beta", data=df)
plt.ylabel(r"$\log(\hat\beta)$")
sns.despine()
plt.savefig("distrib_simu_beta_glad.pdf")
plt.show()

plt.figure()
sns.scatterplot(
    x=range(0, len(df.beta)),
    y=df.beta,
    data=df,
    hue="difficulty",
    size="one vote",
    sizes=(400, 40),
    linewidth=0,
)
plt.ylabel(r"$\log(\hat\beta)$")
plt.xlabel("Sorted task index")
plt.tight_layout()
# plt.title(fr"$p_{{hard}}=${(1-0.1)/(2+1):.3f}")
plt.savefig("beta_glad_scatter.pdf")
plt.show()

# %%
plt.figure()
sns.kdeplot(glad.beta)
plt.title(r"Density $\log(\hat\beta)$")
plt.savefig("glad_density_beta.pdf")


# %%
plt.figure()
sns.kdeplot(glad.alpha)
plt.title(r"Density $\alpha$")
plt.savefig("glad_density_alpha.pdf")

# %%
