import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from typing import Optional, List, Any, Literal, Tuple, Union
from . import (
    register_model,
    register_model_configuration,
    register_model_architecture,
)
from transformers import (
    Cache,
    GenerationMixin,
    PreTrainedModel,
    PretrainedConfig,
)
from transformers.utils import is_flash_attn_2_available

if is_flash_attn_2_available():
    from flash_attn import flash_attn_func

from .chain_of_linear import CoLMLinear

@register_model_configuration("CoLMTransformer")
class ModelConfig(PretrainedConfig):

    def __init__(
        self,
        vocab_size: int = 32000,
        dim: int = 4096,
        hidden_dim: int = 11008,
        n_layers: int = 32,
        n_head: int = 32,
        n_kv_head: int = 8,
        dropout: float = 0.0,
        norm_eps: float = 1e-5,
        theta: float = 500000.0,
        max_position_embeddings=16384,
        tie_word_embeddings: bool = False,
        chain_setting: List[int] = [1, 1, 2, 4, 8, 16],
        pad_token_id: int = None,
        bos_token_id: int = 1,
        eos_token_id: int = 2,
        attn_implementation: str = "flash_attn",
        bias: bool = False,
        return_all_logits = False,
        air: bool = False,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.hidden_size = dim
        self.n_layers = n_layers
        self.n_head = n_head
        self.n_kv_head = n_kv_head
        self.dropout = dropout
        self.norm_eps = norm_eps
        self.theta = theta
        self.max_position_embeddings = max_position_embeddings
        self.chain_setting = chain_setting
        self.attn_implementation = attn_implementation
        self.bias = bias
        self.return_all_logits = return_all_logits
        self.air = air
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )


def convert_heads_to_dims(
    dim: int,
    heads: List[int] = [8, 8, 16],
    cumsum: bool = False
) -> np.array:
    """Map colm heads to the corresponding dimensions (inputs / outputs)."""
    assert dim % sum(heads) == 0, f"{dim} should be divisible by the sum of {heads}."
    head_dim = dim // sum(heads)
    if cumsum:
        heads = np.cumsum(heads)
    return np.array(heads) * head_dim


def create_causal_mask(seq_len, device: torch.device) -> torch.Tensor:
    """Create causal mask for language model."""
    mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
    mask = torch.triu(mask, diagonal=1)
    return mask


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """repeat kv cache along the head dimension."""
    if n_rep == 1:
        return x
    return x.repeat(1, 1, n_rep, 1)


def reshape_for_boardcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    ndim = x.ndim
    assert 0 <= 1 < ndim
    seqlen = x.shape[1]

    freqs_cis = freqs_cis[0:seqlen]
    assert freqs_cis.shape == (seqlen, x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    _xq = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    _xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis = reshape_for_boardcast(freqs_cis, _xq)

    xq_out = torch.view_as_real(_xq * freqs_cis).flatten(_xq.ndim - 1)
    xk_out = torch.view_as_real(_xk * freqs_cis).flatten(_xk.ndim - 1)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class ChainOps:

    def __init__(
        self,
        dim: int = 4096,
        hidden_dim: int = None,
        chain_setting: List[int] = [1, 1, 2, 4, 8, 16],
    ):
        self.dim = dim
        self.hidden_dim = hidden_dim or dim
        self.chain_setting = chain_setting

        self.input_dims = convert_heads_to_dims(dim, chain_setting, cumsum=True)
        self.output_dims = convert_heads_to_dims(self.hidden_dim, chain_setting)

    def __call__(self, x: torch.Tensor, head: int = None) -> torch.Tensor:
        if head is None:
            return x
        assert (
            0 <= head < len(self.input_dims)
        ), f"The head id should be in [0, {len(self.input_dims)})."
        input_dim = self.input_dims[head]
        return x[..., :input_dim]

    def __len__(self):
        return len(self.input_dims)

    @property
    def base_dim(self):
        return self.input_dims[0]

class Linear(nn.Module):

    def __init__(
        self, 
        dim: int = 4096,
        hidden_dim: int = None,
        heads: List[int] = [1, 1, 2, 4, 8, 16],
        bias=False
    ):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim or dim
        self.chain_setting = heads

        self.input_dims = convert_heads_to_dims(
            self.dim, heads, cumsum=True
        )
        self.output_dims = convert_heads_to_dims(
            self.hidden_dim, heads
        )

        self.mlps = nn.ModuleList([
            nn.Linear(indim, outdim, bias=bias)
            for indim, outdim in zip(self.input_dims, self.output_dims)
        ])

    def forward(self, x: torch.Tensor, head: int = None) -> torch.Tensor:
        if head is None:
            outputs = [self.mlps[i](x[..., :indim]) for i, indim in enumerate(self.input_dims)]
        else:
            assert (
                0 <= head < len(self.input_dims)
            ), f"The head id should be in [0, {len(self.input_dims)})."
            outputs = [self.mlps[i](x[..., :indim]) for i, indim in enumerate(self.input_dims[:head + 1])]
        return torch.cat(outputs, dim=-1)

    def init_weights(self, init_std):
        for mlp in self.mlps:
            nn.init.trunc_normal_(mlp.weight, mean=0.0, std=init_std)

class CoLMEmbedding(nn.Embedding):

    def __init__(
        self,
        dim: int,
        vocab_size: int, 
        chain_setting: List[int] = [1, 1, 2, 4, 8, 16],
        padding_idx: Optional[int] = None,
    ):
        super().__init__(vocab_size, dim, padding_idx)

        self.output_dims = convert_heads_to_dims(dim, chain_setting, cumsum=True)
        self.n_heads = len(chain_setting)

    def forward(self, tokens: torch.Tensor, head: int = None) -> torch.Tensor:
        if head is None:
            return F.embedding(tokens, self.weight, self.padding_idx)
        else:
            assert (
                head < self.n_heads
            ), f"The head should be smaller than {self.n_heads}"
            dim = self.output_dims[head]
            return F.embedding(tokens, self.weight[:, :dim], self.padding_idx)


class CoLMRMSNorm(nn.Module):

    # Can be optimized
    def __init__(self, dim: int, eps: float = 1e-6, chain_setting: List[int] = [1, 1, 2, 4, 8, 16]):
        super().__init__()
        self.ops = ChainOps(dim, chain_setting=chain_setting)
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # Reshape the final dim as [dim // base_dim, base_dim]
        x = x.view(*x.shape[:-1], -1, self.ops.base_dim)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    
    def forward(self, x: torch.Tensor, head: int = None) -> torch.Tensor:
        x = self.ops(x, head=head)
        output = self._norm(x.float()).type_as(x).view_as(x)
        return output * self.weight[:x.shape[-1]]

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

class CoLMFFN(nn.Module):

    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        chain_setting: List[int] = [1, 1, 2, 4, 8, 16],
    ):
        super().__init__()

        self.w1 = CoLMLinear(dim, hidden_dim, heads=chain_setting)
        self.w2 = CoLMLinear(hidden_dim, dim, heads=chain_setting)
        self.w3 = CoLMLinear(dim, hidden_dim, heads=chain_setting)

    def forward(self, x: torch.Tensor, head: int = None) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x, head)) * self.w3(x, head), head)

    def init_weights(self, init_std: float):
        self.w1.init_weights(0.02)
        self.w2.init_weights(init_std)
        self.w3.init_weights(init_std)

class CoLMAttention(nn.Module):

    def __init__(
        self,
        dim: int,
        n_head: int,
        n_kv_head: int,
        chain_setting: List[int] = [1, 1, 2, 4, 8, 16],
        air: bool = False,
        dropout: float = 0.0,
        layer_id: Optional[int] = None,
    ):
        super().__init__()

        self.layer_id = layer_id
        self.dim = dim
        self.air = air
        
        self.first_head = chain_setting[0] 
        self.n_head = n_head
        self.n_kv_head = n_kv_head
        self.dropout = dropout
        self.ops = ChainOps(
            self.dim, chain_setting=chain_setting
        )
        self.head_dim = self.dim // self.n_head

        self.kv_dim = self.first_head * self.head_dim
        
        self.wq = CoLMLinear(
            dim=self.dim,
            heads=chain_setting,
            bias=False,
        )
        # For key / value, we only use the first colm head in air setting.
        if self.air:
            self.wk = nn.Linear(
                self.first_head * self.head_dim,
                self.n_kv_head * self.head_dim,
                bias=False,
            )
            self.wv = nn.Linear(
                self.first_head * self.head_dim,
                self.n_kv_head * self.head_dim,
                bias=False,
            )

        self.wk = CoLMLinear(
            dim=self.n_head * self.head_dim,
            hidden_dim=self.n_kv_head * self.head_dim,
            heads=chain_setting,
            bias=False,
        )
        self.wv = CoLMLinear(
            dim=self.n_head * self.head_dim,
            hidden_dim=self.n_kv_head * self.head_dim,
            heads=chain_setting,
            bias=False,
        )
        self.wo = CoLMLinear(
            dim=self.dim,
            heads=chain_setting,
            bias=False,
        )

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        chain_head: int = None,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = self.ops(x, head=chain_head)
        bsz, seqlen, dim = x.shape
        n_q_head = dim // self.head_dim

        xq = self.wq(x, chain_head)

        # x0 = x[..., :self.kv_dim]
        if self.air:
            x0 = x[..., :self.kv_dim]
            xk, xv = self.wk(x0), self.wv(x0)
        else:
            xk, xv = self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, n_q_head, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis)

        """
        Group Query Attention is:
        Q1 Q2 Q3 Q4 Q5 Q6 Q7 Q8
        K1 K1 K2 K2 K3 K3 K4 K4
    
        colm Attention is:
        Q1 Q2 Q3 Q4 Q5 Q6 Q7 Q8
        K1 K2 K3 K4 K1 K2 K3 K4
        """

        if self.air:
            n_rep = xq.shape[2] // xk.shape[2]
            xk = repeat_kv(xk, n_rep)
            xv = repeat_kv(xv, n_rep)
        # [bsz, seqlen, n_head, head_dim]
        output = self.attention(
            xq, xk, xv, mask=mask, dropout=self.dropout, causal=True,
        )
        # [bsz, seqlen, dim]
        output = output.view(bsz, seqlen, -1)
        return self.wo(output, chain_head)

    def attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        dropout: float = 0.0,
        causal: bool = True,
        mask: Optional[torch.Tensor] = None,
    ):
        n_rep = q.shape[2] // k.shape[2]
        k = torch.repeat_interleave(k, repeats=n_rep, dim=2)
        v = torch.repeat_interleave(v, repeats=n_rep, dim=2)
        #k = repeat_kv(k, n_rep)
        #v = repeat_kv(v, n_rep)
        # [bsz, seq, n_head, head_dim] -> [bsz, n_head, seq, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if mask is None and causal:
            mask = create_causal_mask(q.shape[2], q.device)
        scores = torch.matmul(q, k.transpose(2, 3)) / (k.shape[-1] ** 0.5)
        if mask is not None:
            scores += mask
        scores = F.softmax(scores.float(), dim=-1).type_as(q)
        output = torch.matmul(scores, v)
        output = output.transpose(1, 2).contiguous()
        return output

    def init_weights(self, init_std: float):
        self.wq.init_weights(0.02)
        nn.init.normal_(self.wk.weight, std=0.02)
        nn.init.normal_(self.wv.weight, std=0.02)
        self.wo.init_weights(init_std=init_std)

class CoLMFlashAttention(CoLMAttention):

    def __init__(
        self,
        dim: int,
        n_head: int,
        n_kv_head: int,
        chain_setting: List[int] = [1, 1, 2, 4, 8, 16],
        air: bool = False,
        dropout: float = 0.0,
        layer_id: Optional[int] = None,
    ):
        super().__init__(
            dim=dim,
            n_head=n_head,
            n_kv_head=n_kv_head,
            chain_setting=chain_setting,
            air = False,
            dropout=dropout,
            layer_id=layer_id,
        )

    def attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        dropout: float = 0.0,
        causal: bool = True,
        mask: Optional[torch.Tensor] = None,
    ):
        if is_flash_attn_2_available() is not True:
            raise NotImplementedError("Flash Attention 2 is not installed.")
        
        # [bsz, seq, n_head, head_dim]
        output = flash_attn_func(q, k, v, dropout_p=dropout, causal=causal)
        return output
    

class CoLMSdpaAttention(CoLMAttention):

    def __init__(
        self,
        dim: int,
        n_head: int,
        n_kv_head: int,
        chain_setting: List[int] = [1, 1, 2, 4, 8, 16],
        air: bool = False,
        dropout: float = 0.0,
        layer_id: Optional[int] = None,
    ):
        super().__init__(
            dim=dim,
            n_head=n_head,
            n_kv_head=n_kv_head,
            chain_setting=chain_setting,
            air = False,
            dropout=dropout,
            layer_id=layer_id,
        )
    
    def attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        dropout: float = 0.0,
        causal: bool = True,
        mask: Optional[torch.Tensor] = None,
    ):
        # [bsz, seq, n_head, head_dim] -> [bsz, n_head, seq, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        output = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout, is_causal=causal, enable_gqa=True)
        output = output.transpose(1, 2).contiguous()
        return output


ATTENTION_CLASSES = {
    "eager": CoLMAttention,
    "flash_attn": CoLMFlashAttention,
    "sdpa": CoLMSdpaAttention,
}


class CoLMTransformerLayer(nn.Module):

    def __init__(self, args: ModelConfig, layer_id: Optional[int] = None):
        super().__init__()
        self.args = args
        self.layer_id = layer_id

        self.attention = ATTENTION_CLASSES[args.attn_implementation](
            dim=args.dim,
            n_head=args.n_head,
            n_kv_head=args.n_kv_head,
            chain_setting=args.chain_setting,
            air = args.air,
            dropout=args.dropout,
            layer_id=layer_id,
        )
        self.ffn = CoLMFFN(
            dim=args.dim,
            hidden_dim=args.hidden_dim,
            chain_setting=args.chain_setting,
        )
        self.attention_norm = CoLMRMSNorm(
            dim=args.dim,
            eps=args.norm_eps,
            chain_setting=args.chain_setting
        )
        self.ffn_norm = CoLMRMSNorm(
            dim=args.dim,
            eps=args.norm_eps,
            chain_setting=args.chain_setting
        )
        self.ops = ChainOps(args.dim, chain_setting=args.chain_setting)

        
        self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        head: int = None,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = self.ops(x, head=head)
        h = x + self.attention(
            self.attention_norm(x, head=head),
            freqs_cis=freqs_cis,
            chain_head=head,
            mask=mask,
        )
        out = h + self.ffn(
            self.ffn_norm(h, head=head),
            head=head
        )
        return out

    def init_weights(self):
        for norm in (self.attention_norm, self.ffn_norm):
            norm.reset_parameters()
        self.attention.init_weights(self.weight_init_std)
        self.ffn.init_weights(self.weight_init_std)

class CoLMLMHead(nn.Module):

    def __init__(
        self,
        dim: int = 4096,
        vocab_size: int = None,
        chain_setting: List[int] = [1, 1, 2, 4, 8, 16],
        weight: torch.nn.Parameter = None,
    ):
        super().__init__()
        self.ops = ChainOps(dim, chain_setting=chain_setting)
        self.dim = dim
        self.vocab_size = vocab_size
        if weight is None:
            weight = nn.Linear(dim, vocab_size, bias=False).weight
        self.weight = weight

    def forward(self, x: torch.Tensor, head: int = None, return_all_logits: bool = True) -> Any:
        x = self.ops(x, head=head)
        
        if not return_all_logits:
            if head is None:
                return F.linear(x, self.weight)
            else:
                dim = self.ops.input_dims[head]
                return F.linear(x, self.weight[..., :dim])
        else:
            # Can be optimized
            if head is None:
                head = len(self.ops)
            else:
                head = head + 1

                outputs, input_offset, logits = [], 0, 0.0
                for indim in self.ops.output_dims[:head]:
                    logits += F.linear(
                        x[..., input_offset:input_offset + indim],
                        self.weight[..., input_offset:input_offset + indim]
                    )
                    input_offset += indim
                    outputs.append(logits.clone())
                return outputs

    def __str__(self):
        return f"CoLMLMHead({self.extra_repr()})"
    
    def extra_repr(self):
        return f"dim={self.dim}, vocab_size={self.vocab_size}"


class CoLMTransformerPretrainedModel(PreTrainedModel):
    config_class = ModelConfig
    base_model_prefix = "model"

    def _init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


class CoLMTransformerModel(CoLMTransformerPretrainedModel):

    def __init__(self, args: ModelConfig):
        super().__init__(args)

        self.n_layers = args.n_layers
        self.vocab_size = args.vocab_size
        self.args = args

        self.tok_embeddings = CoLMEmbedding(
            args.dim, args.vocab_size, args.chain_setting,
        )

        self.layers = torch.nn.ModuleDict()
        for layer_id in range(args.n_layers):
            self.layers[str(layer_id)] = CoLMTransformerLayer(args, layer_id)
        
        self.norm = CoLMRMSNorm(
            args.dim, args.norm_eps, args.chain_setting,
        )

        self.register_buffer("freqs_cis",precompute_freqs_cis(
            args.dim // args.n_head, args.max_position_embeddings * 2, args.theta
        ), persistent=True)

        self.post_init()

    def get_input_embeddings(self):
        return self.tok_embeddings.weight

    def set_input_embeddings(self, weight: nn.Parameter):
        self.tok_embeddings.weight = weight

    def forward(
        self,
        input_ids: torch.Tensor,
        head: int = None,
        **kwargs
    ):
        seqlen = input_ids.shape[-1]
        h = self.tok_embeddings(input_ids) if self.tok_embeddings else input_ids

        self.freqs_cis = self.freqs_cis.to(h.device)

        causal_mask = create_causal_mask(seqlen, device=input_ids.device)

        for layer in self.layers.values():
            h = layer(h, freqs_cis=self.freqs_cis, mask=causal_mask, head=head)
        h = self.norm(h, head=head) if self.norm else h
        return h


@register_model("CoLMTransformer")
class CoLMTransformer(CoLMTransformerPretrainedModel, GenerationMixin):

    def __init__(self, args: ModelConfig):
        super().__init__(args)
        self.args = args
        assert (
            sum(args.chain_setting) == args.n_head
        ), f"The sum of colm head ({args.chain_setting}) should be equal to the attention heads ({args.n_head})."
        for head in args.chain_setting:
            assert (
                head * args.n_kv_head % args.n_head == 0
            ), f"The kv of {head} head is not disvisible"

        self.vocab_size = args.vocab_size
        self.model = CoLMTransformerModel(args)
        self.return_all_logits = args.return_all_logits

        self.lm_head = CoLMLMHead(
            args.dim, args.vocab_size, args.chain_setting
        )
        self.post_init()
        self.init_weights()

    def get_input_embeddings(self):
        return self.tok_embeddings.weight

    def set_input_embeddings(self, weight: nn.Parameter):
        self.tok_embeddings.weight = weight

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, weight: nn.Parameter):
        self.lm_head.weight = weight

    def set_decoder(self, decoder: nn.Module):
        self.model = decoder

    def get_decoder(self):
        return self.model

    @classmethod
    def build_model(cls, args: ModelConfig):
        return cls(args)

    def init_weights(
        self,
        buffer_device: Optional[torch.device] = None,
    ):
        """
        [Note: On ``init_weights`` vs. ``reset_parameters``]
        Modules may define ``reset_parameters`` to initialize parameter values.
        ``reset_parameters`` is meant to only initialize directly owned
        parameters/buffers, not those of their child modules, and it can be
        used to give the initial values for these tensors.
        Separately, users may want custom initialization for their modules,
        different from that in ``reset_parameters``. For this, we define
        ``init_weights``. We only call it in the constructor of this
        ``Transformer`` root module to avoid reinitializing tensors.
        """
        buffer_device = buffer_device or self.model.freqs_cis.device
        with torch.device(buffer_device):
            self.model.freqs_cis = precompute_freqs_cis(
            self.args.dim // self.args.n_head, self.args.max_position_embeddings * 2, self.args.theta
        )
        if self.model.tok_embeddings is not None:
            nn.init.normal_(self.model.tok_embeddings.weight)
        for layer in self.model.layers.values():
            if layer is not None:
                layer.init_weights()
        if self.model.norm is not None:
            self.model.norm.reset_parameters()
        final_out_std = self.args.dim**-0.5
        cutoff_factor = 3
        if self.lm_head is not None:
            nn.init.trunc_normal_(
                self.lm_head.weight,
                mean=0.0,
                std=final_out_std,
                a=-cutoff_factor * final_out_std,
                b=cutoff_factor * final_out_std,
            )

    def forward(
        self,
        tokens: torch.Tensor,
        head: int = None,
        return_all_logits: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        features = self.model(tokens, head=head)
        if not self.lm_head:
            return features
        output = self.lm_head(features, head=head, return_all_logits=self.return_all_logits)
        return output
        

@register_model_architecture("CoLMTransformer", "CoLMTransformer")
def base_architecture(args: ModelConfig):
    return args


@register_model_architecture("CoLMTransformer", "gpt2_xxl")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [25]
    args.dim = 1600
    args.hidden_dim = 6400
    args.n_layers = 48
    args.n_head = 25
    args.n_kv_head = 25
    return args

@register_model_architecture("CoLMTransformer", "llama3.2-1B")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [32]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "colm_16-16")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [16, 16]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "colm_8-24")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8, 24]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "colm_8-8-16")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8, 8, 16]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "colm_8-8-8-8")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8,8,8,8]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "colm_8-8-8-8_all_heads")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8,8,8,8]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    args.return_all_logits = True
    return args

@register_model_architecture("CoLMTransformer", "colm_16-16_2560")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [16, 16]
    args.dim = 2560
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "same-colm_16-16-filter_size")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [16,16]
    args.dim = 2048
    args.hidden_dim = 12288
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "colm_8-8-8-8_20L")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8, 8, 8, 8]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 20
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "same-colm_8-8-8-8-filter_size")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8,8,8,8]
    args.dim = 2048
    args.hidden_dim = 14336
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "colm_16-16_20L")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [16, 16]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 20
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "same-colm_8-24")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8, 24]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 21
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "same-colm_8-8-16")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8, 8, 16]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 24
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "same-colm_8-8-8-8")
def gpt2_xxl_architecture(args: ModelConfig):
    args.chain_setting = [8,8,8,8]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 26
    args.n_head = 32
    args.n_kv_head = 8
    return args

@register_model_architecture("CoLMTransformer", "colm_32H_16-16")
def colm_transformer_32H_16x16(args: ModelConfig):
    args.chain_setting = [16, 16]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 16
    return args


@register_model_architecture("CoLMTransformer", "colm_32H_8-8-8-8")
def colm_transformer_32H_8x8x8x8(args: ModelConfig):
    args.chain_setting = [8, 8, 8, 8]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 16
    return args


@register_model_architecture("CoLMTransformer", "colm_32H_8-8-16")
def colm_transformer_32H_16x16(args: ModelConfig):
    args.chain_setting = [8, 8, 8, 8]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 16
    args.n_head = 32
    args.n_kv_head = 16
    return args


@register_model_architecture("CoLMTransformer", "colm_40L_32H_16-16")
def colm_transformer_1600M_standard(args: ModelConfig):
    args.chain_setting = [16, 16]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 40
    args.n_head = 32
    args.n_kv_head = 16
    return args

@register_model_architecture("CoLMTransformer", "colm_40L_32H_16-16_air")
def colm_transformer_1600M_standard(args: ModelConfig):
    args.chain_setting = [16, 16]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 40
    args.n_head = 32
    args.n_kv_head = 16
    args.air = True
    return args

@register_model_architecture("CoLMTransformer", "colm_48L_32H_8-8-8-8")
def colm_transformer_1600M_standard(args: ModelConfig):
    args.chain_setting = [8, 8, 8, 8]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 48
    args.n_head = 32
    args.n_kv_head = 8
    return args


@register_model_architecture("CoLMTransformer", "colm_40L_32H_8-24")
def colm_transformer_1600M_standard(args: ModelConfig):
    args.chain_setting = [8, 24]
    args.dim = 2048
    args.hidden_dim = 8192
    args.n_layers = 40
    args.n_head = 32
    args.n_kv_head = 8
    return args


@register_model_architecture("CoLMTransformer", "colm_606M")
def colm_transformer_606M_standard(args: ModelConfig):
    args.chain_setting = [16]
    args.dim = 1024
    args.hidden_dim = 4096
    args.n_layers = 40
    args.n_head = 16
    args.n_kv_head = 16
    return args


@register_model_architecture("CoLMTransformer", "colm_4B")
def colm_transformer_4B_standard(args: ModelConfig):
    args.chain_setting = [8, 8, 8, 8]
    return args


@register_model_architecture("CoLMTransformer", "colm_7B")
def colm_transformer_4B_standard(args: ModelConfig):
    args.chain_setting = [8, 8, 8, 8]
    args.n_layers = 64
    return args