"""
find_init_weights.py
====================================
Finding the weights of the to map an specific activation function
"""

import json
import numpy as np
from .utils import fit_rational_to_base_function
import torch
import os
from models.rational.numpy.rationals import Rational_version_A, Rational_version_B, \
    Rational_version_C, Rational_version_N


def plot_result(x_array, rational_array, target_array,
                original_func_name="Original function"):
    import matplotlib.pyplot as plt
    plt.plot(x_array, rational_array, label="Rational approx")
    plt.plot(x_array, target_array, label=original_func_name)
    plt.legend()
    plt.grid()
    plt.show()


def append_to_config_file(params, approx_name, w_params, d_params, overwrite=None):
    rational_full_name = f'Rational_version_{params["version"]}{params["nd"]}/{params["dd"]}'
    cfd = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
    with open(f'{cfd}/rationals_config.json') as json_file:
        rationals_dict = json.load(json_file)  # rational_version -> approx_func
    approx_name = approx_name.lower()
    if rational_full_name in rationals_dict:
        if approx_name in rationals_dict[rational_full_name]:
            if overwrite is None:
                overwrite = input(f'Rational_{params["version"]} approximation of {approx_name} already exist. \
                                  \nDo you want to replace it ? (y/n)') in ["y", "yes"]
            if not overwrite:
                print("Parameters not stored")
                return
        else:
            rationals_params = {"init_w_numerator": w_params.tolist(),
                                "init_w_denominator": d_params.tolist(),
                                "ub": params["ub"], "lb": params["lb"]}
            rationals_dict[rational_full_name][approx_name] = rationals_params
            with open(f'{cfd}/rationals_config.json', 'w') as outfile:
                json.dump(rationals_dict, outfile, indent=1)
            print("Parameters stored in rationals_config.json")
            return
    rationals_dict[rational_full_name] = {}
    rationals_params = {"init_w_numerator": w_params.tolist(),
                        "init_w_denominator": d_params.tolist(),
                        "ub": params["ub"], "lb": params["lb"]}
    rationals_dict[rational_full_name][approx_name] = rationals_params
    with open(f'{cfd}/rationals_config.json', 'w') as outfile:
        json.dump(rationals_dict, outfile, indent=1)
    print("Parameters stored in rationals_config.json")



def typed_input(text, type, choice_list=None):
    assert isinstance(text, str)
    while True:
        try:
            inp = input(text)
            typed_inp = type(inp)
            if choice_list is not None:
                assert typed_inp in choice_list
            break
        except ValueError:
            print(f"Please provide an type: {type}")
            continue
        except AssertionError:
            print(f"Please provide a value within {choice_list}")
            continue
    return typed_inp


FUNCTION = None


def find_weights(function, function_name=None, degrees=None, bounds=None,
                 version=None, plot=None, save=None, overwrite=None):
    """
    Finds the weights of the numerator and the denominator of the rational function.
    Beside `function`, all parameters can be left to the default ``None``. \n
    In this case, user is asked to provide the params interactively.

    Arguments:
            function (callable):
                The function to approximate (e.g. from torch.functional).\n
            function_name (str):
                The name of this function (used at Rational initialisation)\n
            degrees (tuple of int):
                The degrees of the numerator (P) and denominator (Q).\n
                Default ``None``
            bounds (tuple of int):
                The bounds to approximate on (e.g. (-3,3)).\n
                Default ``None``
            version (str):
                Version of Rational to use. Rational(x) = P(x)/Q(x)\n
                `A`: Q(x) = 1 + \|b_1.x\| + \|b_2.x\| + ... + \|b_n.x\|\n
                `B`: Q(x) = 1 + \|b_1.x + b_2.x + ... + b_n.x\|\n
                `C`: Q(x) = 0.1 + \|b_1.x + b_2.x + ... + b_n.x\|\n
                `D`: like `B` with noise\n
            plot (bool):
                If True, plots the fitted and target functions.
                Default ``None``
            save (bool):
                If True, saves the weights in the config file.
                Default ``None``
            save (bool):
                If True, if weights already exist for this configuration, they are overwritten.
                Default ``None``
    Returns:
        tuple: (numerator, denominator) if not `save`, otherwise `None` \n

    """
    # To be changed by the function you want to approximate
    if function_name is None:
        function_name = input("approximated function name: ")
    FUNCTION = function

    def function_to_approx(x):
        # return np.heaviside(x, 0)
        x = torch.tensor(x)
        return FUNCTION(x)

    if degrees is None:
        nd = typed_input("degree of the numerator P: ", int)
        dd = typed_input("degree of the denominator Q: ", int)
        degrees = (nd, dd)
    else:
        nd, dd = degrees
    if bounds is None:
        print("On what range should the function be approximated ?")
        lb = typed_input("lower bound: ", float)
        ub = typed_input("upper bound: ", float)
    else:
        lb, ub = bounds
    nb_points = 100000
    step = (ub - lb) / nb_points
    x = np.arange(lb, ub, step)
    if version is None:
        version = typed_input("Rational Version: ", str,
                              ["A", "B", "C", "D", "N"])
    if version == 'A':
        rational = Rational_version_A
    elif version == 'B':
        rational = Rational_version_B
    elif version == 'C':
        rational = Rational_version_C
    elif version == 'D':
        rational = Rational_version_B
    elif version == 'N':
        rational = Rational_version_N

    w_params, d_params = fit_rational_to_base_function(rational, function_to_approx, x,
                                                       degrees=degrees,
                                                       version=version)
    print(f"Found coeffient :\nP: {w_params}\nQ: {d_params}")
    if plot is None:
        plot = input("Do you want a plot of the result (y/n)") in ["y", "yes"]
    if plot:
        plot_result(x, rational(x, w_params, d_params), function_to_approx(x),
                    function_name)
    params = {"version": version, "name": function_name, "ub": ub, "lb": lb,
              "nd": nd, "dd": dd}
    if save is None:
        save = input("Do you want to store them in the json file ? (y/n)") in ["y", "yes"]
    if save:
        append_to_config_file(params, function_name, w_params, d_params, overwrite)
    else:
        print("Parameters not stored")
        return w_params, d_params
