# -*- coding: utf-8 -*-
from __future__ import annotations

import math
import os
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, List, NamedTuple
from typing import Any, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange

from transformers import AutoTokenizer, Trainer, TrainingArguments, PreTrainedModel, PretrainedConfig
from datasets import load_from_disk

from torch.nn.functional import scaled_dot_product_attention as sdpa

os.environ["RWKV_MY_TESTING"] = "x070"
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_HEAD_SIZE"] = "64"

NN = 8
KK = 8

MEAN_RECURRENCE = 16
GRAD_STEPS = 8
MAX_ITERATIONS = 48
SAMPLING_SCHEME = "poisson-lognormal-filling"
INJECTION_TYPE = "linear"
INIT_STRATEGY = "zero"
INIT_ORTHOGONAL = False
ACTIVATION_CHECKPOINT_IMPL = "per-iteration"
LOGNORMAL_SIGMA = 0.5

TRAIN_EPOCHS = 1
WARMUP = 0.001
PER_DEVICE_TRAIN_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 1
WEIGHT_DECAY = 0.01
LR_SCHEDULER_TYPE = "cosine_with_min_lr"
LOGGING_STEPS = 200
LEARNING_RATE = 4e-4
SAVE_STRATEGY = "steps"
SAVE_STEP = 2500
USE_BF16 = True
DATALOADER_NUM_WORKERS = 64
REPORT_TO = "none"
MINIR = 0.5

UT_NUM_STEPS = 1
UT_USE_ACT = False
UT_ACT_THRESHOLD = 0.9
UT_ACT_MAX_STEPS = 1
UT_PONDER_TAU = 0.0
UT_MAX_STEPS_EMBED = 64

UT_BOTTOM_LAYERS = 2
UT_TOP_LAYERS = 2
UT_RECURRENT_START = 2
UT_RECURRENT_END = 5

VOCAB_SIZE = 32002
HIDDEN_SIZE = 2048
NUM_LAYERS = 8
MAX_POSITION_EMBEDDINGS = 4096
PAD_TOKEN_ID = 32000
DIMS = 6144

RMS_NORM_EPS = 1e-5

INITIALIZER_RANGE = 0.02
RESCALE_PRENORM_RESIDUAL = True
NUM_RESIDUALS_PER_LAYER = 2

CUSTOM_ATTN_NUM_HEADS = 32
CUSTOM_ATTN_HEAD_DIM = 64

D_DECAY_LORA = 96
D_AAA_LORA = 96
D_GATE_LORA = 256
LRA = 128

CUDA_FILE_PATH_CU = '/path/to/cuda/wkv7_cuda.cu'
CUDA_FILE_PATH_CPP = '/path/to/cuda/wkv7_op.cpp'

TOKENIZER_PATH = "/path/to/tokenizer/"
DATASET_PATH = "/path/to/dataset/"
OUTPUT_DIR = "/path/to/output/"
LOGGING_DIR = "/path/to/logging/"

NUM_MICRO_STEPS = 3

CHUNK_LEN = 16

HEAD_DIM = 128
GQA_NUM_ATTENTION_HEADS = 16
GQA_NUM_KEY_VALUE_HEADS = 4

WINDOW_SIZE = 2048
GLOBAL_INTERVAL = 64
ROPE_BASE = 10000

EMBEDDING_DROPOUT = 0.0

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE"])
except (KeyError, ValueError):
    print("Warning: RWKV_HEAD_SIZE not set or invalid. Using default value 64.")
    HEAD_SIZE = 64

def __nop(ob):
    return ob

MyModule = nn.Module
MyFunction = __nop
if os.environ.get("RWKV_JIT_ON", "0") == "1":
    print("Info: JIT (TorchScript) enabled.")
    MyModule = torch.jit.ScriptModule
    MyFunction = torch.jit.script_method
else:
    print("Info: JIT (TorchScript) not enabled.")

if 'x070' in os.environ.get("RWKV_MY_TESTING", ""):
    print(f"Info: RWKV_MY_TESTING contains 'x070'. Attempting to load 'wind_backstepping' CUDA kernel (HEAD_SIZE={HEAD_SIZE}, CHUNK_LEN={CHUNK_LEN}).")
    from torch.utils.cpp_extension import load
    flags = [
        '-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}",
        "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"
    ]

    cuda_sources_exist = os.path.exists(CUDA_FILE_PATH_CU) and os.path.exists(CUDA_FILE_PATH_CPP)

    if cuda_sources_exist:
        try:
            load(
                name="wind_backstepping",
                sources=[CUDA_FILE_PATH_CU, CUDA_FILE_PATH_CPP],
                is_python_module=False, verbose=True, extra_cuda_cflags=flags
            )
            print("CUDA kernel 'wind_backstepping' loaded successfully.")

            class WindBackstepping(torch.autograd.Function):
                @staticmethod
                def forward(ctx, w, q, k, v, z, b):
                    B, T, H, C_head = q.shape
                    assert C_head == HEAD_SIZE, f"Tensor head size {C_head} != compiled HEAD_SIZE {HEAD_SIZE}"
                    assert T % CHUNK_LEN == 0, f"Sequence length T ({T}) must be multiple of CHUNK_LEN ({CHUNK_LEN})"
                    assert all(i.dtype == torch.bfloat16 for i in [w, q, k, v, z, b]), "All input tensors must be bfloat16"
                    assert all(i.is_contiguous() for i in [w, q, k, v, z, b]), "All input tensors must be contiguous"
                    
                    y = torch.empty_like(v)
                    s = torch.empty(B, H, T // CHUNK_LEN, C_head, C_head, dtype=torch.float32, device=w.device)
                    sa = torch.empty(B, T, H, C_head, dtype=torch.float32, device=w.device)
                    
                    torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa)
                    ctx.save_for_backward(w, q, k, v, z, b, s, sa)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    assert dy.dtype == torch.bfloat16, "dy must be bfloat16"
                    assert dy.is_contiguous(), "dy must be contiguous"
                    
                    w, q, k, v, z, b, s, sa = ctx.saved_tensors
                    dw, dq, dk, dv, dz, db = [torch.empty_like(x) for x in [w, q, k, v, z, b]]
                    torch.ops.wind_backstepping.backward(w, q, k, v, z, b, dy, s, sa, dw, dq, dk, dv, dz, db)
                    return dw, dq, dk, dv, dz, db

            def RUN_CUDA_RWKV7g(r_in, w_in, k_in, v_in, neg_kk_in, pos_kk_a_in, head_size_for_reshape: int):
                B, T_micro, HC = r_in.shape
                
                num_heads = HC // head_size_for_reshape
                assert HC % head_size_for_reshape == 0, f"Total channel dimension ({HC}) cannot be divided by head_size_for_reshape ({head_size_for_reshape})"

                q_reshaped = r_in.view(B, T_micro, num_heads, head_size_for_reshape)
                w_reshaped = w_in.view(B, T_micro, num_heads, head_size_for_reshape) 
                k_reshaped = k_in.view(B, T_micro, num_heads, head_size_for_reshape)
                v_reshaped = v_in.view(B, T_micro, num_heads, head_size_for_reshape)
                neg_kk_reshaped = neg_kk_in.view(B, T_micro, num_heads, head_size_for_reshape)
                pos_kk_a_reshaped = pos_kk_a_in.view(B, T_micro, num_heads, head_size_for_reshape)
                
                output_cuda = WindBackstepping.apply(w_reshaped, q_reshaped, k_reshaped, v_reshaped, neg_kk_reshaped, pos_kk_a_reshaped)
                return output_cuda.view(B, T_micro, HC) 

        except Exception as e:
            print(f"Error: Cannot load or compile CUDA kernel 'wind_backstepping': {e}")
            print("Defining RUN_CUDA_RWKV7g placeholder function that raises NotImplementedError.")
            def RUN_CUDA_RWKV7g(q,w,k,v,a,b, head_size_for_reshape: int):
                raise NotImplementedError(f"CUDA kernel 'wind_backstepping' not successfully loaded or compiled ({e}).")
    else:
        print("Warning: CUDA source files not found.")
        print("Defining RUN_CUDA_RWKV7g placeholder function that raises NotImplementedError.")
        def RUN_CUDA_RWKV7g(q,w,k,v,a,b, head_size_for_reshape: int):
            raise NotImplementedError("CUDA source files missing, 'wind_backstepping' kernel not compiled.")
else:
    print("Info: RWKV_MY_TESTING does not contain 'x070'. Not attempting to load CUDA kernel.")
    def RUN_CUDA_RWKV7g(q,w,k,v,a,b, head_size_for_reshape: int):
        print("Warning: RUN_CUDA_RWKV7g using CPU placeholder implementation.")
        output = (q * torch.sigmoid(k)) * v + w + a + b 
        return output

class AlternatingConfig(PretrainedConfig):
    model_type = "alternating"
    
    def __init__(
        self,
        vocab_size=VOCAB_SIZE,
        hidden_size=HIDDEN_SIZE,
        num_layers=NUM_LAYERS,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        pad_token_id=PAD_TOKEN_ID,
        rms_norm_eps=RMS_NORM_EPS,
        initializer_range=INITIALIZER_RANGE,
        custom_attn_num_heads=CUSTOM_ATTN_NUM_HEADS,
        custom_attn_head_dim=CUSTOM_ATTN_HEAD_DIM,
        num_micro_steps=NUM_MICRO_STEPS,
        gqa_num_attention_heads=GQA_NUM_ATTENTION_HEADS,
        gqa_num_key_value_heads=GQA_NUM_KEY_VALUE_HEADS,
        window_size=WINDOW_SIZE,
        global_interval=GLOBAL_INTERVAL,
        rope_base=ROPE_BASE,
        embedding_dropout=EMBEDDING_DROPOUT,
        **kwargs
    ):
        super().__init__(pad_token_id=pad_token_id, return_dict=kwargs.get("return_dict", True), **kwargs)
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_hidden_layers = num_layers
        self.max_position_embeddings = max_position_embeddings

        self.rms_norm_eps = rms_norm_eps

        self.initializer_range = initializer_range

        self.custom_attn_num_heads = custom_attn_num_heads
        self.custom_attn_head_dim = custom_attn_head_dim

        self.num_micro_steps = num_micro_steps

        self.gqa_num_attention_heads = gqa_num_attention_heads
        self.gqa_num_key_value_heads = gqa_num_key_value_heads

        self.window_size = window_size
        self.global_interval = global_interval
        self.rope_base = rope_base

        self.embedding_dropout = embedding_dropout

        self.ut_num_steps = kwargs.get("ut_num_steps", UT_NUM_STEPS)
        self.ut_use_act = kwargs.get("ut_use_act", UT_USE_ACT)
        self.ut_act_threshold = kwargs.get("ut_act_threshold", UT_ACT_THRESHOLD)
        self.ut_act_max_steps = kwargs.get("ut_act_max_steps", UT_ACT_MAX_STEPS)
        self.ut_ponder_tau = kwargs.get("ut_ponder_tau", UT_PONDER_TAU)
        self.ut_max_steps_embed = kwargs.get("ut_max_steps_embed", UT_MAX_STEPS_EMBED)
        
        self.ut_bottom_layers = kwargs.get("ut_bottom_layers", UT_BOTTOM_LAYERS)
        self.ut_top_layers = kwargs.get("ut_top_layers", UT_TOP_LAYERS)
        self.ut_recurrent_start = kwargs.get("ut_recurrent_start", UT_RECURRENT_START)
        self.ut_recurrent_end = kwargs.get("ut_recurrent_end", UT_RECURRENT_END)
        
        if self.ut_recurrent_start >= self.ut_recurrent_end:
            raise ValueError("ut_recurrent_start must be less than ut_recurrent_end")
        if self.ut_recurrent_end >= self.num_layers:
            raise ValueError("ut_recurrent_end must be less than total layers")
        if self.ut_bottom_layers != self.ut_recurrent_start:
            raise ValueError("ut_bottom_layers should equal ut_recurrent_start")
        if self.ut_bottom_layers + (self.ut_recurrent_end - self.ut_recurrent_start + 1) + self.ut_top_layers != self.num_layers:
            raise ValueError("bottom + recurrent + top layers must equal num_layers")

    
import math
import torch
import torch.nn as nn
from typing import Literal

class StepEmbedding(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        max_steps: int = 64,
        mode: Literal["learned", "sin"] = "learned",
        scale: float = 1.0,
    ):
        super().__init__()
        assert hidden_size > 0 and max_steps > 0
        self.hidden_size = int(hidden_size)
        self.max_steps = int(max_steps)
        self.mode = mode
        self.scale = float(scale)

        if mode == "learned":
            self.E = nn.Embedding(self.max_steps, self.hidden_size)
            nn.init.normal_(self.E.weight, mean=0.0, std=0.02)
        elif mode == "sin":
            table = self._build_sinusoid_table(self.max_steps, self.hidden_size)
            self.register_buffer("sinusoidal_table", table, persistent=False)
        else:
            raise ValueError(f"Unsupported mode: {mode}")

    @torch.no_grad()
    def _build_sinusoid_table(self, steps: int, dim: int) -> torch.Tensor:
        position = torch.arange(steps, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, dim, 2, dtype=torch.float32) * (-math.log(10000.0) / dim)
        )
        pe = torch.zeros(steps, dim, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe

    @torch.no_grad()
    def _maybe_grow(self, new_max_steps: int):
        if new_max_steps <= self.max_steps:
            return
        if self.mode == "learned":
            old_weight = self.E.weight.data
            new_E = nn.Embedding(new_max_steps, self.hidden_size)
            nn.init.normal_(new_E.weight, mean=0.0, std=0.02)
            new_E.weight.data[: self.max_steps].copy_(old_weight)
            self.E = new_E.to(self.E.weight.device, dtype=self.E.weight.dtype)
        else:
            table = self._build_sinusoid_table(new_max_steps, self.hidden_size).to(
                self.sinusoidal_table.device
            )
            self.register_buffer("sinusoidal_table", table, persistent=False)
        self.max_steps = int(new_max_steps)

    def forward(
        self,
        t: int,
        B: int,
        S: int,
        device: torch.device,
        dtype: torch.dtype,
    ) -> torch.Tensor:
        if t < 0:
            raise ValueError("t must be non-negative")
        if t >= self.max_steps:
            new_cap = max(t + 1, int(self.max_steps * 1.25))
            self._maybe_grow(new_cap)

        if self.mode == "learned":
            vec = self.E.weight[t]
            vec = vec.to(torch.float32)
        else:
            vec = self.sinusoidal_table[t]

        vec = (vec * self.scale).to(device=device, dtype=torch.float32)
        vec = vec.view(1, 1, -1).expand(B, S, -1)
        return vec.to(dtype)

    def extra_repr(self) -> str:
        return f"hidden_size={self.hidden_size}, max_steps={self.max_steps}, mode={self.mode!r}, scale={self.scale}"


class MyBaseModelOutputWithPast(NamedTuple):
    last_hidden_state: torch.Tensor
    past_key_values: Optional[Tuple[Optional[Tuple[torch.Tensor, ...]], ...]] = None
    hidden_states: Optional[Tuple[Optional[torch.Tensor], ...]] = None
    attentions: Optional[Tuple[Optional[torch.Tensor], ...]] = None

class MyCausalLMOutputWithPast(NamedTuple):
    loss: Optional[torch.Tensor] = None
    logits: torch.Tensor = None
    past_key_values: Optional[Tuple[Optional[Tuple[torch.Tensor, ...]], ...]] = None
    hidden_states: Optional[Tuple[Optional[torch.Tensor], ...]] = None
    attentions: Optional[Tuple[Optional[torch.Tensor], ...]] = None

class RWKV7Attention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: int,
        layer_idx: int,
        num_hidden_layers: int,
        num_micro_steps: int,
        lora_rank: int = 32,
        **kwargs,
    ) -> None:
        super().__init__()

        class Args:
            pass

        args = Args()
        args.head_size = head_dim
        args.n_embd = hidden_size
        args.dim_att = hidden_size
        args.n_layer = num_hidden_layers
        args.num_micro_steps = num_micro_steps

        self.args = args
        self.layer_id = layer_idx
        self.num_micro_steps = args.num_micro_steps
        self.lora_rank = LRA

        self.head_size = args.head_size
        assert args.dim_att % self.head_size == 0, "dim_att must be divisible by head_size"
        self.n_head = args.dim_att // self.head_size

        H = self.n_head
        N_hs = self.head_size
        C_emb = args.n_embd
        M = self.num_micro_steps

        with torch.no_grad():
            ratio_0_to_1 = layer_idx / (args.n_layer - 1) if args.n_layer > 1 else 0.0
            ratio_1_to_almost0 = 1.0 - (layer_idx / args.n_layer) if args.n_layer > 0 else 1.0

            ddd = torch.ones(1, 1, C_emb)
            for i in range(C_emb):
                ddd[0, 0, i] = i / C_emb if C_emb > 0 else 0.0

            self.x_r = nn.Parameter(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
            self.x_w = nn.Parameter(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
            self.x_k = nn.Parameter(1.0 - torch.pow(ddd, 0.7 * ratio_1_to_almost0))
            self.x_v = nn.Parameter(1.0 - torch.pow(ddd, 0.7 * ratio_1_to_almost0))
            self.x_a = nn.Parameter(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
            self.x_g = nn.Parameter(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))

            www = torch.zeros(C_emb)
            zigzag = torch.zeros(C_emb)
            for n_idx in range(C_emb):
                zigzag_val_denom = (N_hs - 1.0) / 2.0 if N_hs > 1 else 1.0
                zigzag_val_num = (n_idx % N_hs) - zigzag_val_denom
                zigzag[n_idx] = (zigzag_val_num / zigzag_val_denom) if zigzag_val_denom != 0 else 0.0
                zigzag[n_idx] = zigzag[n_idx] * abs(zigzag[n_idx])
                www[n_idx] = (
                    -6.0
                    + 6.0 * (n_idx / (C_emb - 1.0)) ** (1.0 + 1.0 * ratio_0_to_1 ** 0.3)
                ) if C_emb > 1 else -6.0

            self.w1 = nn.Parameter(torch.zeros(C_emb, D_DECAY_LORA))
            self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C_emb))
            nn.init.orthogonal_(self.w2, gain=0.1)
            self.w0 = nn.Parameter(www.reshape(1, 1, C_emb) + 0.5 + zigzag * 2.5)

            self.g1 = nn.Parameter(torch.zeros(C_emb, D_GATE_LORA))
            self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C_emb))
            nn.init.orthogonal_(self.g2, gain=0.1)

            self.k_ac = nn.Parameter(torch.zeros(1, 1, C_emb) + 1.02)
            self.r_k = nn.Parameter(torch.zeros(H, N_hs) - 0.04)

        self.key0   = nn.Linear(C_emb, C_emb, bias=False)
        self.value0 = nn.Linear(C_emb, C_emb, bias=False)

        num_lora_sets = M if M > 0 else 0

        self.value_A_list = nn.ParameterList()
        self.value_B_list = nn.ParameterList()

        if num_lora_sets > 0:
            for _ in range(num_lora_sets):
                self.value_A_list.append(nn.Parameter(torch.zeros(C_emb, self.lora_rank)))
                self.value_B_list.append(nn.Parameter(torch.zeros(self.lora_rank, C_emb)))

            for A_param in self.value_A_list:
                nn.init.kaiming_uniform_(A_param, a=math.sqrt(5))  

            for B in self.value_B_list: nn.init.zeros_(B)

        self.keyA_b = nn.ParameterList()
        self.keyB_b = nn.ParameterList()
        self.keyA_c = nn.ParameterList()
        self.keyB_c = nn.ParameterList()
        self.a0_b = nn.ParameterList();  self.a1_b = nn.ParameterList();  self.a2_b = nn.ParameterList()
        self.a0_c = nn.ParameterList();  self.a1_c = nn.ParameterList();  self.a2_c = nn.ParameterList()        
        for _ in range(M):
            self.keyA_b.append(nn.Parameter(torch.zeros(C_emb, self.lora_rank)))
            self.keyB_b.append(nn.Parameter(torch.zeros(self.lora_rank, C_emb)))
            self.keyA_c.append(nn.Parameter(torch.zeros(C_emb, self.lora_rank)))
            self.keyB_c.append(nn.Parameter(torch.zeros(self.lora_rank, C_emb)))
            self.a0_b.append(nn.Parameter(torch.zeros(1,1,C_emb)))
            self.a1_b.append(nn.Parameter(torch.zeros(C_emb, D_AAA_LORA)))
            a2_b_param = torch.empty(D_AAA_LORA, C_emb)
            nn.init.orthogonal_(a2_b_param, gain=0.1)
            self.a2_b.append(nn.Parameter(a2_b_param))   

            self.a0_c.append(nn.Parameter(torch.zeros(1,1,C_emb)))
            self.a1_c.append(nn.Parameter(torch.zeros(C_emb, D_AAA_LORA)))
            a2_c_param = torch.empty(D_AAA_LORA, C_emb)
            nn.init.orthogonal_(a2_c_param, gain=0.1)
            self.a2_c.append(nn.Parameter(a2_c_param)) 

        for A_param in self.keyA_b:
            nn.init.kaiming_uniform_(A_param, a=math.sqrt(5))
        for A_param in self.keyA_c:
            nn.init.kaiming_uniform_(A_param, a=math.sqrt(5))

        for B in self.keyB_b: nn.init.zeros_(B)  
        for B in self.keyB_c: nn.init.zeros_(B) 

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        self.receptance = nn.Linear(C_emb, C_emb, bias=False)
        self.output = nn.Linear(C_emb, C_emb, bias=False)

        self.ln_x = nn.GroupNorm(H, C_emb, eps=64e-5)

        init_val = 0.5 / (C_emb ** 0.5)
        key_init_val_abs = 0.05 / (C_emb ** 0.5)

        self.receptance.weight.data.uniform_(-init_val, init_val)
        self.key0.weight.data.uniform_(-key_init_val_abs, key_init_val_abs)
        self.value0.weight.data.uniform_(-init_val, init_val)

        self.output.weight.data.zero_()

    def _interleave_tensors(self, tensor_list: List[torch.Tensor], B: int, T: int, C: int) -> torch.Tensor:
        stacked = torch.stack(tensor_list, dim=2)
        return stacked.contiguous().view(B, T * len(tensor_list), C)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, None, Optional[torch.Tensor]]:
        B, T, C_emb = hidden_states.size()
        H, N_hs, M = self.n_head, self.head_size, self.num_micro_steps

        xx = self.time_shift(hidden_states) - hidden_states
        xr = hidden_states + xx * self.x_r
        xw = hidden_states + xx * self.x_w
        xk = hidden_states + xx * self.x_k
        xv = hidden_states + xx * self.x_v
        xa = hidden_states + xx * self.x_a
        xg = hidden_states + xx * self.x_g

        base_r = self.receptance(xr)
        base_w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5

        base_k = self.key0(xk)
        base_v_shared = self.value0(xv)

        r_list, w_list, k_list, v_list = [], [], [], []
        b_neg_kk_list, b_pos_kk_a_list = [], []

        for j in range(M):
            r_j = base_r if j == M - 1 else torch.zeros_like(base_r)
            r_list.append(r_j)

            if j == 0:
                w_j = base_w
            else:
                w_j = torch.full_like(base_w, float("-inf"), dtype=base_w.dtype, device=base_w.device)
            w_list.append(w_j)

            delta_b = (xk @ self.keyA_b[j]) @ self.keyB_b[j]
            delta_c = (xk @ self.keyA_c[j]) @ self.keyB_c[j]
            k_b = base_k + delta_b
            k_c = base_k + delta_c

            β_b = torch.sigmoid(self.a0_b[j] + (xa @ self.a1_b[j]) @ self.a2_b[j])
            β_c = torch.sigmoid(self.a0_c[j] + (xa @ self.a1_c[j]) @ self.a2_c[j])

            c_k_mod_j = k_c * (1.0 + (β_c - 1.0) * self.k_ac)
            k_list.append(c_k_mod_j)            

            norm1 = k_b.norm(p=2, dim=-1, keepdim=True)
            k_b = k_b / (norm1 + 1e-6)

            b_neg_kk_list.append(-k_b)
            b_pos_kk_a_list.append(k_b * β_b)

            if self.num_micro_steps > 0 and len(self.value_A_list) > 0:
                lora_v = (xv @ self.value_A_list[j]) @ self.value_B_list[j]
                v_candidate_j = base_v_shared + lora_v
            else:
                v_candidate_j = base_v_shared

            v_list.append(v_candidate_j)

        if M > 0:
            r_cat = self._interleave_tensors(r_list, B, T, C_emb)
            w_cat = self._interleave_tensors(w_list, B, T, C_emb)
            k_cat = self._interleave_tensors(k_list, B, T, C_emb)
            v_cat = self._interleave_tensors(v_list, B, T, C_emb)
            neg_kk_cat = self._interleave_tensors(b_neg_kk_list, B, T, C_emb)
            pos_kk_a_cat = self._interleave_tensors(b_pos_kk_a_list, B, T, C_emb)

            if attention_mask is not None:
                mask_bool_interleaved = attention_mask.repeat_interleave(M, dim=1).unsqueeze(-1).bool()

                w_cat = torch.where(
                    mask_bool_interleaved,
                    w_cat,
                    torch.full_like(w_cat, float("-inf"))
                )
                v_cat = v_cat * mask_bool_interleaved.to(v_cat.dtype)
                neg_kk_cat   = neg_kk_cat   * mask_bool_interleaved.to(neg_kk_cat.dtype)
                pos_kk_a_cat = pos_kk_a_cat * mask_bool_interleaved.to(pos_kk_a_cat.dtype)
                r_cat = r_cat * mask_bool_interleaved.to(v_cat.dtype)

            r_cat = r_cat.to(torch.bfloat16)
            w_cat = w_cat.to(torch.bfloat16)
            k_cat = k_cat.to(torch.bfloat16)
            v_cat = v_cat.to(torch.bfloat16)
            neg_kk_cat = neg_kk_cat.to(torch.bfloat16)
            pos_kk_a_cat = pos_kk_a_cat.to(torch.bfloat16)

            x_inter = RUN_CUDA_RWKV7g(r_cat, w_cat, k_cat, v_cat, neg_kk_cat, pos_kk_a_cat, self.head_size)
            x_inter = x_inter.to(hidden_states.dtype)

            x_att = x_inter.view(B, T, M, C_emb)[:, :, -1, :]
        else:
            x_att = torch.zeros(B, T, C_emb, device=hidden_states.device, dtype=hidden_states.dtype)

        x_att_gn = self.ln_x(x_att.view(B * T, C_emb)).view(B, T, C_emb)

        if M > 0:
            r_final = r_list[-1].view(B, T, H, N_hs)
            k_final = k_list[-1].view(B, T, H, N_hs)
            v_final = v_list[-1].view(B, T, H, N_hs)
            term2 = (r_final * k_final * self.r_k).sum(dim=-1, keepdim=True) * v_final
            x_out_intermediate = x_att_gn + term2.view(B, T, C_emb)
        else:
            x_out_intermediate = x_att_gn

        g = torch.sigmoid((xg @ self.g1) @ self.g2)
        x_out = self.output(x_out_intermediate * g)

        return x_out, None, past_key_values

class Qwen3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class Qwen3RotaryEmbedding(nn.Module):
    def __init__(self, head_dim, max_seq_len, rope_base=10000, device=None):
        super().__init__()
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.rope_base = rope_base

        inv_freq = 1.0 / (rope_base ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.attention_scaling = 1.0

    @torch.no_grad()
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class LocalGlobalRandomRoPEFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        num_kv_heads,
        window_size,
        global_interval,
        max_seq_len,
        rope_base=10000
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = HEAD_DIM
        assert self.head_dim % 2 == 0, "head_dim must be even for RoPE"

        self.num_key_value_groups = num_heads // num_kv_heads
        self.scaling = self.head_dim ** -0.5

        self.window_size = int(window_size)
        self.global_interval = int(global_interval)
        self.max_seq_len = max_seq_len

        self.q_proj = nn.Linear(embed_dim, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim, bias=False)
        self.q_norm = Qwen3RMSNorm(self.head_dim, eps=1e-5)
        self.k_norm = Qwen3RMSNorm(self.head_dim, eps=1e-5)
        self.out_proj = nn.Linear(num_heads * self.head_dim, embed_dim, bias=False)

        self.rotary_emb = Qwen3RotaryEmbedding(
            head_dim=self.head_dim,
            max_seq_len=max_seq_len,
            rope_base=rope_base
        )

        self.k_cache: List[torch.Tensor] = []
        self.v_cache: List[torch.Tensor] = []
        self.cache_positions: List[int] = []

    def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(
            batch, num_key_value_heads, n_rep, slen, head_dim
        )
        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

    def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, S, _ = x.shape
        device = x.device
        input_dtype = x.dtype

        if S == 0:
            return self.out_proj(torch.zeros(B, 0, self.embed_dim, device=device, dtype=input_dtype))

        q_proj_out = self.q_proj(x)
        k_proj_out = self.k_proj(x)
        v_proj_out = self.v_proj(x)

        q = self.q_norm(q_proj_out.view(B, S, self.num_heads, self.head_dim)).transpose(1, 2)
        k = self.k_norm(k_proj_out.view(B, S, self.num_kv_heads, self.head_dim)).transpose(1, 2)
        v = v_proj_out.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)

        position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
        cos, sin = self.rotary_emb(x, position_ids)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        if self.num_key_value_groups > 1:
            k = self.repeat_kv(k, self.num_key_value_groups)
            v = self.repeat_kv(v, self.num_key_value_groups)

        q = q.to(torch.float32)
        k = k.to(torch.float32)
        v = v.to(torch.float32)

        attn_out = sdpa(q, k, v, attn_mask=None, is_causal=True)

        attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
        attn_out = self.out_proj(attn_out).to(input_dtype)

        if padding_mask is not None:
            attn_out = attn_out * padding_mask.to(attn_out.dtype).unsqueeze(-1)

        return attn_out

class RWKV_CMix_x070(nn.Module):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        self.hidden_size = args.n_embd
        self.intermediate_size = DIMS

        self.gate_proj        = nn.Linear(self.hidden_size,       self.intermediate_size, bias=False)
        self.up_proj          = nn.Linear(self.hidden_size,       self.intermediate_size, bias=False)
        self.act_fn           = nn.SiLU()
        self.down_proj        = nn.Linear(self.intermediate_size, self.hidden_size,       bias=False)

    def forward(self, x):
        u = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        out = self.down_proj(u)
        return out

class AlternatingDecoderLayer(nn.Module):
    def __init__(self, config: AlternatingConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.layer_idx = layer_idx

        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

        if layer_idx % 2 == 0:
            self.self_attn = RWKV7Attention(
                hidden_size=config.hidden_size,
                num_heads=config.custom_attn_num_heads,
                head_dim=config.custom_attn_head_dim,
                layer_idx=layer_idx,
                num_hidden_layers=config.num_hidden_layers,
                num_micro_steps=config.num_micro_steps,
            )
            self.is_gqa_layer = False
        else:
            self.self_attn = LocalGlobalRandomRoPEFlashAttention(
                embed_dim=config.hidden_size,
                num_heads=config.gqa_num_attention_heads,
                num_kv_heads=config.gqa_num_key_value_heads,
                window_size=config.window_size,
                global_interval=config.global_interval,
                max_seq_len=config.max_position_embeddings,
                rope_base=config.rope_base
            )
            self.is_gqa_layer = True

        class Args:
            pass
        args = Args()
        args.n_embd = config.hidden_size
        args.n_layer = config.num_hidden_layers

        self.mlp = RWKV_CMix_x070(args, layer_idx)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
    ):
        residual = hidden_states
        normalized_hidden_states = self.input_layernorm(hidden_states)

        if attention_mask is not None:
            attention_mask = attention_mask.to(hidden_states.dtype)

        if self.is_gqa_layer:
            attn_output = self.self_attn(normalized_hidden_states, padding_mask=attention_mask)
            attn_weights = None
        else:
            attn_output, attn_weights, _ = self.self_attn(
                normalized_hidden_states,
                attention_mask=attention_mask,
                output_attentions=output_attentions
            )

        hidden_states = residual + attn_output

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        mlp_output = self.mlp(hidden_states)
        hidden_states = residual + mlp_output

        return hidden_states, attn_weights

class AlternatingModelCore(nn.Module):
    def __init__(self, config: AlternatingConfig):
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.embed_dropout = nn.Dropout(config.embedding_dropout)

        self.layers = nn.ModuleList(
            [AlternatingDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers)]
        )
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

    def get_input_embeddings(self): return self.embed_tokens
    def set_input_embeddings(self, value): self.embed_tokens = value

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        apply_embed_dropout: bool = True,
    ):
        output_attentions = output_attentions if output_attentions is not None else False
        output_hidden_states = output_hidden_states if output_hidden_states is not None else False
        return_dict = return_dict if return_dict is not None else True

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("Specify input_ids or inputs_embeds, not both.")

        if input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length = inputs_embeds.shape[:-1]
        else:
            raise ValueError("Must specify input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = self.embed_dropout(inputs_embeds) if apply_embed_dropout else inputs_embeds

        all_hidden_states_collector = [] if output_hidden_states else None
        all_attentions_collector = [] if output_attentions else None

        for i, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states_collector.append(hidden_states)

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_embeddings=None,
                output_attentions=output_attentions,
            )

            hidden_states = layer_outputs[0]

            if output_attentions: 
                all_attentions_collector.append(layer_outputs[1])

        hidden_states = self.norm(hidden_states)

        if output_hidden_states: 
            all_hidden_states_collector.append(hidden_states)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states_collector, all_attentions_collector] if v is not None)

        return MyBaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=None,
            hidden_states=tuple(all_hidden_states_collector) if output_hidden_states else None,
            attentions=tuple(all_attentions_collector) if output_attentions else None,
        )


def _init_linear(linear: nn.Linear):
    if INIT_ORTHOGONAL:
        nn.init.orthogonal_(linear.weight)
        if linear.bias is not None:
            nn.init.zeros_(linear.bias)
    else:
        std = 1.0 / math.sqrt(linear.in_features)
        w32 = linear.weight.data.to(torch.float32)
        nn.init.trunc_normal_(w32, mean=0.0, std=std, a=-3*std, b=3*std)
        linear.weight.data.copy_(w32.to(linear.weight.dtype))
        if linear.bias is not None:
            nn.init.zeros_(linear.bias)

def _initialize_state_like(template: torch.Tensor) -> torch.Tensor:
    if INIT_STRATEGY == "takase":
        std = 1.0 / math.sqrt(template.size(-1))
        x32 = torch.empty_like(template, dtype=torch.float32)
        nn.init.trunc_normal_(x32, mean=0.0, std=std, a=-3*std, b=3*std)
        return x32.to(dtype=template.dtype)
    return torch.zeros_like(template)

def _poisson_lognormal_filling(mean_total: int, fixed_grad_steps: int, max_iterations: int) -> int:
    device = torch.device("cpu")
    target = max(mean_total, 1e-6)
    mu = math.log(target) - (LOGNORMAL_SIGMA ** 2) / 2
    
    rate = torch.zeros((1,), device=device).log_normal_(mean=mu, std=LOGNORMAL_SIGMA).item()
    total_steps = max(int(torch.poisson(torch.tensor([rate], dtype=torch.float32)).item()), fixed_grad_steps)
    
    total_steps = min(total_steps, max_iterations)
    no_grad_steps = max(total_steps - fixed_grad_steps, 0)
    
    return no_grad_steps

class PartialRecurrentCompose(nn.Module):
    def __init__(self, config, core_once: 'AlternatingModelCore'):
        super().__init__()
        self.config = config
        self.core = core_once

        self.bottom_layers = self.core.layers[:config.ut_recurrent_start]
        self.recurrent_layers = self.core.layers[config.ut_recurrent_start:config.ut_recurrent_end + 1]
        self.top_layers = self.core.layers[config.ut_recurrent_end + 1:]
        self.hidden_size = int(config.hidden_size)

        assert INJECTION_TYPE == "linear", "Only injection_type='linear' is supported"
        self.adapter = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=True)
        _init_linear(self.adapter)

    def _process_layers(self, hidden_states, layers, attention_mask=None):
        for layer in layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=attention_mask,
                position_embeddings=None,
                output_attentions=False,
            )
            hidden_states = layer_outputs[0]
        return hidden_states

    def _embed_once(self, input_ids, inputs_embeds):
        if inputs_embeds is not None:
            return inputs_embeds
        return self.core.embed_tokens(input_ids)

    def _iteration(self, x, h0, attention_mask):
        z = torch.cat([x, h0], dim=-1)
        z = self.adapter(z)
        out = self._process_layers(z, self.recurrent_layers, attention_mask)
        return out

    def _iter_with_ckpt(self, x, h0, attention_mask):
        fn = lambda _x, _h0: self._iteration(_x, _h0, attention_mask)
        use_ckpt = (ACTIVATION_CHECKPOINT_IMPL == "per-iteration")
        need_grad_input = (x.requires_grad or (isinstance(h0, torch.Tensor) and h0.requires_grad))
        
        if use_ckpt and need_grad_input:
            return torch.utils.checkpoint.checkpoint(fn, x, h0, use_reentrant=False)
        else:
            return fn(x, h0)

    def _sample_steps(self) -> tuple[int, int]:
        grad_steps = min(GRAD_STEPS, MAX_ITERATIONS)
        no_grad_steps = _poisson_lognormal_filling(MEAN_RECURRENCE, grad_steps, MAX_ITERATIONS)
        
        return int(no_grad_steps), int(grad_steps)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ):
        h = self._embed_once(input_ids, inputs_embeds)
        h = self.core.embed_dropout(h)
        B, S, D = h.shape
        device, dtype = h.device, h.dtype

        h0 = self._process_layers(h, self.bottom_layers, attention_mask)

        x = _initialize_state_like(h0).to(dtype)

        with torch.no_grad():
            for _ in range(NN):
                x = self._iteration(x, h0, attention_mask)

        for _ in range(KK):
            x = self._iter_with_ckpt(x, h0, attention_mask)

        x = self._process_layers(x, self.top_layers, attention_mask)

        final_hidden_states = self.core.norm(x)
        act_stats = None
        return final_hidden_states, act_stats
    
class AlternatingForCausalLM(PreTrainedModel):
    config_class = AlternatingConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["AlternatingDecoderLayer"]
    
    def __init__(self, config: AlternatingConfig):
        super().__init__(config)
        self.config = config

        self.core_once = AlternatingModelCore(config)
        self.partial_recurrent = PartialRecurrentCompose(config, self.core_once)

        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        self.post_init()

    def get_input_embeddings(self): return self.core_once.get_input_embeddings()
    def set_input_embeddings(self, value): self.core_once.set_input_embeddings(value)
    def get_output_embeddings(self): return self.lm_head
    def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings

    def post_init(self):
        self.apply(self._init_weights)

    @torch.no_grad()
    def _init_weights(
        self,
        module: nn.Module,
        rescale_prenorm_residual: bool = RESCALE_PRENORM_RESIDUAL,
        num_residuals_per_layer: int = NUM_RESIDUALS_PER_LAYER,
    ):
        if isinstance(module, nn.Embedding):
            scale = -1e-4
            nn.init.uniform_(module.weight, a=scale, b=-scale)
        elif isinstance(module, nn.Linear) and hasattr(self, 'lm_head') and module is self.lm_head:
            if self.config.vocab_size > self.config.hidden_size:
                scale = 0.5 * math.sqrt(self.config.vocab_size / self.config.hidden_size)
            else:
                scale = 0.5
            original_dtype = module.weight.dtype
            module.weight.data = nn.init.orthogonal_(module.weight.data.to(torch.float32), gain=scale).to(original_dtype)
        elif isinstance(module, (nn.Linear, nn.Conv1d)) and getattr(module, '_in_rwkv_module', False) is False:
            original_dtype = module.weight.dtype
            nn.init.normal_(module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range)
            module.weight.data = module.weight.data.to(original_dtype)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Parameter):
            original_dtype = module.dtype
            nn.init.normal_(module.data.to(torch.float32), mean=0.0, std=self.config.initializer_range)
            module.data = module.data.to(original_dtype)
        elif hasattr(module, 'reset_parameters') and getattr(module, '_in_rwkv_module', False) is False:
            module.reset_parameters()

        if rescale_prenorm_residual:
            p = None
            if hasattr(module, 'output') and isinstance(module.output, nn.Linear):
                p = module.output.weight
            elif hasattr(module, 'out_proj') and isinstance(module.out_proj, nn.Linear):
                p = module.out_proj.weight
            elif hasattr(module, 'value') and isinstance(module.value, nn.Linear) and isinstance(module, RWKV_CMix_x070):
                p = module.value.weight
                
            if p is not None:
                original_dtype = p.dtype
                nn.init.kaiming_uniform_(p.data.to(torch.float32), a=math.sqrt(5))
                with torch.no_grad():
                    p.data = p.data / math.sqrt(num_residuals_per_layer * self.config.num_layers)
                    p.data = p.data.to(original_dtype)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        return_dict = return_dict if return_dict is not None else True

        hidden_states, act_stats = self.partial_recurrent(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds
        )

        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            shifted_logits = logits[:, :-1, :].contiguous()

            shifted_labels = labels[:, 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shifted_logits.view(-1, self.vocab_size), shifted_labels.view(-1))
        elif self.training and input_ids is not None:
            if input_ids.shape[1] > 1:
                shift_logits = logits[:, :-1, :].contiguous()
                shift_labels = input_ids[:, 1:].contiguous()

                if attention_mask is not None:
                    loss_active_mask = attention_mask[:, 1:].contiguous()
                    shift_labels = shift_labels.clone()
                    shift_labels[loss_active_mask == 0] = -100

                if hasattr(self, 'config') and hasattr(self.config, 'pad_token_id'):
                    pad_id = self.config.pad_token_id
                    shift_labels[shift_labels == pad_id] = -100

                loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fct(
                    shift_logits.view(-1, self.vocab_size),
                    shift_labels.view(-1)
                )
            elif self.training:
                loss = torch.tensor(0.0, device=logits.device, requires_grad=True)

        if self.config.ut_use_act and self.config.ut_ponder_tau > 0.0 and act_stats is not None:
            rema, n_upd = act_stats
            ponder = n_upd.mean() * float(self.config.ut_ponder_tau)
            loss = ponder if loss is None else (loss + ponder)

        if not return_dict:
            output_tuple = (logits,)
            if output_hidden_states:
                output_tuple += (None,)
            if output_attentions:
                output_tuple += (None,)
            return ((loss,) + output_tuple) if loss is not None else output_tuple

        return MyCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )

def create_model() -> AlternatingForCausalLM:
    config = AlternatingConfig()
    model = AlternatingForCausalLM(config).to(dtype=torch.bfloat16)
    return model

model = create_model()

import os
import torch
import json
from transformers import Trainer
from typing import Optional, Dict, Any


import os
import torch
from typing import Optional, Dict, Any
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

import os
import torch
import torch.distributed as dist
from typing import Optional, Dict, Any
from transformers import Trainer

class CustomCompatibleTrainer(Trainer):
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        if output_dir is None:
            output_dir = self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        self._save(output_dir)

    def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, Any]] = None):
        if output_dir is None:
            output_dir = self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)

        model_to_save = self.model
        if hasattr(model_to_save, "module"):
            model_to_save = model_to_save.module

        if not self.args.should_save:
            if dist.is_available() and dist.is_initialized():
                dist.barrier()
            return

        if state_dict is None:
            state_dict = model_to_save.state_dict()

        weights_path = os.path.join(output_dir, "pytorch_model.bin")
        torch.save(state_dict, weights_path)

        if self.is_world_process_zero():
            if getattr(self, "tokenizer", None) is not None:
                self.tokenizer.save_pretrained(output_dir)
            cfg = getattr(model_to_save, "config", None)
            if cfg is not None and hasattr(cfg, "save_pretrained"):
                cfg.save_pretrained(output_dir)

        if dist.is_available() and dist.is_initialized():
            dist.barrier()

    def save_state(self):
        return super().save_state()


def load_custom_model(model_path: str, model_class, config_class):
    try:
        config = config_class.from_pretrained(model_path)
    except Exception:
        config = config_class()

    model = model_class(config)

    weights_path = os.path.join(model_path, "pytorch_model.bin")
    if os.path.exists(weights_path):
        state_dict = torch.load(weights_path, map_location="cpu")
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        if missing or unexpected:
            print(f"[load_custom_model] missing keys: {missing}, unexpected keys: {unexpected}")
        print(f"Model weights loaded from {weights_path}")
    else:
        print(f"Warning: Weight file not found at {weights_path}")

    return model


def create_custom_trainer(model, training_args, train_dataset, tokenizer=None):
    trainer = CustomCompatibleTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
    )
    return trainer

from datasets import load_from_disk

train_dataset = load_from_disk("/path/to/tokenized/dataset/")


print("Setting training parameters...")
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=TRAIN_EPOCHS,
    max_grad_norm=1.0,
    warmup_ratio=WARMUP,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    weight_decay=WEIGHT_DECAY,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    lr_scheduler_kwargs={"min_lr_rate": MINIR},
    logging_dir=LOGGING_DIR,
    logging_steps=LOGGING_STEPS,
    learning_rate=LEARNING_RATE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    save_strategy=SAVE_STRATEGY,
    save_steps=SAVE_STEP,
    bf16=USE_BF16,
    dataloader_num_workers=DATALOADER_NUM_WORKERS,
    report_to=REPORT_TO,
)

print("Creating trainer...")
trainer = CustomCompatibleTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")

trainer.train()

print("Training completed!")