#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt


def CIFAR_appendix(
    folder: str,
    samplings: list,
    n_iter: int,
    n_SGD: int,
    lr_local: float,
    batch_size: int,
    n_seeds: int,
    lr_global=1.0,
    decay=1.0,
):

    from py_func.importances import clients_importances
    from plots.plot_functions import get_hist

    def rolling_window(a, window):

        return np.array(
            [a[i: i + window] for i in range(len(a) - window + 1)]
        )

    def one_subplot(
        ax,
        dataset: str,
        weight_type: str,
        lr_local: float,
        n_sampled: int,
        n_clients: int,
    ):
        """
        plot a curve for each sampling.
        a sampling curve is the average of `n_seeds` simulations.
        """

        importances = clients_importances(weight_type, dataset)

        for sampling in samplings:
            sampled_clients = np.zeros(
                (
                    n_iter + 1,
                    len(importances),
                )
            )


            for seed in range(n_seeds):


                hist_seed = get_hist(
                    folder,
                    dataset,
                    sampling,
                    n_iter,
                    n_SGD,
                    batch_size,
                    lr_global,
                    lr_local,
                    n_sampled,
                    weight_type,
                    decay,
                    seed,
                )

                sampled_clients += hist_seed



            sampled_clients /= n_seeds
            hist = np.average(sampled_clients, 1, importances)
            hist_plot = (np.mean(rolling_window(hist, window=5), axis=1))
            ax.plot(hist_plot, label=sampling)

    n_cols = 3
    fig, axes = plt.subplots(n_cols, 1,  figsize=(4, 6))

    list_weights = ["ratio"] * n_cols
    list_lr = [lr_local]* n_cols
    list_m = [10] * 3
    list_dataset = ["CIFAR10_0.1", "CIFAR10_0.01", "CIFAR10_0.001"]
    list_alpha = [0.1, 0.01, 0.001]

    for idx, (weight_type, lr_local, m, dataset, alpha) in enumerate(
        zip(list_weights, list_lr, list_m, list_dataset, list_alpha)
    ):



        ax = axes[idx]

        # PLOT THE AVERAGED SAMPLING CURVES
        one_subplot(ax, dataset, weight_type, lr_local, m, "sf")

        # FORMAT THE SUBPLOT
        ax.set_title("(" + chr(97 + idx) + r") - $\alpha$ = " + f"{alpha}" , pad = 0.)

        ax.set_ylabel(r"$\mathcal{L}(\theta^t)$")

    axes[2].set_xlabel("# rounds")

    fig.legend(
        ax,
        labels=samplings,
        ncol=3,
        bbox_to_anchor=(1.02, 0.042),
    )

    plt.tight_layout(pad = 0. )
    fig.savefig("plots/CIFAR_appendix.pdf", bbox_inches="tight")
    # plt.show()


samplings = ["MD", "Uniform", "Clustered"]
n_iter = 1000
n_SGD = 100
lr_local = 0.05
batch_size = 64
n_seeds = 30

CIFAR_appendix("loss", samplings, n_iter, n_SGD, lr_local, batch_size, n_seeds)
