import os
import re
import json
import string
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path

import torch
import tiktoken
from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
from openai import OpenAI, AzureOpenAI

from extract import extract_ans_from_response

from collections import Counter
from prompts import PROMPTS

LONGBENCH_LOCAL_PATH=os.getenv('LONGBENCH_LOCAL_PATH', None)
print(LONGBENCH_LOCAL_PATH)
if LONGBENCH_LOCAL_PATH is None:
    print("LONGBENCH_LOCAL_PATH is None!!!")
    exit()

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth, **kwargs):
    common = Counter(prediction) & Counter(ground_truth)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction)
    recall = 1.0 * num_same / len(ground_truth)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def qa_f1_score(prediction, ground_truth, **kwargs):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    return f1_score(prediction_tokens, ground_truth_tokens)


def ensure_path_exists(path):
    path_obj = Path(path)
    if path_obj.is_file():
        path_obj.parent.mkdir(parents=True, exist_ok=True)
    else:
        path_obj.mkdir(parents=True, exist_ok=True)


def inference_vllm(dataset, model, max_tokens, temperature):
    
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, top_k=32, top_p=1)
    outputs = model.generate([item["prompt"] for item in dataset], sampling_params)
    
    def process(item, idx):
        item["prediction"] = outputs[idx].outputs[0].text
        item["prompt_token_length"] = len(outputs[idx].prompt_token_ids)
        item["pred_ans"] = extract_ans_from_response(item["prediction"])
        return item
    
    dataset = dataset.map(process, with_indices=True, num_proc=8)

    return dataset

def load_data(args, tokenizer):
    if args.benchmark == "longbench_qa":
        data_names = ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique"]
    if args.benchmark == "longbench_e_qa":
        data_names = ["qasper_e", "multifieldqa_en_e", "hotpotqa_e", "2wikimqa_e"]
    dataset_list = []
    for d_name in data_names:
        dataset_list.append(
            load_dataset(LONGBENCH_LOCAL_PATH, d_name, split="test", trust_remote_code=True)
        )
    dataset = concatenate_datasets(dataset_list)
   
    def string_to_token(text):
        encoded_inputs = tokenizer(text)
        tokens = encoded_inputs['input_ids']
        return tokens

    def token_to_string(tokens):
        return tokenizer.decode(tokens[:args.max_model_input_len], skip_special_tokens=True)

    def process(item):
        # truncate
        context = item["context"]
        if len(context) > args.max_model_input_len:
            tokens = string_to_token(context)
            if len(tokens) > args.max_model_input_len:
                context = token_to_string(tokens)

        item["prompt"] = tokenizer.apply_chat_template(
            [{"role": "user", "content": PROMPTS["short2long_qa"].format(context=context, input=item['input'])}],
            add_generation_prompt=True,
            tokenize=False
        )
        return item
    
    dataset = dataset.map(process, num_proc=8)
    
    return dataset

def inference(args):
    
    model = LLM(model=args.model_name_or_path, tensor_parallel_size=args.tensor_parallel_size, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    
    dataset = load_data(args, tokenizer)
    
    # for test
    dataset = dataset.select(range(10))

    print("================example prompt================")
    print(dataset[0]["prompt"])
    print("================example prompt================")
    dataset = inference_vllm(dataset, model, max_tokens=args.max_tokens, temperature=args.temperature)
    print("================end inference_vllm================")
    dataset.save_to_disk(args.output_path)
    


def substring_exact_match_score(prediciton, ground_truth):
    """Check if the ground truth is a (soft) exact match substring of the prediction."""
    return normalize_answer(ground_truth) in normalize_answer(prediciton)

def drqa_metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    """Given a prediction and multiple valid answers, return the score of
    the best prediction-answer_n pair given a metric function.
    """
    # ground truth could be a string or a list of strings or a list of list of strings
    if isinstance(ground_truths, str):
        ground_truths = [ground_truths]
    elif isinstance(ground_truths[0], list):
        ground_truths = [ground_truth for ground_truths_list in ground_truths for ground_truth in ground_truths_list]

    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)

def extract_ans_from_prediction(example):
    
    return example

def evaluate_sub_em(args):
    
    dataset = load_from_disk(args.dataset)
    print(dataset)
    def process(item):
        if item["pred_ans"] in ["", " "]:
            item["pred_ans"] = extract_ans_from_response(item["prediction"])

        item["null_ans"] = True if item["pred_ans"] in ["", " "] else  False
         
        item["sub_em"] = drqa_metric_max_over_ground_truths(substring_exact_match_score, item["pred_ans"], item["answers"])
        return item
    
    dataset = dataset.map(process, num_proc=1)

    
    dataset.filter(lambda x:x["null_ans"] == True ).to_json(f"{args.dataset}_null_ans.jsonl")
    
    eval_length = ["all"]
    for input_token_length in eval_length:
        if input_token_length == "4k":
            tmp_dataset = dataset.filter(lambda x:x["prompt_token_length"]<=4*1024)
        if input_token_length == "8k":
            tmp_dataset = dataset.filter(lambda x:x["prompt_token_length"]>4*1024 and x["prompt_token_length"]<=8*1024)
        if input_token_length == "16k":
            tmp_dataset = dataset.filter(lambda x:x["prompt_token_length"]>8*1024 and x["prompt_token_length"]<=16*1024)
        if input_token_length == "32k":
            tmp_dataset = dataset.filter(lambda x:x["prompt_token_length"]>16*1024 and x["prompt_token_length"]<=32*1024)
        if input_token_length == "32k+":
            tmp_dataset = dataset.filter(lambda x:x["prompt_token_length"]>32*1024)
        if input_token_length == "all":
            tmp_dataset = dataset
        
        metrics = {key: [] for key in set(tmp_dataset["dataset"])}
        for item in tmp_dataset:
            metrics[item["dataset"]].append(item["sub_em"])
        
        null_ans = {key: [] for key in set(tmp_dataset["dataset"])}
        for item in tmp_dataset:
            null_ans[item["dataset"]].append(item["null_ans"])

        for tmp_dataset in metrics.keys():    
            metrics[tmp_dataset] = {
                "sub_em": np.mean(metrics[tmp_dataset]) * 100,
                "num_samples": len(metrics[tmp_dataset]),
                "null_ans": sum(null_ans[tmp_dataset]),
            }
        
        json.dump(metrics, open(args.output_path.replace(".json", f"_{input_token_length}.json"), "w", encoding="utf-8"), indent=4)

def evaluate_all(args):
    
    dataset = load_from_disk(args.dataset)
    print(dataset)
    def process(item):
        if item["pred_ans"] == "":
            item["pred_ans"] = extract_ans_from_response(item["prediction"])
        item["sub_em"] = drqa_metric_max_over_ground_truths(substring_exact_match_score, item["pred_ans"], item["answers"])
        item["f1"] = drqa_metric_max_over_ground_truths(qa_f1_score, item["pred_ans"], item["answers"])
        return item
    
    dataset = dataset.map(process, num_proc=1)
    
    eval_length = ["all"]
    for input_token_length in eval_length:
        if input_token_length == "4k":
            tmp_dataset = dataset.filter(lambda x:x["prompt_token_length"]<=4*1024)
        if input_token_length == "8k":
            tmp_dataset = dataset.filter(lambda x:x["prompt_token_length"]>4*1024 and x["prompt_token_length"]<=8*1024)
        if input_token_length == "8k+":
            tmp_dataset = dataset.filter(lambda x:x["prompt_token_length"]>8*1024)
        if input_token_length == "all":
            tmp_dataset = dataset
        
        metrics = {}
        metric_list = ["sub_em", "f1"]
        for key in set(tmp_dataset["dataset"]):
            for m_name in metric_list:
                metrics[f"{key}_{m_name}"] = []

        for item in tmp_dataset:
            _dataset = item["dataset"]
            for m_name in metric_list:
                metrics[f"{_dataset}_{m_name}"].append(item[f"{m_name}"])
        
        new_metrics = {}
        for key in set(tmp_dataset["dataset"]):
            _sub_em = np.mean(metrics[f"{key}_sub_em"]) * 100
            _f1 = np.mean(metrics[f"{key}_f1"]) * 100
            new_metrics[key] = {
                "sub_em": _sub_em,
                "f1": _f1,
                "num_samples": len(metrics[f"{key}_f1"])
            }
        json.dump(new_metrics, open(args.output_path.replace(".json", f"_{input_token_length}.json"), "w", encoding="utf-8"), indent=4)
        if input_token_length == "all":
            order = [
                ("qasper", "Qasper"),
                ("narrativeqa", "NarrativeQA"),
                ("multifieldqa_en", "MultiFieldQA-En"),
                ("hotpotqa", "HotpotQA"),
                ("musique", "MuSiQue"),
                ("2wikimqa", "2WikiMQA")
            ]
            header = "\t".join([name for _, name in order]) + "\tAVG"
            sub_em_line = []
            all_sub_em = 0
            all_f1 = 0
            for_latex = []
            for dataset, _ in order:
                tmp_sub_em = new_metrics[dataset]['sub_em']
                tmp_f1 = new_metrics[dataset]['f1']
                for_latex.append(f"{tmp_f1:.1f}")
                sub_em_line.append(f"{tmp_sub_em:.1f}/{tmp_f1:.1f}")
                all_sub_em += tmp_sub_em
                all_f1 += tmp_f1
            avg_sub_em = all_sub_em/len(order)
            avg_f1 = all_f1/len(order)
            for_latex.append(f"{avg_f1:.1f}")
            sub_em_line.append(f"{avg_sub_em:.1f}/{avg_f1:.1f}")
            sub_em_line = "\t".join(sub_em_line)
            content = f"{header}\n{sub_em_line}"
            latex_line = "&"+"&".join(for_latex)+"&"
            print(latex_line)
            with open(args.output_path.replace(".json", f"_{input_token_length}.txt"), "w", encoding="utf-8") as fw:
                fw.write(content)
                fw.write(latex_line)

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark", type=str, default="longbench_qa")
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--split", type=str)
    parser.add_argument("--stage", type=str)
    parser.add_argument("--eval_strategy", type=str)
    parser.add_argument("--output_path", type=str)
    

    
    parser.add_argument("--model_name_or_path", type=str)
    parser.add_argument("--tensor_parallel_size", type=int)
    parser.add_argument("--max_tokens", type=int, default=4096)
    parser.add_argument("--max_model_input_len", type=int, default=31000)
    parser.add_argument("--temperature", type=float, default=0)
    
    parser.add_argument("--openai", type=str)
    
    parser.add_argument("--prompt", type=str)
    args = parser.parse_args()
    
    if args.stage == "inference":
        inference(args)
    elif args.stage == "evaluation":
        if args.eval_strategy == "sub_em":
            evaluate_sub_em(args)
        elif args.eval_strategy == "all":
            evaluate_all(args)