# =============================================================================
# Dataset
# =============================================================================

from typing import Any
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, field

import torch
from torch import Tensor



# -----------------------------------------------------------------------------
# BiLevelObservedData
# -----------------------------------------------------------------------------

@dataclass
class BiLevelObservedData:
    """One observed data point in a bilevel optimization problem."""

    x: Tensor  # shape: [num_dims]
    y_upper: Tensor  # shape: [num_objectives[0]]
    y_lower: Tensor  # shape: [num_objectives[1]]
    c_upper: Tensor | None = None  # shape: [num_constraints[0]]
    c_lower: Tensor | None = None  # shape: [num_constraints[1]]
    timestamp: datetime = field(default_factory=datetime.now)
    metadata: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self) -> None:

        self.x = torch.atleast_1d(self.x.squeeze())
        self.y_upper = torch.atleast_1d(self.y_upper.squeeze())
        self.y_lower = torch.atleast_1d(self.y_lower.squeeze())
        if self.c_upper is not None:
            self.c_upper = torch.atleast_1d(self.c_upper.squeeze())
        if self.c_lower is not None:
            self.c_lower = torch.atleast_1d(self.c_lower.squeeze())



# -----------------------------------------------------------------------------
# BiLevelBODataset
# -----------------------------------------------------------------------------

class BiLevelBODataset:
    """Dataset for bilevel bayesian optimization tasks."""

    def __init__(
        self,
        num_dims: list[int],
        num_objectives: list[int],
        num_constraints: list[int],
    ) -> None:

        self.num_dims = num_dims
        self.num_objectives = num_objectives
        self.num_constraints = num_constraints

        self.d_in = sum(num_dims)
        self.d_out = sum(num_objectives) + sum(num_constraints)
        self.data: list[BiLevelObservedData] = []


    def add(
        self,
        X: Tensor,  # shape: [n, d_in]
        outputs: Tensor,  # shape: [n, d_out]
        metadata: dict[str, Any] | list[dict[str, Any]] = {},
    ) -> None:
        """Add a vector of output values to the dataset."""

        X = X.squeeze().view(-1, self.d_in)
        outputs = outputs.squeeze().view(-1, self.d_out)
        split_size = self.num_objectives + self.num_constraints
        if not isinstance(metadata, list):
            metadata = [metadata] * X.shape[0]
        
        for x, out, meta in zip(X, outputs, metadata):
            y_upper, y_lower, c_upper, c_lower = out.split(split_size)
            obs = BiLevelObservedData(
                x=x,
                y_upper=y_upper,
                y_lower=y_lower,
                c_upper=(c_upper if len(c_upper) > 0 else None),
                c_lower=(c_lower if len(c_lower) > 0 else None),
                metadata=meta,
            )
            self.data.append(obs)


    def get(
        self,
    ) -> list[tuple[Tensor, Tensor]]:
        """Retrieve model inputs and outputs from the dataset."""

        train_pairs = []
        data_X = torch.stack([obs.x for obs in self.data])
        data_Y_upper = torch.stack([obs.y_upper for obs in self.data])
        data_Y_lower = torch.stack([obs.y_lower for obs in self.data])
        for i, data_Y in enumerate([data_Y_upper, data_Y_lower]):
            for j in range(self.num_objectives[i]):
                mask = ~data_Y[:, j].isnan()
                train_X = data_X[mask, :]
                train_Y = data_Y[mask, j:j+1]
                train_pairs.append((train_X, train_Y))
        if self.num_constraints[0] > 0:
            data_C_upper = torch.stack([obs.c_upper for obs in self.data])
            for i in range(self.num_constraints[0]):
                mask = ~data_C_upper[:, i].isnan()
                train_X = data_X[mask, :]
                train_C_upper = data_C_upper[mask, i:i+1]
                train_pairs.append((train_X, train_C_upper))
        if self.num_constraints[1] > 0:
            data_C_lower = torch.stack([obs.c_lower for obs in self.data])
            for i in range(self.num_constraints[1]):
                mask = ~data_C_lower[:, i].isnan()
                train_X = data_X[mask, :]
                train_C_lower = data_C_lower[mask, i:i+1]
                train_pairs.append((train_X, train_C_lower))
        return train_pairs


    def save(
        self,
        path: Path,
    ) -> None:
        """Save the dataset to a file."""

        data = {
            "num_dims": self.num_dims,
            "num_objectives": self.num_objectives,
            "num_constraints": self.num_constraints,
            "X": torch.stack([obs.x for obs in self.data]),
            "Y_upper": torch.stack([obs.y_upper for obs in self.data]),
            "Y_lower": torch.stack([obs.y_lower for obs in self.data]),
            "timestamp": [obs.timestamp for obs in self.data],
            "metadata": [obs.metadata for obs in self.data],
        }
        if self.num_constraints[0] > 0:
            data["C_upper"] = torch.stack([obs.c_upper for obs in self.data])
        if self.num_constraints[1] > 0:
            data["C_lower"] = torch.stack([obs.c_lower for obs in self.data])
        torch.save(data, path)


    @classmethod
    def load(
        cls,
        path: Path,
    ) -> "BiLevelBODataset":
        """Load a `BiLevelBODataset` object from a file."""

        data = torch.load(path, weights_only=False)
        dataset = cls(
            num_dims=data["num_dims"],
            num_objectives=data["num_objectives"],
            num_constraints=data["num_constraints"],
        )
        for i in range(len(data["X"])):
            obs = BiLevelObservedData(
                x=data["X"][i],
                y_upper=data["Y_upper"][i],
                y_lower=data["Y_lower"][i],
                c_upper=(data["C_upper"][i] if "C_upper" in data else None),
                c_lower=(data["C_lower"][i] if "C_lower" in data else None),
                timestamp=data["timestamp"][i],
                metadata=data["metadata"][i],
            )
            dataset.data.append(obs)
        return dataset

