import math
import torch
from typing import Optional
from torch.utils.data import DataLoader
import re
import numpy as np
import torch.distributed as dist
#from hydragen.utils import get_rank
from tqdm import tqdm
import sys
import typer
import os
import random
from evaluate import load
from transformers import AutoTokenizer
from pathlib import Path
from datasets import load_dataset
from collections import defaultdict
#from hydragen.haystack import make_multi_needle_haystack
#from hydragen.benchmark_utils import (
#    NeedlesBenchmarkResult,
#)
#from hydragen.utils import dtype_map, maybe_init_dist, save_yaml, dataclass_to_dict
#from hydragen.tp import from_pretrained_tp
#from hydragen.llama import HydragenLlamaForCausalLM, SharedCacheOp
from transformers import AutoTokenizer, AutoModelForCausalLM
from processor import PrePackProcessor
from model import CustomCausalLlamaModel

ANS_RE = re.compile(r"###(.*?)###")
INVALID_ANS = "[invalid]"

#os.environ['CUDA_LAUNCH_BLOCKING']="1"
#os.environ['TORCH_USE_CUDA_DSA'] = "1"

def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS


def is_correct(model_completion, gt_example):
    gt_answer = extract_answer(gt_example)
    assert gt_answer != INVALID_ANS
    return extract_answer(model_completion) == gt_answer


def is_correct_batched(model_completions, answers):
    return [
        is_correct(model_completion, answer)
        for model_completion, answer in zip(model_completions, answers)
    ]


def main(
    pretrained_name: str = "princeton-nlp/Sheared-LLaMA-1.3B",
    device: str = "cuda",
    dtype: str = "bfloat16",
    graph: bool = True,
    tp_path: Optional[Path] = None,
    save_dir: str = "results",
    save_name: str = "needles",
    num_few_shot: int = 5,
    num_parallel_token: int = 64,
    disable_hydragen: bool = False,
    disable_attention: bool = False,
    base_prompt_length: int = 100,
    base_q_length: int = 10,
    num_timing_iters: int = 10,
    num_warmup_iters: int = 1,
    num_questions: Optional[int] = None,
    measure_unique_prefill: bool = True,
):

    tokenizer = AutoTokenizer.from_pretrained(pretrained_name,padding_side='left')
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"
    tokenizer.truncation_side='left'
    custom_model = CustomCausalLlamaModel.from_pretrained(pretrained_name,attn_implementation = "sdpa", torch_dtype=torch.float16,  device_map="auto")
    #custom_model = AutoModelForCausalLM.from_pretrained(pretrained_name,attn_implementation = "sdpa", torch_dtype=torch.float16,  device_map="auto")
    custom_model.generation_config.pad_token_id = tokenizer.pad_token_id
    custom_model.eval()


    allexps = {}
    #dataset = load_dataset("squad",split='validation')
    data = load_dataset(
            "meta-llama/Llama-3.2-1B-Instruct-evals",
            name = "Llama-3.2-1B-Instruct-evals__mmlu__details",
            split="latest"
        )
    #.shard(num_shards=4, index=0)
    squad_metric = load("squad_v2")
    predictions, references = [],[]
    group_dict = defaultdict(list)
    dict_hash_doc = {}
    dict_hash_qa = defaultdict(list)
    #for data_item in dataset:
    #    group_dict[data_item["context"]].append((data_item["question"],data_item["answers"], data_item["id"]))

    print(len(data))
    for data_item in data:
        context = data_item["input_final_prompts"][0].split("Question: "+data_item["input_question"])[0]
        question = data_item["input_question"] + "The candidate answers:\n" + "\n".join([str(k+". "+v) for k,v in data_item["input_choice_list"].items()])
        #"Your response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of "+ " ".join([str(k) for k,v in data_item["input_choice_list"].items()])
        answer = [data_item["input_correct_responses"][0][-1], data_item["input_choice_list"][data_item["input_correct_responses"][0][-1]]]
        group_dict[context].append((question,answer,data_item["input_question_hash"]))
        dict_hash_doc[data_item["input_question_hash"]] = context
        dict_hash_qa[data_item["input_question_hash"]]=(question,answer)
    print(len(group_dict))


    sorted_doc = []
    sorted_qs = []
    sorted_doc_tmp = []
    sorted_qs_tmp = []
    pre_num_q = 0
    for key, value in sorted(group_dict.items(), key=lambda item: len(item[1]),reverse=True):
        for i in range(32,len(value),1):
            if len(value)%i == 0:
                possible_qs = i
                break
        print("possible_qs")
        print(possible_qs)
        if pre_num_q == 0:
            pre_num_q = len(value)
            for x in range(0,len(value),possible_qs):
                tmp = []
                for i in range(possible_qs):
                    tmp.append(value[x+i][-1])
                sorted_qs_tmp.append(tmp.copy())
            #sorted_doc_tmp.append(key)
            #sorted_qs_tmp.append(value)
            continue
        elif pre_num_q != len(value):
            sorted_qs.append(sorted_qs_tmp.copy())
            #sorted_qs.append(sorted_qs_tmp.copy())
            #sorted_doc_tmp = []
            sorted_qs_tmp = []
            print("len(value)")
            print(len(value))
            print(possible_qs)
            for x in range(0,len(value),possible_qs):
                tmp = []
                for i in range(possible_qs):
                    tmp.append(value[x+i][-1])
                sorted_qs_tmp.append(tmp.copy())
            #sorted_doc_tmp.append(key)
            #sorted_qs_tmp.append(value)
            pre_num_q = len(value)
        elif (len(sorted_doc_tmp)+1)*len(value) >= possible_qs:
            sorted_qs.append(sorted_qs_tmp.copy())
            #sorted_qs.append(sorted_qs_tmp.copy())
            #sorted_doc_tmp = []
            #sorted_qs_tmp = []
            for x in range(0,len(value),possible_qs):
                tmp = []
                for i in range(possible_qs):
                    tmp.append(value[x+i][-1])
                sorted_qs_tmp.append(tmp.copy())

            #sorted_doc_tmp.append(key)
            #sorted_qs_tmp.append(value)
            pre_num_q = len(value)
        else:
            for x in range(0,len(value),possible_qs):
                tmp = []
                for i in range(possible_qs):
                    tmp.append(value[x+i][-1])
                sorted_qs_tmp.append(tmp.copy())
            #sorted_doc_tmp.append(value[0][-1])
            #sorted_doc_tmp.append(key)
            #sorted_qs_tmp.append(value)
 

    # List of Tuples, where first element is num questions,
    # second is a list of tuples incorrect answers (question, model_answer, correct_answer)
    result_time = []
    result_mem = []
    for qs_hash_list in sorted_qs:
        #print(qs_hash_list)
        doc = [dict_hash_doc[doc_hash[0]] for doc_hash in qs_hash_list]
        qs = [[dict_hash_qa[q_hash] for q_hash in doc_hash] for doc_hash in qs_hash_list]
        list_possible_parallel = []
        for i in range(1,len(qs[0]),1):
            if len(qs[0])%i == 0:
                list_possible_parallel.append(i)
        print(list_possible_parallel)
        num_q_prompt = list_possible_parallel[-1:][0]

        
        #questions =[["Q[0]: " + q[0] + "\n\nA[0]:" for q in list_q] for list_q in qs ]
        questions =[["Question[0]: " + q[0] + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nAnswer[0]:" for q in list_q[x:x+num_q_prompt]] for list_q in qs  for x in range(0,len(list_q),num_q_prompt) ]
        batch_answers = [
            q[1]
            for item_doc, list_q in zip(doc,qs) for q in list_q
        ]
        batch_ids = qs_hash_list
        
        processor = PrePackProcessor(tokenizer)
        base_prompt = [
            item_doc
            for item_doc, list_q in zip(doc,qs) for x in range(0, len(list_q),num_q_prompt)
        ]

        print("DOC:")
        print(len(base_prompt))
        print("QS:")
        print(len(questions))
        for q in questions:
            print(len(q))
        #parallel_token = len(qs[0])
        parallel_token = num_q_prompt
        print("parallel_token")
        print(parallel_token)
        packed_tokens, restart_positions, independent_mask, _, _, final_tokens, f_tokens, f_positons, doc_position_ids = processor.parralel_process_tensor_batch_efficient2(base_prompt,questions,parallel_token, base_prompt_length, base_q_length)
        print("input tensor")
        print(f_tokens.shape)
        #model_inputs = tokenizer.batch_decode(f_tokens)
        #print("actuall input:")
        #print(model_inputs)
        #var = input()

        f_tokens = f_tokens.to(device)
        independent_mask = independent_mask.to(device)
        f_positons = f_positons.to(device)
        
        max_num_tokens = 10
        #start = torch.cuda.Event(enable_timing=True)
        #end = torch.cuda.Event(enable_timing=True)
        #start.record()

        #normal_outputs = custom_model(
        #    input_ids=f_tokens[None, 0,:len(doc_position_ids)].to(device),
        #    return_dict=True,
        #    output_hidden_states=True,
        #)

        #end.record()
        #torch.cuda.synchronize()

        #prefill_time = start.elapsed_time(end)


        times = []
        mems = []
        unique_prefill_times, unique_prefill_warmup_times = [], []
        for itr in range(num_timing_iters):
        
            def get_time(fn):
                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()
                out = fn()
                end.record()
                torch.cuda.synchronize()

                return out, start.elapsed_time(end)
            
            if measure_unique_prefill:
                torch.cuda.reset_peak_memory_stats("cuda")
                torch.cuda.empty_cache()
                new_ids, time = get_time(
                    lambda: custom_model.generate(
                        input_ids=f_tokens,
                        attention_mask=independent_mask,
                        position_ids=f_positons,
                        max_new_tokens=max_num_tokens*parallel_token,
                        use_cache=True,
                        do_sample=False,
                        do_parallel=True,
                        # past_key_values=normal_outputs.past_key_values,
                        num_return_sequences=1,
                        output_scores=True,
                        return_dict_in_generate=True,
                        parallel_token = parallel_token
                    )
                )

                peak_memory = torch.cuda.max_memory_allocated("cuda") / (1024**2)
                mems.append(peak_memory)
                times.append(time)


            if measure_unique_prefill:
                _, unique_prefill_time = get_time(
                    lambda: custom_model.generate(
                        input_ids=f_tokens.to(device),
                        attention_mask=independent_mask.to(device),
                        position_ids=f_positons.to(device),
                        max_new_tokens=1,
                        use_cache=True,
                        do_sample=False,
                        do_parallel=True,
                        num_return_sequences=1,
                        output_scores=True,
                        return_dict_in_generate=True,
                        parallel_token = parallel_token
                    )
                )

        for item_index, item in enumerate(new_generated_output.sequences[:,f_tokens.shape[-1]:]):
            for mask in final_attention_mask[item_index]:
                batch_custom_generated_output_no_cache.append(tokenizer.decode(torch.masked_select(item, mask[-item.shape[-1]:].ge(0.5))))
                         
        model_answers = batch_custom_generated_output_no_cache
        #print("predict:")
        #print(model_answers)
        #print(len(model_answers))
        #print("Gold:")
        #print(test_answers)
        
        for answer, batch_gold, batch_id in zip(model_answers,batch_answers,batch_ids):
            dict_output = {}
            dict_output["prediction_text"] = answer.split("\n\n")[0].strip().split("<|end_of_text|>")[0]
            dict_output["id"] = batch_id
            dict_output["no_answer_probability"] = 0.
            print("predict")
            #print(answer.split("\n\n")[0].strip().split("<|end_of_text|>")[0])
            print(dict_output)
            #print(dict_output)
            print("golden")
            golden = {}
            golden["answers"] = {"text":batch_gold,'answer_start': [177, 177, 177]}
            golden["id"] = batch_id
            print(golden)
            predictions.append(dict_output)
            references.append(golden)
            #results = squad_metric.compute(predictions=predictions , references=references)
        #print(results)
        result_time.append(np.mean(times[num_warmup_iters:]))
        result_mem.append(np.mean(mems[num_warmup_iters:]))
        #print("TIME:")
        #print(np.mean(times[num_warmup_iters:])/(len(qs[0])*len(qs)))
    
    total_time = np.sum(result_time)
    print(f"squad time : {total_time}")
    peak_mem =  np.max(result_mem)
    print(f"PEAK MEM : {peak_mem}")

    results = squad_metric.compute(predictions=predictions , references=references)
    print(results)

    
if __name__ == "__main__":
    try:
        typer.run(main)
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("OOM")
        else:
            print(e)
