from tqdm import tqdm
from joblib import Parallel, delayed

from math import sqrt
import numpy as np

import matplotlib.pyplot as plt


from Tools.compute_optimal_polynomial import compute_optimal_polynomial


def compute_eigen_rate(sqrt_m,
                       h_list,
                       eig,
                       ):
    """
    Compute the value of the polynomial in eig corresponding to the tuning \sqrt{m} and h_list

    Args:
        sqrt_m: (float) the square root of the momentum parameter, also the rate if the corresponding polynomial supremum on \Lambda is smaller than 1
        h_list: (list of floats) list of the step sizes
        eig: (float) the input on which we evaluate the polynomial, also a possible eigenvalue of the hessian

    Returns: (float) evaluation of the over mentioned polynomial on eig

    """

    # The rate only depends on M and not m
    M = (sqrt_m + 1/sqrt_m)/2

    # Compute the k steps transitions matrix
    A = np.matrix(np.eye(2))

    for h in h_list:

        A *= np.matrix([[2*M*(1-h*eig), -1], [1, 0]])

    # Return its half trace corresponding to the evaluation of the polynomial on eig
    return np.trace(A) / 2


def compute_rate(sqrt_m,
                 h_list,
                 eig_list,
                 ):
    """
    Compute the supremum value of the HB polynomial on \Lambda

    Args:
        sqrt_m: (float) the square root of the momentum parameter, also the rate if the corresponding polynomial supremum on \Lambda is smaller than 1
        h_list: (list of floats) list of the step sizes
        eig_list: (list of float) the inputs on which we evaluate the polynomial, also the discretization of \Lambda

    Returns: (float) supremum value of the over mentioned polynomial on \Lambda

    """

    # Compute the list of absolute values of the HB polynomial on all of the given eigenvalues
    eigen_rates = np.abs([compute_eigen_rate(sqrt_m=sqrt_m, h_list=h_list, eig=eig) for eig in eig_list])

    # Return the max of it
    return float(np.max(eigen_rates))


def optimize_h(sqrt_m,
               nb_step_sizes,
               eig_list,
               grid_size,
               n_jobs=-1,
               ):
    """
    Find the list of step sizes h which minimizes the supremum of the HB polynomial for a given \sqrt(m).
    The knowledge of \sqrt(m) fixes P(0).
    If the \sqrt(m) is the right one, this supremum is exactly 1.
    The right \sqrt(m) is the maximum one for which such list of h exists.
    The finding of h here is done using grid search.

    Args:
        sqrt_m: (float) the square root of the momentum parameter, also the rate if the corresponding polynomial supremum on \Lambda is smaller than 1
        nb_step_sizes: (int) the number of step sizes used for HB, also the number of h we need to determine here
        eig_list: (list of float) the inputs on which we evaluate the polynomial, also the discretization of \Lambda
        grid_size: (int) the number of tried values of h_i. Hence we try grid_size^nb_step_sizes different combinations of h
        n_jobs: (int) the number of used cpus. All are used if n_jobs=-1

    Returns: (tuple) the list of h as well as the corresponding supremum that should be exactly one if the correct \sqrt(m) is used

    """

    # Define smallest and largest eigenvalues
    mu = np.min(eig_list)
    L = np.max(eig_list)

    # Create h_values = [h_k, ..., h_0] with h_i an ndarray of values to explore
    h_values_list = [1 / np.linspace(mu, L, grid_size) for _ in range(nb_step_sizes)]

    # Meshgrid h_values to explore all possible combinations and flatten the obtained matrices
    h_values_list = np.meshgrid(*h_values_list)
    h_values_list = [h_values.reshape(-1) for h_values in h_values_list]

    # Create a list of (multidimensional) h to explore
    h_values_list = [list(h_list) for h_list in list(zip(*h_values_list))]

    # Create the list of values
    values_list = Parallel(n_jobs=n_jobs)(delayed(compute_rate)(sqrt_m=sqrt_m, h_list=h_list, eig_list=eig_list) for h_list in tqdm(h_values_list))
    values_list = np.array(values_list, dtype=np.float)

    # Find the index of the smallest value
    best_position = int(np.nanargmin(values_list))

    # Return the best h, and the corresponding value
    return h_values_list[best_position], values_list[best_position]


def find_best_tuning(eig_list,
                     nb_step_sizes=1,
                     grid_size=100,
                     n_jobs=-1,
                     plot_polynomial=False,
                     ):
    """
    Find the best tuning m and h from \Lambda and the number of h we want

    Args:
        eig_list: (list of float) the inputs on which we evaluate the polynomial, also the discretization of \Lambda
        nb_step_sizes: (int) the number of step sizes used for HB, also the number of h we need to determine here
        grid_size: (int) the number of tried values of h_i. Hence we try grid_size^nb_step_sizes different combinations of h
        n_jobs: (int) the number of used cpus. All are used if n_jobs=-1
        plot_polynomial: (bool) False by default. If True, plot the best polynomial as well as the one corresponding to the best tuning for a saninty check. They must be both the same.

    Returns: (tuple) the tuning parameters m and h

    """

    # Compute the coefficients of the extremal polynomial
    P = compute_optimal_polynomial(eig_list=eig_list, deg=nb_step_sizes)

    # From P(0), we recover \sqrt(m)
    sqrt_m = (P[0] - sqrt(P[0] ** 2 - 1)) ** (1 / nb_step_sizes)

    # From \sqrt(m), we optimize h and return it as well as the supremum value of the polynomial over \Lambda
    best_h_list, should_be_one = optimize_h(sqrt_m=sqrt_m,
                                            nb_step_sizes=nb_step_sizes,
                                            eig_list=eig_list,
                                            grid_size=grid_size,
                                            n_jobs=n_jobs,
                                            )

    # If plot polynomial is True, then run a couple of sanity checks
    if plot_polynomial:

        # This last value must be one. The next print is the a sanity check
        print("This value should be close to one: {}".format(should_be_one))

        # Define an ndarray of inputs
        x = np.linspace(0, min(eig_list) + max(eig_list), 30)

        # Compute the values of the extremal polynomial from its coefficients
        V = np.array([x ** k for k in range(nb_step_sizes + 1)]).T
        y1 = V@P

        # Compute the values of the extremal polynomial from m and h
        y2 = [compute_eigen_rate(sqrt_m=sqrt_m, h_list=best_h_list, eig=x0) for x0 in x]

        # Plot the 2 polynomials
        # They must be identical to each other
        plt.figure()
        plt.plot(x, y1, 'c')
        plt.plot(x, y2, 'r')
        plt.title("Comparison of the ways of computing the same extremal polynomial")
        plt.xlabel("Eigenvalues")
        plt.ylabel("Polynomial values")
        plt.legend(["Extremal polynomial computed from its coefficients", "Extremal polynomial computed from m and h"], loc='best')
        plt.show()

    # Return \sqrt(m) and h
    # \sqrt(m) is also the rate of HB
    return sqrt_m, best_h_list


if __name__ == "__main__":

    # Define a discretization of \Lambda
    eig_list = list(np.linspace(1, 2, 10)) + list(np.linspace(4, 5, 10))

    # Define the number of step sizes we want to use to tune HB
    nb_step_sizes = 2

    # Compute the best tuning of HB
    sqrt_m, best_h_list = find_best_tuning(eig_list=eig_list,
                                           nb_step_sizes=nb_step_sizes,
                                           grid_size=30,
                                           n_jobs=-1,
                                           plot_polynomial=True,
                                           )

    # Print results
    print("HB reaches a rate of {} using the momentum coefficient {} and the {} step sizes {}".format(sqrt_m, sqrt_m**2, nb_step_sizes, best_h_list))
