import numpy as np
import xarray as xr
from matplotlib import pyplot as plt

from tabicl.config.config_benchmark_sweep import ConfigBenchmarkSweep
from tabicl.core.enums import DataSplit


def make_combined_dataset_plot_data(cfg: ConfigBenchmarkSweep, ds: xr.Dataset) -> xr.DataArray:

    score_min = ds['score_min'].sel(data_split=DataSplit.TEST.name)
    score_max = ds['score_max'].sel(data_split=DataSplit.TEST.name)
    sequences_all = (ds['score_sequence'] - score_min) / (score_max - score_min)
    sequences_all = sequences_all.clip(0, 1)
    sequences_all = sequences_all.mean(dim='openml_dataset_id')

    models = ds.coords['model_name'].values
    
    plot_data = np.empty((3, len(models), cfg.plotting.whytrees.n_runs))

    for model_i, model in enumerate(models):

        sequences_model = sequences_all.sel(model_name=model).values

        sequence_mean = np.mean(sequences_model, axis=0)
        sequence_lower_bound = np.quantile(sequences_model, q=1-cfg.plotting.whytrees.confidence_bound, axis=0)
        sequence_upper_bound = np.quantile(sequences_model, q=cfg.plotting.whytrees.confidence_bound, axis=0)

        plot_data[0, model_i, :] = sequence_mean
        plot_data[1, model_i, :] = sequence_lower_bound
        plot_data[2, model_i, :] = sequence_upper_bound

    
    return xr.DataArray(
        plot_data,
        dims=('plot_data', 'model_name', 'run_id'),
        coords={
            'plot_data': ['sequence_mean', 'sequence_lower_bound', 'sequence_upper_bound'],
            'model_name': models,
            'run_id': np.arange(cfg.plotting.whytrees.n_runs)
        }
    )
    

def make_combined_dataset_plot(cfg: ConfigBenchmarkSweep, plot_data: xr.DataArray) -> plt.Figure:
    
    fig, ax = plt.subplots(figsize=(25, 25))

    for model in plot_data.coords['model_name']:

        sequence_mean = plot_data.sel(plot_data='sequence_mean', model_name=model).values
        sequence_lower_bound = plot_data.sel(plot_data='sequence_lower_bound', model_name=model).values
        sequence_upper_bound = plot_data.sel(plot_data='sequence_upper_bound', model_name=model).values

        epochs = np.arange(len(sequence_mean)) + cfg.plotting.whytrees.plot_default_value

        ax.plot(epochs, sequence_mean, label=model.item(), linewidth=12)
        ax.fill_between(x=epochs, y1=sequence_lower_bound, y2=sequence_upper_bound, alpha=0.2)


    ax.set_title(f"Averaged Normalized Test Score \n for all datasets of benchmark {cfg.benchmark.name}", fontsize=50)
    ax.set_xlabel("Number of hyperparameter search runs", fontsize=50)
    ax.set_ylabel("Normalized Test score", fontsize=50)
    ax.tick_params(axis='both', which='major', labelsize=40)

    ax.set_xscale('log')
    ax.set_xlim([1, cfg.plotting.whytrees.n_runs])
    ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: int(x)))     # type: ignore
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=3, fontsize=40, handlelength=3)
    fig.tight_layout(pad=2.0, rect=(0, 0.16, 1, 0.98))

    return fig














