#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import numpy as np
import matplotlib.pyplot as plt
import argparse

from numpy.random import uniform
from itertools import product

from plots.plot_functions import dist, get_lemma_1_quadra, one_step


parser = argparse.ArgumentParser(description="experiments parameters")

parser.add_argument(
    "--n", type=int, help="number of participating clients", default=10)

parser.add_argument(
    "--m", type=int, help="number of selected clients", default=5)

parser.add_argument("--n_params", type=int,
                    help="dimension of the model", default=20)

parser.add_argument(
    "--bound",
    type=int,
    help="amplitude of the normal for the local minimum creation",
    default=1,
)

parser.add_argument("--eta_l", type=float,
                    help="Local learning rate", default=0.1)

parser.add_argument("--eta_g", type=float,
                    help="Global learnign rate", default=1)

parser.add_argument("--K", type=int, help="Number of Local SGD", default=10)

parser.add_argument("--n_draw", type=int,
                    help="Number of simulations", default=1000)


def Decompo_theo(list_p: np.array, loc_min: np.array, args):
    """
    Returns the expected theoretical distance between the global model
    after one FL optimization step and the FL optimum.
    The distance is calculated with the Decomposition Theorem with calculus
    in Appendix.
    """

    theo_distances = {sampling: [] for sampling in args.samplings}

    for p, sampling in product(list_p, args.samplings):

        importances = np.array([p] + [(1 - p) / (args.n - 1)] * (args.n - 1))
        glob_min = importances.dot(loc_min)

        if sampling == "Initial":
            theo_distances[sampling].append(dist(args.theta_0, glob_min))

        else:
            theo_distances[sampling].append(
                get_lemma_1_quadra(
                    sampling, importances, loc_min, glob_min, glob_min, args
                )
            )

    return theo_distances


def practical_step(list_p: np.array, loc_min: np.array, args):
    """
    Run `args.n_draw` optimization step initialized from `args.theta0`
    where the only source of randomness comes from sampling clients
    and return the averaged distance.
    """

    exp_distances = {sampling: [] for sampling in args.samplings}

    for p, sampling in product(list_p, args.samplings):

        importances = np.array([p] + [(1 - p) / (args.n - 1)] * (args.n - 1))
        glob_min = importances.dot(loc_min)

        if sampling == "Full":

            v = one_step(sampling, importances, loc_min, glob_min, 1, args)

        else:
            v = [
                one_step(sampling, importances, loc_min, glob_min, 1, args)
                for _ in range(args.n_draw)
            ]

        exp_distances[sampling].append(v)

    return exp_distances


def plot_quadra_Decompo_theo(
    file_name: str,
    list_p: list,
    theo_dist_niid,
    theo_dist_iid,
    list_p_2,
    exp_dist_niid,
    exp_dist_iid,
    args,
):
    """
    Combine the theoretical and experimental distances for the niid and iid
    case to create the paper plot.
    """

    plt.figure(figsize=(9, 2.5))

    # THEO + IID
    plt.subplot(1, 4, 1)
    for sampling in args.samplings:
        plt.plot(list_p, theo_dist_iid[sampling], label=sampling)

    plt.ylim(1.5, 4)
    plt.xlabel(r"$p_1$")
    plt.ylabel(r"${\left\Vert\theta^1 - \theta^*\right\Vert}^2$")
    plt.title("(a)")
    plt.legend()

    # THEO + NIID
    plt.subplot(1, 4, 2)
    for sampling in args.samplings:
        plt.plot(list_p, theo_dist_niid[sampling], label=sampling)
    plt.ylim(1, 4)
    plt.xlabel(r"$p_1$")
    plt.title("(b)")

    # EXP + IID
    plt.subplot(1, 4, 3)
    for sampling in args.samplings:

        mean = np.mean(exp_dist_iid[sampling], axis=1)
        plt.plot(list_p_2, mean, label=sampling)

    plt.ylim(1.5, 4)
    plt.xlabel(r"$p_1$")
    plt.title("(c)")

    # EXP + NIID
    plt.subplot(1, 4, 4)
    for sampling in args.samplings:

        mean = np.mean(exp_dist_niid[sampling], axis=1)
        plt.plot(list_p_2, mean, label=sampling)
    plt.ylim(1, 4)
    plt.xlabel(r"$p_1$")
    plt.title("(d)")

    plt.tight_layout(pad=0)
    plt.savefig(f"plots/{file_name}.pdf")


def plot_variance_simulations(
    file_name: str,
    list_p: list,
    exp_dist_niid: np.array,
    exp_dist_iid: np.array,
    args,
):
    """
    Combine the theoretical and experimental distances for the niid and iid
    case to create the paper plot.
    """

    v_alpha = 0.5  # transparency factor

    # plt.figure( figsize=(9, 9))
    fig, ax = plt.subplots(2, 2)

    # EXPERIMENTAL + VARIANCE + IID
    for sampling in args.samplings:
        mean = np.mean(exp_dist_iid[sampling], axis=1).reshape(-1)
        std = np.std(exp_dist_iid[sampling], axis=1).reshape(-1)

        ax[0, 0].plot(list_p, mean, label=sampling)
        ax[0, 0].fill_between(list_p, mean - std, mean + std, alpha=v_alpha)

    # EXPERIMENTAL + VARIANCE + NIID
    for sampling in args.samplings:
        mean = np.mean(exp_dist_niid[sampling], axis=1).reshape(-1)
        std = np.std(exp_dist_niid[sampling], axis=1).reshape(-1)

        ax[0, 1].plot(list_p, mean, label=sampling)
        ax[0, 1].fill_between(list_p, mean - std, mean + std, alpha=v_alpha)

    # EXPERIMENTAL + MIN-MAX + IID
    for sampling in args.samplings:
        mean = np.mean(exp_dist_iid[sampling], axis=1).reshape(-1)
        dist_min = np.min(exp_dist_iid[sampling], axis=1).reshape(-1)
        dist_max = np.max(exp_dist_iid[sampling], axis=1).reshape(-1)

        ax[1, 0].plot(list_p, mean, label=sampling)
        ax[1, 0].fill_between(list_p, dist_min, dist_max, alpha=v_alpha)

    # EXPERIMENTAL + MIN-MAX + NIID
    for sampling in args.samplings:
        mean = np.mean(exp_dist_niid[sampling], axis=1).reshape(-1)
        dist_min = np.min(exp_dist_niid[sampling], axis=1).reshape(-1)
        dist_max = np.max(exp_dist_niid[sampling], axis=1).reshape(-1)

        ax[1, 1].plot(list_p, mean, label=sampling)
        ax[1, 1].fill_between(list_p, dist_min, dist_max, alpha=v_alpha)

    ax[0, 0].legend()
    for i, j in product(range(2), range(2)):
        ax[i, j].set_ylim(0, 8)
        ax[i, j].set_title("("+chr(97 + 2*i + j)+")")
        if i == 1:
            ax[i, j].set_xlabel(r"$p_1$")
        if j == 0:
            ax[i, j].set_ylabel(
                r"${\left\Vert\theta^1 - \theta^*\right\Vert}^2$")

    plt.tight_layout(pad=0)
    plt.savefig(f"plots/{file_name}.pdf")


if __name__ == "__main__":

    class args:
        pass

    parser.parse_args(namespace=args)

    print("Number of clients:", args.n, "\nNumber of sampled clients:", args.m)

    # CREATE THE CLIENTS' LOCAL MINIMIA
    np.random.seed(1)
    loc_min_niid = uniform(-args.bound, args.bound,
                           size=(args.n, args.n_params))
    loc_min_iid = np.tile(loc_min_niid[0], (args.n, 1))

    # INITIAL MODEL FL STARTS FROM
    args.theta_0 = uniform(-args.bound, args.bound, size=(1, args.n_params))
    list_p = np.linspace(0, 1, 200)

    # ONE STEP EXPECTED IMPROVEMENT NIID AND IID
    args.samplings = ["Full", "MD", "Uniform"]

    theo_dist_niid = Decompo_theo(list_p, loc_min_niid, args)
    theo_dist_iid = Decompo_theo(list_p, loc_min_iid, args)

    # ONE STEP EXPERIMENTAL IMPROVEMENT AVERAGED OVER MANY RUNS NIID AND IID
    exp_dist_niid = practical_step(list_p, loc_min_niid, args)
    exp_dist_iid = practical_step(list_p, loc_min_iid, args)

    # PLOT THE FIGURE FOR THE PAPER WITH MEAN FOR PRACTICAL
    file_name = f"one_step_mean_{args.n_draw}"
    plot_quadra_Decompo_theo(
        file_name,
        list_p,
        theo_dist_niid,
        theo_dist_iid,
        list_p,
        exp_dist_niid,
        exp_dist_iid,
        args,
    )

    # PLOT FIGURE APPENDIX FOR EXPERIMENTAL INCLUDING VARIANCE
    file_name = f"one_step_variance_{args.n_draw}"
    plot_variance_simulations(
        file_name,
        list_p,
        exp_dist_niid,
        exp_dist_iid,
        args,
    )
