from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Union


@dataclass
class ContextualConfig:
    config: Optional[Path] = None
    domains: List[Dict] = field(default_factory=list)
    adapt_domains: List[Dict] = field(default_factory=list)

    n_domains: int = 2
    add_domain: bool = False
    complex_task: bool = False
    robot: bool = False

    name: Optional[str] = None
    goal: int = 6

    n_tasks: int = -1
    n_traj: int = 1000
    train_ratio: float = 0.9
    batch_size: int = 64
    latent_dim: int = 256
    hid_dim: int = 256
    activation: str = 'gelu'
    device: str = 'cuda:0'
    comet: bool = False

    epochs: int = 500
    eval_interval: int = 50
    n_eval_episodes: int = 50
    n_render_episodes: int = 10

    lr: float = 1e-3

    test_n_traj: int = 200  # not used for training, so small is enough

    image_observation: bool = False
    image_state_dim: int = 1024
    use_image_decoder: bool = False
    image_recon_coef: float = 1.0
    use_coord_conv: bool = False
    evaluate: bool = True
    evaluate_parallel: bool = False  # avoid bug in parallelization
    amp: bool = True

    multienv: bool = False
    n_task_ids: Optional[int] = None
    target_goal_id: Optional[int] = None
    task_id_offset_list: Optional[List[int]] = None
    target_task_id_offset: Optional[int] = None

    n_enc_layers: int = 3
    n_dec_layers: int = 3
    n_head: int = 8
    sa_demo: bool = True

    # set in the script
    max_obs_dim: int = field(init=False)
    max_action_dim: int = field(init=False)
    max_seq_len: int = field(init=False)
    train_goal_ids: List[int] = field(init=False)
    logdir: Union[str, Path] = field(init=False)
