# %%
import numpy as np
import torch
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from peerannot.models.Soft import Soft
from peerannot.models.GLAD import GLAD
from peerannot.models.WAUM import WAUM
import pandas as pd
from sklearn.datasets import (
    make_circles,
    make_classification,
    make_moons,
)
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.svm import SVC, LinearSVC
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.cm as cm
from sklearn.datasets import load_wine
import seaborn as sns

plt.rcParams.update({"axes.titlesize": "xx-large"})
plt.rcParams.update({"axes.labelsize": "xx-large"})
plt.rcParams["text.usetex"] = True
plt.rcParams[
    "text.latex.preamble"
] = r"\usepackage{amsmath,amsfonts,amsthm,amssymb}"
rng = np.random.default_rng(1)
torch.manual_seed(1)

# %%
class Toy_dataset(torch.utils.data.Dataset):
    def __init__(self, tasks, truth, transform=None, target_transform=None):
        self.transform = transform
        self.tasks = tasks
        self.truth = truth
        self.target_transform = target_transform

    def __len__(self):
        return self.tasks.shape[0]

    def __getitem__(self, idx):
        image = self.tasks[idx].reshape(-1, 1, 1)
        label = self.truth[idx]
        if self.transform:
            image = self.transform(image).type(torch.FloatTensor)
        if self.target_transform:
            label = self.target_transform(label)
        return image, "lab", label, idx


class model(nn.Module):
    def __init__(self, input_dim=2, output_dim=2):
        super(model, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.linear = torch.nn.Linear(input_dim, 30)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(30, 20)
        self.fc2 = nn.Linear(20, output_dim)

    def forward(self, x):
        out = self.linear(x.reshape(-1, self.input_dim))
        out = self.relu(out)
        out = self.fc1(out)
        out = self.relu(out)
        outputs = self.fc2(out)
        return outputs


def get_votes(
    workers,
    X,
    y,
    test_size=0.3,
    all_workers=True,
    train_transform=ToTensor(),
    test_transform=ToTensor(),
):
    X_train_bg, X_test, y_train_bg, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    toy_ds_train = Toy_dataset(X_train_bg, y_train_bg, transform=ToTensor())
    toy_ds_test = Toy_dataset(X_test, y_test, transform=test_transform)
    votes = {i: {} for i in range(X_train_bg.shape[0])}
    for j in range(len(workers)):
        X_train, y_train = X_train_bg, y_train_bg
        worker = workers[j]
        worker.fit(X_train, y_train)
        score = worker.score(X_test, y_test)
        print(worker, f"{score:.3f}")
    for i in range(X_train_bg.shape[0]):
        n_ans = (
            len(workers)
            if all_workers is True
            else rng.choice(range(1, all_workers + 1))
        )
        who = rng.choice(range(len(workers)), size=n_ans, replace=False)
        for j in range(n_ans):
            worker = workers[who[j]]
            pred_w = worker.predict(X_train_bg[i].reshape(1, -1))
            votes[i][str(who[j])] = pred_w[0]
    X_train_bg, X_test, y_train_bg, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    return (
        votes,
        X_train_bg,
        X_test,
        y_train_bg,
        y_test,
        toy_ds_train,
        toy_ds_test,
    )


def get_workers():
    workers = [
        # LogisticRegression(random_state=1, max_iter=1),
        LinearSVC(random_state=1),
        # KNeighborsClassifier(n_neighbors=3),
        SVC(random_state=1, max_iter=1),
        # SVC(random_state=1, kernel="poly"),
        GradientBoostingClassifier(n_estimators=5, random_state=1),
    ]
    return workers


# %%
def get_waum(
    toydataset,
    votes,
    y_train_bg,
    alpha=0.1,
    n_iter=30,
    lr=0.1,
    input_dim=2,
    output_dim=2,
):
    simplenet = model(input_dim, output_dim)
    waum = WAUM(
        toydataset,
        votes,
        output_dim,
        simplenet,
        torch.nn.CrossEntropyLoss(),
        torch.optim.SGD(
            simplenet.parameters(),
            lr=lr,
        ),
        n_iter,
        DEVICE="cpu",
    )
    waum.run(alpha=alpha)
    y_waum = waum.get_probas()
    print(
        "WAUM",
        np.mean(
            np.argmax(y_waum, axis=1)
            == np.delete(y_train_bg, waum.too_hard, axis=0)
        ),
    )
    return waum, y_waum


def get_comparison(votes, y_train_bg, output_dim=2):
    soft = Soft(answers=votes, n_classes=output_dim)
    y_soft = soft.get_probas()
    print("Soft", np.mean(np.argmax(y_soft, axis=1) == y_train_bg))
    glad = GLAD(
        n_classes=output_dim,
        answers=votes,
    )
    glad.run_em(epsilon=1e-6, maxiter=100)
    y_glad = glad.get_probas()
    print("GLAD", np.mean(np.argmax(y_glad, axis=1) == y_train_bg))
    return y_soft, soft, y_glad, glad


# %%
output_dim = 3
X_small, y_small = make_circles(
    n_samples=(125, 250), random_state=0, noise=0.1, factor=0.3
)
X_large, y_large = make_circles(
    n_samples=(125, 250), random_state=0, noise=0.1, factor=0.7
)
y_large[y_large == 1] = 2
X, y = np.vstack((X_small, X_large)), np.hstack((y_small, y_large))

workers = get_workers()
(
    votes,
    X_train_bg,
    X_test,
    y_train_bg,
    y_test,
    toy_ds_train,
    toy_ds_test,
) = get_votes(
    workers,
    X,
    y,
    test_size=0.3,
    all_workers=True,
    train_transform=ToTensor(),
    test_transform=ToTensor(),
)
y_soft, soft, y_glad, glad = get_comparison(votes, y_train_bg, output_dim)
waum, y_waum = get_waum(
    toy_ds_train,
    votes,
    y_train_bg,
    alpha=0.1,
    n_iter=100,
    lr=0.1,
    input_dim=2,
    output_dim=output_dim,
)

# %%


class Toy_dataset_red(torch.utils.data.Dataset):
    def __init__(
        self,
        tasks,
        ans,
        workers,
        true_idx,
        transform=None,
        target_transform=None,
    ):
        self.transform = transform
        self.tasks = tasks
        self.ans = ans
        self.true_idx = true_idx
        self.workers = workers
        self.target_transform = target_transform

    def __len__(self):
        return self.tasks.shape[0]

    def __getitem__(self, idx):
        image = self.tasks[idx].reshape(-1, 1, 1)
        label = self.ans[idx]
        true_idx = self.true_idx[idx]
        worker = self.workers[idx]
        if self.transform:
            image = self.transform(image).type(torch.FloatTensor)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label, worker, true_idx, idx


Xred, yred = np.array([1, 1]), np.array([])
true_idx = np.array([])
workers = np.array([])
for i in range(X_train_bg.shape[0]):
    task = votes[i]
    for j, ans in task.items():
        Xred = np.vstack((Xred, X_train_bg[i]))
        yred = np.hstack((yred, ans))
        workers = np.hstack((workers, int(j)))
        true_idx = np.hstack((true_idx, i))
Xred = Xred[1:, :]
dataset = Toy_dataset_red(Xred, yred, workers, true_idx, transform=ToTensor())
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=128)

# %%


def get_psuccess(probas, pij):
    with torch.no_grad():
        return probas @ np.diag(pij)


# %%

simplenet = model(2, 3)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(simplenet.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[500], gamma=0.1
)
AUM_recorder = {
    "task": [],
    "worker": [],
    "label": [],
    "epoch": [],
    "label_logit": [],
    "label_prob": [],
    "secondlogit": [],
    "secondprob": [],
    "score": [],
}
pi = waum.ds1.pi
from tqdm import tqdm

for epoch in tqdm(range(100), total=100):
    for x, yy, ww, dd, id_ in train_dataloader:
        y = torch.Tensor(yy).type(torch.long)
        optimizer.zero_grad()
        dd = list(map(int, dd.tolist()))
        out = simplenet(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        len_ = x.shape[0]
        AUM_recorder["task"].extend(dd)
        AUM_recorder["label"].extend(y.tolist())
        AUM_recorder["worker"].extend(list(map(int, ww.tolist())))
        AUM_recorder["epoch"].extend([epoch] * len_)

        if len_ > 1:
            AUM_recorder["label_logit"].extend(
                out.gather(1, y.view(-1, 1)).squeeze().tolist()
            )
            probs = out.softmax(dim=1)
            AUM_recorder["label_prob"].extend(
                probs.gather(1, y.view(-1, 1)).squeeze().tolist()
            )
            second_logit = torch.sort(out, axis=1)[0][:, -2]
            second_prob = torch.sort(probs, axis=1)[0][:, -2]
            AUM_recorder["secondlogit"].extend(second_logit.tolist())
            AUM_recorder["secondprob"].extend(second_prob.tolist())
            for ll in range(len_):
                AUM_recorder["score"].append(
                    get_psuccess(probs[ll], pi[int(ww[ll])]).numpy()
                )
    scheduler.step()
AUM_recorder = pd.DataFrame(AUM_recorder)

# %%
recorder2 = AUM_recorder.copy()
for task in tqdm(recorder2.task.unique(), total=len(recorder2.task.unique())):
    tmp = recorder2[recorder2.task == task]
    for j in tmp.worker.unique():
        recorder2.loc[
            recorder2[
                (recorder2.task == task) & (recorder2.worker == j)
            ].score.index,
            "score",
        ] = tmp[(tmp.worker == j) & (tmp.epoch == 100 - 1)].score.values[0]
AUM_recorder = recorder2


# %%
aum_df = AUM_recorder
dico_cpt_aum = {}
aum_df["margin"] = np.array(aum_df["label_prob"]) - np.array(
    aum_df["secondprob"]
)
unique_task = np.unique(np.array(aum_df["task"]))
for i, each_task in tqdm(
    enumerate(unique_task),
    total=len(unique_task),
    desc="computing WAUM",
):
    temp = aum_df[aum_df["task"] == each_task]
    avg = []
    score = []
    for j in np.unique(np.array(temp["worker"])):
        tempj = temp[temp["worker"] == j]
        avg.append((np.array(tempj["margin"]) * tempj["score"]).mean())
        score.append(tempj["score"].iloc[0])
    dico_cpt_aum[each_task] = np.sum(avg) / sum(score)
waum_concat = dico_cpt_aum


# %%
fig, ax = plt.subplots(1, 2)
ax[0] = plt.scatter(
    X_train_bg[:, 0],
    X_train_bg[:, 1],
    c=list(waum.waum.values()),
    cmap=cm.plasma_r,
)
ax[0].set_title("WAUM")

ax[0] = plt.scatter(
    X_train_bg[:, 0],
    X_train_bg[:, 1],
    c=list(waum.waum.values()),
    cmap=cm.plasma_r,
)

ax[0].set_title("WAUM")
