#!/usr/bin/env python3

from __future__ import annotations

import torch
from tabulate import tabulate
import re
from rich import print

from rescue.models.causal_model.map_to_NN import CausalMeanVarSurrogateNN

def load_causal_model(
    input_dim: int,
    output_dim: int,
    device: torch.device,
    dtype: torch.dtype,
    state_dict: dict
):
    r"""
    Load a causal model from a state dictionary.

    Args:
        input_dim (int): The input dimensionality.
        output_dim (int): The output dimensionality.
        device (torch.device): The device to load the model on.
        dtype (torch.dtype): The data type for the model.
        state_dict (dict): The state dictionary containing model weights.

    Returns:
        CausalMeanVarSurrogateNN: The loaded causal neural network model.
    """
    causal_net = CausalMeanVarSurrogateNN(
                input_dim=input_dim, 
                output_dim=output_dim
            ).to(device=device, dtype=dtype)   
    causal_net.load_state_dict(state_dict)   
    return causal_net

def term_print(
    show_stats: bool,
    exp_stats: dict[str, float],
    budget: int | float,
    has_constraints: bool,
) -> None:
    r"""
    Print the experiment statistics to the terminal.

    Args:
        show_stats (bool): Whether to show statistics.
        exp_stats (dict[str, float]): The experiment statistics to log.
        budget (int | float): The budget for the experiment.
        has_constraints (bool): Whether the problem has constraints.
    """
    if not show_stats:
        return

    def keep(key, d):
        return {key: d[key]} if key in d else {}

    stats_to_show = {
        **keep("cost", exp_stats),
        **keep("best_nsga2_regret", exp_stats),
        **keep("curr_nsga2_regret", exp_stats),
        **keep("best_nsga2_hv", exp_stats),
        **keep("curr_nsga2_hv", exp_stats),
        **keep("curr_nsga2_violation", exp_stats),
        **keep("observed_hv", exp_stats),
        **keep("observed_regret", exp_stats),
        **keep("observed_violation", exp_stats),
        **keep("new_fidelity", exp_stats),
        **keep("causal_net_loss", exp_stats),
        **keep("acqu_value", exp_stats),
        **keep("iteration", exp_stats),
    }
    # Future me: This is over engineered, but it works.
    # Drop hv and causal_net_loss if curr_nsga2_regret present
    if "curr_nsga2_regret" in stats_to_show:
        stats_to_show = {
            k: v for k, v in stats_to_show.items()
            if "hv" not in k.lower() and "causal_net_loss" not in k.lower()
        }
    if not has_constraints:
        stats_to_show = {
            k: v for k, v in stats_to_show.items()
            if "violation" not in k.lower()
        }
    row = []
    for k, v in stats_to_show.items():
        # Skip causal_net_loss if hv or regret present anywhere
        if (
            k.lower() == "causal_net_loss"
            and any("hv" in kk.lower() or "regret" in kk.lower() for kk in stats_to_show)
        ):
            continue

        label = k
        if label.lower().startswith("observed_"):
            label = "obs_" + label.split("_", 1)[1]

        # replacements
        label = re.sub(r"(?i)acqu_value", "acqu_val", label)
        label = re.sub(r"(?i)new_fidelity", "new_fid", label)
        label = re.sub(r"(?i)regret", "rgrt", label)

        if k == "cost":
            row.append(f"Budget: {v:.1f}/{budget}")
        elif k == "iteration":
            row.append(f"Iter: {v}")
        elif isinstance(v, (float)):
            row.append(f"{label}: {v:.3f}")
        else:
            row.append(f"{label}: {v}")

    print(tabulate([row], tablefmt="plain"))