import pprint
from typing import Annotated, Dict, List, Optional, Type, Union, Any
from multinav.data.dataset import DatasetConfig
from multinav.model.model_base import MultiNavModel
from multinav.model.heads.heads import MULTINAV_HEADS, AnyHeadConfig
from flax.struct import dataclass, field
import tyro


def make_config(actor_head: str, heads: List[str]) -> Type[MultiNavModel.Config]:
    @dataclass
    class MyConfig(MultiNavModel.Config):
        actor_head_name: str = actor_head
        head_configs: Dict[str, AnyHeadConfig] = field(
            default_factory=lambda: {h: MULTINAV_HEADS[h].Config() for h in heads}
        )

    return MyConfig


model_configs = {
    "ar": make_config("ar", ["ar", "td"]),
    "bc": make_config("bc", ["bc", "td"]),
    "cql": make_config("cql", ["cql", "td", "bc"]),
    "discrete_cql": make_config("discrete_cql", ["discrete_cql"]),
}
ModelConfig = Union.__getitem__(
    tuple(
        [
            Annotated[
                MultiNavModel.Config, tyro.conf.subcommand(name=config_name, constructor=config_cls)
            ]
            for config_name, config_cls in model_configs.items()
        ]
    )
)


@dataclass
class TrainingConfig:
    config: ModelConfig
    data_config: DatasetConfig

    num_epochs: int = 10
    batch_size_per_device: int = 64
    devices: Optional[List[int]] = None
    log_interval: int = 10
    save_interval: int = 1000
    eval_interval: int = 1000
    eval_ratio: float = 0.1
    model_path: Optional[str] = None
    models_to_keep: Optional[int] = 10
    data_path: Optional[str] = None
    seed: int = 42


if __name__ == "__main__":
    tyro.cli(TrainingConfig)
