from pathlib import Path

import torch.multiprocessing as mp

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import Phase
from tabicl.core.trainer_pretrain_evaluate import create_config_benchmark_sweep
from tabicl.data.benchmarks import BENCHMARKS
from tabicl.results.decision_boundary_analysis import decision_boundary_analysis
from tabicl.sweeps.run_sweep import run_sweep


def test_model(cfg: ConfigPretrain, model_path: Path):

    for task in cfg.testing.downstream_tasks:
        # Without a process wrapper, some CUDA memory keeps being allocated and not freed
        p = mp.Process(target=decision_boundary_analysis, args=(cfg, model_path, task))
        p.start()
        p.join()

    for benchmark_name in cfg.testing.benchmarks_test:
        for task in cfg.testing.downstream_tasks:
            
            benchmark = BENCHMARKS[benchmark_name]
            output_dir = cfg.output_dir / f"test_{benchmark_name.value}_{task.value}"

            cfg_sweep = create_config_benchmark_sweep(
                cfg=cfg,
                benchmark=benchmark,
                output_dir=output_dir,
                weights_path=model_path,
                plot_name=f"{cfg.model_name.value} Pretrain Test",
                phase=Phase.TESTING,
                task=task,
            )
            run_sweep(cfg_sweep)