########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
from transformers import AutoTokenizer
import torch.nn as nn
from torch.nn import functional as F


def sample_logits(out, temperature=1.0, top_p=0.8):
    probs = F.softmax(out, dim=-1).numpy()
    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    probs[probs < cutoff] = 0
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)
    probs = probs / np.sum(probs)
    out = np.random.choice(a=len(probs), p=probs)
    return out


class RWKV_RNN(nn.Module):
    def __init__(self, pretrained_name: str, n_layer=24, d_model=1024, vocab_size=65536):
        super().__init__()
        self.d_model = d_model
        self.n_layer = n_layer
        self.vocab_size = vocab_size
        self.eval()  # set torch to inference mode
        
        w = torch.load(pretrained_name, map_location='cpu')
        for k in w.keys():
            w[k] = w[k].float() # convert to f32 type
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
            if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)

        self.n_head = w['blocks.0.att.time_decay'].shape[0]
        self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
        
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])

    def layer_norm(self, x, w):
        return F.layer_norm(x, (self.d_model,), weight=w.weight, bias=w.bias)

    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        i0 = (2+self.head_size)*i+0
        xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
        state[i0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
        H = self.n_head
        S = self.head_size

        i1 = (2+S)*i+1
        xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
        xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
        xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
        xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
        state[i1] = x

        r = (rw @ xr).view(H, 1, S)
        k = (kw @ xk).view(H, S, 1)
        v = (vw @ xv).view(H, 1, S)
        g = F.silu(gw @ xg)

        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)

        x = torch.zeros(H, S)
        a = k @ v
        x = r @ (time_first * a + s)
        s = a + time_decay * s
    
        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
        x = x.flatten()

        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
        return ow @ x

    def forward(self, token, state):
        with torch.no_grad():
            if state is None:
                state = torch.zeros(self.n_layer * (2+self.head_size), self.d_model)
            
            x = self.w.emb.weight[token]
            x = self.layer_norm(x, self.w.blocks[0].ln0)
            for i in range(self.n_layer):
                att = self.w.blocks[i].att
                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, 
                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay, 
                    att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
                    att.ln_x.weight, att.ln_x.bias)
                ffn = self.w.blocks[i].ffn
                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                    ffn.time_mix_k, ffn.time_mix_r, 
                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
            
            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
            return x.float(), state


def main():
    pretrained_name = '/mnt/data/user/tc_agi/cyf/rwkv5/RWKV-5-World-0.4B-v2-20231113-ctx4096.pt'
    prompt = "\nElon Musk has"
    # prompt = "\n我们发现"
    NUM_TRIALS = 3
    LENGTH_PER_TRIAL = 100
    TEMPERATURE = 1.0
    TOP_P = 0.7

    tokenizer = AutoTokenizer.from_pretrained('../../tokenizers/rwkv5-tok', trust_remote_code=True)
    model = RWKV_RNN(pretrained_name).cuda()

    print(f'\nPreprocessing prompt (slow version. see v2/rwkv/model.py for fast version)')
    init_state = None
    for token in tokenizer.encode(prompt):
        init_out, init_state = model.forward(token, init_state)

    for TRIAL in range(NUM_TRIALS):
        print(f'\n\n--[ Trial {TRIAL} ]-----------------', prompt, end="")
        all_tokens = []
        out_last = 0
        out, state = init_out.clone(), init_state.clone()
        for i in range(LENGTH_PER_TRIAL):
            token = sample_logits(out, TEMPERATURE, TOP_P)
            all_tokens += [token]
            try:
                tmp = tokenizer.decode(all_tokens[out_last:])
                if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
                    print(tmp, end="", flush=True)
                    out_last = i + 1
            except:
                pass
            out, state = model.forward(token, state)       
    print('\n')


if __name__ == "__main__":
    main()
