r"""
    Runtime configuration to offload a lot of constants in the file and get easy overrides
    Powered by hydra. Prepared for NDT3
"""
from typing import List, Any, Tuple, Dict
from enum import Enum
from pathlib import Path
from dataclasses import dataclass, field
from omegaconf import MISSING
from context_general_bci.config import RLConfig # This is tracked in NDT codebase to keep as much as possible outside of CLIMBER.

class Accelerator(Enum):
    vanilla = 'vanilla' # vanilla pytorch
    slim = 'slim' # torch.compile
    onnx = 'onnx' # onnx runtime

@dataclass
class OnlineConfig:
    r"""
        Mostly immutable settings for the day. Pattern for mutable settings not yet defined.
    """
    seed: int = 0

    default_constraint_fbc: float = 1.0 # Should we assume module initializes in constrained mode? Yes.

    # Monkey patches - deprecated
    default_covariate_mean: float = 0.0
    default_covariate_max: float = 0.007 # Taken from cursor CO observation
    default_covariate_min: float = -0.007

    # Working directory on _local_ drive
    # training_dir: Path = Path("./data/training").resolve() # Training directory holds training data
    run_dir: Path = Path("./data/runs").resolve() # Working directory holds finetuning details
    backbone_dir: Path = Path("./data/pretrained").resolve()
    backbone_ckpt: str = "base_40M.pt" # To be transferred directly from remote cluster

    shared_root_dir: Path = Path("./motor_rtdnn")
    decoder_log: str = "decoder_id.txt"
    record_initial: bool = False # Attach a logger that will record model IO in first minute


    remote_cluster: str = "crc" # Cluster alias (configured in `.ssh/config`) for remote training.
    remote_wandb_checkpoint_dir: str = "/ihome/rREDACT/REDACT/projects/context_general_bci/data/runs/ndt3/{run_id}/"

    # CLIMBER System level
    exec_alias: str = "mercy-REDACTnet-main" # SystemConfig.json key for Exec
    exec_address: str = "" # If provided, will override default provided by alias
    exec_port: int = 7111

    rtma_timeout: float = 0.02 # * Parity Diff: SpikeExtraction refreshes at 10ms; but JW says this shouldn't matter much since this module is largely reactive and will update per message, not per timeout.
    # refresh rate allows actions like button presses without messages, but that's not relevant for this RTMA module

    # Runtime operation
    # serve_framework: Accelerator = Accelerator.slim
    serve_framework: Accelerator = Accelerator.vanilla
    serve_rl: bool = False # If true, load RL infra, attempt online tuning (WIP)
    rl: RLConfig = field(default_factory=RLConfig)

    arrays_to_use: List[str] = field(default_factory=lambda: [
        '{subject}-lateral_m1', '{subject}-medial_m1' # Match order in `context_general_bci.presets`
    ])

    # Readin
    expected_channel_limit: int = 256
    unit_slots_per_channel: int = 5

    # buffer_size_ms: int = 960 # 1s, short horizon
    buffer_size_ms: int = 15000
    # reference vs working split is dynamically allocated
    # buffer_size_ms: int = 3000
    # reference_timesteps: int = 200 # Draws tokens from reference with time under this. 1ms is still model_bin_size_ms.
    default_model_bin_size_ms: int = 20 # technically configurable, but try not to change this...
    system_bin_size_ms: int = 20

    # Misc

    wandb_user: str = "REDACT"
    wandb_project: str = "ndt3"
    wandb_api_key_path: Path = Path("/home/REDACT/.wandb_api").resolve()

    exp: Any = MISSING # delta config, provide via yaml and on CLI as `+exp=<test>.yaml`
    debug: bool = False # for debugging, don't log to wandb, don't save ckpts, etc
    slim: bool = True # Extra slim mode for real time op


# Mutable decoder settings (GUI mutated)!
# Organizing for bookkeeping
@dataclass
class OnlineConfigMutable:
    active_config: str = ""
    ckpt_name: str = "val_kinematic_r2"
    model_bin_size_ms: int = 20
    prompt_index: int = -1

    # Sampling
    kin_nucleus_p: float = 1.0
    kin_temp: float = 0.5

    # MultiGame DT
    # https://arxiv.org/pdf/2205.15241.pdf
    # https://github.com/google-research/google-research/blob/master/multi_game_dt/Multi_game_decision_transformers_public_colab.ipynb
    return_nucleus_p: float = 0.9
    # return_nucleus_p: float = 0.7
    return_temp: float = 0.05 # Performance should not wildly fluctuate.
    # return_temp: float = 1.0
    return_logit_offset_kappa: float = 0.0

    # Defunct-ish
    fixed_return: int = 0 # If > 0, will use assigned value over return samples
