from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit
from scipy.stats import linregress

import cfg
from src.figures import metric_change


def fig_al_schematic(ax):
    # set random seed
    np.random.seed(4)

    def compute_approximation_curve(a, b, x):
        a_0 = -0.2
        return a * (1 - np.exp(a_0-x / b))

    # define vars
    a = 0.9
    b_random = 600
    b_al = 150
    x = np.linspace(0, 1400, 20)

    # get noisy labels
    y_random = compute_approximation_curve(a, b_random, x)
    y_al = compute_approximation_curve(a, b_al, x)

    # get noisy datapoints
    noise_level = 0.04
    y_random_points = np.clip(y_random + np.random.normal(loc=0, scale=noise_level, size=len(x)), 0, a)
    y_al_points = np.clip(y_al + np.random.normal(loc=0, scale=noise_level, size=len(x)), 0, a)
    y_al_points[0] = y_random_points[0]
    # plot data
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xlabel('# added training samples (x)')
    ax.set_ylabel('Performance (p)')

    # plot ceiling line
    ax.axhline(y=a, color='silver', linestyle='-.', label='Ceiling Perf. ($a_{\infty}$)')

    # plot random
    color_rand = cfg.figures['random']['color']
    marker_rand = cfg.figures['random']['marker']
    label_rand = cfg.figures['random']['label']
    ax.scatter(x, y_random_points, marker=marker_rand, facecolors='none', edgecolors=color_rand, label=label_rand + ' (scatter)')
    ax.plot(x, y_random_points, linestyle='--', label=label_rand + ' (connected)', color=color_rand)
    ax.plot(x, y_random, label=label_rand + ' ($\hat{p}_{\\text{rand}}$)', color=color_rand)

    # plot AL
    color_al = 'k'
    marker_al = '*'
    ax.scatter(x, y_al_points, marker=marker_al, label='AL' + ' (scatter)', facecolors='none', edgecolors=color_al)
    ax.plot(x, y_al_points, linestyle='--', label='AL (connected)', color=color_al)
    ax.plot(x, y_al, label='AL ($\hat{p}_{\\text{AL}}$)', color=color_al)

    ax.set_xlim([0, None])

    ax.legend(frameon=False, labelspacing=0, bbox_to_anchor=(.48, .7))


def fig_performance_over_training_samples(ax, tuple, df, qms, linestyle='-', moving_avg=False):
    # use only requested experiments
    df = filter_df_results(df, tuple, qms)

    # iterate over df rows
    for index, row in df.iterrows():
        # get number of training samples and performance
        nr_training_samples = row['nr_training_samples']
        performance = row['evaluation_f1_macro']
        if int(row['random_seed'].split('_')[1]) >= 30:
            performance_var = row['evaluation_f1_macro_std']
        else:
            performance_var = row['evaluation_f1_macro_sem']

        # use moving avg
        if moving_avg:
            window = 5
            performance = np.convolve(performance, np.ones(window) / window, mode='valid')
            performance_var = np.convolve(performance_var, np.ones(window) / window, mode='valid')
            nr_training_samples = nr_training_samples[window - 1:]

        # plot in ax
        qm = row['qm']
        color = cfg.figures[qm]['color']
        marker = cfg.figures[qm]['marker']
        label = cfg.figures[qm]['label']

        # plot
        markevery = max(1, int(len(nr_training_samples) / 5))

        ax.plot(nr_training_samples, performance, label=label, color=color, linestyle=linestyle, marker=marker,
                markevery=markevery)
        ax.fill_between(nr_training_samples, performance - performance_var,
                        performance + performance_var, color=color, alpha=0.1)


def fig_frac_al_rand_over_performance(ax, tuple, df, qms, linestyle='-', moving_avg=False):
    # use only requested experiments
    df = filter_df_results(df, tuple, qms)

    # compute random inverted function
    index_random = df[df['qm'] == 'random'].index[0]
    nr_training_samples = df.loc[index_random, 'nr_training_samples']
    nr_training_samples = nr_training_samples - nr_training_samples[0]
    performance_random = df.loc[index_random, 'evaluation_f1_macro']

    if moving_avg:
        window = 5
        performance_random = np.convolve(performance_random, np.ones(window) / window, mode='valid')
        nr_training_samples = nr_training_samples[window - 1:]

    interp_random = interp1d(performance_random, nr_training_samples, kind='linear', bounds_error=False)

    # get the values for given performance
    performance_values = np.arange(0, 1.01, 0.01)
    x_rand = interp_random(performance_values)

    # iterate over df rows
    for index, row in df.iterrows():
        # get performance qm
        performance_qm = row['evaluation_f1_macro']
        if moving_avg:
            performance_qm = np.convolve(performance_qm, np.ones(window) / window, mode='valid')

        # get inverted function
        len_interp = min(len(performance_qm), (len(nr_training_samples)))
        performance_qm_interp = performance_qm[:len_interp]
        nr_training_samples_interp = nr_training_samples[:len_interp]
        interp_sampling_method = interp1d(performance_qm_interp, nr_training_samples_interp, kind='linear', bounds_error=False)

        # get the values for given performance
        x_al = interp_sampling_method(performance_values)

        # plot difference AL and QM
        x_al_divided_by_x_rand = np.divide(x_al, x_rand)

        # Plot
        qm = row['qm']
        color = cfg.figures[qm]['color']
        marker = cfg.figures[qm]['marker']
        label = cfg.figures[qm]['label']

        # plot
        markevery = max(1, int(len(performance_values) / 20))
        if linestyle != '-':
            label ='_nolegend_'

        ax.plot(performance_values, x_al_divided_by_x_rand, label=label, color=color, linestyle=linestyle, marker=marker,
                markevery=markevery)


def fig_learning_curve_approximation(ax, tuple, df, qms):
    # use only requested experiments
    df = filter_df_results(df, tuple, qms)

    # iterate over df rows
    for index, row in df.iterrows():
        # get number of training samples and performance
        nr_training_samples = row['nr_training_samples']
        performance = row['evaluation_f1_macro']
        if int(row['random_seed'].split('_')[1]) >= 30:
            performance_var = row['evaluation_f1_macro_std']
        else:
            performance_var = row['evaluation_f1_macro_sem']

        # get approximation curve
        best_p = df.loc[index, 'best_p']
        if not best_p:
            continue
        a_0 = df.loc[index, 'a_0_' + best_p]
        a_inf = df.loc[index, 'a_inf']
        b = df.loc[index, 'b_' + best_p]
        x_fit, y_fit = get_approximation_curve(nr_training_samples, best_p, a_inf, a_0, b)
        sf = df.loc[index, 'sf_' + best_p]

        # plot data
        qm = row['qm']
        color = cfg.figures[qm]['color']
        marker = cfg.figures[qm]['marker']
        label = cfg.figures[qm]['label']
        if label == 'random':
            label_fit = '$\hat{p}^' + best_p[-1] + '_{\\text{rand}}$'
        else:
            label_fit = '$\hat{p}^' + best_p[-1] + '_{\\text{' + label + '}}$'
            sf_al = sf

        step = int(len(nr_training_samples) / 21)
        ax.scatter(nr_training_samples[::step], performance[::step], color=color, marker=marker, label=label)
        ax.plot(x_fit, y_fit, color=color, label=label_fit)
        ax.fill_between(nr_training_samples, performance - performance_var, performance + performance_var,
                        color=color, alpha=0.1)

        # plot speedup factor
        if qm != 'random':
            ax.text(0.02, 1, f'$S={np.round(sf, 2)}$', transform=ax.transAxes,
                    horizontalalignment='left', verticalalignment='top')

    ax.legend(frameon=False, loc='lower right', labelspacing=0.2)

    return sf_al


def fig_performance_metrics_over_training_samples(ax, tuple, df, qms, linestyle='-'):
    # use only requested experiments
    df = filter_df_results(df, tuple, qms)

    # load or generate dict performance metric change
    result_dict = metric_change.load(df)

    # iterate over metrics
    metrics = ['lc_mean', 'lc_aulc_norm', 'cutpoints', 's_fix', 's']
    nr_training_samples = result_dict['nr_training_samples']
    for metric in metrics:
        # read metric
        performance = result_dict[metric]

        # keep only values above 700 samples
        mask = nr_training_samples >= 700
        nr_training_samples_filtered = nr_training_samples[mask]
        performance_filtered = performance[mask]

        # Adjust Cutpoints
        if metric == 'cutpoints':
            performance_filtered = np.where(performance_filtered > 0.95, 1, 0)
            if performance_filtered[-1] == 0:
                performance_filtered = 1 - performance_filtered

        # compute difference to final value
        performance_filtered = 100 * np.abs((1 - performance_filtered / performance_filtered[-1]))

        #
        color = cfg.figures[metric]['color']
        marker = cfg.figures[metric]['marker']
        label_qm = cfg.figures[qms[1]]['label']
        label_metric = cfg.figures[metric]['label']
        label_metric = label_metric.replace('qm', label_qm)

        # plot
        markevery = max(1, int(len(nr_training_samples) / 5))
        if linestyle != '-':
            label_metric ='_nolegend_'

        ax.plot(nr_training_samples_filtered, performance_filtered, label=label_metric, color=color, marker=marker,
                markevery=markevery, linestyle=linestyle)

        # add label qm to legend
        handles, labels = ax.get_legend_handles_labels()
        ph = [plt.plot([], marker='', ls='')[0]]  # dummy
        handles = ph + handles
        labels = [label_qm] + labels
        ax.legend(handles, labels, frameon=False, labelspacing=0)



def fig_sf_over_events(ax, folder, dataset, sampling_method, tag='evaluation'):
    # read merged df
    df_all = pd.read_pickle(Path(cfg.path_exp, folder, dataset, 'results_merged.pkl'))
    # filter by # classes = 1 and the selected sampling method
    df = df_all[(df_all['nr_classes'] == 1) & (df_all['sampling_method'] == sampling_method)]

    # nr_events and sf for plot, clear nan
    x_data = df['samples_per_class'].to_numpy()
    x_data = np.array(x_data, dtype=np.float64)
    sf = df['sf_best_p'].to_numpy()
    nan_indices = (np.isnan(sf)) | (sf > 3)
    x_data = x_data[~nan_indices]
    y_data = sf[~nan_indices]

    # get regression line
    slope, intercept, r_value, p_value, std_err = linregress(np.log10(x_data), y_data)
    x_line = np.linspace(x_data.min(), x_data.max(), 100)
    regression_line = slope * np.log10(x_line) + intercept

    # plot data
    color = cfg.figures[sampling_method]['color']
    marker = cfg.figures[sampling_method]['marker']
    label = cfg.figures[sampling_method]['label']
    label_regression = r'$r^2$: ' + str(np.round(r_value, 2)) + '\np: ' + "{:.1e}".format(p_value)

    ax.scatter(x_data, y_data, color=color, marker=marker, label=label)
    ax.plot(x_line, regression_line, c='black', label=label_regression)
    ax.legend(frameon=False, loc='lower right', labelspacing=0.2)


def fig_events_over_training_samples(ax, folder, dataset, hashs, sampling_methods, tag='evaluation'):
    # read merged df
    df = pd.read_pickle(Path(cfg.path_exp, folder, dataset, 'results_merged.pkl'))
    # iterate over all hashs and all sampling methods
    for hash_ in hashs:
        for sampling_method in sampling_methods:
            # get row index sampling method and random
            index = int(df.query('hash == @hash_ and sampling_method == @sampling_method').index[0])

            # get n_training_samples and positive samples
            nr_training_samples = df.loc[index, 'nr_training_samples']
            nr_added_training_samples = nr_training_samples - nr_training_samples[0]
            nr_events = df.loc[index, 'nr_events_in_training_samples']
            nr_events_std = df.loc[index, 'nr_events_in_training_samples_sem']
            if np.shape(nr_events)[1] == 1:
                nr_events = nr_events.squeeze()
                nr_events_std = nr_events_std.squeeze()
                frac_events = 100 * nr_events / nr_training_samples
                frac_events_std = 100 * nr_events_std / nr_training_samples

            # plot data
            color = cfg.figures[sampling_method]['color']
            marker = cfg.figures[sampling_method]['marker']
            label = cfg.figures[sampling_method]['label']

            ax.plot(nr_added_training_samples, frac_events, color=color, label=label, marker=marker, markevery=10)
            ax.fill_between(nr_added_training_samples, frac_events - frac_events_std, frac_events + frac_events_std,
                            color=color, alpha=0.1)
    ax.legend(frameon=False, labelspacing=0.2)


def fig_events_over_training_samples_one_class(ax, folder, dataset, sampling_method, dataset_size, label):
    # read merged df
    df = pd.read_pickle(Path(cfg.path_exp, folder, dataset, 'results_merged.pkl'))

    # filter by sampling method, nr_classes and specific class
    nr_classes_max = 4
    df = df.query('sampling_method == @sampling_method and nr_classes <= @nr_classes_max')

    # string of label
    label_str = cfg.label_prefix + label

    # get n_training_samples and positive samples
    for nr_classes in range(1, nr_classes_max + 1):
        # get df with the number of classes
        df_iter = df.query('nr_classes == @nr_classes')

        # get nr training samples, nr_events_in_training_samples
        nr_events_in_training_samples_all = []
        nr_events_in_training_samples_var_all = []
        for index, row in df_iter.iterrows():
            # continue if target class is not in classes
            if label_str not in row['classes']:
                continue
            # find index of class
            index = row['classes'].index(label_str)

            # get nr training samples and nr_events_in_training_samples
            nr_training_samples = row['nr_training_samples']
            nr_events_in_training_samples = row['nr_events_in_training_samples'][:, index]
            nr_events_in_training_samples_all.append(nr_events_in_training_samples)
            nr_events_in_training_samples_var = row['nr_events_in_training_samples_sem'][:, index]
            nr_events_in_training_samples_var_all.append(nr_events_in_training_samples_var)
            samples_per_class = row['samples_per_class']
            if not isinstance(samples_per_class, np.float64):
                samples_per_class = samples_per_class[index]

        nr_added_training_samples = nr_training_samples - nr_training_samples[0]
        # get mean nr of events in training samples
        nr_events_in_training_samples_mean = np.mean(nr_events_in_training_samples_all, axis=0)
        nr_events_in_training_samples_var_mean = np.mean(nr_events_in_training_samples_var_all, axis=0)

        frac_events = 100 * nr_events_in_training_samples_mean / nr_training_samples
        frac_events_std = 100 * nr_events_in_training_samples_var_mean / nr_training_samples
        frac_events_dataset = 100 * samples_per_class / dataset_size



        # plot data
        linestyle_dict = {1: 'solid',
                          2: 'dashed',
                          3: 'dashdot',
                          4: 'dotted'}
        color = cfg.figures[sampling_method]['color']
        label = cfg.figures[sampling_method]['label'] + ', classes: ' + str(nr_classes)
        linestyle = linestyle_dict[nr_classes]


        ax.plot(nr_added_training_samples, frac_events, color=color, linestyle=linestyle, label=label)
        ax.fill_between(nr_added_training_samples, frac_events - frac_events_std, frac_events + frac_events_std,
                        color=color, alpha=0.1)
    color = cfg.figures['random']['color']
    ax.axhline(y=frac_events_dataset, color=color, label='Fraction in dataset', linewidth=2)
    ax.legend(frameon=False, labelspacing=0.2)


def fig_sf_over_classes(ax, folder, dataset, sampling_method, dataset_size, tag='evaluation'):
    # read merged df
    df = pd.read_pickle(Path(cfg.path_exp, folder, dataset, 'results_merged.pkl'))
    df = df.query("sampling_method == @sampling_method")

    # get best sf, nr_classes and nr_samples, delete nan values
    nr_classes = df['nr_classes'].to_numpy()
    sf = df['sf_best_p'].to_numpy()
    nr_samples = df['samples_per_class'].map(np.mean).to_numpy()
    nan_indices = (np.isnan(sf)) | (sf > 1.5)
    x_data = nr_classes[~nan_indices]
    y_data = sf[~nan_indices]
    color_data = nr_samples[~nan_indices]
    color_data = 100 * color_data / dataset_size

    # generate cmap
    color = cfg.figures[sampling_method]['color']
    label = cfg.figures[sampling_method]['label']
    cmap = LinearSegmentedColormap.from_list('cmap', ['cyan', color])

    # plot data
    sc = ax.scatter(x_data, y_data, alpha=0.3, c=np.log10(color_data), cmap='cool', label=label)

    # colorbar
    if dataset == 'carina':
        cbar_ticks_original = [10, 20, 40, 60]
    elif dataset == 'mscoco':
        cbar_ticks_original = [0.3, 1, 3, 10, 30]
    elif dataset == 'reuters':
        cbar_ticks_original = [1, 3, 10]
    elif dataset == 'scene':
        cbar_ticks_original = [16, 18, 20]
    else:
        cbar_ticks_original = [10, 20, 40, 60]
    cbar_ticks = [np.log10(val) for val in cbar_ticks_original]
    cbar_labels = [str(val).lstrip('0') for val in cbar_ticks_original]
    cbar = plt.colorbar(sc, ax=ax, pad=0)
    cbar.set_label('positive samples [%]')
    cbar.set_ticks(cbar_ticks)
    cbar.set_ticklabels(cbar_labels)

    ax.legend(frameon=False, loc='lower right', labelspacing=0.2)





def fig_processing_time_over_training_samples(ax, tuple, df, qms):
    # use only requested experiments
    df = filter_df_results(df, tuple, qms)

    # get random value
    processing_time_random = df.loc[df['qm'] == 'random', 'processing_time'].values[0]

    # iterate over df rows
    for index, row in df.iterrows():
        # get number of training samples and performance
        nr_training_samples = row['nr_training_samples']
        processing_time_qm = row['processing_time']
        if int(row['random_seed'].split('_')[1]) >= 30:
            processing_time_var = row['processing_time_std']
        else:
            processing_time_var = row['processing_time_sem']

        # get plot values
        processing_time_norm = processing_time_qm / processing_time_random

        # plot in ax
        qm = row['qm']
        color = cfg.figures[qm]['color']
        marker = cfg.figures[qm]['marker']
        label = cfg.figures[qm]['label']

        # plot
        markevery = max(1, int(len(nr_training_samples) / 5))

        ax.plot(nr_training_samples, processing_time_norm, label=label, color=color, marker=marker, markevery=markevery)
        #ax.fill_between(nr_training_samples, processing_time_qm - processing_time_var,
        #                processing_time_qm + processing_time_var, color=color, alpha=0.1)



def get_approximation_curve(x_data, p, a_inf, a_0, b_qm):
    x = np.linspace(x_data.min(), x_data.max(), 100)
    if p == 'p1':
        y = a_inf * (1 - np.exp(a_0 - x / b_qm))
    elif p == 'p2':
        y = a_inf * (1 / (1 + np.exp(a_0 - x / b_qm)))
    return x, y


def filter_df_results(df, tuple, qms):
    (dataset, dataset_type, init_train_samples, add_train_samples, max_train_samples, train_paradigm, weight_init,
     training, nr_classes, classes) = tuple

    # filter df
    df_filter = df[(df['dataset'] == dataset) &
                   (df['dataset_type'] == dataset_type) &
                   (df['init_train_samples'] == init_train_samples) &
                   (df['add_train_samples'] == add_train_samples) &
                   (df['max_train_samples'] == max_train_samples) &
                   (df['train_paradigm'] == train_paradigm) &
                   (df['weight_init'] == weight_init) &
                   (df['training'] == training) &
                   (df['nr_classes'] == nr_classes) &
                   (df['qm'].isin(qms))].reset_index(drop=True)
    if classes != 'ALL':
        df_filter = df_filter[df_filter['classes'].apply(lambda x: f'{cfg.label_prefix}{classes}' in x)]
    return df_filter