"""Evaluation for TAMO.

Functions:
    `evaluate_optimization`: Evaluate TAMO optimization on a test function
    `evaluate_prediction`: Evaluate TAMO predictions on datasets
"""

from dataclasses import asdict
import math
import gc
import os
import os.path as osp
import torch
from torch import Tensor
import wandb
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.pyplot import Figure
from model.module import get_prediction_mean_std
from TAMO.model.tamo import TAMO
from utils.log import Averager
from utils.config import get_train_x_range, get_train_y_range, build_dataloader
from utils.dataclasses import (
    ExperimentConfig,
    PredictionConfig,
    OptimizationConfig,
    DataConfig,
)
from utils.data import HDF5Dataset, set_all_seeds, get_datasets, has_nan_or_inf
from utils.types import FloatListOrNested, NestedFloatList
from typing import Optional, List
from utils.paths import params_to_filename
from utils.plot import (
    plot_prediction,
    plot_1d,
    plot_fronts,
    plot_acq_values,
    adapt_save_fig,
)
from data.function import TestFunction, get_function_environment
from data.function_preprocessing import make_range_nested_list
from data.function_sampling import get_num_subspace_points
from data.data_masking import generate_dim_mask
from policy_learning import select_next_query
from prediction import prepare_prediction_dataset, predict_with_metrics
from evaluation_utils import plot_optimization_batch, expand_metrics
from einops import repeat

PLOT_PER_N_STEPS = -1
GEN_X_MASK_MODE = "full"

x_range = get_train_x_range()
y_range = get_train_y_range()


def save_fig(
    fig: Figure,
    path: str,
    config,
    filename: str,
    override: bool = False,
    log: callable = print,
    log_to_wandb: bool = False,
):
    log(f"--- Saving figure {filename} ---")
    if fig is None:
        log(f"    Figure is None. Skipping save.")
        return None
    folder_name = params_to_filename(asdict(config))
    fig_path = osp.join(path, folder_name, f"{filename}.pdf")
    log(f"    Full path:\t{fig_path}")

    if osp.exists(fig_path) and not override:
        log(f"\n    Figure exists. Skipping save.")
        return fig_path
    elif osp.exists(fig_path) and override:
        log(f"\n    Figure exists but overrides.")
    else:
        os.makedirs(osp.join(path, folder_name), exist_ok=True)
        log(f"\n    Figure saves.")

    adapt_save_fig(fig, fig_path)
    if log_to_wandb:
        wandb.log({filename: wandb.Image(fig)})
    fig.clf()
    plt.close(fig)
    del fig
    return fig_path


def save_data(
    data: Tensor,
    path: str,
    filename: str,
    config,
    override: bool = False,
    log: callable = print,
):
    log(f"--- Saving data {filename} ---")
    if data is None:
        log(f"    Data is None. Skipping save.")
        return None
    folder_name = params_to_filename(asdict(config))
    data_path = osp.join(path, folder_name, f"{filename}.pt")

    log(f"    Full path:\t{data_path}")
    if osp.exists(data_path) and not override:
        log(f"    File exists. Skipping save.")
        return data_path
    elif osp.exists(data_path) and override:
        log(f"    File exists but overrides.")
    else:
        os.makedirs(osp.join(path, folder_name), exist_ok=True)
        log(f"    File saves.")

    torch.save(data, data_path)
    return data_path


def plot_prediction_batch(
    model: TAMO,
    nc: int,
    x: Tensor,
    y: Tensor,
    x_mask: Tensor,
    y_mask: Tensor,
    y_mask_tar: Optional[Tensor] = None,
    xc: Optional[Tensor] = None,
    yc: Optional[Tensor] = None,
    x_bounds: Optional[FloatListOrNested] = None,
    read_cache: bool = False,
    write_cache: bool = False,
    y_mask_history: bool = None,  # [B, t, dy_max]
    plot_mean: bool = True,
    plot_order: bool = False,
    F: Optional[callable] = None,
) -> Figure:
    x_plot, y_plot = x.clone(), y.clone()
    if xc is None or yc is None:
        xc = x_plot[:, :nc]
        yc = y_plot[:, :nc]

    out = model.predict(
        xc, yc, x_plot, x_mask, y_mask, y_mask_tar, read_cache, write_cache
    )
    mean, std = get_prediction_mean_std(out)

    fig = plot_prediction(
        mean=mean,
        std=std,
        x=x_plot,
        y=y_plot,
        xctx=xc,
        yctx=yc,
        x_mask=x_mask,
        y_mask=y_mask_tar,
        y_mask_history=y_mask_history,
        plot_mean=plot_mean,
        plot_order=plot_order,
        F=F,
    )

    del x_plot, y_plot, out, mean, std
    return fig


def evaluate_optimization(
    model: TAMO,
    plot_save_path,
    data_save_path,
    exp_cfg: ExperimentConfig,
    opt_cfg: OptimizationConfig,
    data_cfg: DataConfig,
    pred_cfg: Optional[PredictionConfig] = None,
    train_x: Optional[Tensor] = None,
    train_y: Optional[Tensor] = None,
    train_x_bounds: Optional[FloatListOrNested] = None,
    train_y_bounds: Optional[FloatListOrNested] = None,
    log: callable = print,
    plot_enabled: bool = False,
    plot_per_n_steps: int = PLOT_PER_N_STEPS,
    d: Optional[int] = None,
    cost: float = 1.0,
    cost_mode: bool = True,
    q: int = 1,
    fantasy: bool = False,
    **kwargs,
):
    """Evaluate TAMO optimization."""
    set_all_seeds(exp_cfg.seed)

    plot_save_path = osp.join(plot_save_path, str(exp_cfg.seed))
    data_save_path = osp.join(data_save_path, str(exp_cfg.seed))

    log(
        f"--- Optimization evaluation setup ---"
        f"\nFunction:\t{data_cfg.function_name}"
        f"\nSigma:\t{data_cfg.sigma}"
        f"\nSeed:\t{exp_cfg.seed}"
    )

    test_function = get_function_environment(
        function_name=data_cfg.function_name,
        sigma=data_cfg.sigma,
        train_x=train_x,
        train_y=train_y,
        train_x_bounds=train_x_bounds,
        train_y_bounds=train_y_bounds,
        dim=data_cfg.dim,
    )
    log(
        f"--- Test function details ---"
        f"\nInput dim:\t{test_function.x_dim}"
        f"\nOutput dim:\t{test_function.y_dim}"
        f"\nInput bounds:\t{test_function.x_bounds}"
        f"\nOutput bounds:\t{test_function.y_bounds}"
        f"\nMax hv:\t{test_function.max_hv:.4f}\n"
    )

    model_x_range = make_range_nested_list(x_range, test_function.x_dim)
    model_y_range = make_range_nested_list(y_range, test_function.y_dim)
    log(
        f"--- Model training details ---"
        f"\nModel trained on x_range: {model_x_range}"
        f"\nModel trained on y_range: {model_y_range}\n"
    )

    res = run_optimization(
        model=model,
        test_function=test_function,
        model_x_range=model_x_range,
        model_y_range=model_y_range,
        exp_cfg=exp_cfg,
        opt_cfg=opt_cfg,
        data_cfg=data_cfg,
        pred_cfg=pred_cfg,
        plot_save_path=plot_save_path,
        data_save_path=data_save_path,
        plot_per_n_steps=plot_per_n_steps,
        plot_enabled=plot_enabled,
        log=log,
        d=d,
        cost=cost,
        cost_mode=cost_mode,
        q=q,
        fantasy=fantasy,
    )

    def _log_final_metrics(res, exp_cfg, log=print):
        log(f"--- Final metrics summary (seed {exp_cfg.seed})---")
        for key, item in res.items():
            if isinstance(item, list):
                if item:  # list of tensor [B,]
                    log(f"{key}: {item[-1]}")

    _log_final_metrics(res, exp_cfg, log=log)

    gc.collect()
    torch.cuda.empty_cache()


def get_context_n_target_dim_masks(
    dim: int,
    device: str,
    prec_ctx_mask: Optional[Tensor] = None,
    prec_tar_mask: Optional[Tensor] = None,
    mode: str = "full",
    single_obs_dim: Optional[int] = None,
):
    """Get dimension masks for context and target."""
    # Always full for target
    context_mode_map = {
        "full": "full",
        "single_ctx_full_tar": "single",
        "rs_ctx_full_tar": "random",
        "alt_ctx_full_tar": "alternate",
    }

    context_mode = context_mode_map.get(mode, None)
    if context_mode is None:
        raise ValueError(f"Unknown mode: {mode}")

    ctx_mask = _get_dim_mask(
        dim=dim,
        device=device,
        prev_mask=prec_ctx_mask,
        mode=context_mode,
        single_obs_dim=single_obs_dim,
    )

    # Always full predictions at target
    tar_mask = _get_dim_mask(
        dim=dim,
        device=device,
        prev_mask=prec_tar_mask,
        single_obs_dim=single_obs_dim,
        mode="full",
    )

    return ctx_mask, tar_mask


def _get_dim_mask(
    dim: int,
    device: str,
    prev_mask: Optional[Tensor] = None,
    mode: str = "full",
    single_obs_dim: Optional[int] = None,
):
    """Get dimension mask.

    Args:
        dim: Total dimension
        device: Device to create the mask on
        prev_mask: Previous mask
        mode: "full", "single", "random"
        - "full": return full dim mask
        - "single": return dim mask with only single_obs_dim set to True
        - "random": return dim mask with a random dim set to True
        single_obs_dim: Dimension to set to True if mode is "single"

    Returns:
        dim_mask: [dim, ] boolean tensor
    """
    full_mask, _ = generate_dim_mask(max_dim=dim, device=device)

    if mode == "full":
        if prev_mask is not None:
            # Reuse previous mask if available
            return prev_mask

        return full_mask
    elif mode == "single":
        if prev_mask is not None:
            # Reuse previous mask if available
            return prev_mask

        full_indices = full_mask.int().nonzero(as_tuple=False)[:, 0]  # [num_valid,]
        assert single_obs_dim in full_indices

        mask = torch.zeros_like(full_mask, dtype=torch.bool, device=device)
        mask[single_obs_dim] = True
        return mask
    elif mode == "random":
        full_indices = full_mask.int().nonzero(as_tuple=False)[:, 0]  # [num_valid,]
        rs_from_full_indices = torch.randperm(full_indices.shape[0], device=device)[0]
        valid_idx = full_indices[rs_from_full_indices]

        mask = torch.zeros_like(full_mask, dtype=torch.bool)
        mask[valid_idx] = True
        return mask
    elif mode == "alternate":
        if prev_mask is None:
            full_indices = full_mask.int().nonzero(as_tuple=False)[:, 0]
            first_idx = full_indices[0]

            mask = torch.zeros_like(full_mask, dtype=torch.bool, device=device)
            mask[first_idx] = True
        else:
            prev_indices = prev_mask.int().nonzero(as_tuple=False)[:, 0]
            full_indices = full_mask.int().nonzero(as_tuple=False)[:, 0]

            next_indices = [i for i in full_indices if i not in prev_indices]
            if len(next_indices) == 0:
                next_indices = full_indices
            next_idx = next_indices[0]
            mask = torch.zeros_like(full_mask, dtype=torch.bool, device=device)
            mask[next_idx] = True
        return mask
    else:
        raise ValueError(f"Unknown mode: {mode}")


def _run_prediction_at_opt_step(
    test_function: TestFunction,
    model: TAMO,
    x_ctx: Tensor,
    y_ctx: Tensor,
    x_mask: Tensor,
    y_mask: Tensor,
    y_mask_tar: Tensor,
    train_x_range: list,
    train_y_range: list,
    batch_size: int,
    read_cache: bool,
    write_cache: bool,
    x_bounds: Optional[FloatListOrNested] = None,
    x_tar: Optional[Tensor] = None,
    y_tar: Optional[Tensor] = None,
    num_subspace_points: int = 500,
    sigma: float = 0.0,
    plot_enabled: bool = False,
    y_mask_history: Optional[Tensor] = None,  # [B, t, dy_max]
):
    """Returns: (nll_c, nll_t, mse_c, mse_t, figs {mean, std} or None)"""
    device = x_mask.device

    if x_tar is None or y_tar is None:
        (x_tar, y_tar, _, _) = test_function.sample(
            input_bounds=train_x_range,
            batch_size=batch_size,
            num_subspace_points=num_subspace_points,
            use_grid_sampling=True,
            use_factorized_policy=False,
            device=device,
            x_mask=x_mask,
            y_mask=y_mask_tar,
        )

        # NOTE scale to model train_y_range
        y_tar = test_function.scale_outputs(
            outputs=y_tar, output_bounds=train_y_range, sigma=sigma
        )

        x_mask_exp = repeat(x_mask, "d -> b d", b=batch_size)
        y_mask_exp = repeat(y_mask, "d -> b d", b=batch_size)
        y_mask_tar_exp = repeat(y_mask_tar, "d -> b d", b=batch_size)

    # TODO report context nll
    nll_c, mse_c = torch.tensor([-1.0], device=device), -torch.ones(
        test_function.y_dim, device=device
    )
    nll_t, mse_t, _, _ = predict_with_metrics(
        model=model,
        x_ctx=x_ctx,
        y_ctx=y_ctx,
        x_tar=x_tar,
        y_tar=y_tar,
        x_mask=x_mask_exp,
        y_mask=y_mask_exp,
        y_mask_tar=y_mask_tar_exp,
        compute_nll=True,
        compute_mse=True,
        compute_ktt=True,
        reduce_nll=True,
        reduce_mse=True,
        read_cache=read_cache,
        write_cache=write_cache,
    )

    if plot_enabled:
        pnc = x_ctx.shape[1]
        figs = {}
        for plot_mean in [True, False]:
            fig = plot_prediction_batch(
                model=model,
                nc=pnc,
                xc=x_ctx,
                yc=y_ctx,
                x=x_tar,
                y=y_tar,
                x_bounds=x_bounds,
                x_mask=x_mask_exp,
                y_mask=y_mask_exp,
                y_mask_tar=y_mask_tar_exp,
                read_cache=read_cache,
                write_cache=write_cache,
                y_mask_history=y_mask_history,
                plot_mean=plot_mean,
                plot_order=True,
                F=test_function,
            )
            figs["mean" if plot_mean else "std"] = fig
    else:
        figs = None

    return nll_c, nll_t, mse_c, mse_t, figs


def _log_opt_step(step: int, hv, regret, entropy, x_ctx=None, y_ctx=None, log=print):
    line = (
        f"--- Step {step} ---"
        f"\nMetrics:"
        f"\n  Hypervolume:\n{hv}"
        f"\n  Regret:\n{regret}"
        f"\n  Entropy:\n{entropy}"
    )
    if x_ctx is not None and y_ctx is not None:
        line += f"\nContext Points:" f"\n  x_ctx:\n{x_ctx}" f"\n  y_ctx:\n{y_ctx}"
    log(line)


def _should_plot(
    cost_used, cost_total, plot_per_n_unit_cost, plot_enabled, init_cost=1
):
    plot_enabled = plot_enabled and plot_per_n_unit_cost > 0
    at_plot_step = (
        cost_used == init_cost
        or cost_used == cost_total - 1
        or (
            cost_used < cost_total
            and (cost_used - init_cost) % plot_per_n_unit_cost == 0
        )
    )
    return plot_enabled and at_plot_step


class States:
    def __init__(
        self,
        x_dim: int,
        y_dim: int,
        observation_mode="full",
        observed_y_dim=None,
        device: str = "cuda",
        num_init: int = 1,
        cost: int = 1,
        cost_mode: bool = True,
    ):
        assert num_init >= 1, "num_init must be at least 1"
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.gen_mode = observation_mode
        self.observed_y_dim = observed_y_dim
        self.device = device
        self.cost_mode = cost_mode

        self._init_observed_y_mode(observation_mode, observed_y_dim)
        self._init_masks(x_dim, y_dim, device)
        self._init_costs(y_dim, cost, device)

        # Init for initial observations
        for _ in range(num_init):
            self.step(update_mask=True)

    def _init_masks(self, x_dim, y_dim, device):
        # No missing inputs
        self.x_mask, _ = generate_dim_mask(max_dim=x_dim, device=device)

        # No missing target predictions
        self.y_mask_target, _ = generate_dim_mask(max_dim=y_dim, device=device)

        # Enable partial observations
        self.y_mask_observed = torch.empty(1, 0, y_dim, device=device, dtype=torch.bool)

    def _init_costs(self, y_dim, cost, device):
        self.cost_each = torch.tensor(
            [cost] * y_dim, dtype=torch.float32, device=device
        )
        self.cost_used = torch.tensor([0.0] * y_dim, dtype=torch.float32, device=device)

    def _init_observed_y_mode(self, mode, observed_dim):
        if mode == "single_ctx_full_tar":
            self.observed_y_mode = "single"
            assert (
                observed_dim is not None
            ), f"observed_y_dim should be provided for gen_mode {mode}"
        else:
            assert (
                observed_dim is None
            ), f"observed_y_dim should be None for gen_mode {mode}"
            if mode == "full":
                self.observed_y_mode = "full"
            elif mode == "rs_ctx_full_tar":
                self.observed_y_mode = "random"
            elif mode == "alt_ctx_full_tar":
                self.observed_y_mode = "alternate"
            else:
                raise ValueError(f"Unknown gen_mode: {mode}")

    def get_cost_used(self):
        return self.cost_used.sum().item()

    def get_current_y_mask_observed(self):
        return self.y_mask_observed[0, -1]  # [dy, ]

    def step(self, update_mask: bool = True):
        num_observation = self.y_mask_observed.shape[1]
        if num_observation == 0:
            prev_mask = None
        else:
            # Get previous observation mask
            prev_mask = self.y_mask_observed[0, -1]  # [dy, ]

        if update_mask:
            # Get new observation mask
            mask = _get_dim_mask(
                dim=self.y_dim,
                device=self.device,
                prev_mask=prev_mask,
                mode=self.observed_y_mode,
                single_obs_dim=self.observed_y_dim,
            )
        else:
            # Do not update mask, just reuse previous mask
            mask = prev_mask.clone()

        mask_exp = mask.unsqueeze(0).unsqueeze(0)
        self.y_mask_observed = torch.cat([self.y_mask_observed, mask_exp], dim=1)

        # Update related cost for new observations
        self._update_cost(mask)

        return mask

    def _update_cost(self, mask):
        if self.cost_mode:
            # Update with cost associated with newly observed dimensions
            step_cost = self.cost_each * mask.int()  # [dy, ]
            self.cost_used += step_cost
        else:
            # unit cost per step
            self.cost_used[0] += 1.0


def run_optimization(
    model: TAMO,
    test_function: TestFunction,
    model_x_range: NestedFloatList,
    model_y_range: NestedFloatList,
    plot_save_path: str,
    data_save_path: str,
    exp_cfg: ExperimentConfig,
    opt_cfg: OptimizationConfig,
    data_cfg: DataConfig,
    pred_cfg: Optional[PredictionConfig] = None,
    plot_per_n_steps: int = PLOT_PER_N_STEPS,
    plot_enabled: bool = False,
    log: callable = print,
    d: Optional[int] = None,
    predict: bool = True,
    cost: float = 1.0,
    cost_mode: bool = True,
    q: int = 1,
    fantasy: bool = False,
):

    if predict:
        assert pred_cfg is not None, "`pred_cfg` must be provided if predict is True"

    device = exp_cfg.device
    model = model.to(device)

    if d is None:
        d = get_num_subspace_points(
            x_dim=test_function.x_dim,
            use_factorized_policy=opt_cfg.use_factorized_policy,
        )
    log(f"Number of subspace points (d) to use: {d} ")
    T = opt_cfg.T

    x_ctx, y_ctx, hv, regret = test_function.init(
        input_bounds=model_x_range,
        batch_size=opt_cfg.batch_size,
        num_initial_points=opt_cfg.num_initial_points,
        regret_type=opt_cfg.regret_type,
        compute_hv=True,
        compute_regret=True,
        device=device,
    )

    entropy = torch.zeros((opt_cfg.batch_size,), device=device)
    states = States(
        x_dim=test_function.x_dim,
        y_dim=test_function.y_dim,
        observation_mode=opt_cfg.dim_mask_gen_mode,
        observed_y_dim=opt_cfg.single_obs_y_dim,
        device=device,
        num_init=opt_cfg.num_initial_points,
        cost=cost,
        cost_mode=cost_mode,
    )
    init_cost = states.get_cost_used()
    hv = expand_metrics(hv, x_ctx)  # [B, q]
    regret = expand_metrics(regret, x_ctx)  # [B, q]
    entropy = expand_metrics(entropy, x_ctx)  # [B, q]
    results = {
        "hpvs_list": [hv],
        "instant_hpvs_list": [hv.clone()],
        "regr_list": [regret],
        "entr_list": [entropy],
        "time_list": [0.0],
        "nll_c_list": [],
        "nll_t_list": [],
        "mse_c_list": [],
        "mse_t_list": [],
    }
    _log_opt_step(states.get_cost_used(), hv, regret, entropy, x_ctx, y_ctx, log=log)

    q_chunk, q_chunk_mask, logit_mask = None, None, None

    model.eval()
    with torch.no_grad():
        # for t in tqdm(range(1, T + 1), f"Opimization loop for T={T} steps"):
        while states.get_cost_used() <= T:
            # NOTE scale function outputs to train_y_range
            y_ctx_scaled = test_function.scale_outputs(
                outputs=y_ctx, output_bounds=model_y_range, sigma=data_cfg.sigma
            )

            # # TODO minor bug when plotting if auto_clear_cache=True: optimization have cleared up cache at final step
            if predict:
                nll_c, nll_t, mse_c, mse_t, figs = _run_prediction_at_opt_step(
                    test_function=test_function,
                    model=model,
                    x_ctx=x_ctx,
                    y_ctx=y_ctx_scaled,
                    x_mask=states.x_mask,
                    y_mask=states.get_current_y_mask_observed(),
                    y_mask_tar=states.y_mask_target,
                    x_bounds=model_x_range,
                    train_x_range=model_x_range,
                    train_y_range=model_y_range,
                    batch_size=opt_cfg.batch_size,
                    read_cache=pred_cfg.read_cache,
                    write_cache=pred_cfg.write_cache,
                    sigma=data_cfg.sigma,
                    plot_enabled=_should_plot(
                        states.get_cost_used(),
                        T,
                        plot_per_n_steps,
                        plot_enabled,
                        init_cost,
                    ),
                    y_mask_history=states.y_mask_observed,
                )

                results["nll_c_list"].append(nll_c)  # +[1]
                results["nll_t_list"].append(nll_t)
                results["mse_c_list"].append(mse_c)
                results["mse_t_list"].append(mse_t)  # +[DY]

                # Optional plot at step t
                if figs is not None:
                    ctx_dx_tuple_str = "".join(
                        map(str, states.x_mask.nonzero(as_tuple=False)[:, 0].tolist())
                    )
                    ctx_dy_tuple_str = "".join(
                        map(
                            str,
                            states.get_current_y_mask_observed()
                            .nonzero(as_tuple=False)[:, 0]
                            .tolist(),
                        )
                    )
                    filename = f"context_dx{ctx_dx_tuple_str}dy{ctx_dy_tuple_str}_"
                    nllc_mean = nll_c.detach().mean().item()
                    nllt_mean = nll_t.detach().mean().item()
                    filename += f"nc{x_ctx.shape[1]}_t{states.get_cost_used()}T{T}_nllc{int(nllc_mean)}nllt{int(nllt_mean)}"

                    for key, fig in figs.items():
                        key_filename = f"{key}_{filename}"
                        save_fig(
                            fig,
                            path=plot_save_path,
                            config=opt_cfg,
                            filename=key_filename,
                            override=exp_cfg.override,
                            log=log,
                            log_to_wandb=exp_cfg.log_to_wandb,
                        )

            q_x_next = None
            q_entr_list = []
            inference_time = 0.0

            if fantasy:
                x_ctx_step = x_ctx.clone()
                y_ctx_scaled_step = y_ctx_scaled.clone()
                
                # Expand batch dim for masks 
                x_mask_exp = repeat(
                    states.x_mask,
                    "d -> b d",
                    b=opt_cfg.batch_size,
                )
                y_mask_exp = repeat(
                    states.get_current_y_mask_observed(),
                    "d -> b d",
                    b=opt_cfg.batch_size,
                )
                y_mask_tar_exp = repeat(
                    states.y_mask_target,
                    "d -> b d",
                    b=opt_cfg.batch_size,
                )

            for qi in range(q):
                if fantasy:
                    assert (
                        opt_cfg.write_cache == False
                    ), "Fantasy with write_cache=True not supported"
                    assert (
                        opt_cfg.read_cache == False
                    ), "Fantasy with read_cache=False not supported"
                    assert (
                        pred_cfg.read_cache == False
                    ), "Fantasy with read_cache=False not supported"

                    action_res = select_next_query(
                        model=model,
                        x_mask=states.x_mask,
                        y_mask=states.get_current_y_mask_observed(),
                        y_mask_tar=states.y_mask_target,
                        x_ctx=x_ctx_step,
                        y_ctx=y_ctx_scaled_step,
                        input_bounds=model_x_range,
                        t=states.get_cost_used(),
                        T=T,
                        use_grid_sampling=opt_cfg.use_grid_sampling,
                        use_time_budget=opt_cfg.use_time_budget,
                        use_factorized_policy=opt_cfg.use_factorized_policy,
                        use_fixed_query_set=opt_cfg.use_fixed_query_set,
                        epsilon=opt_cfg.epsilon,
                        read_cache=opt_cfg.read_cache,
                        write_cache=opt_cfg.write_cache,
                        d=d,
                        q_chunk=q_chunk,
                        q_chunk_mask=q_chunk_mask,
                        logit_mask=logit_mask,
                        evaluate=True,
                        auto_clear_cache=True,
                    )

                    x_next = action_res[0]
                    acq_values = action_res[4]  # [B, n, d]
                    entropy = action_res[3]
                    q_chunk = action_res[5]
                    q_chunk_mask = action_res[6]
                    infer_time = action_res[7]
                    logit_mask = action_res[8]

                    x_ctx_step = torch.cat([x_ctx_step, x_next], dim=1)

                    # Augment context with fantasized outcomes
                    out = model.predict(
                        x_ctx=x_ctx_step[:, :-1],  # Previous context
                        y_ctx=y_ctx_scaled_step,
                        x_tar=x_ctx_step,
                        x_dim_mask=x_mask_exp,
                        y_dim_mask=y_mask_exp,
                        y_dim_mask_tar=y_mask_tar_exp,
                        read_cache=False,
                        write_cache=False,
                    )

                    # Take fantasized outcome for the newly selected query point
                    mean, _ = get_prediction_mean_std(out)
                    mean = mean[:, -1:, :]

                    y_ctx_scaled_step = torch.cat([y_ctx_scaled_step, mean], dim=1)

                    def _update_q_context(
                        q_x_next,
                        q_entr_list,
                        q_inference_time,
                        x_next,
                        entropy,
                        inference_time,
                    ):
                        # Update q_x_next
                        if q_x_next is None:
                            q_x_next = x_next
                        else:
                            q_x_next = torch.cat([q_x_next, x_next], dim=1)

                        q_entr_list.append(entropy)
                        q_inference_time += inference_time

                        return q_x_next, q_entr_list, q_inference_time

                    q_x_next, q_entr_list, inference_time = _update_q_context(
                        q_x_next,
                        q_entr_list,
                        inference_time,
                        x_next,
                        entropy,
                        infer_time,
                    )

                    # q_entr_list.append(entropy)
                    # def _update_q_x_next(q_x_next, x_next):
                    #     # [B, q, x_dim]
                    #     if q_x_next is None:
                    #         return x_next
                    #     else:
                    #         return torch.cat([q_x_next, x_next], dim=1)
                    # q_x_next = _update_q_x_next(q_x_next, x_next)

                    # update mask only after last query
                    states.step(update_mask=(qi == q - 1))

                else:
                    action_res = select_next_query(
                        model=model,
                        x_mask=states.x_mask,
                        y_mask=states.get_current_y_mask_observed(),
                        y_mask_tar=states.y_mask_target,
                        x_ctx=x_ctx,
                        y_ctx=y_ctx_scaled,
                        input_bounds=model_x_range,
                        t=states.get_cost_used(),
                        T=T,
                        use_grid_sampling=opt_cfg.use_grid_sampling,
                        use_time_budget=opt_cfg.use_time_budget,
                        use_factorized_policy=opt_cfg.use_factorized_policy,
                        use_fixed_query_set=opt_cfg.use_fixed_query_set,
                        epsilon=opt_cfg.epsilon,
                        read_cache=opt_cfg.read_cache,
                        write_cache=opt_cfg.write_cache,
                        d=d,
                        q_chunk=q_chunk,
                        q_chunk_mask=q_chunk_mask,
                        logit_mask=logit_mask,
                        evaluate=True,
                        auto_clear_cache=True,
                    )

                    x_next = action_res[0]
                    acq_values = action_res[4]  # [B, n, d]
                    entropy = action_res[3]
                    q_chunk = action_res[5]
                    q_chunk_mask = action_res[6]
                    inference_time += action_res[7]
                    logit_mask = action_res[8]

                    q_entr_list.append(entropy)

                    def _update_q_x_next(q_x_next, x_next):
                        # [B, q, x_dim]
                        if q_x_next is None:
                            return x_next
                        else:
                            return torch.cat([q_x_next, x_next], dim=1)

                    q_x_next = _update_q_x_next(q_x_next, x_next)

                    # update mask only after last query
                    states.step(update_mask=(qi == q - 1))

            x_ctx, y_ctx, hv, regret = test_function.step(
                input_bounds=model_x_range,
                x_new=q_x_next,
                x_ctx=x_ctx,
                y_ctx=y_ctx,
                compute_hv=True,
                compute_regret=True,
                regret_type=opt_cfg.regret_type,
            )

            _log_opt_step(
                step=states.get_cost_used(),
                hv=hv,
                regret=regret,
                entropy=entropy,
                x_ctx=x_ctx,
                y_ctx=y_ctx,
                log=log,
            )

            hv = expand_metrics(hv, q_x_next)  # [B, q]
            regret = expand_metrics(regret, q_x_next)  # [B, q]
            entropy = torch.stack(q_entr_list, dim=1)  # [B, q]
            # entropy = expand_metrics(entropy, q_x_next)  # [B, q]
        
            results["hpvs_list"].append(hv)
            results["regr_list"].append(regret)
            results["entr_list"].append(entropy)
            results["time_list"].append(inference_time)

            instant_hv = test_function.compute_hv(
                solutions=y_ctx[:, -q:],  # [B, q, dy]
            )
            instant_hv = expand_metrics(instant_hv, q_x_next)
            results["instant_hpvs_list"].append(instant_hv)

            # Optional plot at step t
            if _should_plot(
                states.get_cost_used(), T, plot_per_n_steps, plot_enabled, init_cost
            ):
                # Plot optimization at step t
                fig = plot_optimization_batch(
                    test_function=test_function,
                    input_range_list=model_x_range,
                    x_query=x_ctx,
                    y_query=y_ctx,
                )

                filename = (
                    f"opt_plot_cost_{states.get_cost_used()}_T{T}_nc{x_ctx.shape[1]}"
                )
                save_fig(
                    fig,
                    path=plot_save_path,
                    config=opt_cfg,
                    filename=filename,
                    override=exp_cfg.override,
                    log=log,
                    log_to_wandb=exp_cfg.log_to_wandb,
                )

                # Plot acq_values at final step
                if acq_values is not None:
                    acq_fig = plot_acq_values(q_chunk=q_chunk, acq_values=acq_values)
                    filename = f"acq_heatmap_cost{states.get_cost_used()}"
                    save_fig(
                        acq_fig,
                        plot_save_path,
                        config=opt_cfg,
                        filename=filename,
                        override=exp_cfg.override,
                        log=log,
                        log_to_wandb=exp_cfg.log_to_wandb,
                    )

    log(f"--- Regret over time ---\n{regret}")

    # Summarize results
    hpvs_stack = torch.cat(results["hpvs_list"], dim=-1)  # [batch_size, T+1]
    regr_stack = torch.cat(results["regr_list"], dim=-1)  # [batch_size, T+1]
    entr_stack = torch.cat(results["entr_list"], dim=-1)  # [batch_size, T+1]
    time_stack = torch.tensor(results["time_list"])  # [T+1]

    instant_hpvs_stack = torch.cat(results["instant_hpvs_list"], dim=-1)  # [batch_size, T+1]

    result_dict = {
        "hv": hpvs_stack.detach().cpu(),
        "instant_hv": instant_hpvs_stack.detach().cpu(),
        "regret": regr_stack.detach().cpu(),
        "entropy": entr_stack.detach().cpu(),
        "time": time_stack.detach().cpu(),
        "x_ctx": x_ctx.detach().cpu(),
        "y_ctx": y_ctx.detach().cpu(),
    }

    if predict:
        nllc_stack = torch.stack(results["nll_c_list"], dim=-1)  # [T+1,]
        nllt_stack = torch.stack(results["nll_t_list"], dim=-1)
        msec_stack = torch.stack(results["mse_c_list"], dim=-1)  # [DY, T+1]
        mset_stack = torch.stack(results["mse_t_list"], dim=-1)
        result_dict.update(
            {
                "nll_c": nllc_stack.detach().cpu(),
                "nll_t": nllt_stack.detach().cpu(),
                "mse_c": msec_stack.detach().cpu(),
                "mse_t": mset_stack.detach().cpu(),
            }
        )

    # Save evaluation setup and results to data_save_path/res_filename.pt
    for key, val in result_dict.items():
        save_data(
            data=val.detach().cpu(),
            path=data_save_path,
            config=opt_cfg,
            filename=key,
            override=exp_cfg.override,
            log=log,
        )

    if plot_enabled:
        for b in range(opt_cfg.batch_size):
            batch_prefix = f"b{b}_" if opt_cfg.batch_size > 1 else ""

            plots = {
                f"{batch_prefix}hv": plot_1d(
                    y_vals=hpvs_stack[b],
                    title="Hypervolume over Iterations",
                    ylabel="Hypervolume",
                ),
                f"{batch_prefix}regret": plot_1d(
                    y_vals=regr_stack[b],
                    title="Regrets over Iterations",
                    ylabel="Regret",
                ),
                f"{batch_prefix}entropy": plot_1d(
                    y_vals=entr_stack[b],
                    title="Entropy over Iterations",
                    ylabel="Entropy",
                ),
                f"{batch_prefix}time": plot_1d(
                    y_vals=time_stack,
                    title="Inference Time over Iterations",
                    ylabel="Time (s)",
                ),
            }
            # Optionally plot pareto front
            front_fig = plot_fronts(
                function_name=data_cfg.function_name,
                dim=test_function.x_dim,
                solutions=y_ctx[b],  # [b]
            )
            if front_fig is not None:
                plots.update({f"{batch_prefix}pareto_front": front_fig})

        if predict:
            plots.update(
                {
                    f"nll_c": plot_1d(
                        y_vals=nllc_stack,
                        title="NLL Context",
                        ylabel="NLL Context",
                    ),
                    f"nll_t": plot_1d(
                        y_vals=nllt_stack,
                        title="NLL Target",
                        ylabel="NLL Target",
                    ),
                }
            )
            del nllc_stack, nllt_stack, msec_stack, mset_stack

        for name, fig in plots.items():
            save_fig(
                fig=fig,
                path=plot_save_path,
                config=opt_cfg,
                filename=name,
                override=exp_cfg.override,
                log=log,
                log_to_wandb=exp_cfg.log_to_wandb,
            )

        del hpvs_stack, regr_stack, entr_stack, time_stack

    del x_ctx, y_ctx
    del q_chunk, q_chunk_mask, logit_mask
    return results


def evaluate_prediction(
    model: TAMO,
    datapaths: List[str],
    data_save_path: str,
    plot_save_path: str,
    exp_cfg: ExperimentConfig,
    pred_cfg: PredictionConfig,
    data_cfg: DataConfig,
    num_workers: int = 0,
    prefetch_factor: Optional[int] = None,
    log: callable = print,
    plot_enabled: bool = False,
    plot_nc_list: Optional[List[int]] = None,
    **kwargs,
):
    """Evaluate TAMO model predictions on a single dataset."""
    set_all_seeds(exp_cfg.seed)

    max_x_dim = model.max_x_dim
    max_y_dim = model.max_y_dim

    plot_save_path = osp.join(plot_save_path, str(exp_cfg.seed))
    data_save_path = osp.join(data_save_path, str(exp_cfg.seed))

    dataset_list = get_datasets(
        datapaths=datapaths,
        max_x_dim=max_x_dim,
        max_y_dim=max_y_dim,
        standardize=True,
        zero_mean=True,
        range_scale=get_train_y_range(),
    )

    for dataset in tqdm(
        dataset_list, desc="Running prediction on datasets", unit="dataset"
    ):
        log(f"Evaluating prediction on data from:\n{dataset.hdf5_path}\n\n")
        run_prediction(
            model=model,
            dataset=dataset,
            plot_save_path=plot_save_path,
            data_save_path=data_save_path,
            exp_cfg=exp_cfg,
            pred_cfg=pred_cfg,
            data_cfg=data_cfg,
            num_workers=num_workers,
            prefetch_factor=prefetch_factor,
            log=log,
            plot_enabled=plot_enabled,
            plot_nc_list=plot_nc_list,
        )

    del dataset_list
    gc.collect()
    torch.cuda.empty_cache()


def run_prediction(
    model: TAMO,
    dataset: HDF5Dataset,
    plot_save_path: str,
    data_save_path: str,
    exp_cfg: ExperimentConfig,
    pred_cfg: PredictionConfig,
    data_cfg: DataConfig,
    num_workers: int = 0,
    prefetch_factor: Optional[int] = None,
    log: callable = print,
    plot_enabled: bool = False,
    plot_nc_list: Optional[List[int]] = None,
) -> Averager:
    """Evaluate prediction on dataset with seed."""
    dataloader = build_dataloader(
        dataset=dataset,
        batch_size=pred_cfg.batch_size,
        split=exp_cfg.mode,
        device=exp_cfg.device,
        num_workers=num_workers,
        prefetch_factor=prefetch_factor,
    )

    ravg = Averager()

    model = model.to(exp_cfg.device)
    model.eval()

    with torch.no_grad():
        for epoch, (x, y, valid_x_counts, valid_y_counts) in enumerate(dataloader):
            if has_nan_or_inf(x, "x") or has_nan_or_inf(y, "y"):
                continue

            x = x.to(exp_cfg.device)
            y = y.to(exp_cfg.device)
            valid_x_counts = valid_x_counts.to(exp_cfg.device)
            valid_y_counts = valid_y_counts.to(exp_cfg.device)

            x, y, x_mask, y_mask, nc = prepare_prediction_dataset(
                x=x,
                y=y,
                valid_x_counts=valid_x_counts,
                valid_y_counts=valid_y_counts,
                dim_scatter_mode=data_cfg.dim_scatter_mode,
                min_nc=pred_cfg.min_nc,
                max_nc=pred_cfg.max_nc,
                nc_fixed=pred_cfg.nc,
            )

            # Predict on context
            nll_c, mse_c, kt_tau_c, _ = predict_with_metrics(
                model=model,
                x_ctx=x[:, :nc],
                y_ctx=y[:, :nc],
                x_tar=x[:, :nc],
                y_tar=y[:, :nc],
                x_mask=x_mask,
                y_mask=y_mask,
                compute_nll=True,
                compute_mse=True,
                compute_ktt=True,
                reduce_nll=True,
                reduce_mse=True,
                read_cache=pred_cfg.read_cache,
                write_cache=pred_cfg.write_cache,
            )

            # Predict on target
            nll_t, mse_t, kt_tau_t, _ = predict_with_metrics(
                model=model,
                x_ctx=x[:, :nc],
                y_ctx=y[:, :nc],
                x_tar=x[:, nc:],
                y_tar=y[:, nc:],
                x_mask=x_mask,
                y_mask=y_mask,
                compute_nll=True,
                compute_mse=True,
                compute_ktt=True,
                reduce_nll=True,
                reduce_mse=True,
                read_cache=pred_cfg.read_cache,
                write_cache=pred_cfg.write_cache,
            )

            log_dict = {
                "nll_context": nll_c.detach().item(),
                "nll_target": nll_t.detach().item(),
            }

            for j, (mse_c_val, mse_t_val, kt_c_val, kt_t_val) in enumerate(
                zip(mse_c, mse_t, kt_tau_c, kt_tau_t)
            ):
                log_dict[f"mse_context_{j}"] = mse_c_val
                log_dict[f"mse_target_{j}"] = mse_t_val

                log_dict[f"rmse_context_{j}"] = math.sqrt(mse_c_val)
                log_dict[f"rmse_target_{j}"] = math.sqrt(mse_t_val)

                log_dict[f"kt_tau_context_{j}"] = kt_c_val
                log_dict[f"kt_tau_target_{j}"] = kt_t_val

            ravg.batch_update(log_dict)

            if plot_enabled and epoch == 0:
                plot_nc_list = plot_nc_list or [nc]
                for pnc in plot_nc_list:
                    fig = plot_prediction_batch(
                        model=model,
                        nc=pnc,
                        x=x,
                        y=y,
                        x_mask=x_mask,
                        y_mask=y_mask,
                        read_cache=pred_cfg.read_cache,
                        write_cache=pred_cfg.write_cache,
                    )

                    # Save figure
                    valid_x_counts = x_mask.int().sum(dim=-1)  # [B, ]
                    valid_y_counts = y_mask.int().sum(dim=-1)  # [B, ]
                    save_fig(
                        fig=fig,
                        path=plot_save_path,
                        config=pred_cfg,
                        filename=f"nc{pnc}",
                        override=exp_cfg.override,
                        log=log,
                        log_to_wandb=exp_cfg.log_to_wandb,
                    )
                    # save_fig(
                    #     fig,
                    #     path=plot_save_path,
                    #     filename=params_to_filename(
                    #         params=pred_cfg.to_dict(),
                    #         suffix={"nc": pnc},
                    #     ),
                    #     unique_id=f"dx{"".join(map(str, valid_x_counts.tolist()))}_dy{"".join(map(str, valid_y_counts.tolist()))}",
                    #     override=exp_cfg.override,
                    #     log=log,
                    # )

                gc.collect()
                torch.cuda.empty_cache()

    # Log / print results
    line = f"[results, seed={exp_cfg.seed}]\n" f"{ravg.info()}"
    log(line)

    if exp_cfg.log_to_wandb:
        wandb.log(
            {
                "eval/nll_context": ravg.get("nll_context"),
                "eval/nll_target": ravg.get("nll_target"),
            }
        )
    del dataloader, x, y, x_mask, y_mask
    del nll_c, nll_t, mse_c, mse_t
    gc.collect()
    torch.cuda.empty_cache()
