import json
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any, Literal

@dataclass
class ExperimentInputs:
    train_data_path: str
    val_data_path: str
    test_data_path: str
    seed: int

@dataclass
class TrainConfig:
    #------------------------
    # Experiment settings
    device: Literal["cpu", "gpu"]
    precision: Literal["float32", "float64"]
    use_wandb: bool
    epochs: int
    batch_size: int
    anchor_after: int
    noise_level: float
    crop_window_size: int    
    #------------------------
    # Model settings
    emulator_lr: float
    emulator_config: dict[str, Any]
    emulator_optimizer_type: str
    #------------------------
    # Summary settings
    summary_config: dict[str, Any]
    summary_optimizer_type: str
    summary_lr: float
    #------------------------
    # Distance settings
    distance_config: dict[str, Any]
    critic_optimizer_type: str
    critic_lr: float
    #------------------------
    # Saddle training settings
    ot_rollout: str
    ot_warm_up: int 
    ot_horizon: int
    adversarial_steps: int
    lambda_ot: float
    #------------------------

    def post_init(self):
        pass

def _filter_to_dataclass_keys(raw: dict[str, Any]) -> dict[str, Any]:
    allowed = {f.name for f in fields(TrainConfig)}
    return {k: v for k, v in raw.items() if k in allowed}

def get_exp_configs(config_paths: list[Path]) -> tuple[TrainConfig, ExperimentInputs]:
    with config_paths[0].open("r", encoding="utf-8") as f:
        raw: dict[str, Any] = json.load(f)

    train_cfg: dict[str, Any] = dict(raw)
    exp_inputs_cfg = {}
    for key in ["train_data_path", "val_data_path", "test_data_path", "seed"]:
        if key not in train_cfg:
            raise ValueError(f"Experiment inputs config missing required key: {key}\nConfig file: {config_paths[0]}")
        exp_inputs_cfg[key] = train_cfg.pop(key)

    nested = train_cfg.pop("train_config", None)
    if isinstance(nested, dict):
        train_cfg.update(nested)

    train_cfg = _filter_to_dataclass_keys(train_cfg)

    train_missing = [f.name for f in fields(TrainConfig) if f.name not in train_cfg]
    if train_missing:
        raise ValueError(
            f"Train config missing required keys: {train_missing}\n"
            f"Config file: {config_paths[0]}"
        )

    train_config_obj = TrainConfig(**train_cfg)
    exp_inputs_obj = ExperimentInputs(**exp_inputs_cfg)

    return train_config_obj, exp_inputs_obj
