from pathlib import Path

import torch.multiprocessing as mp

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import Phase
from tabicl.core.trainer_pretrain_evaluate import create_config_benchmark_sweep
from tabicl.data.benchmarks import BENCHMARKS
from tabicl.results.decision_boundary_analysis import decision_boundary_analysis
from tabicl.sweeps.run_sweep import run_sweep


def test_model(cfg: ConfigPretrain, model_path: Path):
    """
    This function is placed outside the trainer, because we have problems with DDP not properly cleaning up CUDA memory
    By running this function in a separate process, we can ensure that the CUDA memory is properly cleaned up before running test
    """

    if cfg.testing.decision_boundary_analysis_enabled:
        for task in cfg.testing.downstream_tasks:
            # Without a process wrapper, some CUDA memory keeps being allocated and not freed
            p = mp.Process(target=decision_boundary_analysis, args=(cfg, model_path, task))
            p.start()
            p.join()

    for benchmark_name in cfg.testing.benchmarks_test:
        for task in cfg.testing.downstream_tasks:
            
            benchmark = BENCHMARKS[benchmark_name]
            output_dir = cfg.output_dir / f"test_{benchmark_name.value}_{task.value}"

            cfg_sweep = create_config_benchmark_sweep(
                cfg=cfg,
                benchmark=benchmark,
                output_dir=output_dir,
                weights_path=model_path,
                plot_name=f"{cfg.model_name.value} Pretrain Test",
                phase=Phase.TESTING,
                task=task,
            )
            run_sweep(cfg_sweep)