"""Test script."""

import torch
from omegaconf import DictConfig
import hydra
from wandb_wrapper import init as wandb_init
from utils.log import get_logger, log_fn
from utils.config import load_checkpoint, build_model
from utils.dataclasses import (
    ExperimentConfig,
    PredictionConfig,
    OptimizationConfig,
    DataConfig,
)
from utils.paths import (
    get_exp_path,
    get_log_filepath,
    get_result_plot_path,
    get_result_data_path,
    get_split_dataset_path,
)
from evaluate import evaluate_prediction, evaluate_optimization
from utils.data import set_all_seeds
from evaluation_utils import (
    get_opt_dataset,
    from_function_name_to_datapaths,
    from_seed_to_data,
)


@hydra.main(version_base=None, config_path="configs", config_name="test.yaml")
def main(config: DictConfig):
    assert config.experiment.mode == "test", f"Set mode to 'test'!"
    torch.set_printoptions(threshold=torch.inf)
    torch.set_default_dtype(torch.float32)
    torch.set_default_device("cpu")
    set_all_seeds(config.experiment.seed)

    exp_cfg = ExperimentConfig(**config.experiment)
    pred_cfg = PredictionConfig(**config.prediction)
    opt_cfg = OptimizationConfig(**config.optimization)
    data_cfg = DataConfig(**config.data)

    log_filename = get_log_filepath(
        group_name=exp_cfg.model_name,
        expid=exp_cfg.expid,
        prefix=exp_cfg.task,
    )
    log = log_fn(get_logger(file_name=log_filename, mode="w"))

    def _get_function_name_parts(function_name, subfolder, ckpt_name, extra=None):
        """Get path suffix based on function name, subfolder, and checkpoint name.
        function_name/subfolder/ckpt_name (without .tar)
        """
        parts = function_name

        if subfolder is not None:
            parts = f"{parts}/{subfolder}"

        if ckpt_name != "ckpt.tar":
            parts = f"{parts}/{ckpt_name.split(".")[0]}"

        if extra is not None:
            parts = f"{parts}/{extra}"
        return parts

    suffix = _get_function_name_parts(
        function_name=data_cfg.function_name,
        subfolder=config.extra.subfolder,
        ckpt_name=config.extra.ckpt_name,
        extra=config.extra.save_suffix,
    )
    plot_save_path = get_result_plot_path(
        model_name=exp_cfg.model_name,
        expid=exp_cfg.expid,
        task_type=exp_cfg.task,
        suffix=suffix,
    )
    data_save_path = get_result_data_path(
        model_name=exp_cfg.model_name,
        expid=exp_cfg.expid,
        task_type=exp_cfg.task,
        suffix=suffix,
    )
    exp_path = get_exp_path(model_name=exp_cfg.model_name, expid=exp_cfg.expid)
    dataset_path = get_split_dataset_path(split=exp_cfg.mode)
    log(
        f"--- Setup logging and paths ---"
        f"\nlog_filename:\t{log_filename}"
        f"\nexp_path:\t{exp_path}"
        f"\nplot_save_path:\t{plot_save_path}"
        f"\ndata_save_path:\t{data_save_path}"
        f"\ndataset_path:\t{dataset_path}"
    )

    if exp_cfg.log_to_wandb:
        log(f"--- Setup W&B ---\n{config.wandb}")
        wandb_init(config=config, **config.wandb)

    ckpt = load_checkpoint(
        exp_path=exp_path,
        device=exp_cfg.device,
        resume=exp_cfg.resume,
        ckpt_name=config.extra.ckpt_name,
    )
    model_state_dict = ckpt.get("model", {})
    if not model_state_dict:
        raise RuntimeError(
            f"Invalid checkpoint loaded from {exp_path}. "
            "Checkpoint is either empty or missing the 'model' key."
        )

    model = build_model(
        model_name=exp_cfg.model_name,
        model_kwargs=config.model,
        use_factorized_policy=opt_cfg.use_factorized_policy,
    )
    model = model.to(exp_cfg.device)
    model.load_state_dict(model_state_dict, strict=False)

    if exp_cfg.task == "prediction":
        datapaths = from_function_name_to_datapaths(
            function_name=data_cfg.function_name,
            split=exp_cfg.mode,
            subfolder=config.extra.subfolder,
        )
        if datapaths is None:
            raise ValueError(f"Unsupported function name: {data_cfg.function_name}")

        evaluate_prediction(
            model=model,
            datapaths=datapaths,
            data_save_path=data_save_path,
            plot_save_path=plot_save_path,
            exp_cfg=exp_cfg,
            pred_cfg=pred_cfg,
            data_cfg=data_cfg,
            log=log,
            **config.logging,
        )

    elif exp_cfg.task == "optimization":
        train_x, train_y, train_x_bounds, train_y_bounds = get_opt_dataset(
            function_name=data_cfg.function_name,
            seed=exp_cfg.seed,
            subfolder=config.extra.subfolder,
            device=exp_cfg.device,
        )

        evaluate_optimization(
            model=model,
            plot_save_path=plot_save_path,
            data_save_path=data_save_path,
            exp_cfg=exp_cfg,
            opt_cfg=opt_cfg,
            data_cfg=data_cfg,
            pred_cfg=pred_cfg,
            train_x=train_x,
            train_y=train_y,
            train_x_bounds=train_x_bounds,
            train_y_bounds=train_y_bounds,
            log=log,
            d=config.extra.d,
            cost=config.extra.cost,
            cost_mode=config.extra.cost_mode,
            q=config.extra.q,
            fantasy=config.extra.fantasy,
            **config.logging,
        )
    else:
        raise ValueError(
            f"Unsupported task: {exp_cfg.task}. "
            "Supported tasks are 'prediction' and 'optimization'."
        )


if __name__ == "__main__":
    main()
