
from typing import Union

import xarray as xr

from tabicl.config.config_benchmark_sweep import ConfigBenchmarkSweep
from tabicl.core.enums import BenchmarkOrigin
from tabicl.results.dataset_manipulations import (add_model_plot_name_as_model_name, average_out_the_cv_split,
                                                  create_normalized_score,
                                                  only_use_models_and_datasets_specified_in_cfg,
                                                  select_only_the_first_default_run_of_every_model_and_dataset,
                                                  take_run_with_best_validation_loss,
                                                  update_model_name_to_model_name_and_enum)
from tabicl.results.normalize_scores import create_normalization_info
from tabicl.results.reformat_results_get import get_reformatted_results


def merge_ds_bench_tabzilla_and_sweeps(cfg: ConfigBenchmarkSweep, ds_sweeps: Union[xr.Dataset, list[xr.Dataset]]) -> xr.Dataset:

    assert cfg.benchmark.origin == BenchmarkOrigin.TABZILLA
    if not isinstance(ds_sweeps, list):
        ds_sweeps = [ds_sweeps]
    ds_sweeps = [ds_sweep.copy() for ds_sweep in ds_sweeps]
    
    ds_bench = get_reformatted_results(BenchmarkOrigin.TABZILLA)
    ds_bench = update_model_name_to_model_name_and_enum(ds_bench)
    ds_bench = only_use_models_and_datasets_specified_in_cfg(cfg, ds_bench)

    for i in range(len(ds_sweeps)):
        ds_sweeps[i] = add_model_plot_name_as_model_name(ds_sweeps[i], ds_sweeps[i].attrs['model_plot_name'])

    ds = xr.merge([ds_bench, *ds_sweeps])

    ds = average_out_the_cv_split(ds)
    ds_norm = create_normalization_info(cfg, ds)

    ds_bench = average_out_the_cv_split(ds_bench)
    ds_bench = take_run_with_best_validation_loss(ds_bench)

    for i in range(len(ds_sweeps)):
        ds_sweeps[i] = select_only_the_first_default_run_of_every_model_and_dataset(cfg, ds_sweeps[i])
        ds_sweeps[i] = average_out_the_cv_split(ds_sweeps[i])

    ds = xr.merge([ds_bench, *ds_sweeps], combine_attrs='drop')

    ds['score_min'] = ds_norm['score_min']
    ds['score_max'] = ds_norm['score_max']
    ds = create_normalized_score(ds)

    return ds