#!/usr/bin/env python3

from __future__ import annotations

from typing import TypedDict

from torch import Tensor
import pandas as pd
from dowhy import gcm
import networkx as nx

from rescue.models.causal_gp.multitask import CausalMultitaskGP
from rescue.models.causal_gp.multitask_multifidelity import (
    CausalMultitaskMultifidelityGP
)
from rescue.models.causal_model.map_to_NN import CausalMeanVarSurrogateNN


class RescueState(TypedDict):
    r"""
    TypedDict for the state of the Rescue algorithm.
    """
    # Future me, this dict is dynamic
    # not all key may present
    is_multifidelity: bool
    budget: float
    initial_cost: float
    current_cost: float
    design_variables: list[str]
    objective_variables: list[str]
    kpi_variables: None | list[str]
    constraint_variables: None | list[str]
    train_x: Tensor
    train_obj: Tensor
    train_constraints: None | Tensor
    x_intervention: Tensor
    observational_data: None | pd.DataFrame
    causal_graph: nx.DiGraph
    scm: gcm.StructuralCausalModel
    causal_net: CausalMeanVarSurrogateNN
    model: CausalMultitaskGP | CausalMultitaskMultifidelityGP  
    bounds: Tensor
    ref_point: Tensor
    max_hv: float | None
    objective_indices: list 
    constraints_indices: None | list
    get_seed: int
    iteration: int
    met_best_nsga2_regret: None | float
    met_best_nsga2_hv: None | float
    causal_net_loss: float
    acquisition_value: Tensor | None
    new_fidelity: float | None


class RescueStateMultifidelity(RescueState, TypedDict):
    is_discrete_fidelities: bool
    fidelity_param_name: str
    fixed_features_list:  None | list[dict[int, float]]
    target_fid_x: None | Tensor
    target_fid_obj: None | Tensor
    target_fid_constraints: None | Tensor