import math
import torch
from typing import Optional
from torch.utils.data import DataLoader
import re
import numpy as np
import torch.distributed as dist
#from hydragen.utils import get_rank
from tqdm import tqdm
import sys
import typer
import os
import random
from evaluate import load
from transformers import AutoTokenizer
from pathlib import Path
from datasets import load_dataset
from collections import defaultdict
#from hydragen.haystack import make_multi_needle_haystack
#from hydragen.benchmark_utils import (
#    NeedlesBenchmarkResult,
#)
#from hydragen.utils import dtype_map, maybe_init_dist, save_yaml, dataclass_to_dict
#from hydragen.tp import from_pretrained_tp
#from hydragen.llama import HydragenLlamaForCausalLM, SharedCacheOp
from transformers import AutoTokenizer, AutoModelForCausalLM
from processor import PrePackProcessor
from model import CustomCausalLlamaModel

ANS_RE = re.compile(r"###(.*?)###")
INVALID_ANS = "[invalid]"

#os.environ['CUDA_LAUNCH_BLOCKING']="1"
#os.environ['TORCH_USE_CUDA_DSA'] = "1"

def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS


def is_correct(model_completion, gt_example):
    gt_answer = extract_answer(gt_example)
    assert gt_answer != INVALID_ANS
    return extract_answer(model_completion) == gt_answer


def is_correct_batched(model_completions, answers):
    return [
        is_correct(model_completion, answer)
        for model_completion, answer in zip(model_completions, answers)
    ]


def main(
    pretrained_name: str = "princeton-nlp/Sheared-LLaMA-1.3B",
    device: str = "cuda",
    dtype: str = "bfloat16",
    graph: bool = True,
    tp_path: Optional[Path] = None,
    save_dir: str = "results",
    save_name: str = "needles",
    num_few_shot: int = 5,
    num_parallel_token: int = 64,
    disable_hydragen: bool = False,
    disable_attention: bool = False,
    base_prompt_length: int = 100,
    base_q_length: int = 10,
    num_timing_iters: int = 10,
    num_warmup_iters: int = 1,
    num_questions: Optional[int] = None,
    measure_unique_prefill: bool = True,
):

    tokenizer = AutoTokenizer.from_pretrained(pretrained_name,padding_side='left')
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"
    tokenizer.truncation_side='left'
    custom_model = CustomCausalLlamaModel.from_pretrained(pretrained_name,attn_implementation = "sdpa", torch_dtype=torch.float16,  device_map="auto")
    #custom_model = AutoModelForCausalLM.from_pretrained(pretrained_name,attn_implementation = "sdpa", torch_dtype=torch.float16,  device_map="auto")
    custom_model.generation_config.pad_token_id = tokenizer.pad_token_id
    custom_model.eval()


    allexps = {}
    dataset = load_dataset("narrativeqa",split='validation')
    squad_metric = load("squad_v2")
    predictions, references = [],[]
    group_dict = defaultdict(list)

    for data_item in dataset:
        print(data_item)
        group_dict[data_item["document"]["text"]].append((data_item["question"]["text"],{"text":[item["text"] for item in data_item["answers"]],"answer_start":["0"]}, data_item["document"]["id"]))
    
    sorted_doc = []
    sorted_qs = []
    sorted_doc_tmp = []
    sorted_qs_tmp = []
    pre_num_q = 0
    for key, value in sorted(group_dict.items(), key=lambda item: len(item[1]),reverse=True):

        if pre_num_q == 0:
            pre_num_q = len(value)
            sorted_doc_tmp.append(key)
            sorted_qs_tmp.append(value)
            continue
        elif pre_num_q != len(value):
            sorted_doc.append(sorted_doc_tmp.copy())
            sorted_qs.append(sorted_qs_tmp.copy())
            sorted_doc_tmp = []
            sorted_qs_tmp = []
            sorted_doc_tmp.append(key)
            sorted_qs_tmp.append(value)
            pre_num_q = len(value)
        elif (len(sorted_doc_tmp)+1)*len(value) >= 40:
            sorted_doc.append(sorted_doc_tmp.copy())
            sorted_qs.append(sorted_qs_tmp.copy())
            sorted_doc_tmp = []
            sorted_qs_tmp = []
            sorted_doc_tmp.append(key)
            sorted_qs_tmp.append(value)
            pre_num_q = len(value)
        else:
            sorted_doc_tmp.append(key)
            sorted_qs_tmp.append(value)
 

    # List of Tuples, where first element is num questions,
    # second is a list of tuples incorrect answers (question, model_answer, correct_answer)
    result_time = []
    result_mem = []
    for doc, qs in zip(sorted_doc, sorted_qs):
        list_possible_parallel = []
        for i in range(1,len(qs[0])+1,1):
            if len(qs[0])%i == 0:
                list_possible_parallel.append(i)
        num_q_prompt = list_possible_parallel[-1:][0]

        
        questions =[["Q[0]: " + q[0] + "\n\nA[0]:" for q in list_q] for list_q in qs ]
        questions =[["Q[0]: " + q[0] + "\n\nA[0]:" for q in list_q[x:x+num_q_prompt]] for list_q in qs  for x in range(0,len(list_q),num_q_prompt) ]
        batch_answers = [
            q[1]
            for item_doc, list_q in zip(doc,qs) for q in list_q
        ]
        batch_ids = [
            q[2]
            for item_doc, list_q in zip(doc,qs) for q in list_q
        ]
        
        processor = PrePackProcessor(tokenizer)
        base_prompt = [
            "Answer each question using information in the preceding background paragraph.\n\nBackground:  At Madeline Hall, an old mansion-house near Southampton belonging to the wealthy de Versely family, lives an elderly spinster Miss Delmar, the aunt of the earl de Versely and Captain Delmar. Miss Delmar invites Arabella Mason, the daughter of a deceased, well-liked steward to stay with her as a lower-class guest in the house. Captain Delmar is known to visit his aunt at Madeline Hall frequently, accompanied by his valet Ben Keene, who is also a private marine. Captain Delmar eventually suggests that Ben should propose to Arabella, and the two marry in secret, to the frustration of Miss Delmar and Arabella's mother. The captain is able to smooth over the situation with his aunt, even after it is discovered that Arabella was six months pregnant at the time of the marriage. She later gives birth to a boy, who takes the Captain's Christian name and Ben's surname--the titular Percival Keene.\nThe family moves to Chatham, after Ben is ordered back with his detachment. Arabella opens up a successful shop and circulating library below her house, enlisting the help of her mother and sister, Amelia. Percival becomes well known in town from his mischievous pranks on officers and other strangers, often encouraged by his aunt Amelia. However, Percival's mother and grandmother are less fond of his disregard for manners, and insist on sending him to school after an episode in which he bites his grandmother. Percival reports to the school house of Mr. O'Gallagher, a poor Irish scholar, who rules his class with a system of severe corporal punishment. Mr. O'Gallagher routinely bullies Percival by stealing his lunch, leading Percival to seek revenge by poisoning his sandwiches with calomel. On Guy Fawkes Day the schoolteacher confiscates all the schoolboys' fireworks, for which Percival retaliates by setting off the collected fireworks while the teacher sits above them, leading to the total destruction of the schoolhouse and near death of the schoolmaster. When Percival is a young teenager, Captain Delmar reappears and offers him a position aboard his new navy ship, the H.M. Calliope. While preparing to enter service, Percival overhears gossip of his illegitimate birth, introducing the idea that Captain Delmar may be his father. He confronts his mother about his parentage, which she at first harshly denies but later tearfully explains the truth of her affair. Early in his service in the navy, Percival is captured during a pirate raid along with others. The pirate crew is entirely black, and the captain explains that they are primarily escaped slaves from the Americas. Percival is taken in as a cabin boy, and later dyes his skin tan in the appearance of a mulatto to please the captain who doesn't approve of white skin. The pirates often seek to take over slave trading vessels, killing every white person on board. During the taking of one such vessel, Percival is able is convince the captain to spare the lives of a wealthy Dutch merchant and his young daughter, Minnie. Eventually the H.M. Calliope takes the pirate ship, and Percival--unrecognizable with his dyed skin--is taken as a prisoner, later to convince his fellow shipman of his true identity.\nAfter his reappearance aboard the ship, Percival gains esteem among the crew and is welcomed back by the emotional Captain Delmar. His reputation continues to grow over the course of his service in conflicts with Dutch and French vessels around the island of Curacao. He also stands in for an ill Captain Delmar in a duel with a French officer, effectively saving the captain's life. At this point, the captain receives news that his older brother has died, making him the new Lord de Versely, and before returning to England he grants Perceval command of his own schooner. After another intense but successful battle with a French war ship, Percival is promoted to captain. During his service in the Navy, Percival still partakes in the merry pranks of his youth, and at one point teams up with a mulatto hotel owner in Cura to convince his fellow officers they've been poisoned. He also keeps correspondence with Minnie, developing a romance with the beautiful heiress.\nNear the end of the story, Percival guides his crew through a terrible storm in which many of the crew are killed and the ship is heavily damaged. After being saved by another English vessel, he receives a letter informing him of Lord de Versely's sudden death from heart complications and learns that he has been left all of his personal property. Percival is still disappointed that he can not take his father's name. He later journey's with his friend Bob Cross to Hamburg to reunite with Minnie, but is captured by French troops on the road and sentenced to execution for spying. During a skirmish between the French and the Cossacks, Percival and Cross are able to escape and continue on the road. At the end of the novel, Percival proposes to Minnie, and stands to inherit a great fortune through her father. He also receives a letter from the de Versely attorney letting him know he has been granted the arms and name of Delmar.\n\nQ[0]: Who is Miss Delmer?\n\nQ[1]: Who does Arabella Mason wed?\n\nQ[2]: How does Percival Keene get his name?\n\nQ[3]: Who is the bully that steals Percival's lunch?\n\nA[0]: the elderly spinster aunt of the Earl de Verseley and Captain Delmar\n\nA[1]: Ben Keene, Delmar's valet\n\nA[2]: Percival is Captain Delmar's first name, and Keene is Ben's last name\n\nA[3]: his teacher, Mr. O'Gallagher\n\nBackground: " +item_doc + "\n\n"
            for item_doc in doc
        ]
        base_prompt = [
                "Answer each question using information in the preceding background paragraph.\n\nBackground:  At Madeline Hall, an old mansion-house near Southampton belonging to the wealthy de Versely family, lives an elderly spinster Miss Delmar, the aunt of the earl de Versely and Captain Delmar. Miss Delmar invites Arabella Mason, the daughter of a deceased, well-liked steward to stay with her as a lower-class guest in the house. Captain Delmar is known to visit his aunt at Madeline Hall frequently, accompanied by his valet Ben Keene, who is also a private marine. Captain Delmar eventually suggests that Ben should propose to Arabella, and the two marry in secret, to the frustration of Miss Delmar and Arabella's mother. The captain is able to smooth over the situation with his aunt, even after it is discovered that Arabella was six months pregnant at the time of the marriage. She later gives birth to a boy, who takes the Captain's Christian name and Ben's surname--the titular Percival Keene.\nThe family moves to Chatham, after Ben is ordered back with his detachment. Arabella opens up a successful shop and circulating library below her house, enlisting the help of her mother and sister, Amelia. Percival becomes well known in town from his mischievous pranks on officers and other strangers, often encouraged by his aunt Amelia. However, Percival's mother and grandmother are less fond of his disregard for manners, and insist on sending him to school after an episode in which he bites his grandmother. Percival reports to the school house of Mr. O'Gallagher, a poor Irish scholar, who rules his class with a system of severe corporal punishment. Mr. O'Gallagher routinely bullies Percival by stealing his lunch, leading Percival to seek revenge by poisoning his sandwiches with calomel. On Guy Fawkes Day the schoolteacher confiscates all the schoolboys' fireworks, for which Percival retaliates by setting off the collected fireworks while the teacher sits above them, leading to the total destruction of the schoolhouse and near death of the schoolmaster. When Percival is a young teenager, Captain Delmar reappears and offers him a position aboard his new navy ship, the H.M. Calliope. While preparing to enter service, Percival overhears gossip of his illegitimate birth, introducing the idea that Captain Delmar may be his father. He confronts his mother about his parentage, which she at first harshly denies but later tearfully explains the truth of her affair. Early in his service in the navy, Percival is captured during a pirate raid along with others. The pirate crew is entirely black, and the captain explains that they are primarily escaped slaves from the Americas. Percival is taken in as a cabin boy, and later dyes his skin tan in the appearance of a mulatto to please the captain who doesn't approve of white skin. The pirates often seek to take over slave trading vessels, killing every white person on board. During the taking of one such vessel, Percival is able is convince the captain to spare the lives of a wealthy Dutch merchant and his young daughter, Minnie. Eventually the H.M. Calliope takes the pirate ship, and Percival--unrecognizable with his dyed skin--is taken as a prisoner, later to convince his fellow shipman of his true identity.\nAfter his reappearance aboard the ship, Percival gains esteem among the crew and is welcomed back by the emotional Captain Delmar. His reputation continues to grow over the course of his service in conflicts with Dutch and French vessels around the island of Curacao. He also stands in for an ill Captain Delmar in a duel with a French officer, effectively saving the captain's life. At this point, the captain receives news that his older brother has died, making him the new Lord de Versely, and before returning to England he grants Perceval command of his own schooner. After another intense but successful battle with a French war ship, Percival is promoted to captain. During his service in the Navy, Percival still partakes in the merry pranks of his youth, and at one point teams up with a mulatto hotel owner in Cura to convince his fellow officers they've been poisoned. He also keeps correspondence with Minnie, developing a romance with the beautiful heiress.\nNear the end of the story, Percival guides his crew through a terrible storm in which many of the crew are killed and the ship is heavily damaged. After being saved by another English vessel, he receives a letter informing him of Lord de Versely's sudden death from heart complications and learns that he has been left all of his personal property. Percival is still disappointed that he can not take his father's name. He later journey's with his friend Bob Cross to Hamburg to reunite with Minnie, but is captured by French troops on the road and sentenced to execution for spying. During a skirmish between the French and the Cossacks, Percival and Cross are able to escape and continue on the road. At the end of the novel, Percival proposes to Minnie, and stands to inherit a great fortune through her father. He also receives a letter from the de Versely attorney letting him know he has been granted the arms and name of Delmar.\n\nQ[0]: Who is Miss Delmer?\n\nQ[1]: Who does Arabella Mason wed?\n\nQ[2]: How does Percival Keene get his name?\n\nQ[3]: Who is the bully that steals Percival's lunch?\n\nA[0]: the elderly spinster aunt of the Earl de Verseley and Captain Delmar\n\nA[1]: Ben Keene, Delmar's valet\n\nA[2]: Percival is Captain Delmar's first name, and Keene is Ben's last name\n\nA[3]: his teacher, Mr. O'Gallagher\n\nBackground: " +item_doc + "\n\n"
            for item_doc, list_q in zip(doc,qs) for x in range(0, len(list_q),num_q_prompt)
        ]
        base_prompt = [
                "Answer each question using information in the preceding background paragraph.\n\nBackground: "+item_doc + "\n\n"
            for item_doc, list_q in zip(doc,qs) for x in range(0, len(list_q),num_q_prompt)
        ]
        #print("DOC:")
        #print(base_prompt)
        #print("QS:")
        #print(questions)
        #parallel_token = len(qs[0])
        parallel_token = num_q_prompt
        #print("parallel_token")
        #print(parallel_token)
        """
        new_parallel_token = parallel_token
        for i in range(2,parallel_token+1, 1):
            if parallel_token%i == 0:
                new_parallel_token = i
                break
        parallel_token = 1
        print("change to ")
        print(parallel_token)
        """
        packed_tokens, restart_positions, independent_mask, _, _, final_tokens, f_tokens, f_positons, doc_position_ids = processor.parralel_process_tensor_batch_efficient2(base_prompt,questions,parallel_token, base_prompt_length, base_q_length)
        #print("input tensor")
        #print(f_tokens.shape)
        model_inputs = tokenizer.batch_decode(f_tokens)
        print("actuall input:")
        print(model_inputs)
        #var = input()

        f_tokens = f_tokens.to(device)
        independent_mask = independent_mask.to(device)
        f_positons = f_positons.to(device)
        
        max_num_tokens = 10
        #start = torch.cuda.Event(enable_timing=True)
        #end = torch.cuda.Event(enable_timing=True)
        #start.record()

        #normal_outputs = custom_model(
        #    input_ids=f_tokens[None, 0,:len(doc_position_ids)].to(device),
        #    return_dict=True,
        #    output_hidden_states=True,
        #)

        #end.record()
        #torch.cuda.synchronize()

        #prefill_time = start.elapsed_time(end)


        times = []
        mems = []
        unique_prefill_times, unique_prefill_warmup_times = [], []
        for itr in range(num_timing_iters):
        
            def get_time(fn):
                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()
                out = fn()
                end.record()
                torch.cuda.synchronize()

                return out, start.elapsed_time(end)
            
            if measure_unique_prefill:
                torch.cuda.reset_peak_memory_stats("cuda")
                torch.cuda.empty_cache()
                new_ids, time = get_time(
                    lambda: custom_model.generate(
                        input_ids=f_tokens,
                        attention_mask=independent_mask,
                        position_ids=f_positons,
                        max_new_tokens=max_num_tokens*parallel_token,
                        use_cache=True,
                        do_sample=False,
                        do_parallel=True,
                        # past_key_values=normal_outputs.past_key_values,
                        num_return_sequences=1,
                        output_scores=True,
                        return_dict_in_generate=True,
                        parallel_token = parallel_token
                    )
                )

                peak_memory = torch.cuda.max_memory_allocated("cuda") / (1024**2)
                mems.append(peak_memory)
                times.append(time)


            if measure_unique_prefill:
                _, unique_prefill_time = get_time(
                    lambda: custom_model.generate(
                        input_ids=f_tokens.to(device),
                        attention_mask=independent_mask.to(device),
                        position_ids=f_positons.to(device),
                        max_new_tokens=1,
                        use_cache=True,
                        do_sample=False,
                        do_parallel=True,
                        # past_key_values=normal_outputs.past_key_values,
                        num_return_sequences=1,
                        output_scores=True,
                        return_dict_in_generate=True,
                        parallel_token = parallel_token
                    )
                )

                unique_prefill_times.append(unique_prefill_time)
        warmup_times = times[:num_warmup_iters]
        #times = times[num_warmup_iters:]
        if measure_unique_prefill:
            unique_prefill_warmup_times = unique_prefill_times[:num_warmup_iters]
            #unique_prefill_times = unique_prefill_times[num_warmup_iters:]
        new_generated_output, final_attention_mask = new_ids
        batch_custom_generated_output_no_cache = []
        for item_index, item in enumerate(new_generated_output.sequences[:,f_tokens.shape[-1]:]):
            for mask in final_attention_mask[item_index]:
                batch_custom_generated_output_no_cache.append(tokenizer.decode(torch.masked_select(item, mask[-item.shape[-1]:].ge(0.5))))
                         
        model_answers = batch_custom_generated_output_no_cache
        #print("predict:")
        #print(model_answers)
        #print(len(model_answers))
        #print("Gold:")
        #print(test_answers)
        
        for answer, batch_gold, batch_id in zip(model_answers,batch_answers,batch_ids):
            dict_output = {}
            dict_output["prediction_text"] = answer.split("\n\n")[0].strip()
            dict_output["id"] = batch_id
            dict_output["no_answer_probability"] = 0.
            #print("predict")
            #print(answer)
            print(answer.split("\n\n")[0].strip())
            #print(dict_output)
            #print("golden")
            golden = {}
            golden["answers"] = batch_gold
            golden["id"] = batch_id
            print(golden)

            predictions.append(dict_output)
            references.append(golden)
        results = squad_metric.compute(predictions=predictions , references=references)
        print(results)
        var = input()
        result_time.append(np.mean(times[num_warmup_iters:]))
        result_mem.append(np.mean(mems[num_warmup_iters:]))
        #print("TIME:")
        #print(np.mean(times[num_warmup_iters:])/(len(qs[0])*len(qs)))
    
    total_time = np.sum(result_time)
    print(f"squad time : {total_time}")
    peak_mem =  np.max(result_mem)
    print(f"PEAK MEM : {peak_mem}")

    results = squad_metric.compute(predictions=predictions , references=references)
    print(results)

    
if __name__ == "__main__":
    try:
        typer.run(main)
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("OOM")
        else:
            print(e)
