

import xarray as xr

from tabicl.config.config_benchmark_sweep import ConfigBenchmarkSweep
from tabicl.core.enums import Task


def create_normalization_info(cfg: ConfigBenchmarkSweep, ds: xr.Dataset) -> xr.Dataset:

    scores = ds['score']

    match cfg.benchmark.task:
        case Task.REGRESSION:
            score_min = scores.quantile(0.5, dim=('run_id', 'model_name'), skipna=True)
        case Task.CLASSIFICATION:
            score_min = scores.quantile(0.1, dim=('run_id', 'model_name'), skipna=True)

    score_min = score_min.drop_vars('quantile')

    score_max = scores.max(dim=('run_id', 'model_name'))

    ds['score_min'] = score_min
    ds['score_max'] = score_max

    return ds