########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import torch, types, os, gc, math
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import os
from einops import rearrange

# Import flash attention utilities
try:
    from transformers.modeling_flash_attention_utils import flash_attn_varlen_func, index_first_axis
except ImportError:
    # For newer versions of transformers
    try:
        from transformers.modeling_flash_attention_utils import flash_attn_varlen_func
        from transformers.modeling_rope_utils import index_first_axis
    except ImportError:
        # Define fallback functions if not available
        def flash_attn_varlen_func(*args, **kwargs):
            raise NotImplementedError("flash_attn_varlen_func not available")
        def index_first_axis(*args, **kwargs):
            raise NotImplementedError("index_first_axis not available")
def __nop(ob):
    return ob
MyFunction = __nop

args = types.SimpleNamespace()

args.n_layer = 24
args.n_embd = 2048

args.vocab_size = 152064
args.ctx_len = 32768

########################################################################################################
# CUDA Kernel
########################################################################################################

args.head_size_a = 64 # don't change
args.head_size_divisor = 8 # don't change

import os
from torch.utils.cpp_extension import load

def load_wkv6_extension(args):
    # 尝试多个可能的路径
    possible_paths = [
        os.path.dirname(os.path.abspath(__file__)),  # Directory where the current file is located
        "/path/to/your/project/Hybrid_rwkv/model",  # The parent directory of your cuda folder
        # You can add more possible paths here as needed
    ]
    
    for base_path in possible_paths:
        cuda_dir = os.path.join(base_path, "cuda")
        if os.path.exists(cuda_dir):
            break
    else:
        raise FileNotFoundError(f"CUDA source files not found in any of the searched paths")
    
    wkv6_cuda = load(
        name="wkv6", 
        sources=[
            os.path.join(cuda_dir, "wkv6_op.cpp"),
            os.path.join(cuda_dir, "wkv6_cuda.cu")
        ],
        verbose=True, 
        extra_cuda_cflags=[
            "-res-usage", 
            "--use_fast_math", 
            "-O3", 
            "-Xptxas -O3", 
            "--extra-device-vectorization", 
            f"-D_N_={args.head_size_a}", 
            f"-D_T_={args.ctx_len}"
        ]
    )
    return wkv6_cuda

# 使用时
wkv6_cuda = load_wkv6_extension(args)

class WKV_6(torch.autograd.Function):
    @staticmethod
    def forward(ctx, B, T, C, H, r, k, v, w, u): # forward: r, k, v, w, u => y
        with torch.no_grad():
            assert r.dtype == torch.bfloat16
            assert k.dtype == torch.bfloat16
            assert v.dtype == torch.bfloat16
            assert w.dtype == torch.bfloat16
            assert u.dtype == torch.bfloat16
            assert args.head_size_a == C // H
            ctx.B = B
            ctx.T = T
            ctx.C = C
            ctx.H = H
            assert r.is_contiguous()
            assert k.is_contiguous()
            assert v.is_contiguous()
            assert w.is_contiguous()
            assert u.is_contiguous()
            ctx.save_for_backward(r, k, v, w, u)
            y = torch.empty((B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
            wkv6_cuda.forward(B, T, C, H, r, k, v, w, u, y)
            return y

    @staticmethod
    def backward(ctx, gy): # backward: gy => gr, gk, gv, gw, gu
        with torch.no_grad():
            assert gy.dtype == torch.bfloat16
            B = ctx.B
            T = ctx.T
            C = ctx.C
            H = ctx.H
            assert gy.is_contiguous()
            r, k, v, w, u = ctx.saved_tensors
            gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
            gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
            gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
            gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
            gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)#.uniform_(-100, 100)
            wkv6_cuda.backward(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu)
            gu = torch.sum(gu, 0).view(H, C//H)
            return (None, None, None, None, gr, gk, gv, gw, gu) # return gradients for r,k,v,w,u

def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
    return WKV_6.apply(B, T, C, H, r, k, v, w, u)

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

########################################################################################################
# RWKV TimeMix
########################################################################################################

class RWKV_Tmix_x060b(nn.Module):
    def __init__(self, args, layer_idx):
        super().__init__()
        self.args = args
        self.layer_idx = layer_idx

        self.head_size = args.head_size_a
        self.n_head = args.dim_att // self.head_size
        assert args.dim_att % self.n_head == 0

        with torch.no_grad():
            ratio_0_to_1 = layer_idx / (args.n_layer - 1)  # 0 to 1
            ratio_1_to_almost0 = 1.0 - (layer_idx / args.n_layer)  # 1 to ~0
            ddd = torch.ones(1, 1, args.n_embd)
            for i in range(args.n_embd):
                ddd[0, 0, i] = i / args.n_embd

            self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
            self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
            self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
            self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1))
            self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
            D_MIX_LORA = 32
            self.time_maa_rkvw_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA*4))
            self.time_maa_rkvw_w2 = nn.Parameter(torch.zeros(4, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01))

            decay_speed = torch.ones(args.dim_att)
            for n in range(args.dim_att):
                decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
            self.time_decay = nn.Parameter(decay_speed.reshape(1,1,args.dim_att))
            D_DECAY_LORA = 64
            self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
            self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01))

            tmp = torch.zeros(args.dim_att)
            for n in range(args.dim_att):
                zigzag = ((n + 1) % 3 - 1) * 0.1
                tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
            self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)

        self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
        self.ln_x = nn.LayerNorm(args.dim_att)

    @MyFunction
    def jit_func(self, x):
        B, T, C = x.size()

        xx = self.time_shift(x) - x

        xxx = x + xx * self.time_maa_x
        xxx = torch.tanh(xxx @ self.time_maa_rkvw_w1).view(B*T, 4, -1).transpose(0, 1)
        xxx = torch.bmm(xxx, self.time_maa_rkvw_w2).view(4, B, T, C)

        r, k, v, w = xxx.unbind(dim=0)
        r = x + xx * (self.time_maa_r + r)
        k = x + xx * (self.time_maa_k + k)
        v = x + xx * (self.time_maa_v + v)
        w = x + xx * (self.time_maa_w + w)
        
        r = self.receptance(r)
        k = self.key(k)
        v = self.value(v)
        w = self.time_decay + torch.tanh(w @ self.time_decay_w1) @ self.time_decay_w2
        return r, k, v, w

    @MyFunction
    def jit_func_2(self, x):
        x = self.ln_x(x)
        x = self.output(x)
        return x

    def forward(self, x):
        B, T, C = x.size()
        H = self.n_head

        r, k, v, w = self.jit_func(x)
        x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)

        return self.jit_func_2(x)

########################################################################################################
# Attention Cross Attention
########################################################################################################
class ScaleDotProductCrossAttention(nn.Module):
    
    def __init__(self, layer_number, softmax_scale=None, attention_dropout=0.0):
        super().__init__()
        self.layer_number = layer_number
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, q, k, v, attn_mask=None):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
        # (N,...,L,E)

        if attn_mask is not None:
            attn_mask = attn_mask[:,None,:,:].repeat(1, q.shape[1], 1, 1)

        # attention mask, True means it will take part in attention B H s_q s_k
        if self.training:
            dropout_p = self.dropout_p
        else:
            dropout_p = 0.0

        if q.device.type == "cuda" and attn_mask is not None:
            q = q.contiguous()
            k = k.contiguous()
            v = v.contiguous()
        
        # debug only, calculate the FLOPs for cross-attn
        ##################
        # attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(128) # hardcode
        # if attn_mask is not None:  # no matter the length, we just slice it
        #     causal_mask = attn_mask[:, :, :, : k.shape[-2]]
        #     attn_weights = attn_weights + causal_mask

        # # upcast attention to fp32
        # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
        # # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        # o = torch.matmul(attn_weights, v)
        ###################

        o = nn.functional.scaled_dot_product_attention(q, k, v,
            attn_mask=attn_mask,
            dropout_p=dropout_p,
            is_causal=False,
            scale=self.softmax_scale)
        
        # B Head L D -> L B (Head D)
        o = rearrange(o, 'B Head L D -> B L (Head D)').contiguous()
        
        return o



class FlashAttnCrossAttention(nn.Module):
    
    def __init__(self, layer_number, softmax_scale=None, attention_dropout=0.0):
        super().__init__()
        self.layer_number = layer_number
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def _get_unpad_data(self, attention_mask: torch.Tensor):
        """
        Retrieves indexing data required to repad unpadded (ragged) tensors.

        Arguments:
            attention_mask (`torch.Tensor`):
                Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.

        Return:
            indices (`torch.Tensor`):
                The indices of non-masked tokens from the flattened input sequence.
            cu_seqlens (`torch.Tensor`):
                The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
            max_seqlen_in_batch (`int`):
                Maximum sequence length in batch.
        """
        seqlens_in_batch = attention_mask[:, 0, :].sum(dim=-1, dtype=torch.int32) # attn mask are the same for the query dimension, pick the first query
        indices = torch.nonzero(attention_mask[:, 0, :].flatten(), as_tuple=False).flatten()
        max_seqlen_in_batch = seqlens_in_batch.max().item()
        cu_seqlens = nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
        return (
            indices,
            cu_seqlens,
            max_seqlen_in_batch,
            seqlens_in_batch
        )
    def unpad_q(self, q_layer):
        # no need to unpad, just flatten
        
        batch_size, q_seq_len, num_key_value_heads, head_dim = q_layer.shape
        cu_seqlens_q = torch.tensor([q_seq_len] * batch_size, dtype=torch.int32, device=q_layer.device)
        cu_seqlens_q = nn.functional.pad(torch.cumsum(cu_seqlens_q, dim=0, dtype=torch.int32), (1, 0))    
        q_layer = q_layer.reshape(batch_size * q_seq_len, num_key_value_heads, head_dim)
    
        return (
            q_layer,
            cu_seqlens_q,
            q_seq_len)
    def unpad_kv(self, key_layer, value_layer, attn_mask):
        
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k, split_size = self._get_unpad_data(attn_mask)
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )

        return (
            key_layer,
            value_layer,
            indices_k,
            cu_seqlens_k,
            max_seqlen_in_batch_k,
            split_size)
    
    def forward(self, q, k, v, attn_mask=None):
        """
        Implements the multihead softmax attention with flash attention varlen api.
        Unpad the kv sequence
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
        # (N,...,L,E)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # NOTE: don't know if it's necessary
        if q.device.type == "cuda" and attn_mask is not None:
            q = q.contiguous()
            k = k.contiguous()
            v = v.contiguous()

        # batch_size = q.shape[0]
        # first unpad the q and kv, get cu_seq_len and indices
        batch_size, q_seq_len, head_num, head_dim = q.shape
        q, cu_seq_lens_q, max_seqlen_in_batch_q = self.unpad_q(q)
        k, v, indices_kv, cu_seq_lens_kv, max_seqlen_in_batch_kv, split_size = self.unpad_kv(k, v, attn_mask)
        
        attn_output = flash_attn_varlen_func(
            q,
            k,
            v,
            cu_seqlens_q=cu_seq_lens_q,
            cu_seqlens_k=cu_seq_lens_kv,
            max_seqlen_q=max_seqlen_in_batch_q,
            max_seqlen_k=max_seqlen_in_batch_kv,
            dropout_p=self.dropout_p if self.training else 0.0,
            softmax_scale=None,
            causal=False,
            # **flash_kwargs
        )

        return attn_output.reshape(batch_size, q_seq_len, head_num, head_dim).flatten(2, 3).contiguous()

    
########################################################################################################
# RWKV ChannelMix
########################################################################################################

class RWKV_CMix_x060(nn.Module):
    def __init__(self, args, layer_idx):
        super().__init__()
        self.args = args
        self.layer_idx = layer_idx
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

        with torch.no_grad():  # fancy init of time_mix
            ratio_1_to_almost0 = 1.0 - (layer_idx / args.n_layer)  # 1 to ~0
            ddd = torch.ones(1, 1, args.n_embd)
            for i in range(args.n_embd):
                ddd[0, 0, i] = i / args.n_embd
            self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
            self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

        self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
        self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
        self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)

    def forward(self, x):
        xx = self.time_shift(x) - x
        xk = x + xx * self.time_maa_k
        xr = x + xx * self.time_maa_r

        k = self.key(xk)
        k = torch.relu(k) ** 2
        kv = self.value(k)
        return torch.sigmoid(self.receptance(xr)) * kv

########################################################################################################
# RWKV Block
########################################################################################################

class Block(nn.Module):
    def __init__(self, args, layer_idx):
        super().__init__()
        self.args = args
        self.layer_idx = layer_idx

        self.ln1 = nn.LayerNorm(args.n_embd)
        self.ln2 = nn.LayerNorm(args.n_embd)

        if self.layer_idx == 0:
            self.ln0 = nn.LayerNorm(args.n_embd)

        self.att = RWKV_Tmix_x060b(args, layer_idx)
        self.ffn = RWKV_CMix_x060(args, layer_idx)
        
        # Add cross-attention components for scene embeddings
        self.is_hyper_enabled = getattr(args, "is_hyper_enabled", False)
        if self.is_hyper_enabled:
            self.cross_attn_gating_type = getattr(args, "cross_attn_gating_type", "whole-dynamic-sigmoid")
            self.num_key_value_heads = args.num_key_value_heads
            self.num_attention_heads = args.num_attention_heads
            self.head_dim = args.hidden_size // args.num_attention_heads
            self.num_key_value_groups = args.num_attention_heads // args.num_key_value_heads
            self.attention_dropout = args.attention_dropout
            
            # Initialize cross-attention components
            self.cross_attn_q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=True)
            self.cross_attn_kv_proj = nn.Linear(args.hidden_size, args.num_key_value_heads * self.head_dim * 2, bias=True)
            
            if self.cross_attn_gating_type.startswith("whole-dynamic"):
                if "tanh" in self.cross_attn_gating_type:
                    self.cross_attn_gate_proj = nn.Sequential(
                        nn.Linear(args.hidden_size, 1),
                        nn.Tanh()
                    )
                else:
                    self.cross_attn_gate_proj = nn.Sequential(
                        nn.Linear(args.hidden_size, 1),
                    )
                    
                if self.cross_attn_gating_type.endswith("warmup"):
                    self.cross_attn_warm_up_gate = torch.nn.Parameter(torch.zeros(1))
            
            self.cross_attn_core_attention = FlashAttnCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout)
        
    def forward(self, hidden_states,scene_embeds,cross_attn_mask,token_type):
        if self.is_hyper_enabled:
            bsz, q_len, _ = hidden_states.size()
            cross_attn_query_states = self.cross_attn_q_proj(hidden_states)
            cross_attn_query_states = cross_attn_query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim)
        if self.layer_idx == 0:
            hidden_states = self.ln0(hidden_states)
        hidden_states = hidden_states + self.att(self.ln1(hidden_states))
        hidden_states = hidden_states + self.ffn(self.ln2(hidden_states))

        # Apply cross-attention with scene embeddings
        if scene_embeds is not None and self.is_hyper_enabled and hidden_states.shape[1] == token_type.shape[1]:
            
            # Determine which samples contain text (not pure text samples)
            # Token type 3 indicates visual tokens
            if token_type is not None:
                all_text_mask = (token_type == 3).sum(dim=-1).bool()  # [bs, ] if False, indicate that this sample contains no image input
            else:
                all_text_mask = torch.zeros(hidden_states.shape[0], dtype=torch.bool, device=hidden_states.device)  # Default to no visual tokens
            
            # Apply cross-attention
            hidden_states = self.scene_cross_attn(
                hidden_states, 
                cross_attn_query_states,
                scene_embeds,
                token_type=token_type,
                cross_attn_mask=cross_attn_mask,
                all_text_mask=all_text_mask
            )
        return hidden_states

    def scene_cross_attn(self, text_state, text_query, scene_embeds, token_type,cross_attn_mask=None, all_text_mask=None):
        '''
        text_query: [bs n h d]
        text_state: [bs n d]
        scene_embeds: [bs, vis_n, d]
        token_type: [bs, n]
        '''      
        # if scene_embeds is None or (self.is_hyper_enabled == False) or (all_text_mask.sum() == 0):
        if scene_embeds is None or (self.is_hyper_enabled == False):
            return text_state

        # select all the pure text token
        pure_text_query = []
        text_mask = ((token_type - 2) <= 0).bool()

        for idx, text_query_ in enumerate(text_query):
            pure_text_query.append(text_query_[text_mask[idx]])    
            
        # 2. pad all the text tokens
        text_query = torch.nn.utils.rnn.pad_sequence(pure_text_query, batch_first=True)
        padding_attn_mask = torch.ones(text_query.shape[:-2], dtype=torch.bool, device=text_state.device)
        for i, tensor in enumerate(pure_text_query):
            padding_attn_mask[i, len(tensor):] = False  # Mark padded elements as False

        B_c, L_c = text_query.shape[:2]
        D_head = self.head_dim
        
        # obtain dynamic gate value
        gate_value = self.cross_attn_gate_proj(text_state[text_mask]) # n, D
        if "warmup" in self.cross_attn_gating_type:
            gate_value = gate_value * self.cross_attn_warm_up_gate.tanh()

        scene_embeds = scene_embeds.contiguous()
        scene_embeds = self.cross_attn_kv_proj(scene_embeds)
        text_query = text_query.transpose(1, 2)

        vision_kv = rearrange(scene_embeds, 'BL Lv (H KV D) -> KV BL H Lv D', KV=2, H=self.num_key_value_heads)
        vision_key = vision_kv[0].contiguous() # [b h s d]
        vision_value = vision_kv[1].contiguous()
        
        vision_key = repeat_kv(vision_key, self.num_key_value_groups)
        vision_value = repeat_kv(vision_value, self.num_key_value_groups)
        
        # expend_cross_attn_mask
        attention_mask = cross_attn_mask[:, None, :].repeat(1, text_query.shape[2], 1) 
        vision_context = self.cross_attn_core_attention(text_query, vision_key, vision_value, attn_mask=attention_mask)

        # mask out the output if a sample is pure text
        vision_context = all_text_mask[:, None, None] * vision_context
        
        # Apply dynamic gate
        extended_attn_output = torch.zeros_like(text_state, dtype=text_state.dtype, device=text_state.device)
        extended_attn_output[text_mask] = extended_attn_output[text_mask] + vision_context[padding_attn_mask] * gate_value
        text_state = text_state + extended_attn_output
        # NOTE Min: just equvalent to the following line. Avoid error under deepspeed zero3
        # text_state[text_mask] = text_state[text_mask] + vision_context[padding_attn_mask] * gate_value
            
        return text_state

########################################################################################################
# RWKV Model
########################################################################################################

class RWKV(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        args.dim_att = args.n_embd
        args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)

        assert args.n_embd % 32 == 0
        assert args.dim_att % 32 == 0
        assert args.dim_ffn % 32 == 0

        self.emb = nn.Embedding(args.vocab_size, args.n_embd)

        self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])

        self.ln_out = nn.LayerNorm(args.n_embd)
        self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)

        # self.init_params() # !!! When you train RWKV from scratch, try my initialization for best performance !!!

    def forward(self, idx):

        x = self.emb(idx)

        for block in self.blocks:
            x = block(x)

        x = self.ln_out(x)
        x = self.head(x)

        return x
    
    def init_params(self):
        m = self.state_dict()
        n_params = 0

        for n in self.state_dict():
            p = m[n]
            shape = p.shape

            s0 = str(shape[0]) if len(shape) > 0 else ""
            s1 = str(shape[1]) if len(shape) > 1 else ""
            s2 = str(shape[2]) if len(shape) > 2 else ""
            print(f"{s0.ljust(5)} {s1.ljust(5)} {s2.ljust(5)} {n}", end="")

            scale = 1.0
            if "ln_" in n or ".ln" in n or "time_" in n or n.endswith('_w') or n.endswith('_w1') or n.endswith('_w2') or n.endswith('_bias'):
                if 'ln_x.weight' in n:
                    layer_scale = (1+int(n.split('.')[1])) / self.args.n_layer
                    m[n] = (p * 0.0) + (layer_scale ** 0.7)
                else:
                    m[n] = p
                print()
            elif n == "emb.weight":
                m[n] = p
                scale = -1e-4
                nn.init.uniform_(m[n], a=scale, b=-scale) # !!! If you are using positional embedding, maybe it's better to remove block.0.ln0, and use default initialization for emb.weight instead of my uniform_(a=-1e-4, b=1e-4) !!!
                print(f" [scale {scale}]")
            elif n == "head.weight":
                m[n] = p
                if self.args.vocab_size > self.args.n_embd:
                    scale = 0.5 * math.sqrt(self.args.vocab_size / self.args.n_embd)
                else:
                    scale = 0.5
                nn.init.orthogonal_(m[n], gain=scale)
                print(f" [scale {scale}]")
            else:
                assert n.endswith('.weight') # should always be true

                for kk in [".att.output.", ".ffn.value.", ".ffn.receptance."]:
                    if kk in n:
                        scale = 0
                for kk in [".att.key."]:
                    if kk in n:
                        scale = 0.1
                for kk in [".att.gate."]:
                    if kk in n:
                        scale = 0.1

                print(f" [scale {scale}]")

                m[n] = torch.empty((shape[0], shape[1]), device=p.device)
                if scale == 0:
                    nn.init.zeros_(m[n])
                else:
                    nn.init.orthogonal_(m[n], gain=scale)

            n_params += m[n].numel()
        
        print('model params', n_params)
        gc.collect()
        torch.cuda.empty_cache()
