import os
from types import SimpleNamespace
import warnings

import torch

os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "1"

from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS


class RwkvModel:
    def __init__(self, model_path):
        warnings.warn(
            "Experimental support. Please use ChatRWKV if you want to chat with RWKV"
        )
        self.config = SimpleNamespace(is_encoder_decoder=False)
        self.model = RWKV(model=model_path, strategy="cuda fp16")

        self.tokenizer = None
        self.model_path = model_path

    def to(self, target):
        assert target == "cuda"

    def __call__(self, input_ids, use_cache, past_key_values=None):
        assert use_cache == True
        input_ids = input_ids[0].detach().cpu().numpy()
        logits, state = self.model.forward(input_ids, past_key_values)
        logits = logits.unsqueeze(0).unsqueeze(0)
        out = SimpleNamespace(logits=logits, past_key_values=state)
        return out

    def generate(
        self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0
    ):
        from transformers import AutoTokenizer

        from fastchat.serve.inference import generate_stream
        from fastchat.conversation import get_conv_template

        if self.tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(
                "EleutherAI/pythia-160m", use_fast=True
            )
        prompt = self.tokenizer.decode(input_ids[0].tolist())
        conv = get_conv_template("rwkv")

        gen_params = {
            "model": self.model_path,
            "prompt": prompt,
            "temperature": temperature,
            "repetition_penalty": repetition_penalty,
            "max_new_tokens": max_new_tokens,
            "stop": conv.stop_str,
            "stop_token_ids": conv.stop_token_ids,
            "echo": False,
        }
        res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda")

        for res in res_iter:
            pass

        output = res["text"]
        output_ids = self.tokenizer.encode(output)

        return [input_ids[0].tolist() + output_ids]
