import math
import torch
from typing import Optional
from torch.utils.data import DataLoader
import re
import numpy as np
import torch.distributed as dist
from hydragen.utils import get_rank
from tqdm import tqdm
import sys
import typer
import os
import random
from transformers import AutoTokenizer
from pathlib import Path

from hydragen.haystack import make_needle_haystack
from hydragen.haystack import make_multi_needle_haystack
from hydragen.benchmark_utils import (
    NeedlesBenchmarkResult,
)
from hydragen.utils import dtype_map, maybe_init_dist, save_yaml, dataclass_to_dict
from hydragen.tp import from_pretrained_tp
from hydragen.llama import HydragenLlamaForCausalLM, SharedCacheOp

ANS_RE = re.compile(r"###(.*?)###")
INVALID_ANS = "[invalid]"


def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS


def is_correct(model_completion, gt_example):
    gt_answer = extract_answer(gt_example)
    assert gt_answer != INVALID_ANS
    return extract_answer(model_completion) == gt_answer


def is_correct_batched(model_completions, answers):
    return [
        is_correct(model_completion, answer)
        for model_completion, answer in zip(model_completions, answers)
    ]


def main(
    pretrained_name: str = "princeton-nlp/Sheared-LLaMA-1.3B",
    device: str = "cuda",
    dtype: str = "bfloat16",
    graph: bool = True,
    tp_path: Optional[Path] = None,
    save_dir: str = "results",
    save_name: str = "needles",
    num_few_shot: int = 5,
    disable_hydragen: bool = False,
    disable_attention: bool = False,
    base_prompt_string_length: int = 50000,
    base_prompt_length: int = 1000,
    base_q_length: int = 10,
    num_timing_iters: int = 10,
    num_warmup_iters: int = 5,
    num_questions: Optional[int] = None,
    measure_unique_prefill: bool = True,
):
    rank = maybe_init_dist()
    use_tp = rank is not None

    dtype = dtype_map[dtype]
    tokenizer = AutoTokenizer.from_pretrained(pretrained_name)
    tokenizer.add_bos_token = False
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"
    tokenizer.truncation_side='left'
    
    if use_tp:
        assert tp_path is not None
        model = from_pretrained_tp(pretrained_name, tp_path, dtype)
        torch.manual_seed(9)  # needed to make sure sampling is same on each device
    else:
        model = HydragenLlamaForCausalLM.from_pretrained(
            pretrained_name, torch_dtype=dtype, device_map=device
        )

    random.seed(9)
    unique_doc_range = [i for i in range(1,11)]
    num_parallel_token_list = [2**i for i in range(1,11)]
    for num_questions  in num_parallel_token_list:
        for unique in unique_doc_range:
            try:
                list_haystack, list_needles = make_multi_needle_haystack(
                    target_context_length=int(base_prompt_length/10),
                    num_few_shot = num_few_shot,
                    num_needles=num_questions,
                    num_unique_doc= unique
                )
                questions =["Question: " + needle.question + "\nAnswer:"  for needles in list_needles for needle in needles ]

                input_ids = []
                seq_lens = []

                for i, prompts in enumerate([list_haystack,questions]):
                    prompt_inputs = tokenizer(prompts, return_tensors="pt", padding="max_length", truncation=True,max_length =base_prompt_length).to(device)
                    if i ==1:
                        prompt_inputs = tokenizer(prompts, return_tensors="pt", padding="max_length", truncation=True,max_length =base_q_length).to(device)
                    input_ids.append(prompt_inputs.input_ids)
                    #print(tokenizer.batch_decode(prompt_inputs.input_ids))
                    #seq_lens.append(prompt_inputs.attention_mask.sum(dim=1))
                    seq_lens.append(torch.ones_like(prompt_inputs.attention_mask).to(model.device).sum(dim=1))
                
                """
                prompt_list = [
                    ["The best starter Pokemon is"],
                    ["Piplup because"*5, "Chimchar because"*5, "Turtwig because"*5],
                ]

                input_ids = []
                seq_lens = []

                for i, prompts in enumerate(prompt_list):
                    prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=(i==0)).to(device)
                    input_ids.append(prompt_inputs.input_ids)
                    seq_lens.append(prompt_inputs.attention_mask.sum(dim=1))

                """

                #print(input_ids)
                #print(seq_lens)
                #var = input()
                times = []
                mems, prefill_mems = [], []
                for itr in range(num_timing_iters):
                    model.setup_caches(
                        max_unique_batch_size= input_ids[-1].shape[0],
                        max_unique_seq_length=5,
                        max_shared_batch_sizes=[ids.shape[0] for ids in input_ids],
                        max_shared_seq_lengths=[ids.shape[1] for ids in input_ids],
                    )
                    model.graph()



                    def get_time(fn):
                        torch.cuda.reset_peak_memory_stats("cuda")
                        torch.cuda.empty_cache()
                        start = torch.cuda.Event(enable_timing=True)
                        end = torch.cuda.Event(enable_timing=True)
                        start.record()

                        out = fn()

                        end.record()
                        torch.cuda.synchronize()
                        peak_memory = torch.cuda.max_memory_allocated("cuda") / (1024**2)

                        return out, start.elapsed_time(end), peak_memory
                    #print([needle_attention_masks.sum(1)])
                    #print([needle_ids.shape[1] for item in range(needle_ids.shape[0])])
                    new_ids, time, mem= get_time(
                        lambda: model.generate(
                            input_ids=input_ids,
                            seq_lens=seq_lens,
                            num_return_sequences=1,
                            max_new_tokens=5,
                            temperature=0,
                            shared_cache_op=SharedCacheOp.PRESERVE,
                        )
                    )
                    #print(prefill_time)
                    #print(time)
                    times.append(time)
                    mems.append(mem)


                times = times[num_warmup_iters:]
                result_times = np.mean(times)
                result_mems= np.mean(mems[num_warmup_iters:])
                model_answers = tokenizer.batch_decode(new_ids)
                #print(model_answers)

                #print("Times:", times)
                #print("unique_prefill",unique_prefill_times)
                #print("prefill", prefill_time)
                #print("Peakmem:", mems)
                #print("Prefill_peakmem:", prefill_mems)
                total_time = np.sum(result_times)
                max_mem = np.max(result_mems)
                print(f"num_dod: {unique}, num_q: {num_questions} Time: {total_time} Mem: {max_mem}")
            except RuntimeError as e:
                #continue
                print(f"num_doc: {unique} num_q: {num_questions} OOM")
if __name__ == "__main__":
    typer.run(main)
