# %%
from peerannot.models.MV import MV
from peerannot.models.Soft import Soft
from peerannot.models.DS import Dawid_Skene as DS
from peerannot.models.GLAD import GLAD
from peerannot.models.WAUM import WAUM
import numpy as np
from torchvision.transforms import ToTensor
import torch

import os
import json

# %% import data
current_dir = os.path.dirname(__file__)
with open(os.path.join(current_dir, "data", "toy_data.json")) as data_file:
    toy_data = json.load(data_file)

print(toy_data)
tasks = np.array([-2, 0.1, 2, 2.1, -1]).reshape(-1, 1)
print(tasks)

# %% Majority vote

mv = MV(toy_data)
print(mv.get_answers())

# %% Soft labelling

soft = Soft(toy_data, 2)
print(soft.get_answers())

# %% DS model

ds = DS(toy_data, 2)
ds.run_em(epsilon=1e-5, maxiter=30)
print(ds.get_answers())
print(ds.pi)

# %% GLAD

glad = GLAD(toy_data, 2)
glad.run_em(epsilon=1e-5, maxiter=30)
print(glad.get_answers())
print(glad.alpha)

# %% WAUM

input_dim = 1
output_dim = 2


class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = self.linear(x.reshape(-1, 1))
        return outputs


class Toy_dataset(torch.utils.data.Dataset):
    def __init__(self, tasks, truth, transform=None, target_transform=None):
        self.transform = transform
        self.tasks = tasks
        self.truth = truth
        self.target_transform = target_transform

    def __len__(self):
        return self.tasks.shape[0]

    def __getitem__(self, idx):
        image = self.tasks[idx].reshape(-1, 1, 1)
        label = self.truth[idx]
        if self.transform:
            image = self.transform(image).type(torch.FloatTensor)
        if self.target_transform:
            label = self.target_transform(label)
        return image, "lab", label, idx


model = LogisticRegression(input_dim, output_dim)
waum = WAUM(
    Toy_dataset(tasks, [0, 0, 0, 0, 0], transform=ToTensor()),
    toy_data,
    2,
    model,
    torch.nn.CrossEntropyLoss(),
    torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=5e-4,
    ),
    5,
)
waum.run(alpha=0.01)
print(waum.get_answers())
print(waum.too_hard)
print(waum.get_probas())

# %%
