from calendar import c
from dataclasses import dataclass
from os import path
from typing import Any

from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, MISSING


@dataclass
class ModelConfig:
    _target_: str = MISSING
    embedding_size: int = MISSING
    is_causal: bool = True


@dataclass
class T5Config(ModelConfig):
    _target_: str = "models.T5"
    dim: Any = 768
    enc_num_tokens: int = 512
    enc_depth: int = 6
    enc_heads: int = 12
    enc_dim_head: int = 64
    enc_mlp_mult: int = 4
    dec_num_tokens: int = 512
    dec_depth: int = 6
    dec_heads: int = 12
    dec_dim_head: int = 64
    dec_mlp_mult: int = 4
    dropout: float = 0.0
    tie_token_emb: bool = True


@dataclass
class EncoderV2Config(ModelConfig):
    _target_: str = "model.Encoder_v2"
    n_heads: int = 8
    n_layers: int = 6
    dropout: float = 0.0
    batch_first: bool = True


@dataclass
class EncoderV3Config(ModelConfig):
    _target_: str = "model.Encoder_v3"
    n_heads: int = 8
    n_layers: int = 6
    dropout: float = 0.0
    batch_first: bool = True


@dataclass
class AlphaEncoderConfig(ModelConfig):
    _target_: str = "model.AlphaEncoder"
    n_heads: int = 8
    n_layers: int = 6
    dropout: float = 0.0
    batch_first: bool = True


@dataclass
class ShortCircuitConfig(ModelConfig):
    _target_: str = "model.ShortCircuit"
    n_heads: int = 16
    n_layers: int = 4
    n_policy_layers: int = 3
    n_value_layers: int = 3
    num_key_value_heads: int | None = None
    intermediate_size: int = 4096
    dropout: float = 0.0
    is_causal: bool = True
    rms_norm_eps: float = 1e-6
    attention_bias: bool = False
    position_embeddings: bool = True
    act_fn: Any = MISSING


@dataclass
class ShortCircuit2Config(ModelConfig):
    _target_: str = "model.ShortCircuit2"
    bool_embedding_size: int = 4
    n_heads: int = 16
    n_layers: int = 4
    n_policy_layers: int = 3
    n_value_layers: int = 3
    num_key_value_heads: int | None = None
    intermediate_size: int = 4096
    dropout: float = 0.0
    is_causal: bool = True
    rms_norm_eps: float = 1e-6
    attention_bias: bool = False
    position_embeddings: bool = True
    act_fn: Any = MISSING


@dataclass
class AlmaEncoderConfig(ModelConfig):
    _target_: str = "model.AlmaEncoder"
    n_heads: int = 8
    n_layers: int = 6
    n_policy_layers: int = 2
    n_value_layers: int = 2
    num_key_value_heads: int | None = None
    intermediate_size: int = 4096
    dropout: float = 0.0
    is_causal: bool = True
    rms_norm_eps: float = 1e-6
    attention_bias: bool = False
    position_embeddings: bool = False
    act_fn: Any = MISSING


@dataclass
class ActivationConfig:
    _target_: str = MISSING


@dataclass
class KLDivActivationConfig(ActivationConfig):
    _target_: str = "model.kldiv_activation"


@dataclass
class NormalizeActionConfig(ActivationConfig):
    _target_: str = "model.normalize_action"


@dataclass
class ValueActivationConfig(ActivationConfig):
    _target_: str = "model.value_activation"


@dataclass
class ValueActivation2Config(ActivationConfig):
    _target_: str = "model.value_activation2"


@dataclass
class BaseActivationConfig(ActivationConfig):
    _target_: str = "model.base_activation"


def register_models_configs() -> None:
    cs = ConfigStore.instance()

    cs.store(
        group="model",
        name="base_T5",
        node=T5Config,
    )

    cs.store(
        group="model",
        name="base_Encoder_v2",
        node=EncoderV2Config,
    )

    cs.store(
        group="model",
        name="base_Encoder_v3",
        node=EncoderV3Config,
    )

    cs.store(
        group="model",
        name="base_AlphaEncoder",
        node=AlphaEncoderConfig,
    )

    cs.store(
        group="model",
        name="base_ShortCircuit",
        node=ShortCircuitConfig,
    )

    cs.store(
        group="model",
        name="base_ShortCircuit2",
        node=ShortCircuit2Config,
    )

    cs.store(
        group="model",
        name="base_AlmaEncoder",
        node=AlmaEncoderConfig,
    )

    cs.store(
        group="activation",
        name="base_kldiv_activation",
        node=KLDivActivationConfig,
    )

    cs.store(
        group="activation",
        name="base_normalize_action",
        node=NormalizeActionConfig,
    )

    cs.store(
        group="activation",
        name="base_value_activation",
        node=ValueActivationConfig,
    )

    cs.store(
        group="activation",
        name="base_value_activation2",
        node=ValueActivation2Config,
    )

    cs.store(
        group="activation",
        name="base_activation",
        node=BaseActivationConfig,
    )
