from dataclasses import dataclass
from typing import Any, Literal, Optional, Type

import torch
from typing_extensions import Self

import lit_gpt.model
from lit_gpt.utils import find_multiple


@dataclass
class Config:
    org: str = "Lightning-AI"
    name: str = "lit-GPT"
    block_size: int = 4096
    vocab_size: int = 50254
    padding_multiple: int = 512
    padded_vocab_size: Optional[int] = None
    n_layer: int = 16
    n_head: int = 32
    n_embd: int = 4096
    rotary_percentage: float = 0.25
    parallel_residual: bool = True
    bias: bool = True
    # to use multi-head attention (MHA), set this to `n_head` (default)
    # to use multi-query attention (MQA), set this to 1
    # to use grouped-query attention (GQA), set this to a value in between
    # Example with `n_head=4`
    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
    # │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    #   │    │    │    │         │        │                 │
    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
    # │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    #   │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
    # ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
    # │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
    # └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
    # ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
    #         MHA                    GQA                   MQA
    #   n_query_groups=4       n_query_groups=2      n_query_groups=1
    #
    # credit https://arxiv.org/pdf/2305.13245.pdf
    n_query_groups: Optional[int] = None
    shared_attention_norm: bool = False
    _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
    norm_eps: float = 1e-5
    _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
    intermediate_size: Optional[int] = None
    condense_ratio: int = 1
    intradoc_mask: str = ""
    merge_method: str = "no"
    positional_embedding: str = "rope"
    rope_base: int = 10000
    window_size: int = -1

    def __post_init__(self):
        # error checking
        assert self.n_embd % self.n_head == 0
        # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
        if self.padded_vocab_size is None:
            self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
        # compute the number of query groups
        if self.n_query_groups is not None:
            assert self.n_head % self.n_query_groups == 0
        else:
            self.n_query_groups = self.n_head
        # compute the intermediate size for MLP if not set
        if self.intermediate_size is None:
            if self._mlp_class == "LLaMAMLP":
                raise ValueError("The config needs to set the `intermediate_size`")
            self.intermediate_size = 4 * self.n_embd

    @property
    def head_size(self) -> int:
        return self.n_embd // self.n_head

    @classmethod
    def from_name(cls, name: str, **kwargs: Any) -> Self:
        conf_dict = name_to_config[name].copy()
        conf_dict.update(kwargs)
        return cls(**conf_dict)

    @property
    def mlp_class(self) -> Type:
        # `self._mlp_class` cannot be the type to keep the config json serializable
        return getattr(lit_gpt.model, self._mlp_class)

    @property
    def norm_class(self) -> Type:
        # `self._norm_class` cannot be the type to keep the config json serializable
        if self._norm_class == "RMSNorm":
            from lit_gpt.rmsnorm import RMSNorm

            return RMSNorm
        elif self._norm_class == "FusedRMSNorm":
            from lit_gpt.rmsnorm import FusedRMSNorm
            return FusedRMSNorm
        return getattr(torch.nn, self._norm_class)


########################
# Stability AI StableLM
########################
configs = [
    # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
    dict(org="stabilityai", name="stablelm-base-alpha-3b", padding_multiple=512),
    # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
    dict(org="stabilityai", name="stablelm-base-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
    dict(org="stabilityai", name="stablelm-tuned-alpha-3b", n_head=32, padding_multiple=512),
    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
    dict(org="stabilityai", name="stablelm-tuned-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
]

####################
# EleutherAI Pythia
####################
pythia = [
    # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
    dict(org="EleutherAI", name="pythia-70m", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128),
    # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-160m", block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128
    ),
    # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-410m", block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=128
    ),
    # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
    dict(org="EleutherAI", name="pythia-1b", block_size=2048, n_layer=16, n_embd=2048, n_head=8, padding_multiple=128),
    # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-1.4b", block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128
    ),
    # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-2.8b", block_size=2048, n_layer=32, n_embd=2560, n_head=32, padding_multiple=128
    ),
    # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-6.9b", block_size=2048, n_layer=32, n_embd=4096, n_head=32, padding_multiple=256
    ),
    # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-12b", block_size=2048, n_layer=36, n_embd=5120, n_head=40, padding_multiple=512
    ),
]
configs.extend(pythia)
for c in pythia:
    copy = c.copy()
    copy["name"] = f"{c['name']}-deduped"
    configs.append(copy)

####################################
# togethercomputer RedPajama INCITE
####################################
redpajama_incite = [
    # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
    dict(
        org="togethercomputer",
        name="RedPajama-INCITE-{}-3B-v1",
        block_size=2048,
        n_layer=32,
        n_embd=2560,
        n_head=32,
        padding_multiple=256,
        rotary_percentage=1.0,
        parallel_residual=False,
    ),
    # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
    dict(
        org="togethercomputer",
        name="RedPajama-INCITE-7B-{}",
        block_size=2048,
        n_layer=32,
        n_embd=4096,
        n_head=32,
        padding_multiple=256,
        rotary_percentage=1.0,
        parallel_residual=False,
    ),
    # this redirects to the checkpoint above. kept for those who had the old weights already downloaded
    dict(
        org="togethercomputer",
        name="RedPajama-INCITE-{}-7B-v0.1",
        block_size=2048,
        n_layer=32,
        n_embd=4096,
        n_head=32,
        padding_multiple=256,
        rotary_percentage=1.0,
        parallel_residual=False,
    ),
]
for c in redpajama_incite:
    for kind in ("Base", "Chat", "Instruct"):
        copy = c.copy()
        copy["name"] = c["name"].format(kind)
        configs.append(copy)

#################
# TII UAE Falcon
#################
falcon = [
    # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
    dict(
        org="tiiuae",
        name="falcon-7b{}",
        block_size=2048,
        padded_vocab_size=65024,
        n_layer=32,
        n_head=71,
        n_embd=4544,
        rotary_percentage=1.0,
        parallel_residual=True,
        n_query_groups=1,
        bias=False,
        # this is not in the config, but in the original model implementation, only for this config
        shared_attention_norm=True,
    ),
    # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
    dict(
        org="tiiuae",
        name="falcon-40b{}",
        block_size=2048,
        padded_vocab_size=65024,
        n_layer=60,
        n_head=128,
        n_embd=8192,
        rotary_percentage=1.0,
        parallel_residual=True,
        n_query_groups=8,
        bias=False,
    ),
]
for c in falcon:
    for kind in ("", "-instruct"):
        copy = c.copy()
        copy["name"] = c["name"].format(kind)
        configs.append(copy)

#############################
# StatNLP Research
#############################
tiny_LLaMA = [

    # https://twitter.com/cwolferesearch/status/1691929174175264858
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_1k",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_1k_intramask",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_2k",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_1k",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_1k_intramask",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_512",
        block_size=512,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),

    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_512_intramask",
        block_size=512,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_256",
        block_size=256,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_256_intramask",
        block_size=256,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_128",
        block_size=128,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_128_intramask",
        block_size=128,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask='strict',
    ), dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_64",
        block_size=64,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_64_intramask",
        block_size=64,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_2k",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_2k_intramask",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_4k",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_4k_intramask",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_dm1",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask='dm1',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k_rb5",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        rope_base=100000,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_1k_rb5",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        rope_base=100000,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_2k_rb5",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        rope_base=100000,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_nope",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        positional_embedding="no",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_nopefix2",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        positional_embedding="no",
        intradoc_mask="fix2",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_nopefix1",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        positional_embedding="no",
        intradoc_mask="fix1",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_1k_nope",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        positional_embedding="no",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_fix2",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask="fix2",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_fix1",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask="fix1",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_intramask",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask="strict",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask="strict",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_16k_intramask",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask="strict",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_intramask_olm",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask="strict",
        merge_method="overlap",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_360M_8k_adamask",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=18,  # 16
        n_head=16,  # 16
        n_embd=1024,  # 1024
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=4096,
        n_query_groups=16,  # 16
        intradoc_mask="adaptive",
        merge_method="overlap",
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_512",
        block_size=512,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_512_intramask",
        block_size=512,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_256",
        block_size=256,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_256_intramask",
        block_size=256,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_128",
        block_size=128,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_128_intramask",
        block_size=128,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_64",
        block_size=64,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_64_intramask",
        block_size=64,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_2k",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_2k_intramask",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_1k",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_1k_intramask",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_1k_rb5",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        rope_base=100000,
    ),

    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M_16k_intramask",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_4k",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_4k_intramask",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_32k_dm2",
        block_size=32768,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        rope_base=1000000,
        intradoc_mask='dm2',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_rb5",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        rope_base=100000,
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_2k_rb5",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        rope_base=100000,
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_nope",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        positional_embedding="no"
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_nopefix1",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        positional_embedding="no",
        intradoc_mask="fix1",
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_nopefix2",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        positional_embedding="no",
        intradoc_mask="fix2",
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_1k_nope",
        block_size=1024,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        positional_embedding="no"
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_intramask",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_dm1",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='dm1',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_dm2",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='dm2',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_dm4",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='dm2',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_intranope",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='strict',
        positional_embedding="no"
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_fix2",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='fix2',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_fix2rerope",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='fix2rerope',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_fix1",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='fix1',
    ),
    dict(
        org="new",
        name="tiny_LLaMA_120M_8k_fix1rerope",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
        intradoc_mask='fix1rerope',
    ),
    dict(
        org="StatNLP-research",
        name="code_tiny_LLaMA_1b_8k",
        block_size=8192,
        vocab_size=49152,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        condense_ratio=4
    ),
    dict(
        org="StatNLP-research",
        name="coder_tinyllama_1b_8k",
        block_size=8192,
        vocab_size=49152,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        rope_base=1000000,
    ),
    dict(
        org="StatNLP-research",
        name="coder_tinyllama_1b_32k",
        block_size=32768,
        vocab_size=49152,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        rope_base=1000000,
    ),
    dict(
        org="StatNLP-research",
        name="code_tiny_LLaMA_1b_32k",
        block_size=32768,
        vocab_size=49152,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        condense_ratio=4
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_4k",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_512",
        block_size=512,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_512_intramask",
        block_size=512,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_256",
        block_size=256,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_256_intramask",
        block_size=256,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_128",
        block_size=128,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_128_intramask",
        block_size=128,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_64",
        block_size=64,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_64_intramask",
        block_size=64,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k_fix1",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask='fix1'
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k_fix2",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask='fix2'
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k_dm1",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask='dm1'
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_32k",
        block_size=32768,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        rope_base=1000000,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_32k_intramask",
        block_size=32768,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        rope_base=1000000,
        intradoc_mask='strict',
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k_intramask",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="strict"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_4k_intramask",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="strict"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_2k_intramask",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="strict"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k_matchmask",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="match"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_16k_matchmask",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="match"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_16k_intramask",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="strict"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k_intramask_olm",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="strict",
        merge_method="overlap"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_8k_adamask",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="adaptive",
        merge_method="overlap"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b_16k_adamask",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        intradoc_mask="adaptive",
        merge_method="overlap"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_3b_8k",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=26,
        n_head=32,
        n_embd=3200,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=8640,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_3b_2k",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=26,
        n_head=32,
        n_embd=3200,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=8640,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_3b_8k_intramask",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=26,
        n_head=32,
        n_embd=3200,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=8640,
        n_query_groups=4,
        intradoc_mask="strict"
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_3b_2k_intramask",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=26,
        n_head=32,
        n_embd=3200,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=8640,
        n_query_groups=4,
        intradoc_mask="strict"
    ),
]

llama3_2 = [
    dict(
        name="llama3.2_3b_8k",
        block_size=8192,
        vocab_size=32000,
        n_layer=28,
        n_embd=3072,
        n_head=24,
        n_query_groups=8,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=8192,
        rope_base=100000,
    ),
    dict(
        name="llama3.2coder_3b_8k",
        block_size=8192,
        vocab_size=49152,
        n_layer=28,
        n_embd=3072,
        n_head=24,
        n_query_groups=8,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=8192,
        rope_base=100000,
    ),
    dict(
        name="llama3.2_3b_16k",
        block_size=16384,
        vocab_size=32000,
        n_layer=28,
        n_embd=3072,
        n_head=24,
        n_query_groups=8,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=8192,
        rope_base=500000,
    ),
    dict(
        name="llama3.2_3b_32k",
        block_size=32768,
        vocab_size=32000,
        n_layer=28,
        n_embd=3072,
        n_head=24,
        n_query_groups=8,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=8192,
        rope_base=1000000,
    ),
]

for dm_mask in ['intradm1', 'intradm2', 'intradm4', 'dm1', 'dm2', 'dm4', 'dm8', 'intradm8', 'intramask']:
    llama3_2.extend([
        dict(
            name="llama3.2coder_3b_8k_" + dm_mask,
            block_size=8192,
            vocab_size=49152,
            n_layer=28,
            n_embd=3072,
            n_head=24,
            n_query_groups=8,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            _norm_class="RMSNorm",
            norm_eps=1e-5,
            _mlp_class="LLaMAMLP",
            intermediate_size=8192,
            rope_base=100000,
            intradoc_mask=dm_mask if dm_mask != 'intramask' else 'strict'
        ),
        dict(
            name="llama3.2_3b_8k_" + dm_mask,
            block_size=8192,
            vocab_size=32000,
            n_layer=28,
            n_embd=3072,
            n_head=24,
            n_query_groups=8,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            _norm_class="RMSNorm",
            norm_eps=1e-5,
            _mlp_class="LLaMAMLP",
            intermediate_size=8192,
            rope_base=100000,
            intradoc_mask=dm_mask if dm_mask != 'intramask' else 'strict'
        ),
        dict(
            name="llama3.2_3b_16k_" + dm_mask,
            block_size=16384,
            vocab_size=32000,
            n_layer=28,
            n_embd=3072,
            n_head=24,
            n_query_groups=8,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            _norm_class="RMSNorm",
            norm_eps=1e-5,
            _mlp_class="LLaMAMLP",
            intermediate_size=8192,
            rope_base=500000,
            intradoc_mask=dm_mask if dm_mask != 'intramask' else 'strict'
        ),
        dict(
            name="llama3.2_3b_32k_" + dm_mask,
            block_size=32768,
            vocab_size=32000,
            n_layer=28,
            n_embd=3072,
            n_head=24,
            n_query_groups=8,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            _norm_class="RMSNorm",
            norm_eps=1e-5,
            _mlp_class="LLaMAMLP",
            intermediate_size=8192,
            rope_base=1000000,
            intradoc_mask=dm_mask if dm_mask != 'intramask' else 'strict'
        ),
    ])
configs.extend(llama3_2)

for dm_mask in ['intradm1', 'intradm2', 'intradm4', 'dm1', 'dm2', 'dm4', 'fix2', "sc4",
                "dm8", "intradm8",
                'dm1st4', 'dm1st8', 'dm1st16', 'dm1st64', 'dm1st128', 'dm1st256', 'dm1st512',
                'dm2st4', 'dm2st8', 'dm2st16', 'dm2st64', 'dm2st128', 'dm2st256', 'dm2st512',
                'dm4st4', 'dm4st8', 'dm4st16', 'dm4st64', 'dm4st128', 'dm4st256', 'dm4st512', "sc4",
                "exp2inc1024", "exp2", "sin2inc1024", "sin2", "dm2inc1024",
                "exp8inc1024", "exp8", "sin8inc1024", "sin8", "dm8inc1024",
                "cos8", "log8", "dm32", "dm5", "dm6", "dm7", "dm3", "inv8", "lin95p", 'lin90p', 'lin80p'
                ]:
    tiny_LLaMA.append(
        dict(
            org="StatNLP-research",
            name=f"tiny_LLaMA_1b_32k_{dm_mask}",
            block_size=32768,
            vocab_size=32000,
            padding_multiple=64,
            n_layer=22,
            n_head=32,
            n_embd=2048,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            _norm_class="FusedRMSNorm",
            norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
            _mlp_class="LLaMAMLP",
            intermediate_size=5632,
            n_query_groups=4,
            intradoc_mask=dm_mask,
            rope_base=1000000,
        ))
    tiny_LLaMA.append(
        dict(
            org="StatNLP-research",
            name=f"code_tiny_LLaMA_1b_32k_{dm_mask}",
            block_size=32768,
            vocab_size=49152,
            padding_multiple=64,
            n_layer=22,
            n_head=32,
            n_embd=2048,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            _norm_class="FusedRMSNorm",
            norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
            _mlp_class="LLaMAMLP",
            intermediate_size=5632,
            n_query_groups=4,
            intradoc_mask=dm_mask,
            rope_base=1000000,
        )
    )
    for context_length_str, context_length in [('1k', 1024), ('4k', 4096), ('8k', 8192), ('2k', 2048), ]:
        tiny_LLaMA.append(
            dict(
                org="new",
                name=f"tiny_LLaMA_120M_{context_length_str}_{dm_mask}",
                block_size=context_length,
                vocab_size=32000,
                padding_multiple=64,
                n_layer=12,
                n_head=12,
                n_embd=768,
                rotary_percentage=1.0,
                parallel_residual=False,
                bias=False,
                _norm_class="FusedRMSNorm",
                norm_eps=1e-5,
                _mlp_class="LLaMAMLP",
                intermediate_size=2048,
                n_query_groups=1,
                intradoc_mask=dm_mask,
            ))
        tiny_LLaMA.append(dict(
            org="StatNLP-research",
            name=f"tiny_LLaMA_360M_{context_length_str}_{dm_mask}",
            block_size=context_length,
            vocab_size=32000,
            padding_multiple=64,
            n_layer=18,  # 16
            n_head=16,  # 16
            n_embd=1024,  # 1024
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            _norm_class="FusedRMSNorm",
            norm_eps=1e-5,
            _mlp_class="LLaMAMLP",
            intermediate_size=4096,
            n_query_groups=16,  # 16
            intradoc_mask=dm_mask,
        ))
        tiny_LLaMA.append(
            dict(
                org="StatNLP-research",
                name=f"tiny_LLaMA_1b_{context_length_str}_{dm_mask}",
                block_size=context_length,
                vocab_size=32000,
                padding_multiple=64,
                n_layer=22,
                n_head=32,
                n_embd=2048,
                rotary_percentage=1.0,
                parallel_residual=False,
                bias=False,
                _norm_class="FusedRMSNorm",
                norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
                _mlp_class="LLaMAMLP",
                intermediate_size=5632,
                n_query_groups=4,
                intradoc_mask=dm_mask,
            ))
        tiny_LLaMA.append(dict(
            org="StatNLP-research",
            name=f"tiny_LLaMA_3b_{context_length_str}_{dm_mask}",
            block_size=context_length,
            vocab_size=32000,
            padding_multiple=64,
            n_layer=26,
            n_head=32,
            n_embd=3200,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            _norm_class="FusedRMSNorm",
            norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
            _mlp_class="LLaMAMLP",
            intermediate_size=8640,
            n_query_groups=4,
            intradoc_mask=dm_mask,
        ))
        tiny_LLaMA.append(
            dict(
                org="StatNLP-research",
                name=f"code_tiny_LLaMA_1b_{context_length_str}_{dm_mask}",
                block_size=context_length,
                vocab_size=49152,
                padding_multiple=64,
                n_layer=22,
                n_head=32,
                n_embd=2048,
                rotary_percentage=1.0,
                parallel_residual=False,
                bias=False,
                _norm_class="FusedRMSNorm",
                norm_eps=1e-5,  # Llama 2 use 1e-5. Llama 1 use 1e-6
                _mlp_class="LLaMAMLP",
                intermediate_size=5632,
                n_query_groups=4,
                intradoc_mask=dm_mask,
                rope_base=1000000,
            )
        )

configs.extend(tiny_LLaMA)

#############################
# OpenLM Research Open LLaMA
#############################
open_LLaMA = [
    # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
    dict(
        org="openlm-research",
        name="open_llama_3b",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=26,
        n_head=32,
        n_embd=3200,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=8640,
    ),
    # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
    dict(
        org="openlm-research",
        name="open_llama_7b",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
    dict(
        org="openlm-research",
        name="open_llama_13b",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    ),
]
configs.extend(open_LLaMA)

###############
# LMSYS Vicuna
###############
vicuna = [
    # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
    dict(
        org="lmsys",
        name="vicuna-7b-v1.3",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
    dict(
        org="lmsys",
        name="vicuna-13b-v1.3",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    ),
    # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
    dict(
        org="lmsys",
        name="vicuna-33b-v1.3",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=60,
        n_head=52,
        n_embd=6656,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=17920,
    ),
    dict(
        org="lmsys",
        name="vicuna-7b-v1.5",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    dict(
        org="lmsys",
        name="vicuna-7b-v1.5-16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
        condense_ratio=4,
    ),
    dict(
        org="lmsys",
        name="vicuna-13b-v1.5",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    ),
    dict(
        org="lmsys",
        name="vicuna-13b-v1.5-16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
        condense_ratio=4,
    ),
]
configs.extend(vicuna)

#################
# LMSYS LongChat
#################
long_chat = [
    # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
    dict(
        org="lmsys",
        name="longchat-7b-16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
        condense_ratio=8,
    ),
    # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
    dict(
        org="lmsys",
        name="longchat-13b-16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
        condense_ratio=8,
    ),
]
configs.extend(long_chat)

######################
# NousResearch Hermes
######################
nous_research = [
    # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
    dict(
        org="NousResearch",
        name="Nous-Hermes-13b",
        block_size=2048,
        padded_vocab_size=32001,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    )
]
configs.extend(nous_research)

###############
# Meta LLaMA 2
###############
llama_2 = [
    # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
    dict(
        org="meta-llama",
        name="Llama-2-7b{}-hf",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    dict(
        org="meta-llama",
        name="Llama-2-7b{}-hf-8k",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    dict(
        org="meta-llama",
        name="CodeLlama-2-7b-hf",
        block_size=4096,
        vocab_size=32016,
        padded_vocab_size=32016,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
    dict(
        org="meta-llama",
        name="Llama-2-13b{}-hf",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    ),
    # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
    dict(
        org="meta-llama",
        name="Llama-2-70b{}-hf",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=80,
        n_head=64,
        n_embd=8192,
        n_query_groups=8,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=28672,
    ),
]
for c in llama_2:
    for kind in ("", "-chat"):
        copy = c.copy()
        copy["name"] = c["name"].format(kind)
        configs.append(copy)

##########################
# Stability AI FreeWilly2
##########################
freewilly_2 = [
    # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
    dict(
        org="stabilityai",
        name="FreeWilly2",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=80,
        n_head=64,
        n_embd=8192,
        n_query_groups=8,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=28672,
    )
]
configs.extend(freewilly_2)

name_to_config = {config["name"]: config for config in configs}
