from typing import Literal
from flax import struct
from jax import numpy as jnp
import yaml


# TODO: Figure a way to introduce comments and options in the field of this dataclasses

BLOCK_TYPE = (
    Literal["mamba"]
    | Literal["transformer"]
    | Literal["transformer-sota"]  # advanced architecture like griffin and mamba
    | Literal["transformer-qual"]  # GAU unit
    | Literal["transformer-linear"]
    | Literal["lamma"]
    | Literal["retnet"]
    | Literal["rwkv"]
    | Literal["glu"]
    # used for pretraining
    | Literal["gemma"]
    | Literal["gemma-hugg"]  # use torch model directly from huggingface (for testing)
    | Literal["llama"]
)
ATT_TYPE = (
    Literal["latte_causal"]  # causal latte with absolute embeddimgs
    | Literal["latte_bid"]  # bidirectional latte with absolute embeddimgs
    | Literal["latte_convQR_bid"]  # bidirectional latte with absolute embeddimgs
    | Literal[
        "latte_convQR_causal"
    ]  # apply same convolution to X before using projections to K,Q
    | Literal[
        "latte_convAll_causal"
    ]  # apply same convolution to X, before using projections to K,Q,V
    | Literal["latte_mach_simple_causal"]
    | Literal["latte_mach_simple_bid"]
    | Literal["latte_mach_sliding_causal"]  # latte machiatto
    | Literal["latte_mach_sliding_bid"]  # latte machiatto
    | Literal["standard_causal"]
    | Literal["scan_standard_causal"]  # linear sequential att
    | Literal["standard_bid"]
    | Literal["linformer"]
)

EMBED_TYPE = Literal["rope"] | Literal["xpos"] | Literal["absolute"] | Literal["nope"]


@struct.dataclass
class EmptyConfig:
    @classmethod
    def load(cls, yaml_file, **kwargs):
        """Read configuration from json
        Args:
        yaml_file: Union[str, os.PathLike]
            Path to the yaml config
        **kwargs: other config parameters which are not part of the config file.
            If they are both part of the config and passed in the method,
                the parameters from the config file will take precendence
        """
        with open(yaml_file, "r", encoding="utf-8") as reader:
            config = yaml.safe_load(reader)
        # update json file configs with the bash configs
        config.update(kwargs)

        config = cls.validate(config)
        return cls(**config)

    @classmethod
    def validate(cls, config):
        if not "name" in config:
            raise NotImplementedError(
                "Experiemnt must have a name. Default not supported"
            )
        if not "base_dir" in config:
            raise NotImplementedError(
                "Experiemnt must have a base_dir. Default not supported"
            )
        return config


@struct.dataclass
class Config:
    # Name of the experiment.
    name: str
    # base directory where to dump trainign output. Experiment name will be a subfolder here.
    base_dir: str
    # The project under which run is saved
    project: str = "diffusion"
    # The team/account under which the project is saved
    entity: str = "baesian-learning"
    # tokenizer path: used for pretrained tokenizers
    tokenizer_path: str = None
    # name of the dataset used for classification
    dataset_name: str = "shakespeare"
    # number of epochs to train the VAE for
    epochs: int = 10
    # number of train steps. Should be set to None if we want to use epochs
    train_steps: int = None
    # number layers
    nlayers: int = 6
    # number heads
    nheads: int = 4
    num_key_value_heads: int = 4
    # attention dimension
    head_dim: int = None
    # Hidden dimension
    hidden_dim: int = 128
    # Dimention for the rotation matrix
    L: int = 10
    latte_nheads: int = 1
    # State dimension mamba models
    state_dim: int = 128
    # number of unrolls used for the scan operations in latte
    unroll: int = 100
    # maximum sequence length:
    max_seq_len: int = 1024
    # local attention for latte machiatto
    att_block_len: int = 128
    # maximu length used in positional embedings
    pos_embed_max_len: int = 1024
    # number of steps between evaluation
    eval_steps: int = 10
    # The maximum number of checkpoints to save
    max_checkpoints: int = 1
    # dropout each layer
    dropout_att: float = 0.0
    dropout: float = 0.1
    # weight decay for optimizer
    weight_decay: float = 0.01
    # used for initialisation
    initializer_range: float = 0.02
    # adam optimizer parameters
    adam_b1: float = 0.9
    adam_b2: float = 0.999
    adam_eps: float = 1e-08
    # The learning rate"
    lr: float = 3e-4
    # percentage of steps to do warmup out of total steps
    warmup_pc: float = 0
    # exact number of warmup steps, takes precedance over warmup_pc
    warmup: int = 0
    # learning rate decay function
    lr_decay_fn: str = (
        "cosine"  # "constant", "linear" TODO: replace str with typing.Literal
    )
    # training precission
    mixed_precision: Literal["bf16"] | None = None
    # end value used only for linear decay learning rate
    lr_end_value: float = 0.00001
    # use batchnorm or layer norm
    batchnorm: float = True
    # normalize before or after mlp
    prenorm: bool = False
    # shuffle examples in train
    shuffle_train: bool = True
    # batch size per all devices
    batch_size: int = 32
    # gradient accumulation steps
    grad_accumulation_steps: int = 1
    # Path to the pretrained checkpoint, useful for resuming training
    check_path: str = None
    # wandb run id to resume if checkpath specified
    run_id: str = None
    # Whether to use wandb logging
    wandb_log: bool = False
    # whether to process data with hugging face from scratch or not.
    # True = use cached versions
    disable_cache: bool = False

    @classmethod
    def load(cls, yaml_file, **kwargs):
        """Read configuration from json
        Args:
        yaml_file: Union[str, os.PathLike]
            Path to the yaml config
        **kwargs: other config parameters which are not part of the config file.
            If they are both part of the config and passed in the method,
                the parameters from the config file will take precendence
        """
        with open(yaml_file, "r", encoding="utf-8") as reader:
            config = yaml.safe_load(reader)
        # update json file configs with the bash configs
        config.update(kwargs)

        config = cls.validate(config)
        return cls(**config)

    @classmethod
    def validate(cls, config):
        if not "name" in config:
            raise NotImplementedError(
                "Experiemnt must have a name. Default not supported"
            )
        if not "base_dir" in config:
            raise NotImplementedError(
                "Experiemnt must have a base_dir. Default not supported"
            )
        return config


@struct.dataclass
class LMTaskConfig(Config):
    attention_type: ATT_TYPE = "stable_latte"
    block_type: BLOCK_TYPE = "transformer"
    hugg_chk: str = None
    embed_type: EMBED_TYPE = "nope"
    eval_gen_len: int = 100
    # Go through a defined number of samples instead of the entire val loader
    eval_samples: int = None
    # path to the file containing test prompts
    promt_path: str = None

    attention_bias: bool = False
    hidden_dim: int = 512
    intermediate_dim: int = 2048
    projection_dim: int = 512
    nlayers: int = 12
    nheads: int = 8
    num_key_value_heads: int = None  # used by gemma
    head_dim: int = None
    dropout_att: float = 0.0
    dropout: float = 0.1
    hidden_act: str = "silu"
    initializer_range: float = 0.02
    # local attention for latte machiatto
    att_block_len: int = 128
    L: int = 128
    unroll: int = 100
    pos_embed_max_len: int = 1024
    max_seq_len: int = 224  #     context_len: int = 220
    text_vocab_size: int = None


@struct.dataclass
class VisionEncoderConfig(EmptyConfig):
    attention_type: ATT_TYPE = "standard_bid"
    embed_type: EMBED_TYPE = "nope"
    attention_bias: bool = True
    image_size: int = 336
    patch_size: int = 14
    hidden_dim: int = 512
    intermediate_dim: int = 2048
    projection_dim: int = 512
    nlayers: int = 12
    nheads: int = 8
    num_key_value_heads: int = None  # used by gemma
    head_dim: int = None
    dropout_att: float = 0.0
    dropout: float = 0.1
    hidden_act: str = "silu"
    initializer_range: float = 0.02
    # local attention for latte machiatto
    att_block_len: int = 128
    L: int = 128
    unroll: int = 100
    pos_embed_max_len: int = 1024
    max_seq_len: int = 220  #     context_len: int = 220


@struct.dataclass
class TextDecoderConfig(EmptyConfig):
    attention_type: ATT_TYPE = "standard_causal"
    embed_type: EMBED_TYPE = "rope"
    attention_bias: bool = False
    hidden_dim: int = 512
    intermediate_dim: int = 2048
    projection_dim: int = 512
    nlayers: int = 12
    nheads: int = 8
    num_key_value_heads: int = None  # used by gemma
    head_dim: int = None
    dropout_att: float = 0.0
    dropout: float = 0.1
    hidden_act: str = "silu"
    initializer_range: float = 0.02
    # local attention for latte machiatto
    att_block_len: int = 128
    L: int = 128
    unroll: int = 100
    pos_embed_max_len: int = 1024
    max_seq_len: int = 220  #     context_len: int = 220
    text_vocab_size: int = None


@struct.dataclass
class LavaTaskConfig(EmptyConfig):
    @classmethod
    def load(cls, yaml_file, **kwargs):
        """Read configuration from json
        Args:
        yaml_file: Union[str, os.PathLike]
            Path to the yaml config
        **kwargs: other config parameters which are not part of the config file.
            If they are both part of the config and passed in the method,
                the parameters from the config file will take precendence
        """
        with open(yaml_file, "r", encoding="utf-8") as reader:
            config = yaml.safe_load(reader)
        # update json file configs with the bash configs
        config.update(kwargs)
        vis_conf = VisionEncoderConfig(**config["vision_config"])
        text_conf = TextDecoderConfig(**config["text_config"])
        config.update({"vision_config": vis_conf, "text_config": text_conf})

        config = cls.validate(config)
        return cls(**config)

    # Name of the experiment.
    name: str
    # base directory where to dump trainign output. Experiment name will be a subfolder here.
    base_dir: str
    # The project under which run is saved
    project: str = "diffusion"
    # The team/account under which the project is saved
    entity: str = "baesian-learning"

    block_type: BLOCK_TYPE = "transformer"
    hugg_chk: str = None  # name of the pretrained torch model to load from
    # Go through a defined number of samples instead of the entire val loader
    eval_samples: int = None
    vision_config: VisionEncoderConfig = None
    text_config: VisionEncoderConfig = None
    text_vocab_size: int = None
    batchnorm: bool = False
    # tokenizer path: used for pretrained tokenizers
    tokenizer_path: str = None
    # name of the dataset used for classification
    dataset_name: str = "shakespeare"
    # number of epochs to train the VAE for
    epochs: int = 10
    # number of train steps. Should be set to None if we want to use epochs
    train_steps: int = None
    # number of steps between evaluation
    eval_steps: int = 10
    # The maximum number of checkpoints to save
    max_checkpoints: int = 1
    # weight decay for optimizer
    weight_decay: float = 0.01
    # adam optimizer parameters
    adam_b1: float = 0.9
    adam_b2: float = 0.999
    adam_eps: float = 1e-08
    # The learning rate"
    lr: float = 3e-4
    # percentage of steps to do warmup out of total steps
    warmup_pc: float = 0
    # exact number of warmup steps, takes precedance over warmup_pc
    warmup: int = 0
    # learning rate decay function
    lr_decay_fn: str = (
        "cosine"  # "constant", "linear" TODO: replace str with typing.Literal
    )
    # training precission
    mixed_precision: Literal["bf16"] | None = None
    # end value used only for linear decay learning rate
    lr_end_value: float = 0.00001
    # shuffle examples in train
    shuffle_train: bool = True
    # batch size per all devices
    batch_size: int = 32
    # gradient accumulation steps
    grad_accumulation_steps: int = 1
    dtype: str = "bfloat32"  # jnp.dtype = jnp.bfloat16
    # Path to the pretrained checkpoint, useful for resuming training
    check_path: str = None
    # wandb run id to resume if checkpath specified
    run_id: str = None
    # Whether to use wandb logging
    wandb_log: bool = False
    # whether to process data with hugging face from scratch or not.
    # True = use cached versions
    disable_cache: bool = False


@struct.dataclass
class CopyTaskConfig(EmptyConfig):
    # Name of the experiment.
    name: str
    # base directory where to dump trainign output. Experiment name will be a subfolder here.
    base_dir: str
    # tokenizer path: used for pretrained tokenizers
    train_task: Literal["copy"] | Literal["prefix_ngram"] | Literal["suffix_ngram"]
    eval_task: (
        Literal["copy"]
        | Literal["prefix_ngram"]
        | Literal["suffix_ngram"]
        | Literal["duplicate_ngram"]
    )
    model: Literal["latte"] | Literal["mamba"]
    # The project under which run is saved
    project: str = "diffusion"
    # The team/account under which the project is saved
    entity: str = "baesian-learning"
    attention_type: ATT_TYPE = "stable_latte"
    block_type: BLOCK_TYPE = "transformer"
    vocab_size: int = 26
    n_gram: int = 0
    length_answer: int = 0

    hidden_dim: int = 1024
    # Dimention for the rotation matrix
    L: int = 10
    # number heads
    nheads: int = 16
    # number layers
    nlayers: int = 12
    # The maximum number of checkpoints to save
    max_checkpoints: int = 1
    # dropout each layer
    dropout: float = 0.1
    dropout_att: float = 0.0
    # weight decay for optimizer
    weight_decay: int = 0.01
    # The learning rate"
    lr: float = 3e-4
    # percentage of steps to do warmup out of total steps
    warmup_pc: float = 0
    # exact number of warmup steps, takes precedance over warmup_pc
    warmup: int = 0
    # learning rate decay function
    lr_decay_fn: str = (
        "cosine"  # "constant", "linear" TODO: replace str with typing.Literal
    )
    # end value used only for linear decay learning rate
    lr_end_value: float = 0.00001
    # use batchnorm or layer norm
    batchnorm: float = False
    # normalize before or after mlp
    prenorm: bool = False
    batch_size: int = 32
    shuffle_train: bool = False
    # number of unrolls used for the scan operations in latte
    unroll: int = 100
    # gradient accumulation steps
    grad_accumulation_steps: int = 1
    # wandb run id to resume if checkpath specified
    run_id: str = None
    # Whether to use wandb logging
    wandb_log: bool = False
    # whether to process data with hugging face from scratch or not.
    # True = use cached versions
    disable_cache: bool = False
    # Path to the pretrained checkpoint, useful for resuming training
    check_path: str = None

    epochs: int = 10
    # number of steps between evaluation
    eval_steps: int = 50
    train_steps: int = None
    pos_embed_max_len: int = 1024
    max_seq_len: int = 220  #     context_len: int = 220
    eval_num_batches: int = 3

    min_train_len: int = 5
    max_train_len: int = 20
    min_eval_len: int = 10
    max_eval_len: int = 20
    eval_context_len: int = 220


@struct.dataclass
class LRATaskConfig(Config):
    attention_type: ATT_TYPE = "stable_latte"
    block_type: BLOCK_TYPE = "transformer"
    embed_type: EMBED_TYPE = "nope"
    blocks: int = 16
    small_lr: float = 0.001
    att_block_len: int = 128
    # whether to devide by 255.0 or not
    normalize_img: bool = False
    # Whether to use tokens and embeddings like in nlp for images. (vocab size = 256)
    tokenize_img: bool = False
    # whether to use convolution instead of dense embedding for images.
    # like VitTransformers
    conv_embed: bool = False
    # model of pooling: ["mean", "last"]
    pool: str = "last"
    # num_classes
    num_classes: int = 10
