import torch
import time
import torch.nn.functional as F
from transformers import DynamicCache, SlidingWindowCache
from rnsa.qwen3 import RNSAQwen3ForCausalLM, RNSAQwen3Config
from rnsa.llama import RNSALlamaForCausalLM, RNSALlamaConfig
from rnsa.cache_utils import RNSACache
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import Qwen3ForCausalLM, LlamaForCausalLM, LlamaConfig
from transformers import Qwen2ForCausalLM, Phi3ForCausalLM, Phi3Config
from datasets import load_dataset

from rnsa.cache_utils import RNSACache

@torch.inference_mode()
def generate(
    model,
    tokenizer,
    prompts,
    max_new_tokens: int = 128,
    temperature: float = 1.0,
    top_p: float = 1.0,
    do_sample: bool = True,
    eos_token_id: int | None = None,
    device: str | torch.device | None = None,
):
    dev = device or model.device
    model.eval()

    model_inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=tokenizer.model_max_length,
    ).to(dev)

    end_token_ids = (
        [tokenizer.eos_token_id]
        + [model.config.eos_token_id]
    )

    if isinstance(model, RNSAQwen3ForCausalLM) or isinstance(model, RNSALlamaForCausalLM):
        pkv = RNSACache()
    else:
        pkv = DynamicCache()

    input_ids = model_inputs.input_ids
    attention_mask = model_inputs.attention_mask

    # start_time = torch.cuda.Event(enable_timing=True)
    # end_time = torch.cuda.Event(enable_timing=True)
    # start_time.record()
    # forward pass to get past_key_values
    out = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        past_key_values=pkv,
        use_cache=True
    )
    next_token_logits = out.logits[:, -1, :]          # (B, V)
    # end_time.record()
    # torch.cuda.synchronize()  # Wait for the events to be recorded
    # elapsed_time = start_time.elapsed_time(end_time)  # Time in milliseconds
    # print(f"Initial forward pass time: {elapsed_time:.2f} ms")

    # store per-sequence token lists
    generated = model_inputs.input_ids.tolist()       # List[List[int]]
    finished = [False] * len(prompts)                 # early-stop flags

    for idx in range(max_new_tokens):
        if temperature != 1.0:
            next_token_logits = next_token_logits / temperature

        probs = F.softmax(next_token_logits, dim=-1)

        if do_sample and top_p < 1.0:
            sorted_probs, sorted_idx = torch.sort(probs, dim=-1, descending=True)
            cumprobs = torch.cumsum(sorted_probs, dim=-1)
            mask = cumprobs > top_p
            sorted_probs = sorted_probs.masked_fill(mask, 0.0)
            probs = torch.zeros_like(probs).scatter_(1, sorted_idx, sorted_probs)
            probs = probs / probs.sum(dim=-1, keepdim=True)

        # sample / greedy
        next_token = (
            torch.multinomial(probs, num_samples=1) if do_sample
            else probs.argmax(dim=-1, keepdim=True)
        )                                             # (B, 1)
        # print(next_token, tokenizer.batch_decode(next_token, skip_special_tokens=True))
        # print(idx, tokenizer.batch_decode(next_token, skip_special_tokens=True)[0], end='\n', flush=True)
        print(tokenizer.batch_decode(next_token, skip_special_tokens=True)[0], end='', flush=True)

        # append and check EOS per sequence
        for i, tok in enumerate(next_token.squeeze(1).tolist()):
            if not finished[i]:
                generated[i].append(tok)
            if tok in end_token_ids:
                    finished[i] = True
        if all(finished):
            break

        attention_mask = torch.cat(
            [
                attention_mask,
                torch.ones((len(prompts), 1), dtype=torch.long, device=dev),
            ],
            dim=-1,
        )

        # start_time.record()
        # feed only the freshly generated tokens
        next_token_logits = model(
            input_ids=next_token,
            attention_mask=attention_mask,
            past_key_values=pkv,
            use_cache=True,
        ).logits[:, -1, :]                            # (B, V)
        # end_time.record()
        # torch.cuda.synchronize()  # Wait for the events to be recorded
        # elapsed_time = start_time.elapsed_time(end_time)  # Time in milliseconds
        # print(f"Forward pass time for token generation: {elapsed_time:.2f} ms", _)

    # remove the input_ids from the generated sequences
    texts = []

    for i, seq in enumerate(generated):
        # seq = seq[len(model_inputs.input_ids[i]):]
        texts.append(tokenizer.decode(seq, skip_special_tokens=False))
    return texts

prompt_template = "You are given a math problem.\n\nProblem: {question}\n\n You need to solve the problem step by step. First, you need to provide the chain-of-thought, then provide the final answer.\n\n Provide the final answer in the format: Final answer:  \\boxed{{}}"


def apply_chat_template(
    example,
    tokenizer,
    add_prefix: bool = False,
    prefix: str = "<think>\n",
) -> dict[str, str]:

    ## When using vllm to generate data, the prompt uses 'add_generation_prompt'.
    ## So the \<think> token does not appear in the assistant.
    ## This is not necessary if use OPEN-R1-MATH-220K dataset directly
    messages = example["messages"]
    if add_prefix: 
        assert messages[1]["role"] == "assistant"
        assistant_message = messages[1]["content"]
        assistant_message = prefix + assistant_message
        messages[1]["content"] = assistant_message

    messages = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True,
        enable_thinking=False,
    )

    return {"text": messages}

def apply_math_chat_template(
    example,
    tokenizer,
) -> dict[str, str]:
    prompt = prompt_template.format(question=example["problem"])
    messages = [
        { "role": "user", "content": prompt },
    ]
    text = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True,
    )

    return {"text": text}



import torch
import numpy as np
import random
import argparse

parser = argparse.ArgumentParser(description="Run Qwen3 model with RNSA attention")
parser.add_argument("--attn_impl", type=str, default="flex",)

args = parser.parse_args()

# set print options
torch.set_printoptions(precision=4, sci_mode=False, linewidth=120)
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


# model_id = "Qwen/Qwen3-1.7B"
# model_id = "/gpfs/radev/home/nhb25/scratch/jpmc/outputs/models/Qwen3-1.7B/rnsa_openr1_Qwen3-1.7B_fg4_rnsa_flex_memory512_fw5.0_bias10.0_ff_ebs4_wd0.01_lr5e-4/"
# model_id = "/gpfs/radev/home/nhb25/scratch/jpmc/outputs/models/Qwen3-1.7B/rnsa_openr1_Qwen3-1.7B_fg2_rnsa_flex_memory512_fw5.0_bias7.0_ff_ebs4_wd0.01_lr5e-4/"

# model_id = "/home/nhb25/scratch/jpmc/outputs/models/Qwen3-1.7B/rnsa_openr1_Qwen3-1.7B_fw_logits_distil_16384_fg4_rnsa_flex_memory512_fw1.0_bias8.0_ff_ebs4_wd0.01_lr5e-4/"

# mdl = RNSAQwen3ForCausalLM.from_pretrained(
#     model_id,
#     load_rnsa_weights=True,
#     # config=config,
#     torch_dtype=torch.bfloat16,
#     device_map="cuda:0",
# )

# model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
# model_id = "meta-llama/Llama-3.1-8B"
# model_id = "mobiuslabsgmbh/DeepSeek-R1-ReDistill-Llama3-8B-v1.1"
# model_id = "Qwen/Qwen3-1.7B"
# model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
model_id = "microsoft/Phi-4-reasoning"
# model_id = "microsoft/Phi-4-mini-reasoning"
mdl = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)
# model_id = "/home/nhb25/scratch/jpmc/outputs/models/DeepSeek-R1-Distill-Llama-8B/rnsa_openr1_DeepSeek-R1-Distill-Llama-8B_fw_logits_distil_16384_fg4_rnsa_flex_memory512_fw1.0_bias10.0_ff_ebs4_wd0.0_lr1e-4/"
# model_id = "/gpfs/radev/home/nhb25/scratch/jpmc/outputs/models/DeepSeek-R1-Distill-Llama-8B/rnsa_openr1_DeepSeek-R1-Distill-Llama-8B_fw_logits_distil_16384_fg4_rnsa_flex_memory512_fw5.0_bias10.0_ff_ebs4_wd0.01_lr5e-4/"

# config = RNSALlamaConfig.from_pretrained(model_id)
# config.attn_impl = args.attn_impl
# config.compile_attn = False
# config.compress_memory = False
# config.compress_strategy = "lw_alpha"
# config.memory_size = 1024

# mdl = RNSALlamaForCausalLM.from_pretrained(
#     model_id,
#     load_rnsa_weights=True,
#     # config=config,
#     torch_dtype=torch.bfloat16,
#     device_map="cuda:0",
# )

# mdl.config.attn_impl = args.attn_impl
# mdl.config.compress_memory = False
# mdl.config.compress_strategy = "lw_alpha"
# mdl.config.memory_size = 512


tok = AutoTokenizer.from_pretrained(getattr(mdl.config, "base_model", model_id))
# set padding side to left for Qwen3
tok.padding_side = "left"
# tok.pad_token = tok.eos_token

# prompts = [
#     # "Q: What is the capital of France?",
#     # "Answer the question: Who wrote 'To Kill a Mockingbird'?",
#     "Explain why the sky is blue in one sentence.",
# ]
# print("Prompt: ", prompts[0])
# prompts = [apply_chat_template({"messages": [{"role": "user", "content": p}]}, tok)["text"] for p in prompts]


dataset_name = "open-r1/OpenR1-Math-220k"
dataset = load_dataset(dataset_name, "default", split="train")
# print(dataset[0]['problem'])
# print(dataset[0]['answer'])
prompts = [apply_math_chat_template(dataset[0], tok)["text"]]

# prompts = [
    # "<｜begin▁of▁sentence｜><｜User｜>You are given a math problem.\n\nProblem: Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop.\n\n You need to solve the problem step by step. First, you need to provide the chain-of-thought, then provide the final answer.\n\n Provide the final answer in the format: Final answer:  \\boxed{}<｜Assistant｜><think>\n",
    # "<|im_start|>user\nYou are given a math problem.\n\nProblem: Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop.\n\n You need to solve the problem step by step. First, you need to provide the chain-of-thought, then provide the final answer.\n\n Provide the final answer in the format: Final answer:  \\boxed{}<|im_end|>\n<|im_start|>assistant\n",
# ]

print("Prompt after applying math chat template: ", prompts[0])



time_start = time.time()

text = generate(mdl, tok,
                prompts=prompts,
                max_new_tokens=16000,
                temperature=0.7,
                do_sample=False,
                top_p=0.9)
time_end = time.time()

print(f"Time taken: {time_end - time_start:.2f} seconds")
print(text)

with open(f"output_{args.attn_impl}.txt", "w") as f:
    f.write(text[0])

# if isinstance(mdl, RNSAQwen3ForCausalLM) or isinstance(mdl, RNSALlamaForCausalLM):
#     pkv = RNSACache(
#         max_seq_len=32000,
#     )
# else:
#     pkv = None
# mdl.eval()
# output = mdl.generate(
#     input_ids=tok(prompts, return_tensors="pt", padding=True).input_ids.to("cuda"),
#     max_new_tokens=16000,
#     temperature=0.6,
#     do_sample=True,
#     num_return_sequences=4,
#     top_p=0.9,
#     use_cache=True,
#     attention_mask=None,
#     past_key_values=pkv,
# )
# output_text = tok.batch_decode(output, skip_special_tokens=True)
# print("Output text:")
# for i, text in enumerate(output_text):
#     print(f"Prompt {i+1}: {text}")

