# %%
import numpy as np
import torch
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from peerannot.models.Soft import Soft
from peerannot.models.GLAD import GLAD
from peerannot.models.WAUM import WAUM
from sklearn.datasets import (
    make_circles,
    make_classification,
    make_moons,
)
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.svm import SVC, LinearSVC
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.cm as cm
from sklearn.datasets import load_wine
import seaborn as sns

plt.rcParams.update({"axes.titlesize": "xx-large"})
plt.rcParams.update({"axes.labelsize": "xx-large"})
plt.rcParams["text.usetex"] = True
plt.rcParams[
    "text.latex.preamble"
] = r"\usepackage{amsmath,amsfonts,amsthm,amssymb}"
rng = np.random.default_rng(1)
torch.manual_seed(1)

# %%


class Toy_dataset(torch.utils.data.Dataset):
    def __init__(self, tasks, truth, transform=None, target_transform=None):
        self.transform = transform
        self.tasks = tasks
        self.truth = truth
        self.target_transform = target_transform

    def __len__(self):
        return self.tasks.shape[0]

    def __getitem__(self, idx):
        image = self.tasks[idx].reshape(-1, 1, 1)
        label = self.truth[idx]
        if self.transform:
            image = self.transform(image).type(torch.FloatTensor)
        if self.target_transform:
            label = self.target_transform(label)
        return image, "lab", label, idx


class model(nn.Module):
    def __init__(self, input_dim=2, output_dim=2):
        super(model, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.linear = torch.nn.Linear(input_dim, 30)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(30, 20)
        self.fc2 = nn.Linear(20, output_dim)

    def forward(self, x):
        out = self.linear(x.reshape(-1, self.input_dim))
        out = self.relu(out)
        out = self.fc1(out)
        out = self.relu(out)
        outputs = self.fc2(out)
        return outputs


def get_votes(
    workers,
    X,
    y,
    test_size=0.3,
    all_workers=True,
    train_transform=ToTensor(),
    test_transform=ToTensor(),
):
    X_train_bg, X_test, y_train_bg, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    toy_ds_train = Toy_dataset(X_train_bg, y_train_bg, transform=ToTensor())
    toy_ds_test = Toy_dataset(X_test, y_test, transform=test_transform)
    votes = {i: {} for i in range(X_train_bg.shape[0])}
    for j in range(len(workers)):
        X_train, y_train = X_train_bg, y_train_bg
        worker = workers[j]
        worker.fit(X_train, y_train)
        score = worker.score(X_test, y_test)
        print(worker, f"{score:.3f}")
    for i in range(X_train_bg.shape[0]):
        n_ans = (
            len(workers)
            if all_workers is True
            else rng.choice(range(1, all_workers + 1))
        )
        who = rng.choice(range(len(workers)), size=n_ans, replace=False)
        for j in range(n_ans):
            worker = workers[who[j]]
            pred_w = worker.predict(X_train_bg[i].reshape(1, -1))
            votes[i][str(who[j])] = pred_w[0]
    X_train_bg, X_test, y_train_bg, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    return (
        votes,
        X_train_bg,
        X_test,
        y_train_bg,
        y_test,
        toy_ds_train,
        toy_ds_test,
    )


def get_comparison(votes, y_train_bg, output_dim=2):
    soft = Soft(answers=votes, n_classes=output_dim)
    y_soft = soft.get_probas()
    print("Soft", np.mean(np.argmax(y_soft, axis=1) == y_train_bg))
    glad = GLAD(
        n_classes=output_dim,
        answers=votes,
    )
    glad.run_em(epsilon=1e-6, maxiter=100)
    y_glad = glad.get_probas()
    print("GLAD", np.mean(np.argmax(y_glad, axis=1) == y_train_bg))
    return y_soft, soft, y_glad, glad


@torch.no_grad()
def predict(dataloader, model):
    model.eval()
    predictions = np.array([])
    for x_batch, _, _, _ in dataloader:
        outp = model(x_batch)
        preds = torch.argmax(outp, axis=1)
        predictions = np.hstack((predictions, preds.numpy().flatten()))
    predictions = predictions
    return predictions.flatten()


def get_waum(
    toydataset,
    votes,
    y_train_bg,
    alpha=0.1,
    n_iter=30,
    lr=0.1,
    input_dim=2,
    output_dim=2,
):
    simplenet = model(input_dim, output_dim)
    waum = WAUM(
        toydataset,
        votes,
        output_dim,
        simplenet,
        torch.nn.CrossEntropyLoss(),
        torch.optim.SGD(
            simplenet.parameters(),
            lr=lr,
        ),
        n_iter,
        DEVICE="cpu",
    )
    waum.run(alpha=alpha)
    y_waum = waum.get_probas()
    print(
        "WAUM",
        np.mean(
            np.argmax(y_waum, axis=1)
            == np.delete(y_train_bg, waum.too_hard, axis=0)
        ),
    )
    return waum, y_waum


def train_plot(
    workers,
    X_train_bg,
    y_train_bg,
    X_test,
    y_test,
    waum,
    y_waum,
    y_glad,
    y_soft,
    name,
    toy_ds_train,
    toy_ds_test,
    input_dim,
    output_dim,
    train_transform=ToTensor(),
    votes=None,
    lr=0.1,
    n_epoch=1000,
    save=False,
    plot=True,
):
    waum_dataset = Toy_dataset(
        np.delete(X_train_bg, waum.too_hard, axis=0),
        np.delete(y_train_bg, waum.too_hard, axis=0),
        transform=train_transform,
    )
    accuracy = []
    criterion = torch.nn.CrossEntropyLoss()
    methods = [
        "Aggregation = Naïve soft",
        "Aggregation = GLAD",
        "Aggregation = WAUM",
        # "Aggregation = Redundancy",
    ]

    # plot the true dataset and workers responses
    fig, ax = plt.subplots(
        1, len(workers) + 1, figsize=(10, 5)
    )  # , sharex=True, sharey=True
    for a, tt in enumerate(ax):
        tt.tick_params(
            left=False,
            right=False,
            labelleft=False,
            labelbottom=False,
            bottom=False,
        )
        # tt.axis("image")
        tt.set_aspect("equal")
    ax[0].scatter(X_train_bg[:, 0], X_train_bg[:, 1], c=y_train_bg, s=10)
    ax[0].set_title(r"Ground truth $\mathcal{D}_{\text{train}}$")
    for id_worker, worker in enumerate(workers):
        ax[id_worker + 1].scatter(
            X_train_bg[:, 0],
            X_train_bg[:, 1],
            c=worker.predict(X_train_bg),
            s=10,
        )
        ax[id_worker + 1].set_title(worker.__class__.__name__)
    plt.tight_layout()
    if save:
        plt.savefig(f"binary_workers_workers_{name}.pdf")
    if plot:
        plt.show()

    # plot soft and glad
    for id_method, (method, labels) in enumerate(
        zip(methods, [y_soft, y_glad, y_waum])
    ):
        simplenet = model(input_dim, output_dim)
        optimizer = torch.optim.SGD(simplenet.parameters(), lr=lr)
        if method == "Aggregation = WAUM":
            waum_dataset.transform = train_transform
            train_dataloader = torch.utils.data.DataLoader(
                waum_dataset, batch_size=128
            )
        elif method == "Aggregation = Redundancy":
            Xred, yred = np.array([1, 1]), np.array([])
            for i in range(X_train_bg.shape[0]):
                task = votes[i]
                for j, ans in task.items():
                    Xred = np.vstack((Xred, X_train_bg[i]))
                    yred = np.hstack((yred, ans))
            Xred = Xred[1:, :]
            dataset = Toy_dataset(Xred, yred, transform=train_transform)
            train_dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=128
            )
        else:
            toy_ds_train.transform = train_transform
            train_dataloader = torch.utils.data.DataLoader(
                toy_ds_train, batch_size=128
            )
        test_dataloader = torch.utils.data.DataLoader(
            toy_ds_test, shuffle=False
        )
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[500], gamma=0.1
        )
        for epoch in range(n_epoch):
            for x, yy, _, id_ in train_dataloader:
                if method != "Aggregation = Redundancy":
                    y = torch.Tensor(labels[id_]).type(torch.FloatTensor)
                else:
                    y = torch.Tensor(yred[id_]).type(torch.long)

                optimizer.zero_grad()
                output = simplenet(x)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()
            scheduler.step()
        preds = predict(test_dataloader, simplenet)
        print(method, np.mean(y_test == preds))
        accuracy.append(np.mean(y_test == preds))
        Zt = simplenet(torch.Tensor(X_test))
        Zt = (
            torch.argmax(Zt, axis=1).reshape(-1, 1).numpy()
            if output_dim > 2
            else Zt[:, 1].detach().numpy()
        )

        if id_method != 2:
            fig, ax = plt.subplots(
                1, 2, figsize=(10, 5), sharex=True, sharey=True
            )  # , sharex=True, sharey=True
            for a, tt in enumerate(ax):
                tt.tick_params(
                    left=False,
                    right=False,
                    labelleft=False,
                    labelbottom=False,
                    bottom=False,
                )
                # tt.axis("image")
                tt.set_aspect("equal")
        else:
            fig, ax = plt.subplots(
                1, 2, figsize=(10, 5), sharex=True, sharey=True
            )
            # , sharex=True, sharey=True
            for a, tt in enumerate(ax):
                tt.tick_params(
                    left=False,
                    right=False,
                    labelleft=False,
                    labelbottom=False,
                    bottom=False,
                )
                # tt.set_aspect("equal")

        # soft
        if id_method == 0:
            yplot = (
                y_soft[:, 1] if output_dim <= 2 else np.argmax(y_soft, axis=1)
            )
            im = ax[0].scatter(
                X_train_bg[:, 0],
                X_train_bg[:, 1],
                c=yplot,
            )
            ax[1].scatter(X_test[:, 0], X_test[:, 1], c=Zt)
            ax[0].set_title(r"$\mathcal{D}_{\text{train}}$")
            ax[0].set_ylabel(method)
            ax[1].set_title(r"Predictions")

            ax[1].scatter(X_test[:, 0], X_test[:, 1], c=Zt)
            ax[0].set_title(r"$\mathcal{D}_{\text{train}}$")
            ax[0].set_ylabel(method)
            ax[1].set_title(r"Predictions")
        # glad
        elif id_method == 1:
            yplot = (
                y_glad[:, 1] if output_dim <= 2 else np.argmax(y_glad, axis=1)
            )
            im = ax[0].scatter(
                X_train_bg[:, 0],
                X_train_bg[:, 1],
                c=yplot,
            )
            ax[1].scatter(X_test[:, 0], X_test[:, 1], c=Zt)
            ax[0].set_title(r"$\mathcal{D}_{\text{train}}$")
            ax[0].set_ylabel(method)
            ax[1].set_title(r"Predictions")

        elif id_method == 3:
            yplot = (
                y_soft[:, 1] if output_dim <= 2 else np.argmax(y_soft, axis=1)
            )

            im = ax[0].scatter(
                X_train_bg[:, 0],
                X_train_bg[:, 1],
                c=yplot,
            )
            ax[1].scatter(X_test[:, 0], X_test[:, 1], c=Zt)
            ax[0].set_title(r"$\mathcal{D}_{\text{train}}$")
            ax[0].set_ylabel(method)
            ax[1].set_title(r"Predictions")
        # waum
        else:
            ax[0].scatter(
                X_train_bg[waum.too_hard, 0],
                X_train_bg[waum.too_hard, 1],
                c="red",
                s=200,
                marker="s",
            )
            yplot = (
                y_waum[:, 1] if output_dim <= 2 else np.argmax(y_waum, axis=1)
            )

            im = ax[0].scatter(
                np.delete(X_train_bg, waum.too_hard, axis=0)[:, 0],
                np.delete(X_train_bg, waum.too_hard, axis=0)[:, 1],
                c=yplot,
            )
            ax[0].set_title(r"$\mathcal{D}_{\text{train}}$")
            ax[0].set_ylabel(method)
            ax[1].set_title(r"Predictions")
            ax[1].scatter(X_test[:, 0], X_test[:, 1], c=Zt)
            plt.tight_layout()
            fig.colorbar(im, ax=ax.ravel().tolist())
            if save:
                plt.savefig(f"binary_workers_pt{id_method+1}_{name}.pdf")
            if plot:
                plt.show()

            fig, ax = plt.subplots(1, 2, figsize=(10, 5))
            for a, tt in enumerate(ax):
                tt.tick_params(
                    left=False,
                    right=False,
                    labelleft=False,
                    labelbottom=False,
                    bottom=False,
                )
                # tt.set_aspect
            g = sns.kdeplot(
                list(waum.waum.values()),
                ax=ax[0],
            )
            g.vlines(
                x=[waum.quantile],
                ymin=0,
                ymax=max(list(waum.waum.values())),
                colors=["red"],
                ls="--",
            )
            plot_waum = ax[1].scatter(
                X_train_bg[:, 0],
                X_train_bg[:, 1],
                c=list(waum.waum.values()),
                cmap=cm.plasma_r,
            )
            plt.colorbar(plot_waum)
            ax[0].set_xlabel("")
            ax[0].set_ylabel("")
            ax[1].set_title(r"$\mathrm{WAUM}\ \mathcal{D}_{\text{train}}$")
            ax[0].set_title(r"Density $\mathrm{WAUM}$")
            plt.tight_layout()
            if save:
                plt.savefig(
                    f"binary_workers_pt{id_method+1}_density_{name}.pdf"
                )
            if plot:
                plt.show()
        if id_method != 2:
            plt.tight_layout()
            fig.colorbar(im, ax=ax.ravel().tolist())
            if save:
                plt.savefig(f"binary_workers_pt{id_method+1}_{name}.pdf")
            if plot:
                plt.show()
    plt.show()
    return accuracy


def get_workers():
    workers = [
        # LogisticRegression(random_state=1, max_iter=1),
        LinearSVC(random_state=1),
        # KNeighborsClassifier(n_neighbors=3),
        SVC(random_state=1, max_iter=1),
        # SVC(random_state=1, kernel="poly"),
        GradientBoostingClassifier(n_estimators=5, random_state=1),
    ]
    return workers


# %%
def run(
    X,
    y,
    name,
    output_dim=2,
    input_dim=2,
    save=False,
    all_workers=True,
    train_transform=ToTensor(),
    test_transform=ToTensor(),
    plot=True,
    alpha=0.1,
):
    workers = get_workers()
    (
        votes,
        X_train_bg,
        X_test,
        y_train_bg,
        y_test,
        toy_ds_train,
        toy_ds_test,
    ) = get_votes(
        workers,
        X,
        y,
        test_size=0.3,
        all_workers=all_workers,
        train_transform=train_transform,
        test_transform=test_transform,
    )
    y_soft, soft, y_glad, glad = get_comparison(votes, y_train_bg, output_dim)
    waum, y_waum = get_waum(
        toy_ds_train,
        votes,
        y_train_bg,
        alpha=alpha,
        n_iter=100,
        lr=0.1,
        input_dim=input_dim,
        output_dim=output_dim,
    )
    acc = train_plot(
        workers,
        X_train_bg,
        y_train_bg,
        X_test,
        y_test,
        waum,
        y_waum,
        y_glad,
        y_soft,
        name,
        toy_ds_train,
        toy_ds_test,
        input_dim,
        output_dim,
        votes=votes,
        train_transform=train_transform,
        save=save,
        plot=plot,
    )
    return (
        workers,
        votes,
        toy_ds_train,
        toy_ds_test,
        y_soft,
        soft,
        y_glad,
        glad,
        y_waum,
        waum,
        acc,
    )


# %%
from sklearn.datasets import load_wine
import seaborn as sns

wine = load_wine()
datasets = [
    make_circles(n_samples=500, noise=0.1, factor=0.5, random_state=1),
    make_classification(
        n_samples=500,
        n_features=2,
        n_redundant=0,
        n_informative=2,
        random_state=0,
        n_clusters_per_class=1,
    ),
    make_moons(n_samples=500, noise=0.2, random_state=0),
    (wine.data[:, [0, 9]], wine.target),
]

# %%
res_circles = run(
    *make_circles(n_samples=500, noise=0.2, factor=0.4, random_state=1),
    "circles_up_4_votes",
    save=False,
    all_workers=4,
    alpha=0.01,
)

# %%
alphas = [0.001, 0.01, 0.1, 0.25, 0.5]
all_acc = {al: [] for al in alphas}
for alpha in alphas:
    for rep in range(10):
        res_circles = run(
            *make_circles(
                n_samples=500, noise=0.2, factor=0.4, random_state=1
            ),
            "circles",
            save=False,
            plot=False,
            alpha=alpha,
        )
        all_acc[alpha].append(res_circles[-1][2])
print(all_acc)
for key, val in all_acc.items():
    print(key, np.mean(val), np.std(val))
# %% ambihuity artificial
res_circles = run(
    *make_circles(n_samples=500, noise=0.2, factor=0.4, random_state=1),
    "circles",
    save=False,
    plot=True,
    alpha=0.01,
)

# %%

all_acc = {"soft": [], "glad": [], "waum": []}
for rep in range(10):
    res_circles = run(
        *make_circles(n_samples=500, noise=0.2, factor=0.4, random_state=1),
        "circles",
        save=False,
        plot=False,
        alpha=0.01,
    )
    all_acc["soft"].append(res_circles[-1][0])
    all_acc["glad"].append(res_circles[-1][1])
    all_acc["waum"].append(res_circles[-1][2])
print(all_acc)
for key, val in all_acc.items():
    print(key, np.mean(val), np.std(val))


# %%

import pandas as pd

answers = res_circles[1]
glad = res_circles[7]
waum = res_circles[-1]
sns.color_palette("tab10")
df = {
    "beta": glad.beta,
    "n_votes": [len(answers[ans]) for ans in answers],
}
df = pd.DataFrame(df)
df = df.sort_values(by="beta")
answ_1 = {
    task: value for task, value in answers.items() if len(value.keys()) == 1
}
plt.figure()
sns.scatterplot(
    x=range(0, len(df.beta)),
    y=df.beta,
    data=df,
    hue="n_votes",
    sizes=(400, 40),
    palette="tab10",
    linewidth=0,
)
plt.ylabel(r"$\log(\hat\beta)$")
plt.xlabel("Sorted task index")
plt.legend(title=r"$|\mathcal{A}(x)|$", loc="upper left")
# plt.savefig("circles_difficulty_by_nb_votes.pdf")
plt.show()

df = pd.DataFrame({"waum": waum.waum})
df = df.sort_values(by=["waum"])
plt.scatter(
    res_circles[2].tasks[:, 0],
    res_circles[2].tasks[:, 1],
    c=res_circles[2].truth,
)
for i in range(5):
    plt.scatter(
        res_circles[2].tasks[df.index[i]][0],
        res_circles[2].tasks[df.index[i]][1],
        c="red",
        s=200,
    )
    plt.scatter(
        res_circles[2].tasks[np.argsort(glad.beta)[i]][0],
        res_circles[2].tasks[np.argsort(glad.beta)[i]][1],
        c="green",
        s=200,
    )
plt.show()
# %%
votes = res_circles[1]
waum = res_circles[-1]
# for i in range(len(waum.too_hard)):
#     print(i, len(votes[waum.too_hard[i]]))

# make the plot of beta vs nb of votes and waum vs number of votes

# %% W AND WO data augmentation


class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + "(mean={0}, std={1})".format(
            self.mean, self.std
        )


all_acc = {"soft": [], "glad": [], "waum": []}
all_acc_aug = {"soft": [], "glad": [], "waum": []}

data = make_circles(n_samples=500, noise=0.2, factor=0.4, random_state=1)
for seed in range(3):
    print("########", seed)
    rng = np.random.default_rng(seed)
    torch.manual_seed(seed)
    res_circles_1 = run(
        *data,
        "circles_4_votes",
        save=False,
        all_workers=True,
        plot=False,
        alpha=0.01,
    )
    all_acc["soft"].append(res_circles_1[-1][0])
    all_acc["glad"].append(res_circles_1[-1][1])
    all_acc["waum"].append(res_circles_1[-1][2])

    train_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.RandomApply([AddGaussianNoise(0.0, 0.5)], p=0.75),
        ]
    )
    test_transform = transforms.ToTensor()
    torch.manual_seed(seed)
    rng = np.random.default_rng(seed)
    res_circles_2 = run(
        *data,
        "circles_votes_aug",
        save=False,
        all_workers=True,
        train_transform=train_transform,
        test_transform=test_transform,
        plot=False,
        alpha=0.01,
    )
    all_acc_aug["soft"].append(res_circles_2[-1][0])
    all_acc_aug["glad"].append(res_circles_2[-1][1])
    all_acc_aug["waum"].append(res_circles_2[-1][2])

print(all_acc)
print(all_acc_aug)
# %%
for dic in [all_acc, all_acc_aug]:
    for key, val in dic.items():
        print(key, np.mean(val), np.std(val))


# %% not very ambiguous
X, y = make_classification(
    n_samples=500,
    n_features=2,
    n_redundant=0,
    n_informative=2,
    random_state=0,
    n_clusters_per_class=2,
    flip_y=0.0,
    class_sep=1.5,
    hypercube=True,
)
all_acc = {"soft": [], "glad": [], "waum": []}
for rep in range(3):
    res_gmm = run(
        X,
        y,
        "gmm",
        save=False,
        plot=False,
        alpha=0.1,
    )
    all_acc["soft"].append(res_gmm[-1][0])
    all_acc["glad"].append(res_gmm[-1][1])
    all_acc["waum"].append(res_gmm[-1][2])
print(all_acc)
for key, val in all_acc.items():
    print(key, np.mean(val), np.std(val))

# %% ambiguity necessary
X, y = make_moons(n_samples=500, noise=0.2, random_state=0)
all_acc = {"soft": [], "glad": [], "waum": []}
res_moons = run(
    X,
    y,
    "moons",
    save=True,
    plot=False,
    alpha=0.1,
)

for rep in range(10):
    res_moons = run(
        X,
        y,
        "moons",
        save=False,
        plot=False,
        alpha=0.1,
    )
    all_acc["soft"].append(res_moons[-1][0])
    all_acc["glad"].append(res_moons[-1][1])
    all_acc["waum"].append(res_moons[-1][2])
print(all_acc)
for key, val in all_acc.items():
    print(key, np.mean(val), np.std(val))

# %%

X_small, y_small = make_circles(
    n_samples=(125, 250), random_state=10, noise=0.2, factor=0.3
)
X_large, y_large = make_circles(
    n_samples=(125, 250), random_state=10, noise=0.2, factor=0.7
)
y_large[y_large == 1] = 2
X, y = np.vstack((X_small, X_large)), np.hstack((y_small, y_large))


seed = 5
rng = np.random.default_rng(seed)
torch.manual_seed(seed)
res_3circles = run(
    X, y, "3circles", save=False, plot=False, alpha=1e-2, output_dim=3
)
rng = np.random.default_rng(seed)
torch.manual_seed(seed)

res_3circles = run(
    X, y, "3circles", save=False, plot=False, alpha=1e-1, output_dim=3
)
# %%
X_small, y_small = make_circles(
    n_samples=(125, 250), random_state=0, noise=0.1, factor=0.3
)
X_large, y_large = make_circles(
    n_samples=(125, 250), random_state=0, noise=0.1, factor=0.7
)
y_large[y_large == 1] = 2
X, y = np.vstack((X_small, X_large)), np.hstack((y_small, y_large))

all_acc = {"soft": [], "glad": [], "waum": []}
for rep in range(3):
    rng = np.random.default_rng(rep + 5)
    torch.manual_seed(rep + 5)
    res_3circles = run(
        X, y, "3circles", save=False, plot=False, alpha=0.01, output_dim=3
    )
    all_acc["soft"].append(res_3circles[-1][0])
    all_acc["glad"].append(res_3circles[-1][1])
    all_acc["waum"].append(res_3circles[-1][2])
print(all_acc)

for key, val in all_acc.items():
    print(key, np.mean(val), np.std(val))

# %%
X_small, y_small = make_circles(
    n_samples=(125, 250), random_state=42, noise=0.1, factor=0.2
)
X_large, y_large = make_circles(
    n_samples=(125, 250), random_state=1, noise=0.1, factor=0.6
)
y_large[y_large == 1] = 2
X, y = np.vstack((X_small, X_large)), np.hstack((y_small, y_large))

alphas = [0.001, 0.01, 0.1, 0.25, 0.5]
all_acc = {al: [] for al in alphas}
for alpha in alphas:
    for rep in range(3):
        res = run(
            X, y, "3circles", save=False, plot=False, alpha=alpha, output_dim=3
        )
        all_acc[alpha].append(res[-1][2])
print(all_acc)
for key, val in all_acc.items():
    print(key, np.mean(val), np.std(val))


# %%
waum = res_circles[-2]
glad = res_circles[-4]
X_train_bg = res_circles[2].tasks
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
for a, tt in enumerate(ax):
    tt.tick_params(
        left=False,
        right=False,
        labelleft=False,
        labelbottom=False,
        bottom=False,
    )
plot_glad = ax[0].scatter(
    X_train_bg[:, 0],
    X_train_bg[:, 1],
    c=np.exp(glad.beta) / np.max(np.exp(glad.beta)),
    cmap=cm.plasma_r,
)
plot_waum = ax[1].scatter(
    X_train_bg[:, 0],
    X_train_bg[:, 1],
    c=list(waum.waum.values()) / np.max(list(waum.waum.values())),
    cmap=cm.plasma_r,
)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(plot_waum, cax=cbar_ax)
plt.show()

# %%
X, y = make_classification(
    n_samples=500,
    n_features=2,
    n_redundant=0,
    n_informative=2,
    random_state=0,
    n_clusters_per_class=1,
    flip_y=0.0,
    class_sep=1,
    hypercube=True,
    n_classes=4,
)
res_gmm4 = run(X, y, "gmm4", save=False, plot=False, alpha=0.1, output_dim=4)

# %%
from sklearn.datasets import make_blobs

X, y_true = make_blobs(
    n_samples=500, centers=4, cluster_std=2, random_state=56
)
res_gmm4 = run(
    X,
    y_true,
    "blobs",
    save=False,
    plot=True,
    alpha=0.1,
    output_dim=4,
    all_workers=True,
)
waum = res_gmm4[-2]
glad = res_gmm4[-4]
X_train_bg = res_gmm4[2].tasks
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
for a, tt in enumerate(ax):
    tt.tick_params(
        left=False,
        right=False,
        labelleft=False,
        labelbottom=False,
        bottom=False,
    )
plot_glad = ax[0].scatter(
    X_train_bg[:, 0],
    X_train_bg[:, 1],
    c=np.exp(glad.beta) / np.max(np.exp(glad.beta)),
    cmap=cm.plasma_r,
)
plot_waum = ax[1].scatter(
    X_train_bg[:, 0],
    X_train_bg[:, 1],
    c=list(waum.waum.values()) / np.max(list(waum.waum.values())),
    cmap=cm.plasma_r,
)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(plot_waum, cax=cbar_ax)
plt.show()


# %%
