from typing import List

import matplotlib.ticker as mticker

from path_learning.utils import create_dir
from analysis.analysis_utils import gather_analysis_results_df, ANALYSIS_CREATORS, MODEL_ANALYSIS_CREATORS, \
    NONLINEAR_DYNAMICS_ANALYSIS
from analysis.plotting import plot_model_step, performance_stability_plot, ANALYSIS_PLOTS, vector_plot, \
    step_number_effect

COLUMNS: List[str] = ["Experiment Date", " Experiment Time", "Task Date", " Task Time", "Experiment", "Model-type",
                      "Model-task-step", "Seed", "Hash",
                      "Status", "Task-step", "Dataloder", "Domains", "Loss-fct",
                      "Task Duration", "User", "Logdir"]


# TODO: fix this - load which values to load from config not from these dicts



class MathTextSciFormatter(mticker.Formatter):
    def __init__(self, fmt="%1.2e"):
        self.fmt = fmt

    def __call__(self, x, pos=None):
        s = self.fmt % x
        decimal_point = '.'
        positive_sign = '+'
        tup = s.split('e')
        significand = tup[0].rstrip(decimal_point)
        sign = tup[1][0].replace(positive_sign, '')
        exponent = tup[1][1:].lstrip('0')
        if exponent:
            exponent = '10^{%s%s}' % (sign, exponent)
        if significand and exponent:
            s = r'%s{\times}%s' % (significand, exponent)
        else:
            s = r'%s%s' % (significand, exponent)
        return "${}$".format(s)


def visualize_analysis_df(title: str, exp_names: list, analysis_config: str, plot_type: str):
    """
    Visualize from a selected set of experiments
    :param title:
    :param exp_names:
    :param analysis_config
    :param plot_type:
    :return:
    """
    # 1. Get all results from all past experiments
    analysis_cols = []
    for analysis_mode in ["-train", "-test"]:
        for key in ANALYSIS_CREATORS.keys():
            analysis_cols.append(key + analysis_mode)
    model_cols = []
    for key in MODEL_ANALYSIS_CREATORS.keys():
        model_cols.append(key)
    # COLUMNS.extend(ANALYSIS_CREATORS.keys())
    # COLUMNS.extend(MODEL_ANALYSIS_CREATORS.keys())
    # TODO: fix this, needlessly complicated
    non_dyn_cols = []
    for i, key in enumerate(NONLINEAR_DYNAMICS_ANALYSIS.keys()):
        non_dyn_cols.append(key+"-weights")
        non_dyn_cols.append(key + "-grads")

    df = gather_analysis_results_df(exp_names, COLUMNS,
                                    analysis_cols, model_cols,
                                    non_dyn_cols)
    create_dir(ANALYSIS_PLOTS)

    print(f"DF : {df.iloc[0]}")
    # TODO: implement individual plotting functions and automatic way to select them
    # TODO: plotting config to specify what to plot
    # 2. Plotting
    # plot_model_step(df, title, analysis_config, plot_type, model_step=1)
    performance_stability_plot(df, title, analysis_config, plot_type=None)
    step_number_effect(df, title, analysis_config, plot_type=None)

    # vector_plot(df, title, analysis_config, plot_type=None)




