import math
import torch
from typing import Optional
from torch.utils.data import DataLoader
import re
import numpy as np
import torch.distributed as dist
#from hydragen.utils import get_rank
from tqdm import tqdm
import sys
import typer
import os
import random
from evaluate import load
from transformers import AutoTokenizer
from pathlib import Path
from datasets import load_dataset
from collections import defaultdict
#from hydragen.haystack import make_multi_needle_haystack
#from hydragen.benchmark_utils import (
#    NeedlesBenchmarkResult,
#)
#from hydragen.utils import dtype_map, maybe_init_dist, save_yaml, dataclass_to_dict
#from hydragen.tp import from_pretrained_tp
#from hydragen.llama import HydragenLlamaForCausalLM, SharedCacheOp
from transformers import AutoTokenizer, AutoModelForCausalLM
from processor import PrePackProcessor
from model import CustomCausalLlamaModel

ANS_RE = re.compile(r"###(.*?)###")
INVALID_ANS = "[invalid]"

#os.environ['CUDA_LAUNCH_BLOCKING']="1"
#os.environ['TORCH_USE_CUDA_DSA'] = "1"

def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS


def is_correct(model_completion, gt_example):
    gt_answer = extract_answer(gt_example)
    assert gt_answer != INVALID_ANS
    return extract_answer(model_completion) == gt_answer


def is_correct_batched(model_completions, answers):
    return [
        is_correct(model_completion, answer)
        for model_completion, answer in zip(model_completions, answers)
    ]


def main(
    pretrained_name: str = "princeton-nlp/Sheared-LLaMA-1.3B",
    device: str = "cuda",
    dtype: str = "bfloat16",
    graph: bool = True,
    tp_path: Optional[Path] = None,
    save_dir: str = "results",
    save_name: str = "needles",
    num_few_shot: int = 5,
    num_parallel_token: int = 64,
    disable_hydragen: bool = False,
    disable_attention: bool = False,
    base_prompt_length: int = 100,
    base_q_length: int = 10,
    num_timing_iters: int = 10,
    num_warmup_iters: int = 1,
    num_questions: Optional[int] = None,
    measure_unique_prefill: bool = True,
):

    tokenizer = AutoTokenizer.from_pretrained(pretrained_name,padding_side='left')
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"
    tokenizer.truncation_side='left'
    custom_model = CustomCausalLlamaModel.from_pretrained(pretrained_name,attn_implementation = "sdpa", torch_dtype=torch.float16,  device_map="auto")
    #custom_model = AutoModelForCausalLM.from_pretrained(pretrained_name,attn_implementation = "sdpa", torch_dtype=torch.float16,  device_map="auto")
    custom_model.generation_config.pad_token_id = tokenizer.pad_token_id
    custom_model.eval()


    allexps = {}
    #dataset = load_dataset("squad",split='validation')
    data = load_dataset("meta-llama/Llama-3.2-1B-evals",
        name="Llama-3.2-1B-evals__agieval_english__details",
        split="latest"
    )
    squad_metric = load("squad_v2")
    predictions, references = [],[]
    group_dict = defaultdict(list)

    #for data_item in dataset:
    #    group_dict[data_item["context"]].append((data_item["question"],data_item["answers"], data_item["id"]))
    
    for data_item in data:
        raw_context = data_item["input_final_prompts"][0].split(data_item["input_question"])[0]
        list_context = raw_context.strip("\n ").split("\n\n")
        context = ""
        if "Answer:" in list_context[-1]:
            question = ""
            list_example = list_context
        else:
            question = list_context[-1]
            list_example = list_context[:-1]
        for index, example in enumerate(list_example):
            list_example_snt = example.split("\n")
            choice = {}
            def contains_any_substring(string, substrings):
                for substring in substrings:
                    if substring in string:
                        return substring
                return None
            for item in list_example_snt:
                target = contains_any_substring(item,[" A. "," B. "," C. "," D. "," E. "])
                if target != None:
                    choice[target[1]] = item.split(target)[-1]
                if "Answer: " in item and item[-1] in choice.keys():
                    context += (item+". "+choice[item[-1]])
                else:
                    context += item+"\n"
            context += "\n\n"
        print("changed:")
        print(context)
        print("Q")
        question += "\n\n"+data_item["input_question"] + "\n" + "\n".join([str(v) for k,v in data_item["input_choice_list"].items()]) 
        print(question)
        indexmap_key ={"0":"A","1":"B","2":"C","3":"D","4":"E"}
        answer_value = {}
        for k, v in data_item["input_choice_list"].items():
            if v != None:
                if v[1] == ".": 
                    answer_value[v[0]] = v[2:]
                else:
                    answer_value[indexmap_key[k]] = v
        #print(data_item["input_choice_list"].items())
        #print(answer_value)
        answer = [data_item["input_correct_responses"][0].split()[1]+". "+answer_value[data_item["input_correct_responses"][0].split()[1]]]
        #answer = [data_item["input_correct_responses"][0].split()[1]]
        group_dict[context].append((question,answer,data_item["input_question_hash"]))
    print("finish")
    var = input()
    print(len(group_dict))
    for tiem  in group_dict.keys():
        print("key:")
        print(tiem)
        
    sorted_doc = []
    sorted_qs = []
    sorted_doc_tmp = []
    sorted_qs_tmp = []
    pre_num_q = 0
    for key, value in sorted(group_dict.items(), key=lambda item: len(item[1]),reverse=True):

        if pre_num_q == 0:
            pre_num_q = len(value)
            sorted_doc_tmp.append(key)
            sorted_qs_tmp.append(value)
            continue
        elif pre_num_q != len(value):
            sorted_doc.append(sorted_doc_tmp.copy())
            sorted_qs.append(sorted_qs_tmp.copy())
            sorted_doc_tmp = []
            sorted_qs_tmp = []
            sorted_doc_tmp.append(key)
            sorted_qs_tmp.append(value)
            pre_num_q = len(value)
        elif (len(sorted_doc_tmp)+1)*len(value) >= 40:
            sorted_doc.append(sorted_doc_tmp.copy())
            sorted_qs.append(sorted_qs_tmp.copy())
            sorted_doc_tmp = []
            sorted_qs_tmp = []
            sorted_doc_tmp.append(key)
            sorted_qs_tmp.append(value)
            pre_num_q = len(value)
        else:
            sorted_doc_tmp.append(key)
            sorted_qs_tmp.append(value)
 

    # List of Tuples, where first element is num questions,
    # second is a list of tuples incorrect answers (question, model_answer, correct_answer)
    result_time = []
    result_mem = []
    for doc, qs in zip(sorted_doc, sorted_qs):
        list_possible_parallel = []
        for i in range(1,len(qs[0])+1,1):
            if len(qs[0])%i == 0:
                list_possible_parallel.append(i)
        num_q_prompt = list_possible_parallel[-1:][0]

        
        #questions =[["Q[0]: " + q[0] + "\n\nA[0]:" for q in list_q] for list_q in qs ]
        questions =[["Question[0]: " + q[0] + "\nAnswer[0]:" for q in list_q[x:x+num_q_prompt]] for list_q in qs  for x in range(0,len(list_q),num_q_prompt) ]
        batch_answers = [
            q[1]
            for item_doc, list_q in zip(doc,qs) for q in list_q
        ]
        batch_ids = [
            q[2]
            for item_doc, list_q in zip(doc,qs) for q in list_q
        ]
        
        processor = PrePackProcessor(tokenizer)
        base_prompt = [
            item_doc
            for item_doc, list_q in zip(doc,qs) for x in range(0, len(list_q),num_q_prompt)
        ]

        #print("DOC:")
        #print(base_prompt)
        #print("QS:")
        #print(questions)
        #parallel_token = len(qs[0])
        parallel_token = num_q_prompt
        #print("parallel_token")
        #print(parallel_token)
        """
        new_parallel_token = parallel_token
        for i in range(2,parallel_token+1, 1):
            if parallel_token%i == 0:
                new_parallel_token = i
                break
        parallel_token = 1
        print("change to ")
        print(parallel_token)
        """
        packed_tokens, restart_positions, independent_mask, _, _, final_tokens, f_tokens, f_positons, doc_position_ids = processor.parralel_process_tensor_batch_efficient2(base_prompt,questions,parallel_token, base_prompt_length, base_q_length)
        #print("input tensor")
        #print(f_tokens.shape)
        model_inputs = tokenizer.batch_decode(f_tokens)
        #print("actuall input:")
        #print(model_inputs)
        #var = input()

        f_tokens = f_tokens.to(device)
        independent_mask = independent_mask.to(device)
        f_positons = f_positons.to(device)
        
        max_num_tokens = 10
        #start = torch.cuda.Event(enable_timing=True)
        #end = torch.cuda.Event(enable_timing=True)
        #start.record()

        #normal_outputs = custom_model(
        #    input_ids=f_tokens[None, 0,:len(doc_position_ids)].to(device),
        #    return_dict=True,
        #    output_hidden_states=True,
        #)

        #end.record()
        #torch.cuda.synchronize()

        #prefill_time = start.elapsed_time(end)


        times = []
        mems = []
        unique_prefill_times, unique_prefill_warmup_times = [], []
        for itr in range(num_timing_iters):
        
            def get_time(fn):
                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()
                out = fn()
                end.record()
                torch.cuda.synchronize()

                return out, start.elapsed_time(end)
            
            if measure_unique_prefill:
                torch.cuda.reset_peak_memory_stats("cuda")
                torch.cuda.empty_cache()
                new_ids, time = get_time(
                    lambda: custom_model.generate(
                        input_ids=f_tokens,
                        attention_mask=independent_mask,
                        position_ids=f_positons,
                        max_new_tokens=max_num_tokens*parallel_token,
                        use_cache=True,
                        do_sample=False,
                        do_parallel=True,
                        # past_key_values=normal_outputs.past_key_values,
                        num_return_sequences=1,
                        output_scores=True,
                        return_dict_in_generate=True,
                        parallel_token = parallel_token
                    )
                )

                peak_memory = torch.cuda.max_memory_allocated("cuda") / (1024**2)
                mems.append(peak_memory)
                times.append(time)


            if measure_unique_prefill:
                _, unique_prefill_time = get_time(
                    lambda: custom_model.generate(
                        input_ids=f_tokens.to(device),
                        attention_mask=independent_mask.to(device),
                        position_ids=f_positons.to(device),
                        max_new_tokens=1,
                        use_cache=True,
                        do_sample=False,
                        do_parallel=True,
                        # past_key_values=normal_outputs.past_key_values,
                        num_return_sequences=1,
                        output_scores=True,
                        return_dict_in_generate=True,
                        parallel_token = parallel_token
                    )
                )

                unique_prefill_times.append(unique_prefill_time)
        warmup_times = times[:num_warmup_iters]
        #times = times[num_warmup_iters:]
        if measure_unique_prefill:
            unique_prefill_warmup_times = unique_prefill_times[:num_warmup_iters]
            #unique_prefill_times = unique_prefill_times[num_warmup_iters:]
        new_generated_output, final_attention_mask = new_ids
        batch_custom_generated_output_no_cache = []
        for item_index, item in enumerate(new_generated_output.sequences[:,f_tokens.shape[-1]:]):
            for mask in final_attention_mask[item_index]:
                batch_custom_generated_output_no_cache.append(tokenizer.decode(torch.masked_select(item, mask[-item.shape[-1]:].ge(0.5))))
                         
        model_answers = batch_custom_generated_output_no_cache
        #print("predict:")
        #print(model_answers)
        #print(len(model_answers))
        #print("Gold:")
        #print(test_answers)
        
        for answer, batch_gold, batch_id in zip(model_answers,batch_answers,batch_ids):
            dict_output = {}
            dict_output["prediction_text"] = answer.split("\n\n")[0].strip().split("<|end_of_text|>")[0]
            dict_output["id"] = batch_id
            dict_output["no_answer_probability"] = 0.
            print("predict")
            print(answer.split("\n\n")[0].strip().split("<|end_of_text|>")[0])
            print(dict_output)
            #print(dict_output)
            print("golden")
            golden = {}
            golden["answers"] = {"text":batch_gold,'answer_start': [177, 177, 177]}
            golden["id"] = batch_id
            print(golden)
            predictions.append(dict_output)
            references.append(golden)
        results = squad_metric.compute(predictions=predictions , references=references)
        print(results)
        var = input()
        result_time.append(np.mean(times[num_warmup_iters:]))
        result_mem.append(np.mean(mems[num_warmup_iters:]))
        #print("TIME:")
        #print(np.mean(times[num_warmup_iters:])/(len(qs[0])*len(qs)))
    
    total_time = np.sum(result_time)
    print(f"squad time : {total_time}")
    peak_mem =  np.max(result_mem)
    print(f"PEAK MEM : {peak_mem}")

    results = squad_metric.compute(predictions=predictions , references=references)
    print(results)

    
if __name__ == "__main__":
    try:
        typer.run(main)
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("OOM")
        else:
            print(e)
