#!/usr/bin/env python3
# -*- coding: utf-8 -*-


from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import argparse

from numpy.random import uniform
from itertools import product

from plots.plot_functions import dist, get_lemma_1_quadra, one_step, update, sample, dist_opt


parser = argparse.ArgumentParser(description="experiments parameters")

parser.add_argument(
    "--n", type=int, help="number of participating clients", default=10)

parser.add_argument(
    "--m", type=int, help="number of selected clients", default=5)

parser.add_argument("--n_params", type=int,
                    help="dimension of the model", default=20)

parser.add_argument(
    "--bound",
    type=int,
    help="amplitude of the normal for the local minimum creation",
    default=1,
)

parser.add_argument("--eta_l", type=float,
                    help="Local learning rate", default=0.1)

parser.add_argument("--eta_g", type=float,
                    help="Global learning rate", default=1)

parser.add_argument("--K", type=int, help="Number of Local SGD", default=10)

parser.add_argument("--n_draw", type=int,
                    help="Number of simulations", default=1)

parser.add_argument("--T", type=int, help="Number of simulations", default=10)


def one_process(p: float, loc_min: np.array, args):

    losses_hist = {sampling: [] for sampling in args.samplings}

    importances = np.array([p] + [(1 - p) / (args.n - 1)] * (args.n - 1))
    glob_min = importances.dot(loc_min)

    phi = 1 - (1 - args.eta_l) ** args.K
    # print(phi)
    # phi = 1.

    for sampling in args.samplings:

        # print(sampling)
        theta = deepcopy(args.theta_0)
        losses = [dist_opt(theta, glob_min)/2]

        for _ in range(args.T):

            # Sample a set of clients
            sampled_clients = sample(sampling, args.m, importances)
            # print(sampled_clients)

            # Sampled clients perform their work
            updates = phi * (loc_min - theta)
            contribution = sampled_clients.dot(updates)

            # Update theta
            theta += args. eta_g * contribution

            # Compute the loss of the model
            losses.append(dist_opt(theta, glob_min)/2)

        losses_hist[sampling] = losses
        # print(losses)

    return losses_hist


def plot_paper_quadratic(
    file_name: str,
    losses_iid: np.array,
    losses_niid: np.array,
    args,
):
    """
    Combine the theoretical and experimental distances for the niid and iid
    case to create the paper plot.
    """

    plt.figure(figsize=(7, 2))

    # EXP + IID
    plt.subplot(1, 2, 1)
    for sampling in args.samplings:
        plt.plot(losses_iid[sampling], label=sampling)
    plt.yscale("log")
    plt.ylabel(r"${\left\Vert\theta^t - \theta^*\right\Vert}^2$")
    plt.xlabel(r"$t$")
    plt.legend()

    # EXP + NIID
    plt.subplot(1, 2, 2)
    for sampling in args.samplings:
        plt.plot(losses_niid[sampling], label=sampling)
    plt.yscale("log")
    plt.xlabel(r"$t$")

    plt.tight_layout(pad=0)
    plt.savefig(f"plots/{file_name}.pdf")

    plt.show()


if __name__ == "__main__":

    class args:
        pass

    parser.parse_args(namespace=args)

    print("Number of clients:", args.n, "\nNumber of sampled clients:", args.m)

    # CREATE THE CLIENTS' LOCAL MINIMIA
    np.random.seed(1)
    loc_min_niid = uniform(-args.bound, args.bound,
                           size=(args.n, args.n_params))
    loc_min_iid = np.tile(loc_min_niid[0], (args.n, 1))

    # INITIAL MODEL FL STARTS FROM
    args.theta_0 = uniform(-args.bound, args.bound, size=(1, args.n_params))

    # ONE STEP EXPECTED IMPROVEMENT NIID AND IID
    args.samplings = ["Full", "MD", "Uniform"]

    p = 0.9
    print(args.eta_l)
    losses_niid = one_process(p, loc_min_niid, args)
    losses_iid = one_process(p, loc_min_iid, args)
    # losses_iid = losses_niid

    # EXPECTED LEARNING
    file_name = f"full_convergence"

    plot_paper_quadratic(
        file_name,
        losses_iid,
        losses_niid,
        args,
    )
