import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from itertools import cycle


def highlight_max(s):
    '''
    Highlights the maximum value in a Series in bold.
    '''
    is_max = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_max]

def highlight_positive_differences(s):
    '''
    Highlights positive differences in a Series in bold.
    '''
    return ['font-weight: bold' if v > 0 else '' for v in s]

def highlight_differences(s):
    '''
    Highlights positive differences in green and negative ones in red.
    '''
    return ['color: green' if v > 0 else 'color: black' if v == 0 else 'color: red' for v in s]

def highlight_pvalues(s):
    '''
    Highlights significant p-values: green for p < 0.01, orange for p <= 0.05, black otherwise.
    '''
    return ['color: green' if v < 0.01 else 'color: orange' if v <= 0.05 else 'color: black' for v in s]


def highlight_best(df):
    styles = df.copy()
    
    # Highlight best values for NDCG, MRR, and SCORE (higher is better)
    for col in df.columns:
        if col.startswith(('NDCG', 'MRR', 'SCORE')):
            max_value = df[col].max()
            styles[col] = df[col].apply(lambda x: 'font-weight: bold' if x == max_value else '')

    # Highlight best values for TTB and AVG_RANK (lower is better)
    for col in df.columns:
        if col.startswith(('TTB', 'AVG_RANK')):
            min_value = df[col].min()
            styles[col] = df[col].apply(lambda x: 'font-weight: bold' if x == min_value else '')
    
    return styles

def highlight_best_differences(df):
    styles = df.copy()
    
    for col in df.columns:
        if col.startswith(('NDCG', 'MRR', 'SCORE')):
            styles[col] = df[col].apply(lambda v: 'color: green' if v > 0 else 'color: black' if v == 0 else 'color: red')

        if col.startswith(('TTB', 'AVG_RANK')):
            styles[col] = df[col].apply(lambda v: 'color: green' if v < 0 else 'color: black' if v == 0 else 'color: red')
    
    return styles


def highlight_best_differences_reg(df):
    styles = df.copy()
    
    for col in df.columns:
        if col.startswith(('NDCG', 'MRR')):
            styles[col] = df[col].apply(lambda v: 'color: green' if v > 0 else 'color: black' if v == 0 else 'color: red')

        if col.startswith(('TTB', 'AVG_RANK', 'SCORE')):
            styles[col] = df[col].apply(lambda v: 'color: green' if v < 0 else 'color: black' if v == 0 else 'color: red')
    
    return styles


mean_and_std_sup = lambda x: f"{np.mean(x):.3f} {f'$^{{({np.std(x):.2f})}}$'}"
mean_and_std_sup_not = lambda x: f"{np.mean(x):.2e} {f'$^{{({np.std(x):.2e})}}$'}"

mean_and_std = lambda x: f"{np.mean(x):.3f}"+'±'+f"{np.std(x):.3f}"


def rename_columns(df):
    return df.rename(columns={
        'NDCG@10': 'NDCG@5',
        'NDCG@100': 'NDCG@10',
        'MRR@10': 'MRR@5',
        'MRR@100': 'MRR@10',
        'SCORE@10': 'SCORE@5',
        'SCORE@100': 'SCORE@10',
        'TTB@10': 'TTB@5',
        'TTB@100': 'TTB@10',
        'AVG_RANK@10': 'AVG_RANK@5',
        'AVG_RANK@100': 'AVG_RANK@10'
    })

def get_score_pos_dataframe(scores_results, pos):    
    score_pos_dict = {}
    for approach in scores_results.keys():
        app_score_pos = {}
        for task in scores_results[approach].keys():
            app_score_pos[task] = scores_results[approach][task]['score'][pos]
        score_pos_dict[approach] = app_score_pos
    return pd.DataFrame(score_pos_dict).T

def prepare_data(data):
    df_melted = data.reset_index().melt(id_vars='index', var_name='id', value_name='result')
    df_melted.rename(columns={'index': 'framework'}, inplace=True)
    df_melted = df_melted.reindex(columns=['id', 'framework', 'result'])
    df_melted['id'] = df_melted.id.apply(lambda x: 'openml.org/t/' + str(x))
    df_melted['metric'] = 'balanced_accuracy'
    df_melted['constraint'] = 'sklearn_preprocessed'
    df_melted['task'] = df_melted['id'].apply(lambda x: x.split('/')[-1])

    mean_results = df_melted[["framework", "task", "constraint", "metric", "result"]].groupby(
        ["framework", "task", "constraint", "metric"], as_index=False).agg(
        {"result": "mean"}
    )

    lookup = mean_results.set_index(["framework", "task", "constraint"])
    for index, row in mean_results.iterrows():
        lower = lookup.loc[(slice(None), row["task"], row["constraint"]), "result"].mean()
        upper = lookup.loc[(slice(None), row["task"], row["constraint"]), "result"].max()
        if lower == upper:
            mean_results.loc[index, "scaled"] = float("nan")
        else:
            mean_results.loc[index, "scaled"] = (row["result"] - lower) / (upper - lower)

    return mean_results
    
def plot_three_scaled_boxes_horizontal(data1, data2, data3, labels=['','',''], replace={}):
    
    dataframes = [prepare_data(data) for data in [data1, data2, data3]]
    sorted_app = [
        'MCTS_Rank', 'MCTS_Score',
        'BO_Rank', 'BO_Score',
        'LGBMRegressor_Rank', 'LGBMRegressor_Score',
        'GradientBoostingRegressor_Rank', 'GradientBoostingRegressor_Score',
        'RandomForestRegressor_Rank', 'RandomForestRegressor_Score',
        'Ridge_Rank', 'Ridge_Score',
        'Lasso_Rank', 'Lasso_Score',
        'LinearRegression_Rank', 'LinearRegression_Score',
        'Avg_Rank', 'Avg_Score',
        'Random'
    ]

    fig, axes = plt.subplots(1, 3, figsize=(9, 6), sharex=True)

    y_tick_printed = False
    for ax, mean_results, data_name in zip(axes, dataframes, labels):
        boxplot = ax.boxplot(
            mean_results.pivot(index='framework', columns='task', values='scaled').dropna(axis=1, how='all').T.reindex(columns=sorted_app),
            vert=False,  # Change orientation to horizontal
            labels=sorted_app,
            patch_artist=True, showfliers=False
        )
        
        if not y_tick_printed:
            print(ax.get_yticklabels())
            print(ax.get_yticks())
            ax.set_yticklabels([replace.get(text.get_text().split('_')[0], text.get_text().split('_')[0]) for text in ax.get_yticklabels()], rotation=0)
            ax.set_yticks([1,3,5,7,9,11,13,15,17,19])
            
            tomato_patch = mpatches.Patch(color='tomato', label='Random', edgecolor='k')
            dodgerblue_patch = mpatches.Patch(color='cornflowerblue', label='Ranked', edgecolor='k')
            orange_patch = mpatches.Patch(color='darkorange', label='Scored', edgecolor='k')
            line_patch = mpatches.Patch(color='k', label='mean', linewidth=2, linestyle='--')
            
            ax.legend(handles=[tomato_patch,orange_patch,dodgerblue_patch], loc='lower left')
            
            y_tick_printed = True
            
        else:
            ax.set_yticks([])
            ax.set_yticklabels([])

        
        cycling_colors = cycle(['cornflowerblue', 'darkorange'])
        positions_combo_boxes = []
        for i, patch in enumerate(boxplot['boxes'][:-1]):  # Apply colors except to the last one
            color = next(cycling_colors)
            patch.set_facecolor(color)
            positions_combo_boxes.append(i + 1)  # Guarda la posición de los boxes azul y naranja

        # Last boxplot is assigned 'tomato' color
        last_patch = boxplot['boxes'][-1]
        last_patch.set_facecolor('tomato')
        positions_combo_boxes.append(len(boxplot['boxes']))  # Guarda la posición del boxplot 'tomato'

        for element in ['whiskers', 'caps', 'medians']:
            plt.setp(boxplot[element], color='black')

        ax.axvline(x=0, color='gray', linestyle='--')  # Line along x-axis instead of y
        ax.set_title(data_name)
        
        ax.set_xlim(left=-5)

        # Add a horizontal line after every pair of boxplots (blue and orange)
        for i in range(1, len(positions_combo_boxes) - 1, 2):
            ax.axhline(y=positions_combo_boxes[i + 1] - 0.5, color='#C5C9C7', linestyle='dotted')#, linewidth=1.5)

    plt.subplots_adjust(wspace=0, left=0.1, right=0.95, top=0.95, bottom=0.1)
    plt.savefig('plot_three_scaled_boxes_horizontal.pdf')
    plt.show()
    
    
rename_metrics = {
    'NDCG@1': 'N@1',
    'NDCG@10': 'N@10',
    'NDCG@100': 'N@100',
    'MRR@1': 'M@1',
    'MRR@10': 'M@10',
    'MRR@100': 'M@100',
    'SCORE@1': 'S@1',
    'SCORE@10': 'S@10',
    'SCORE@100': 'S@100',
    'TTB@1': 'T@1',
    'TTB@10': 'T@10',
    'TTB@100': 'T@100',
    'AVG_RANK@1': 'R@1',
    'AVG_RANK@10': 'R@10',
    'AVG_RANK@100': 'R@100'
}

rename_metrics_aslib = {
    'NDCG@1': 'N@1',
    'NDCG@10': 'N@5',
    'NDCG@100': 'N@10',
    'MRR@1': 'M@1',
    'MRR@10': 'M@5',
    'MRR@100': 'M@10',
    'SCORE@1': 'S@1',
    'SCORE@10': 'S@5',
    'SCORE@100': 'S@10',
    'TTB@1': 'T@1',
    'TTB@10': 'T@5',
    'TTB@100': 'T@10',
    'AVG_RANK@1': 'R@1',
    'AVG_RANK@10': 'R@5',
    'AVG_RANK@100': 'R@10'
}