# %%
import torch
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

from peerannot_expe.models.Soft import Soft
from peerannot_expe.models.MV import MV
from peerannot_expe.models.DS import Dawid_Skene as DS
from peerannot_expe.models.GLAD import GLAD
from peerannot_expe.models.WAUM_perworker import WAUM_perworker as WAUM
from peerannot_expe.models.WAUM import WAUM_redundant
from tqdm import tqdm
import pandas as pd
from sklearn.datasets import (
    make_circles,
    make_classification,
    make_moons,
    make_blobs,
)
import matplotlib.pyplot as plt
import torch.nn as nn
from utils import (
    model,
    Toy_dataset,
    Toy_dataset_red,
    expected_calibration_error,
)
from sklearn.svm import SVC, LinearSVC
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.cm as cm
import os
import seaborn as sns
import numpy as np

plt.rcParams.update({"axes.titlesize": "xx-large"})
plt.rcParams.update({"axes.labelsize": "xx-large"})
plt.rcParams["text.usetex"] = True
plt.rcParams[
    "text.latex.preamble"
] = r"\usepackage{amsmath,amsfonts,amsthm,amssymb}"
rng = np.random.default_rng(42)
torch.manual_seed(42)

# %%
def get_votes(
    workers,
    X,
    y,
    test_size=0.3,
    all_workers=True,
    mean_per_worker=None,
    nozeros=None,
    train_transform=ToTensor(),
    test_transform=ToTensor(),
):
    X_train_bg, X_test, y_train_bg, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    votes = {i: {} for i in range(X_train_bg.shape[0])}
    scores = []
    for j in range(len(workers)):
        X_train, y_train = X_train_bg, y_train_bg
        worker = workers[j]
        worker.fit(X_train, y_train)
        scores.append(
            [worker.score(X_train, y_train), worker.score(X_test, y_test)]
        )
        # print(worker, f"Train {worker.score(X_train, y_train):.3f}")
        # print(worker, f"Test {score:.3f}")

    if nozeros is not None:
        which = []
        howmany = np.zeros(X_train_bg.shape[0])
        lower = 0
        for j in range(len(workers)):
            belowthresh = np.where(howmany == lower)
            if belowthresh[0].shape[0] > 0:
                which.append(
                    rng.choice(
                        belowthresh[0],
                        size=min(belowthresh[0].shape[0], mean_per_worker),
                        replace=False,
                    )
                )
            else:
                while np.where(howmany == lower)[0].shape[0] == 0:
                    lower += 1
                belowthresh = np.where(howmany == lower)
                which.append(
                    rng.choice(
                        belowthresh[0], size=mean_per_worker, replace=False
                    )
                )
            howmany[which[-1]] += 1

        for j in range(len(workers)):
            worker = workers[j]
            pred_w = worker.predict(X_train_bg[which[j], :])
            for idx, i0 in enumerate(which[j]):
                votes[i0][str(j)] = pred_w[idx]

    elif mean_per_worker is not None:
        which = []
        howmany = np.zeros(X_train_bg.shape[0])
        for j in range(len(workers)):
            belowthresh = np.where(howmany < mean_per_worker)
            if belowthresh[0].shape[0] > 0:
                which.append(
                    rng.choice(
                        belowthresh[0], size=mean_per_worker, replace=False
                    )
                )
            else:
                which.append(
                    rng.choice(
                        range(X_train_bg.shape[0]),
                        size=mean_per_worker,
                        replace=False,
                    )
                )
            howmany[which[-1]] += 1

        for j in range(len(workers)):
            worker = workers[j]
            pred_w = worker.predict(X_train_bg[which[j], :])
            for idx, i0 in enumerate(which[j]):
                votes[i0][str(j)] = pred_w[idx]

    else:
        for i in range(X_train_bg.shape[0]):
            n_ans = (
                len(workers)
                if all_workers is True
                else rng.choice(
                    range(1, all_workers + 1)
                )  # , p=[0.2, 0.4, 0.4])
            )
            who = rng.choice(range(len(workers)), size=n_ans, replace=False)
            for j in range(n_ans):
                worker = workers[who[j]]
                pred_w = worker.predict(X_train_bg[i].reshape(1, -1))
                votes[i][str(who[j])] = pred_w[0]

    X_train_bg, X_test, y_train_bg, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    return (votes, X_train_bg, X_test, y_train_bg, y_test, scores)


# %%
def agg_mv(votes, n_classes, y_train_bg):
    mv = MV(votes, n_classes=n_classes)
    y_mv = mv.get_answers()
    acc = np.mean(y_mv == y_train_bg)
    return mv, y_mv, acc


def agg_soft(votes, n_classes, y_train_bg):
    soft = Soft(answers=votes, n_classes=n_classes)
    y_soft = soft.get_probas()
    acc = np.mean(np.argmax(y_soft, axis=1) == y_train_bg)
    return soft, y_soft, acc


def agg_ds(votes, n_classes, y_train_bg):
    ds = DS(answers=votes, n_classes=n_classes)
    ds.run_em(maxiter=50, epsilon=1e-6)
    y_ds = ds.get_probas()
    acc = np.mean(np.argmax(y_ds, axis=1) == y_train_bg)
    return ds, y_ds, acc


def agg_glad(votes, n_classes, y_train_bg):
    glad = GLAD(
        n_classes=n_classes,
        answers=votes,
    )
    glad.run_em(epsilon=1e-6, maxiter=50)
    y_glad = glad.get_probas()
    acc = np.mean(np.argmax(y_glad, axis=1) == y_train_bg)
    return glad, y_glad, acc


def agg_waum(
    X_train_bg, votes, n_classes, y_train_bg, alpha, n_iter, inputdim=2
):
    toydataset = Toy_dataset(X_train_bg, y_train_bg, transform=ToTensor())
    simplenet = model(X_train_bg.shape[1], n_classes)
    waum = WAUM(
        toydataset,
        votes,
        n_classes,
        simplenet,
        torch.nn.CrossEntropyLoss(),
        torch.optim.SGD(
            simplenet.parameters(),
            lr=0.1,
        ),
        n_iter,
        DEVICE="cpu",
    )
    waum.run(alpha=alpha)
    y_waum = waum.get_probas()
    acc = np.mean(
        np.argmax(y_waum, axis=1)
        == np.delete(y_train_bg, waum.too_hard, axis=0)
    )
    return waum, y_waum, acc


def agg_waum_ds(
    X_train_bg, votes, n_classes, y_train_bg, alpha, n_iter, inputdim=2
):
    toydataset = Toy_dataset(X_train_bg, y_train_bg, transform=ToTensor())
    simplenet = model(X_train_bg.shape[1], n_classes)
    waum = WAUM(
        toydataset,
        votes,
        n_classes,
        simplenet,
        torch.nn.CrossEntropyLoss(),
        torch.optim.SGD(
            simplenet.parameters(),
            lr=0.1,
        ),
        n_iter,
        DEVICE="cpu",
    )
    waum.run_DS()
    waum.ds1 = waum.ds
    waum.pi1 = waum.ds1.pi
    waum.get_aum()
    waum.get_psi5_waum()
    waum.cut_lowests(alpha)
    waum.answers_waum = {
        key: val
        for key, val in waum.answers.items()
        if key not in waum.too_hard
    }
    ds = DS(
        n_classes=n_classes,
        answers=waum.answers_waum,
    )
    ds.run_em(epsilon=1e-6, maxiter=50)
    y_glad = ds.get_probas()
    acc = np.mean(
        np.argmax(y_glad, axis=1)
        == np.delete(y_train_bg, waum.too_hard, axis=0)
    )
    return waum, ds, y_glad, acc


def agg_waum_glad(
    X_train_bg, votes, n_classes, y_train_bg, alpha, n_iter, inputdim=2
):
    toydataset = Toy_dataset(X_train_bg, y_train_bg, transform=ToTensor())
    simplenet = model(X_train_bg.shape[1], n_classes)
    waum = WAUM(
        toydataset,
        votes,
        n_classes,
        simplenet,
        torch.nn.CrossEntropyLoss(),
        torch.optim.SGD(
            simplenet.parameters(),
            lr=0.1,
        ),
        n_iter,
        DEVICE="cpu",
    )
    waum.run_DS()
    waum.ds1 = waum.ds
    waum.pi1 = waum.ds1.pi
    waum.get_aum()
    waum.get_psi5_waum()
    waum.cut_lowests(alpha)
    waum.answers_waum = {
        key: val
        for key, val in waum.answers.items()
        if key not in waum.too_hard
    }
    glad = GLAD(
        n_classes=n_classes,
        answers=waum.answers_waum,
    )
    glad.run_em(epsilon=1e-5, maxiter=50)
    y_glad = glad.get_probas()
    acc = np.mean(
        np.argmax(y_glad, axis=1)
        == np.delete(y_train_bg, waum.too_hard, axis=0)
    )
    return waum, glad, y_glad, acc


def agg_waum_red(
    X_train_bg, votes, n_classes, y_train_bg, alpha, n_iter, inputdim=2
):
    Xred, yred = np.ones(X_train_bg.shape[1]), np.array([])
    true_idx = np.array([])
    workers = np.array([])
    for i in range(X_train_bg.shape[0]):
        task = votes[i]
        for j, ans in task.items():
            Xred = np.vstack((Xred, X_train_bg[i]))
            yred = np.hstack((yred, ans))
            workers = np.hstack((workers, int(j)))
            true_idx = np.hstack((true_idx, i))
    Xred = Xred[1:, :]
    dataset = Toy_dataset_red(
        Xred, yred, workers, true_idx, transform=ToTensor()
    )
    toydl = torch.utils.data.DataLoader(dataset, batch_size=64)

    simplenet = model(X_train_bg.shape[1], n_classes)
    waum_red = WAUM_redundant(
        toydl,
        votes,
        n_classes=n_classes,
        model=simplenet,
        criterion=torch.nn.CrossEntropyLoss(),
        optimizer=torch.optim.SGD(
            simplenet.parameters(),
            lr=0.1,
        ),
        n_epoch=n_iter,
    )
    waum_red.run(alpha=alpha)
    y_red = waum_red.get_probas()
    acc = np.mean(
        np.argmax(y_red, axis=1)
        == np.delete(y_train_bg, waum_red.too_hard, axis=0)
    )
    return waum_red, y_red, acc


def train(X_train, y_train, X_test, y_test, n_classes, n_epoch, inputdim=2):
    dataset = Toy_dataset(X_train, y_train, transform=ToTensor())
    dl = torch.utils.data.DataLoader(dataset, batch_size=64)
    simplenet = model(X_train.shape[1], n_classes)
    optimizer = torch.optim.SGD(simplenet.parameters(), lr=0.1)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[500, 1000], gamma=0.1
    )
    simplenet.train()
    for epoch in range(n_epoch):
        for x, _, y, id_ in dl:
            y = (
                y.type(torch.long)
                if y.ndim == 1
                else y.type(torch.FloatTensor)
            )
            optimizer.zero_grad()
            output = simplenet(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
        scheduler.step()

    simplenet.eval()
    Zt = simplenet(torch.Tensor(X_test)).softmax(1)
    y_pred = torch.argmax(Zt, axis=1).numpy()
    test_accuracy = np.mean(y_test == y_pred)
    ece = expected_calibration_error(Zt.data.numpy(), y_test)
    loss = (
        criterion(Zt, torch.Tensor(y_test).type(torch.long))
        .detach()
        .data.numpy()
    )

    return test_accuracy, Zt, y_pred, ece, loss.item()


# %%
def plot_workers(X_train, y_train, workers, name="", save=False):
    fig, ax = plt.subplots(
        1, len(workers) + 1, figsize=(15, 5)
    )  # , sharex=True, sharey=True
    for a, tt in enumerate(ax):
        tt.tick_params(
            left=False,
            right=False,
            labelleft=False,
            labelbottom=False,
            bottom=False,
        )
        tt.set_aspect("equal")
    ax[0].scatter(X_train[:, 0], X_train[:, 1], c=y_train, s=10)
    ax[0].set_title(r"Ground truth $\mathcal{D}_{\text{train}}$")
    for id_worker, worker in enumerate(workers):
        ax[id_worker + 1].scatter(
            X_train[:, 0],
            X_train[:, 1],
            c=worker.predict(X_train),
            s=10,
        )
        ax[id_worker + 1].set_title(worker.__class__.__name__)
    plt.tight_layout()
    if save:
        plt.savefig(f"binary_workers_workers_{name}.pdf")
    plt.show()


def plot_predictions(
    X_train,
    y_train,
    X_test,
    y_test,
    Zt,
    name="",
    method="MV",
    too_hard=None,
    save=False,
):
    method = f"Aggregation = {method}"
    fig, ax = plt.subplots(1, 2, figsize=(15, 5), sharex=True, sharey=True)
    for a, tt in enumerate(ax):
        tt.tick_params(
            left=False,
            right=False,
            labelleft=False,
            labelbottom=False,
            bottom=False,
        )
        tt.set_aspect("equal")
    if y_train.ndim == 2:
        y_plot = y_train[:, 1]
    elif y_train.ndim == 1:
        y_plot = y_train
    else:
        y_plot = [
            np.random.choice(np.flatnonzero(y_train[i] == y_train[i].max()))
            for i in range(y_train.shape[0])
        ]
    ax[0].scatter(
        X_train[:, 0],
        X_train[:, 1],
        c=y_plot,
    )
    if too_hard:
        ax[0].scatter(
            *too_hard,
            c="red",
            s=200,
            marker="s",
        )

    Zt_plot = (
        torch.argmax(Zt, axis=1).reshape(-1, 1).numpy()
        if Zt.ndim > 2
        else Zt[:, 1].detach().numpy()
    )
    im = ax[1].scatter(X_test[:, 0], X_test[:, 1], c=Zt_plot)
    ax[0].set_title(r"$\mathcal{D}_{\text{train}}$")
    ax[0].set_ylabel(method)
    ax[1].set_title(r"Predictions")
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    if save:
        plt.savefig(f"binary_workers_{method}_{name}.pdf")
    plt.show()


def plot_waum(waum, X_train, name="", save=False):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[1].tick_params(
        left=False,
        right=False,
        labelleft=False,
        labelbottom=False,
        bottom=False,
    )
    g = sns.kdeplot(list(waum.waum.values()), ax=ax[0])
    g.vlines(
        x=[waum.quantile],
        ymin=0,
        ymax=max(list(waum.waum.values())),
        colors=["red"],
        ls="--",
    )
    plot_waum = ax[1].scatter(
        X_train[:, 0],
        X_train[:, 1],
        c=list(waum.waum.values()),
        cmap=cm.plasma_r,
    )
    ax[0].set_xlabel("")
    ax[0].set_ylabel("")
    ax[1].set_title(r"$\mathrm{WAUM}\ \mathcal{D}_{\text{train}}$")
    ax[0].set_title(r"Density $\mathrm{WAUM}$")
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    fig.colorbar(plot_waum, cax=cbar_ax)
    if save:
        plt.savefig(f"binary_workers_density_{name}.pdf")
    plt.show()


def plot_aum_workers(waum, X_train, workers, name="", save=False):
    fig, ax = plt.subplots(1, len(workers), figsize=(15, 5))
    for a, tt in enumerate(ax.flatten()):
        tt.tick_params(
            left=False,
            right=False,
            labelleft=False,
            labelbottom=False,
            bottom=False,
        )
        tt.set_aspect("equal")

    for j in range(len(workers)):
        aum_j = {
            i: l
            for i, val in waum.aum_per_worker.items()
            for w, l in val.items()
            if w == j
        }
        ax[j].set_title(fr"$\mathrm{{AUM}}$ {workers[j].__class__.__name__}")
        aumj = ax[j].scatter(
            X_train[:, 0],
            X_train[:, 1],
            c=list(aum_j.values()),
            cmap=cm.plasma_r,
            vmin=0,
            vmax=1,
            s=10,
        )
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.7])
    fig.colorbar(aumj, cax=cbar_ax)
    if save:
        plt.savefig(f"AUM_workers_{name}.pdf")
    plt.show()


def plot_scores(waum, X_train, workers, name="", save=False):
    fig, ax = plt.subplots(1, len(workers), figsize=(15, 5))
    for a, tt in enumerate(ax.flatten()):
        tt.tick_params(
            left=False,
            right=False,
            labelleft=False,
            labelbottom=False,
            bottom=False,
        )
        tt.set_aspect("equal")
    for j in range(len(workers)):
        sj = {
            i: l
            for i, val in waum.score_per_worker.items()
            for w, l in val.items()
            if w == j
        }
        ax[j].set_title(fr"$s^{{(j)}}$ {workers[j].__class__.__name__}")
        sjx = ax[j].scatter(
            X_train[:, 0],
            X_train[:, 1],
            c=list(sj.values()),
            vmin=0,
            vmax=1,
            s=10,
        )
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.7])
    fig.colorbar(sjx, cax=cbar_ax)
    if save:
        plt.savefig(f"trust_{name}.pdf")
    plt.show()


# %%


def get_workers(many=None):
    if many is None:
        workers = [
            # LogisticRegression(random_state=1),
            LinearSVC(random_state=2),
            SVC(random_state=2, max_iter=1),
            GradientBoostingClassifier(n_estimators=5, random_state=2),
        ]
    else:
        if many == True:
            many = 100
        workers = []
        # Linear SVC
        C = np.linspace(0.001, 3, 20)
        max_iter = np.arange(1, 100)
        for j in range(many):
            ll = rng.choice(["hinge", "squared_hinge"])
            cc = rng.choice(C)
            iter_ = rng.choice(max_iter)
            workers.append(
                LinearSVC(random_state=j, loss=ll, C=cc, max_iter=iter_)
            )
        # SVC
        kernels = ["poly", "rbf", "sigmoid"]
        for j in range(many):
            ker = rng.choice(kernels)
            iter_ = rng.choice(max_iter)
            workers.append(SVC(random_state=j, kernel=ker, max_iter=iter_))
        #
        for j in range(many):
            lr = rng.choice([0.01, 0.1, 0.5])
            splt = rng.choice([2, 5, 10])
            iter_ = rng.choice(max_iter)
            nest = rng.choice([1, 2, 5, 10, 15, 20, 30, 50, 100])
            workers.append(
                GradientBoostingClassifier(
                    n_estimators=nest,
                    max_depth=iter_,
                    min_samples_split=splt,
                    learning_rate=lr,
                    # loss=ll,
                    random_state=j,
                )
            )
    return workers


def save_all_csv(
    X_train,
    X_test,
    y_train,
    yws,
    ys,
    ypred,
    y_test,
    waum,
    waum_red,
    workers,
    name="",
):
    os.makedirs(f"./output/{name}", exist_ok=True)
    traincsv = {
        "Xtrain0": X_train[:, 0],
        "Xtrain1": X_train[:, 1],
        "truth": y_train,
    }
    testcsv = {"Xtest0": X_test[:, 0], "Xtest1": X_test[:, 1], "ytest": y_test}
    # for j in range(len(workers)):
    #     traincsv[f"{workers[j].__class__.__name__}"] = yws[:, j]
    #     traincsv[f"AUM{j}"] = list(
    #         {
    #             i: l
    #             for i, val in waum.aum_per_worker.items()
    #             for w, l in val.items()
    #             if w == j
    #         }.values()
    #     )
    #     sumsj = {i: 0 for i in range(X_train.shape[0])}
    #     for i, val in waum.score_per_worker.items():
    #         sumsj[i] += sum(val.values())
    #     sj = {
    #         i: l
    #         for i, val in waum.score_per_worker.items()
    #         for w, l in val.items()
    #         if w == j
    #     }
    #     traincsv[f"norms{j}"] = np.array(list(sj.values())) / np.array(
    #         list(sumsj.values())
    #     )
    for y, method in ys:
        new = {}
        if method == "WAUM":
            new["Xtrain0"] = np.delete(X_train[:, 0], waum.too_hard, axis=0)
            new["Xtrain1"] = np.delete(X_train[:, 1], waum.too_hard, axis=0)
        elif method == "WAUMredundant":
            new["Xtrain0"] = np.delete(
                X_train[:, 0], waum_red.too_hard, axis=0
            )
            new["Xtrain1"] = np.delete(
                X_train[:, 1], waum_red.too_hard, axis=0
            )
        else:
            new["Xtrain0"] = X_train[:, 0]
            new["Xtrain1"] = X_train[:, 1]

        if y.ndim == 1:
            new[method] = y
        else:
            new[method] = np.argmax(y, axis=1) if y.shape[1] > 2 else y[:, 1]
        new["round"] = np.round(new[method])
        pd.DataFrame(new).to_csv(
            f"./output/{name}/{method}_train.csv", index=False
        )
    for y, method in ypred:
        y = y.detach().numpy()
        new = {}
        new["Xtest0"] = X_test[:, 0]
        new["Xtest1"] = X_test[:, 1]

        if y.ndim == 1:
            new[method] = y
        else:
            new[method] = np.argmax(y, axis=1) if y.shape[1] > 2 else y[:, 1]
        new["round"] = np.round(new[method])
        pd.DataFrame(new).to_csv(
            f"./output/{name}/{method}_test.csv", index=False
        )

    too_hard = {"x": X_train[waum.too_hard, 0], "y": X_train[waum.too_hard, 1]}
    too_hard_red = {
        "x": X_train[waum_red.too_hard, 0],
        "y": X_train[waum_red.too_hard, 1],
    }
    # pd.DataFrame({"WAUMvalues": list(waum.waum.values())}).to_csv(
    #     f"./output/{name}/waumvalues.csv", index=False
    # )
    pd.DataFrame({"WAUMvalues": list(waum_red.waum.values())}).to_csv(
        f"./output/{name}/waumredvalues.csv", index=False
    )

    # pd.DataFrame(too_hard).to_csv(f"./output/{name}/too_hard.csv", index=False)
    pd.DataFrame(too_hard_red).to_csv(
        f"./output/{name}/too_hard_red.csv", index=False
    )
    pd.DataFrame(testcsv).to_csv(f"./output/{name}/test.csv", index=False)
    pd.DataFrame(traincsv).to_csv(f"./output/{name}/train.csv", index=False)


# %%
save = True
alpha = 0.1

# seed = 5
# n_classes = 2
# dataset = f"two_circles"  # "two_circles_max3"
# X, y = make_circles(n_samples=500, noise=0.2, factor=0.4, random_state=seed)


seed = 6
n_classes = 3
dataset = "three_circles"
X_small, y_small = make_circles(
    n_samples=(125, 250), random_state=10, noise=0.1, factor=0.3
)
X_large, y_large = make_circles(
    n_samples=(125, 250), random_state=10, noise=0.1, factor=0.7
)
y_large[y_large == 1] = 2
X, y = np.vstack((X_small, X_large)), np.hstack((y_small, y_large))

# seed = 5
# n_classes = 2
# dataset = "gmm"
# X, y = make_classification(
#     n_samples=500,
#     n_features=2,
#     n_redundant=0,
#     n_informative=2,
#     random_state=0,
#     n_clusters_per_class=2,
#     flip_y=0.0,
#     class_sep=1.5,
#     hypercube=True,
# )


# seed = 6
# n_classes = 2
# dataset = "two_moons"
# X, y = make_moons(n_samples=500, noise=0.2, random_state=0)


def warn(*args, **kwargs):
    pass


import warnings

warnings.warn = warn

rng = np.random.default_rng(seed)
torch.manual_seed(seed)


workers = get_workers(None)
(votes, X_train, X_test, y_train, y_test, scores) = get_votes(
    workers,
    X,
    y,
    test_size=0.3,
    all_workers=True,
    mean_per_worker=None,
    train_transform=ToTensor(),
    test_transform=ToTensor(),
)

# %%
waum, y_waum, acc_waum_train = agg_waum(
    X_train, votes, n_classes, y_train, 0.1, 100
)
acc_waum, Zt_waum, y_pred_waum, ece_waum, loss_waum = train(
    np.delete(X_train, waum.too_hard, axis=0),
    y_waum,
    X_test,
    y_test,
    n_classes,
    1000,
)
# %% run for best alpha
rng = np.random.default_rng(6)
torch.manual_seed(6)

alphas = [1e-2, 1e-1, 0.25]
res_alphas = {al: {"Acc": [], "ECE": [], "Agg": []} for al in alphas}
for alpha in alphas:
    for rep in range(10):
        waum_red, y_waum_red, acc_waum_train_red = agg_waum_red(
            X_train, votes, n_classes, y_train, alpha, 100
        )
        (
            acc_waum_red,
            Zt_waum_red,
            y_pred_waum_red,
            ece_waum_red,
            loss_waum_red,
        ) = train(
            np.delete(X_train, waum_red.too_hard, axis=0),
            y_waum_red,
            X_test,
            y_test,
            n_classes,
            1000,
        )
        print(acc_waum_train_red)
        res_alphas[alpha]["Agg"].append(acc_waum_train_red)
        res_alphas[alpha]["Acc"].append(acc_waum_red)
        res_alphas[alpha]["ECE"].append(ece_waum_red)
        if alpha == 0.1 and rep == 1:
            save_all_csv(
                X_train,
                X_test,
                y_train,
                votes,
                [
                    (y_waum_red, "WAUMredundant"),
                ],
                [
                    (Zt_waum_red, "WAUMredundant"),
                ],
                y_test,
                waum_red,
                waum_red,
                workers,
                name="3circles",
            )

for al, vals in res_alphas.items():
    print("### alpha = ", al)
    for metric, values in vals.items():
        print(
            metric, np.round(np.mean(values), 3), np.round(np.std(values), 3)
        )
save_all_csv(
    X_train,
    X_test,
    y_train,
    votes,
    [
        (y_waum_red, "WAUMredundant"),
    ],
    [
        (Zt_waum_red, "WAUMredundant"),
    ],
    y_test,
    waum_red,
    waum_red,
    workers,
    name="3circles",
)


# %% repeat perf with alpha
from itertools import chain

seed = 3
nrep = 10
alpha = 0.1
rng = np.random.default_rng(seed)
torch.manual_seed(seed)
res_meth = {
    "train_acc": [],
    "accuracy": [],
    "ECE": [],
}


for rep in tqdm(range(nrep)):
    waum, y_waum, acc_waum_train = agg_waum(
        X_train, votes.copy(), n_classes, y_train, 0.1, 100
    )
    waum_red, y_red, acc_red_train = agg_waum_red(
        X_train, votes.copy(), n_classes, y_train, 0.1, 100
    )

    # mv, y_mv, acc_mv_train = agg_mv(votes.copy(), n_classes, y_train)
    # soft, y_soft, acc_soft_train = agg_soft(votes.copy(), n_classes, y_train)
    # ds, y_ds, acc_ds_train = agg_ds(votes.copy(), n_classes, y_train)
    # glad, y_glad, acc_glad_train = agg_glad(votes.copy(), n_classes, y_train)
    # wdswaum, wgds, y_wds, acc_wds_train = agg_waum_ds(
    #     X_train, votes.copy(), n_classes, y_train, 0.1, 100, inputdim=2
    # )
    # wmvwaum, wmvmv, y_wmv, acc_wmv_train = agg_waum_mv(
    #     X_train, votes.copy(), n_classes, y_train, 0.1, 100, inputdim=2
    # )

    acc_waum, Zt_waum, y_pred_waum, ece_waum, loss_waum = train(
        np.delete(X_train, waum.too_hard, axis=0),
        y_waum,
        X_test,
        y_test,
        n_classes,
        1000,
    )
    acc_red, Zt_red, y_pred_red, ece_red, loss_red = train(
        np.delete(X_train, waum_red.too_hard, axis=0),
        y_red,
        X_test,
        y_test,
        n_classes,
        1000,
    )
    # acc_mv, Zt_mv, y_pred_mv, ece_mv, loss_mv = train(
    #     X_train, y_mv, X_test, y_test, n_classes, 1000
    # )
    # acc_soft, Zt_soft, y_pred_soft, ece_soft, loss_soft = train(
    #     X_train, y_soft, X_test, y_test, n_classes, 1000
    # )
    # acc_glad, Zt_glad, y_pred_glad, ece_glad, loss_glad = train(
    #     X_train, y_glad, X_test, y_test, n_classes, 1000
    # )
    # acc_ds, Zt_ds, y_pred_ds, ece_ds, loss_ds = train(
    #     X_train, y_ds, X_test, y_test, n_classes, 1000
    # )
    # acc_wglad, Zt_wglad, y_pred_wglad, ece_wglad, loss_wglad = train(
    #     np.delete(X_train, wdswaum.too_hard, axis=0),
    #     y_wds,
    #     X_test,
    #     y_test,
    #     n_classes,
    #     1000,
    # )

    # acc_wmv, Zt_wmv, y_pred_wmv, ece_wmv, loss_wmv = train(
    #     np.delete(X_train, wmvwaum.too_hard, axis=0),
    #     y_wmv,
    #     X_test,
    #     y_test,
    #     n_classes,
    #     1000,
    # )

    res_meth["train_acc"].append(
        [
            # acc_mv_train,
            # acc_soft_train,
            # acc_ds_train,
            # acc_glad_train,
            acc_waum_train,
            acc_red,
            # acc_wds_train,
            # acc_wmv_train,
        ]
    )
    res_meth["accuracy"].append(
        [acc_waum, acc_red]
        # [acc_mv, acc_soft, acc_ds, acc_glad, acc_waum]  # , acc_wglad, acc_wmv]
    )
    res_meth["ECE"].append(
        [ece_waum, ece_red]
        # [ece_mv, ece_soft, ece_ds, ece_glad, ece_waum]  # , ece_wglad, ece_wmv]
    )

for t in res_meth:
    try:
        res_meth[t] = list(chain(*res_meth[t]))
        if any(isinstance(i, list) for i in res_meth[t]):
            res_meth[t] = list(chain(*res_meth[t]))
    except:
        continue
res_meth["method"] = ["WAUM", "WAUMredundant"] * nrep
pd.DataFrame(res_meth).groupby("method").mean()


# %%

res_meth = {
    "train_acc": [],
    "accuracy": [],
    "ECE": [],
}
nrep = 10

for rep in tqdm(range(nrep)):
    wdswaum, wgds, y_wds, acc_wds_train = agg_waum_ds(
        X_train, votes.copy(), n_classes, y_train, 0.1, 100, inputdim=2
    )
    wgladwaum, wgladglad, y_wglad, acc_wglad_train = agg_waum_glad(
        X_train, votes.copy(), n_classes, y_train, 0.1, 100, inputdim=2
    )
    acc_wds, Zt_wds, y_pred_wds, ece_wds, loss_wds = train(
        np.delete(X_train, wdswaum.too_hard, axis=0),
        y_wds,
        X_test,
        y_test,
        n_classes,
        1000,
    )
    acc_wglad, Zt_wglad, y_pred_wglad, ece_wglad, loss_wglad = train(
        np.delete(X_train, wgladwaum.too_hard, axis=0),
        y_wglad,
        X_test,
        y_test,
        n_classes,
        1000,
    )
    res_meth["train_acc"].append(
        [
            acc_wds_train,
            acc_wglad_train,
        ]
    )
    res_meth["accuracy"].append([acc_wds, acc_wglad])
    res_meth["ECE"].append([ece_wds, ece_wglad])

for t in res_meth:
    try:
        res_meth[t] = list(chain(*res_meth[t]))
        if any(isinstance(i, list) for i in res_meth[t]):
            res_meth[t] = list(chain(*res_meth[t]))
    except:
        continue
res_meth["method"] = ["WAUM + DS", "WAUM + GLAD"] * nrep
pd.DataFrame(res_meth).groupby("method").mean()
# %%
pd.DataFrame(res_meth).groupby("method").mean()

# %%
try:
    import pickle

    pickle.dump(
        res_meth, open("res_2circles", "w"), protocol=pickle.HIGHEST_PROTOCOL
    )
except:
    print("Could not do it")
# %%
res_meth = pd.DataFrame(res_meth, ["MV", "Naïve Soft", "DS", "GLAD", "WAUM"])
res_meth.round(3)
# %%
how_many = np.zeros(len(workers))
for task, labs in votes.items():
    for worker in list(labs.keys()):
        how_many[int(worker)] += 1
plt.bar(range(how_many.shape[0]), how_many)
plt.ylabel(r"$|\mathcal{T}(w_j)|$")
plt.xlabel(r"$w_j$")

# %%
# gridspec = {"width_ratios": [1, 1, 1, 0.1]}
# fig, axs = plt.subplots(3, 4, figsize=(10, 8), gridspec_kw=gridspec)
# for i in range(4):
#     if i == 3:
#         plt.colorbar(im, cax=axs[0, i], fraction=0.046, pad=0.04)
#         continue
#     im = axs[0, i].imshow(waum.pi1[i], vmin=0, vmax=1, cmap="RdBu")
#     axs[0, i].set_title(rf"$w_{i}$")
#     axs[0, i].set_yticks([0, 1, 2])
#     axs[0, i].set_yticklabels([0, 1, 2])
#     if i == 0:
#         axs[0, i].set_ylabel(r"$\hat\pi^{(j)}_{\mathcal{D}_\text{train}}$")
#         axs[0, i].tick_params(
#             axis="x",  # changes apply to the x-axis
#             which="both",  # both major and minor ticks are affected
#             bottom=False,  # ticks along the bottom edge are off
#             top=False,
#             labelleft=False,  # ticks along the top edge are off
#             labelbottom=False,
#         )  #
#     if i > 0:
#         axs[0, i].tick_params(
#             axis="both",
#             which="both",
#             bottom=False,
#             top=False,
#             labelbottom=False,
#             left=False,
#             labelleft=False,
#         )

# for i in range(4):
#     if i == 3:
#         plt.colorbar(im, cax=axs[1, i], fraction=0.046, pad=0.04)
#         continue
#     im = axs[1, i].imshow(waum.pi2[i], vmin=0, vmax=1, cmap="RdBu")
#     axs[1, i].set_yticks([0, 1, 2])
#     axs[1, i].set_yticklabels([0, 1, 2])

#     if i == 0:
#         axs[1, i].set_ylabel(r"$\hat\pi^{(j)}_{\mathcal{D}_\text{pruned}}$")
#         axs[1, i].tick_params(
#             axis="x",  # changes apply to the x-axis
#             which="both",  # both major and minor ticks are affected
#             bottom=False,  # ticks along the bottom edge are off
#             top=False,
#             labelleft=False,  # ticks along the top edge are off
#             labelbottom=False,
#         )  #
#     if i > 0:
#         axs[1, i].tick_params(
#             axis="both",
#             which="both",
#             bottom=False,
#             top=False,
#             labelbottom=False,
#             left=False,
#             labelleft=False,
#         )

# for i in range(4):
#     if i == 3:
#         plt.colorbar(im, cax=axs[2, i], fraction=0.046, pad=0.04)
#         continue
#     im = axs[2, i].imshow(waum.pi1[i] - waum.pi2[i], vmin=-1, vmax=1, cmap="BrBG")
#     axs[2, i].set_yticks([0, 1, 2])
#     axs[2, i].set_yticklabels([0, 1, 2])

#     if i == 0:
#         axs[2, i].set_ylabel(r"$\hat\pi^{(j)}_{\mathcal{D}_\text{pruned}}$")
#     if i > 0:
#         axs[2, i].tick_params(
#             axis="both", which="both", top=False, left=False, labelleft=False
#         )
#     axs[2, i].set_xticks([0, 1, 2])
#     axs[2, i].set_xticklabels([0, 1, 2])
#     if i == 0:
#         axs[2, i].set_ylabel(
#             r"$\hat\pi^{(j)}_{\mathcal{D}_\text{train}} - \hat\pi^{(j)}_{\mathcal{D}_\text{pruned}}$"
#         )
# plt.tight_layout()

# %%
# Running 1<= A(x_i) <= 3
seed = 5


import warnings

warnings.filterwarnings("ignore")


alpha = 0.1
n_classes = 3
X_small, y_small = make_circles(
    n_samples=(125, 250), random_state=10, noise=0.1, factor=0.3
)
X_large, y_large = make_circles(
    n_samples=(125, 250), random_state=10, noise=0.1, factor=0.7
)
y_large[y_large == 1] = 2
X, y = np.vstack((X_small, X_large)), np.hstack((y_small, y_large))
rng = np.random.default_rng(seed)
torch.manual_seed(seed)

all_res = []
workers = get_workers(None)
list_howmany = [3]
for how_many in list_howmany:
    rng = np.random.default_rng(seed)
    torch.manual_seed(seed)
    print("## IN", how_many)
    (votes, X_train, X_test, y_train, y_test, scores) = get_votes(
        workers,
        X,
        y,
        test_size=0.3,
        all_workers=how_many,
        mean_per_worker=None,
        nozeros=None,
        train_transform=ToTensor(),
        test_transform=ToTensor(),
    )
    res_meth = {
        "train_acc": [],
        "accuracy": [],
        "ECE": [],
    }
    for rep in range(5):
        waum, y_waum, acc_waum_train = agg_waum(
            X_train, votes.copy(), n_classes, y_train, 0.1, 100
        )
        soft, y_soft, acc_soft_train = agg_soft(
            votes.copy(), n_classes, y_train
        )
        ds, y_ds, acc_ds_train = agg_ds(votes.copy(), n_classes, y_train)
        glad, y_glad, acc_glad_train = agg_glad(
            votes.copy(), n_classes, y_train
        )
        acc_waum, Zt_waum, y_pred_waum, ece_waum, loss_waum = train(
            np.delete(X_train, waum.too_hard, axis=0),
            y_waum,
            X_test,
            y_test,
            n_classes,
            1000,
        )
        acc_soft, Zt_soft, y_pred_soft, ece_soft, loss_soft = train(
            X_train, y_soft, X_test, y_test, n_classes, 1000
        )
        acc_glad, Zt_glad, y_pred_glad, ece_glad, loss_glad = train(
            X_train, y_glad, X_test, y_test, n_classes, 1000
        )
        acc_ds, Zt_ds, y_pred_ds, ece_ds, loss_ds = train(
            X_train, y_ds, X_test, y_test, n_classes, 1000
        )

        res_meth["train_acc"].append(
            [
                acc_soft_train,
                acc_ds_train,
                acc_glad_train,
                acc_waum_train,
            ]
        )
        res_meth["accuracy"].append(
            [
                acc_soft,
                acc_ds,
                acc_glad,
                acc_waum,
            ]
        )
        res_meth["ECE"].append(
            [
                ece_soft,
                ece_ds,
                ece_glad,
                ece_waum,
            ]
        )
        print(res_meth)
    all_res.append(res_meth)
# %%
fullresult = {
    "howmany": [],
    "trainacc": [],
    "ECE": [],
    "Acc": [],
    "method": [],
}
for i, res in enumerate(all_res):
    fullresult["howmany"].append([list_howmany[i]] * 5 * 5)
    fullresult["trainacc"].append(res["train_acc"])
    fullresult["ECE"].append(res["ECE"])
    fullresult["Acc"].append(res["accuracy"])
    fullresult["method"].append(["NaiveSoft", "DS", "GLAD", "WAUM"] * 5)
# %%
from itertools import chain

for t in fullresult:
    fullresult[t] = list(chain(*fullresult[t]))
    if any(isinstance(i, list) for i in fullresult[t]):
        fullresult[t] = list(chain(*fullresult[t]))

# %%
pd.DataFrame(fullresult).groupby(["howmany", "method"]).mean()
# %%
fullresult = pd.DataFrame(fullresult)
# fullresult.to_csv("up_to_3_red.csv", index=False)


# %%

# Running with many workers
seed = 5
import warnings

warnings.filterwarnings("ignore")


alpha = 0.1
n_classes = 5
X, y = make_classification(
    1500,
    n_classes=5,
    random_state=423,
    n_features=5,
    n_informative=5,
    n_clusters_per_class=1,
    n_redundant=0,
)

# X_small, y_small = make_circles(
#     n_samples=(125, 250), random_state=10, noise=0.1, factor=0.3
# )
# X_large, y_large = make_circles(
#     n_samples=(125, 250), random_state=10, noise=0.1, factor=0.7
# )
# y_large[y_large == 1] = 2
# X, y = np.vstack((X_small, X_large)), np.hstack((y_small, y_large))
rng = np.random.default_rng(seed)
torch.manual_seed(seed)

all_res = []
workers = get_workers(50)
list_howmany = [5]
for how_many in list_howmany:
    rng = np.random.default_rng(seed)
    torch.manual_seed(seed)
    print("## IN", how_many)
    (votes, X_train, X_test, y_train, y_test, scores) = get_votes(
        workers,
        X,
        y,
        test_size=0.3,
        # nozeros=True,
        # mean_per_worker=5,
        all_workers=how_many,
        train_transform=ToTensor(),
        test_transform=ToTensor(),
    )
    res_meth = {
        "train_acc": [],
        "accuracy": [],
        "ECE": [],
    }
    for rep in range(1):
        waum, y_waum, acc_waum_train = agg_waum(
            X_train,
            votes.copy(),
            n_classes,
            y_train,
            0.1,
            100,
        )
        waum_red, y_red, acc_red_train = agg_waum_red(
            X_train, votes.copy(), n_classes, y_train, 0.1, 100
        )
        soft, y_soft, acc_soft_train = agg_soft(
            votes.copy(), n_classes, y_train
        )
        ds, y_ds, acc_ds_train = agg_ds(votes.copy(), n_classes, y_train)
        glad, y_glad, acc_glad_train = agg_glad(
            votes.copy(), n_classes, y_train
        )
        acc_waum, Zt_waum, y_pred_waum, ece_waum, loss_waum = train(
            np.delete(X_train, waum.too_hard, axis=0),
            y_waum,
            X_test,
            y_test,
            n_classes,
            1000,
        )
        acc_soft, Zt_soft, y_pred_soft, ece_soft, loss_soft = train(
            X_train, y_soft, X_test, y_test, n_classes, 1000
        )
        acc_glad, Zt_glad, y_pred_glad, ece_glad, loss_glad = train(
            X_train, y_glad, X_test, y_test, n_classes, 1000
        )
        acc_ds, Zt_ds, y_pred_ds, ece_ds, loss_ds = train(
            X_train, y_ds, X_test, y_test, n_classes, 1000
        )
        acc_red, Zt_red, y_pred_red, ece_red, loss_red = train(
            np.delete(X_train, waum_red.too_hard, axis=0),
            y_red,
            X_test,
            y_test,
            n_classes,
            1000,
        )

        res_meth["train_acc"].append(
            [
                acc_soft_train,
                acc_ds_train,
                acc_glad_train,
                acc_waum_train,
                acc_red_train,
            ]
        )
        res_meth["accuracy"].append(
            [
                acc_soft,
                acc_ds,
                acc_glad,
                acc_waum,
                acc_red,
            ]
        )
        res_meth["ECE"].append(
            [
                ece_soft,
                ece_ds,
                ece_glad,
                ece_waum,
                ece_red,
            ]
        )
        print(res_meth)
        df = {"glad": [], "waum": [], "concat": []}
        df["glad"] = np.exp(glad.beta)
        df["waum"] = list(waum.waum.values())
        df["concat"] = list(waum_red.waum.values())
    all_res.append(res_meth)
# %%
fullresult = {
    "howmany": [],
    "trainacc": [],
    "ECE": [],
    "Acc": [],
    "method": [],
}
for i, res in enumerate(all_res):
    fullresult["howmany"].append([list_howmany[i]] * 10 * 5)
    fullresult["trainacc"].append(res["train_acc"])
    fullresult["ECE"].append(res["ECE"])
    fullresult["Acc"].append(res["accuracy"])
    fullresult["method"].append(
        ["NaiveSoft", "DS", "GLAD", "WAUM", "Red"] * 10
    )
# %%
from itertools import chain

for t in fullresult:
    fullresult[t] = list(chain(*fullresult[t]))
    if any(isinstance(i, list) for i in fullresult[t]):
        fullresult[t] = list(chain(*fullresult[t]))


# %%
pd.DataFrame(fullresult).groupby(["howmany", "method"]).mean()
# %%
fullresult = pd.DataFrame(fullresult)
# %%
# Expe 3 circles lots of workers
seed = 5


import warnings

warnings.filterwarnings("ignore")


alpha = 0.1

n_classes = 5
X, y = make_classification(
    1500,
    n_classes=5,
    random_state=423,
    n_features=5,
    n_informative=5,
    n_clusters_per_class=1,
    n_redundant=0,
)

# n_classes = 3
# X_small, y_small = make_circles(
#     n_samples=(125, 250), random_state=10, noise=0.1, factor=0.3
# )
# X_large, y_large = make_circles(
#     n_samples=(125, 250), random_state=10, noise=0.1, factor=0.7
# )
# y_large[y_large == 1] = 2
# X, y = np.vstack((X_small, X_large)), np.hstack((y_small, y_large))
rng = np.random.default_rng(seed)
torch.manual_seed(seed)

all_res = []
workers = get_workers(None) * 200
list_howmany = [5]
for how_many in list_howmany:
    rng = np.random.default_rng(seed)
    torch.manual_seed(seed)
    print("## IN", how_many)
    (votes, X_train, X_test, y_train, y_test, scores) = get_votes(
        workers,
        X,
        y,
        test_size=0.3,
        all_workers=how_many,
        mean_per_worker=None,
        nozeros=None,
        train_transform=ToTensor(),
        test_transform=ToTensor(),
    )
    res_meth = {
        "train_acc": [],
        "accuracy": [],
        "ECE": [],
    }
    for rep in range(3):
        waum, y_waum, acc_waum_train = agg_waum(
            X_train, votes.copy(), n_classes, y_train, 0.1, 100
        )
        soft, y_soft, acc_soft_train = agg_soft(
            votes.copy(), n_classes, y_train
        )
        ds, y_ds, acc_ds_train = agg_ds(votes.copy(), n_classes, y_train)
        glad, y_glad, acc_glad_train = agg_glad(
            votes.copy(), n_classes, y_train
        )
        acc_waum, Zt_waum, y_pred_waum, ece_waum, loss_waum = train(
            np.delete(X_train, waum.too_hard, axis=0),
            y_waum,
            X_test,
            y_test,
            n_classes,
            1000,
        )
        acc_soft, Zt_soft, y_pred_soft, ece_soft, loss_soft = train(
            X_train, y_soft, X_test, y_test, n_classes, 1000
        )
        acc_glad, Zt_glad, y_pred_glad, ece_glad, loss_glad = train(
            X_train, y_glad, X_test, y_test, n_classes, 1000
        )
        acc_ds, Zt_ds, y_pred_ds, ece_ds, loss_ds = train(
            X_train, y_ds, X_test, y_test, n_classes, 1000
        )
        res_meth["train_acc"].append(
            [
                acc_soft_train,
                acc_ds_train,
                acc_glad_train,
                acc_waum_train,
            ]
        )
        res_meth["accuracy"].append(
            [
                acc_soft,
                acc_ds,
                acc_glad,
                acc_waum,
            ]
        )
        res_meth["ECE"].append(
            [
                ece_soft,
                ece_ds,
                ece_glad,
                ece_waum,
            ]
        )
        print(res_meth)
    all_res.append(res_meth)
# %%
fullresult = {
    "howmany": [],
    "trainacc": [],
    "ECE": [],
    "Acc": [],
    "method": [],
}
for i, res in enumerate(all_res):
    fullresult["howmany"].append([list_howmany[i]] * 3 * 7)
    fullresult["trainacc"].append(res["train_acc"])
    fullresult["ECE"].append(res["ECE"])
    fullresult["Acc"].append(res["accuracy"])
    fullresult["method"].append(["NaiveSoft", "DS", "GLAD", "WAUM"] * 3)
# %%
from itertools import chain

for t in fullresult:
    fullresult[t] = list(chain(*fullresult[t]))
    if any(isinstance(i, list) for i in fullresult[t]):
        fullresult[t] = list(chain(*fullresult[t]))

fullresult = pd.DataFrame(fullresult)
pd.DataFrame(fullresult).groupby(["howmany", "method"]).mean()

# %%
