#! /usr/bin/env python3

from __future__ import annotations

import torch
from torch import Tensor
import pandas as pd
from botorch.test_functions.base import (
    MultiObjectiveTestProblem,
    ConstrainedBaseTestProblem
)

def _load_tensors(
    f: str,
) -> dict[int, tuple[Tensor, Tensor, Tensor | None]]:
    return torch.load(
        f,
        weights_only=False,
        map_location=torch.device("cpu"),
    )

def load_observational_data(
    sample_seed: int,
    problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
) -> pd.DataFrame:
    r"""
    Load observational data for a given problem and sample seed.
    
    Args:
        sample_seed (int): The seed corresponding to the desired data sample.
        problem (MultiObjectiveTestProblem | ConstrainedBaseTestProblem): 
            The problem instance for which to load observational data.
    
    Returns:
        pd.DataFrame: A DataFrame containing the observational data with 
            appropriate columns.
    """
    # Dynamically determine the data directory based on the problem class
    import os
    problem_name = problem.__class__.__name__
    # Get the directory where this file (load_data.py) is located
    current_dir = os.path.dirname(os.path.abspath(__file__))
    data_dir = os.path.join(current_dir, "observational_data", f"{problem_name}_obs_data_mf.pt")

    # Load tensors and check if sample_seed exists
    data_tensors_dict = _load_tensors(data_dir)
    valid_seeds = list(data_tensors_dict.keys())
    if sample_seed not in data_tensors_dict:
        raise ValueError(
            f"No data found for sample_seed {sample_seed} in {data_dir}. Valid seeds are: {valid_seeds}"
        )

    data_tensors = data_tensors_dict[sample_seed]
    data_tensors = torch.cat(
        [t.cpu() for t in data_tensors if t is not None], dim=-1).numpy()

    # Handle missing constraint_var_names gracefully
    try:
        columns = (
            problem.design_var_names +
            problem.objective_var_names +
            (getattr(problem, 'constraint_var_names', []) or [])
        )
    except Exception as e:
        raise ValueError(
            "Problem must have design_var_names, objective_var_names, and optionally constraint_var_names attributes."
        ) from e

    return pd.DataFrame(
        data_tensors,
        columns=columns
    )