from __future__ import annotations

import pandas as pd
import xarray as xr

from tabicl.config.config_benchmark_sweep import ConfigBenchmarkSweep
from tabicl.core.enums import DataSplit
from tabicl.results.dataset_manipulations import (add_model_plot_name_as_model_name, average_out_the_cv_split,
                                                  create_normalized_score,
                                                  only_use_models_and_datasets_specified_in_cfg,
                                                  select_only_default_runs_and_average_over_them,
                                                  select_only_the_first_default_run_of_every_model_and_dataset,
                                                  update_model_name_to_model_name_and_enum)
from tabicl.results.normalize_scores import create_normalization_info
from tabicl.results.reformat_results_get import get_reformatted_results
from tabicl.results.results_sweep import ResultsSweep
from tabicl.utils.paths_and_filenames import DEFAULT_RESULTS_TEST_FILE_NAME, DEFAULT_RESULTS_VAL_FILE_NAME


def make_default_results(cfg: ConfigBenchmarkSweep, results_sweep: ResultsSweep) -> None:

    ds_bench = get_reformatted_results(cfg.benchmark.origin)
    ds_bench = update_model_name_to_model_name_and_enum(ds_bench)
    ds_bench = only_use_models_and_datasets_specified_in_cfg(cfg, ds_bench)
    ds_sweep = add_model_plot_name_as_model_name(results_sweep.ds, cfg.model_plot_name)

    ds = xr.merge([ds_bench, ds_sweep])

    ds = average_out_the_cv_split(ds)
    ds_norm = create_normalization_info(cfg, ds)

    ds_bench = average_out_the_cv_split(ds_bench)
    ds_bench = select_only_default_runs_and_average_over_them(ds_bench)

    ds_sweep = select_only_the_first_default_run_of_every_model_and_dataset(cfg, ds_sweep)
    ds_sweep = average_out_the_cv_split(ds_sweep)

    ds = xr.merge([ds_bench, ds_sweep], combine_attrs='drop')

    ds['score_min'] = ds_norm['score_min']
    ds['score_max'] = ds_norm['score_max']
    ds = create_normalized_score(ds)

    make_df_results(cfg, ds, DataSplit.VALID)
    make_df_results(cfg, ds, DataSplit.TEST)



def make_df_results(cfg: ConfigBenchmarkSweep, ds: xr.Dataset, data_split: DataSplit) -> pd.DataFrame:
    
    ds = ds.sel(data_split=data_split.name).reset_coords('data_split', drop=True)

    df = ds['score'].to_pandas()
    normalized_score = ds['normalized_score'].mean(dim='openml_dataset_id').to_dataframe()

    df['aggregate'] = normalized_score['normalized_score']
    df = df.set_index(ds['model_name'].values)           # type: ignore
    df = df.round(4)

    df.to_csv(cfg.output_dir / get_results_file_name(data_split), mode='w', header=True)


def get_results_file_name(data_split: DataSplit):

    match data_split:
        case DataSplit.VALID:
            return DEFAULT_RESULTS_VAL_FILE_NAME
        case DataSplit.TEST:
            return DEFAULT_RESULTS_TEST_FILE_NAME
