import argparse
import os
import re
import json
import tqdm
import glob
import torch
import random
import evaluate
import time
from pathlib import Path
from eval.utils import (
    load_hf_lm_and_tokenizer,
    generate_completions,
    query_openai_chat_model,
    dynamic_import_function,
)


exact_match = evaluate.load("exact_match")

def is_multiprocess(args):
    if args.world_size > 1:
        assert args.process_idx < args.world_size and args.process_idx >= 0
        assert args.run_id is not None
        return True
    else:
        return False

def get_predictions_file_path(args, task_name, limit, process_idx=None):
    process_idx = process_idx or args.process_idx

    limit_str = f"_limit{limit}" if limit is not None else ""
    if is_multiprocess(args):
        file_name = f"{task_name}_predictions{limit_str}_{args.run_id}_{process_idx}.jsonl"
    else:
        file_name = f"{task_name}_predictions{limit_str}.jsonl"
    return Path(args.save_dir) / file_name

def get_existing_predcition_files(args, task_name, limit):
    all_file_paths = []
    if is_multiprocess(args):
        for p_id in range(args.world_size):
            curr_file_path = get_predictions_file_path(args, task_name, process_idx=p_id, limit=args.max_num_examples)
            if not curr_file_path.exists():
                return None

            all_file_paths.append(curr_file_path)
    else:
        curr_file_path = get_predictions_file_path(args, task_name, process_idx=p_id, limit=args.max_num_examples)
        if curr_file_path.exists():
            return None
        all_file_paths = [curr_file_path]

    return all_file_paths


def add_multiprocess_args(parser):
    parser.add_argument("--process_idx", type=int, default=0)
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument('--run_id', type=str, default=None)
    parser.add_argument('--max_wait_retry', type=int, default=10)
    parser.add_argument('--debug', action='store_true')

    return parser

def read_predictions_files(existing_predcition_files, expected_len):
    all_test_predictions = []
    for existing_predcition_file in existing_predcition_files:
        curr_cnt = 0
        with jsonlines.open(existing_predcition_file) as fout:
            curr_cnt = 0
            for idx, obj in enumerate(fout):
                all_test_predictions.append(obj)
                curr_cnt += 1
            p_print(args, f"Read {curr_cnt} from {existing_predcition_file}")

    if len(all_test_predictions) != expected_len:
        raise ValueError('Read test prediction len missmatch')

    return all_test_predictions

def p_print(args, msg):
    if is_multiprocess(args):
        print(f"Process {args.process_idx}: {msg}")
    else:
        print(msg)

def get_limit_str(args):
    if args.max_num_examples_per_task is None:
        return ""
    else:
        return f"_limit{args.max_num_examples_per_task}"


def shard_data(args, test_data):
    assert is_multiprocess(args)

    process_data_size = len(test_data) // args.world_size
    start_idx = args.process_idx * process_data_size
    end_idx = len(test_data) if args.process_idx == args.world_size - 1 else start_idx + process_data_size
    shard = test_data[start_idx: end_idx]
    print(f'Process {args.process_idx}: [{start_idx}, {end_idx}[ -> Len={len(shard)}')
    return shard 
    

@torch.no_grad()
def eval_hf_model(args, model, tokenizer, examples, task_prompt, save_path=None):
    targets = [example["target"] for example in examples]
    if save_path:
        fout = open(save_path, "w")

    prompts = []
    chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
    for example in examples:
        prompt = task_prompt.strip() + "\n\nQ: " + example["input"]
        if args.use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, add_bos=False)
            if prompt[-1] in ["\n", " "]:
                prompt += "A:"
            else:
                prompt += " A:"
        else:
            prompt += "\nA:"
        prompts.append(prompt)

    if args.no_cot:
        stop_sequnce = tokenizer.encode("\n\n", add_special_tokens=False)[-2:] # get the last token because the tokenizer may add space tokens at the start.
    else:
        # let's not use the stop sequence for cot now since it's too inefficient when the generation is long. 
        # instead, we'll do some post-processing to extract the answer.
        stop_sequnce = None
    
    outputs = generate_completions(
        model=model,
        tokenizer=tokenizer,
        prompts=prompts,
        max_new_tokens=512,
        batch_size=args.eval_batch_size if args.eval_batch_size else 1,
        stop_id_sequences=[[stop_sequnce]] 
    )

    predictions = []
    for example, output in zip(examples, outputs):
        example["raw_output"] = output
        
        # extract the first answer after `So the answer is` and before the next period.
        # if there is no such answer, we will just use the raw output.
        results = re.search(r"So the answer is (.*?)\.", output)
        if results:
            prediction = results.group(1).strip()
        else:
            # only keep the first part of the output - this is mainly for vanilla language models.
            output = output.strip().split("\n\n")[0].strip()
            prediction = output.strip()

        example["prediction"] = prediction
        predictions.append(prediction)
        if save_path:
            fout.write(json.dumps(example) + "\n")        

    assert len(predictions) == len(targets), "number of predictions and targets are not the same."
    return exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]


def eval_openai_chat_engine(args, examples, task_prompt, save_path=None):
    targets = [example["target"] for example in examples]
    instances = []
    for i, example in enumerate(examples):
        prompt = task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA:"
        instances.append({
            "id": example["id"] if "id" in example else i,
            "prompt": prompt,
        })

    if save_path:
        openai_result_save_path = os.path.join(os.path.dirname(save_path), os.path.basename(save_path).split(".")[0] + "_openai_results.jsonl")
    
    results = query_openai_chat_model(
        engine=args.openai_engine,
        instances=instances,
        batch_size=args.eval_batch_size if args.eval_batch_size else 10,
        output_path=openai_result_save_path if save_path else None,
    )

    outputs = [result["output"] for result in results]
    assert len(outputs) == len(targets), "number of predictions and targets are not the same."

    if save_path:
        fout = open(save_path, "w")

    predictions = []
    for example, output in zip(examples, outputs):
        example["raw_output"] = output
        # extract the first answer after `So the answer is` and before the next period.
        # if there is no such answer, we will just use the raw output.
        results = re.search(r"So the answer is (.*?)\.", output)
        if results:
            prediction = results.group(1).strip()
        else:
            prediction = output.strip()
        example["prediction"] = prediction
        predictions.append(prediction)
        if save_path:
            fout.write(json.dumps(example) + "\n")        

    assert len(predictions) == len(targets), "number of predictions and targets are not the same."
    return exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]


def load_data(args):

    proccessed_tasks_file = Path(args.data_dir) / f"processed_test_tasks_data{get_limit_str(args)}.jsonl"
    processed_prompts_file = Path(args.data_dir) / f"processed_test_promtps_data{get_limit_str(args)}.jsonl"

    if is_multiprocess(args) and args.process_idx != 0:
        while not(proccessed_tasks_file.exists() and processed_prompts_file.exists()): 
            time.sleep(5)
            p_print(args, "Waiting for data creation by main process")
        time.sleep(3)

    if proccessed_tasks_file.exists() and processed_prompts_file.exists(): 
        p_print(args, f"Loading cached data {proccessed_tasks_file} and {processed_prompts_file}")

        with proccessed_tasks_file.open() as f:
            all_tasks = json.load(f)

        with processed_prompts_file.open() as f:
            all_prompts = json.load(f)
        
    else:
        p_print(args, "Generating test data")
        all_tasks = {}
        task_files = glob.glob(os.path.join(args.data_dir, "bbh", "*.json"))
        for task_file in tqdm.tqdm(task_files, desc="Loading tasks"):
            with open(task_file, "r") as f:
                task_name = os.path.basename(task_file).split(".")[0]
                all_tasks[task_name] = json.load(f)["examples"]
                if args.max_num_examples_per_task:
                    all_tasks[task_name] = random.sample(all_tasks[task_name], args.max_num_examples_per_task)


        all_prompts = {}
        cot_prompt_files = glob.glob(os.path.join(args.data_dir, "cot-prompts", "*.txt"))
        for cot_prompt_file in tqdm.tqdm(cot_prompt_files, desc="Loading prompts"):
            with open(cot_prompt_file, "r") as f:
                task_name = os.path.basename(cot_prompt_file).split(".")[0]
                task_prompt = "".join(f.readlines()[2:])
                if args.no_cot:
                    prompt_fields = task_prompt.split("\n\n")
                    new_prompt_fields = []
                    for prompt_field in prompt_fields:
                        if prompt_field.startswith("Q:"):
                            assert "So the answer is" in prompt_field, f"`So the answer is` not found in prompt field of {task_name}.txt."
                            assert "\nA:" in prompt_field, "`\nA:` not found in prompt field."
                            answer = prompt_field.split("So the answer is")[-1].strip()
                            question = prompt_field.split("\nA:")[0].strip()
                            new_prompt_fields.append(question + "\nA: " + answer)
                        else:
                            new_prompt_fields.append(prompt_field)
                    task_prompt = "\n\n".join(new_prompt_fields)
                all_prompts[task_name] = task_prompt


        with proccessed_tasks_file.open('w') as f:
            json.dump(all_tasks, f)

        with processed_prompts_file.open('w') as f:
            json.dump(all_prompts, f)

    assert set(all_tasks.keys()) == set(all_prompts.keys()), "task names in task data and task prompts are not the same."

    return all_tasks, all_prompts



def main(args):
    random.seed(42)

    all_tasks, all_prompts = load_data(args)

    if args.model_name_or_path:
        if args.debug:
            p_print(args, '(((((((((((((((((((((DEBUG)))))))))))))))))))))')
            model, tokenizer = None, None
        else:
            p_print(args, "Loading model and tokenizer...")
            model, tokenizer = load_hf_lm_and_tokenizer(
                model_name_or_path=args.model_name_or_path, 
                tokenizer_name_or_path=args.tokenizer_name_or_path, 
                load_in_8bit=args.load_in_8bit, 
                device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
                gptq_model=args.gptq
            )

    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(os.path.join(args.save_dir, "predictions"), exist_ok=True)

    all_task_names = list(all_tasks.keys())
    if not is_multiprocess(args):
        curr_tasks = all_task_names
    else:
        # Shard the number of tasks
        curr_tasks = shard_data(args, all_task_names)

    for task_name in tqdm.tqdm(curr_tasks, desc="Evaluating"):
        task_examples = all_tasks[task_name]
        prompt = all_prompts[task_name]
        if args.model_name_or_path:
            if args.debug:
                task_perf = 0.0
            else:
                task_perf = eval_hf_model(
                    args, 
                    model, 
                    tokenizer, 
                    task_examples, 
                    prompt, 
                    save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl")
                )
        else:
            task_perf = eval_openai_chat_engine(
                args,
                task_examples,
                prompt,
                save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl")
            )
        p_print(args, f"Task {task_name} - EM: {task_perf}")
        with open(os.path.join(args.save_dir, f"metrics_{task_name}.json"), "w") as fout:
            json.dump({f'{task_name}_em': task_perf, 'process_idx': args.process_idx}, fout, indent=4)

    if not is_multiprocess(args) or (is_multiprocess(args) and args.process_idx == 0):
        # Merging the scores
        all_score_file_paths = [Path(args.save_dir) / f"metrics_{t}.json" for t in all_task_names]
        retry = 0
        while retry < args.max_wait_retry and not all([f.exists() for f in all_score_file_paths]): 
            p_print(args, 'Score files not yet created. Waiting')
            time.sleep(10)
            retry += 1

        all_scores = {}
        for f in all_score_file_paths:
            with f.open() as f:
                score = json.load(f)
                process = score.pop('process_idx')
                for k in score:
                    if k in all_scores:
                        raise RuntimeError(f'Duplicate ???. Create by {process}')

                all_scores.update(score)
            
        p_print(args, f'Loaded {len(all_scores)} scores')
        with open(os.path.join(args.save_dir, f"metrics_{args.run_id}.json"), "w") as fout:
            all_scores["average_exact_match"] = sum(all_scores.values()) / len(all_scores)
            p_print(args, f"Average EM: {all_scores['average_exact_match']}")
            json.dump(all_scores, fout, indent=4)


        p_print(args, 'Done! BBH')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir", 
        type=str, 
        default="data/bbh"
    )
    parser.add_argument(
        "--save_dir", 
        type=str, 
        default="results/bbh"
    )
    parser.add_argument(
        "--model_name_or_path", 
        type=str, 
        default=None, 
        help="if specified, we will load the model to generate the predictions."
    )
    parser.add_argument(
        "--tokenizer_name_or_path", 
        type=str, 
        default=None, 
        help="if specified, we will load the tokenizer from here."
    )
    parser.add_argument(
        "--openai_engine", 
        type=str, 
        default=None, 
        help="if specified, we will use the OpenAI API to generate the predictions."
    )
    parser.add_argument(
        "--no_cot", 
        action="store_true", 
        help="if specified, chain of thoughts will be removed from the prompts."
    )
    parser.add_argument(
        "--max_num_examples_per_task", 
        type=int, 
        default=None, 
        help="maximum number of examples to evaluate per task."
    )
    parser.add_argument(
        "--eval_batch_size", 
        type=int, 
        default=1, 
        help="batch size for evaluation."
    )
    parser.add_argument(
        "--load_in_8bit", 
        action="store_true", 
        help="load model in 8bit mode, which will reduce memory and speed up inference."
    )
    parser.add_argument(
        "--gptq", 
        action="store_true", 
        help="If given, we're evaluating a 4-bit quantized GPTQ model."
    )
    parser.add_argument(
        "--use_chat_format", 
        action="store_true", 
        help="If given, we will use the chat format for the prompts."
    )
    parser.add_argument(
        "--chat_formatting_function", 
        type=str, 
        default="eval.templates.create_prompt_with_tulu_chat_format", 
        help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`."
    )

    parser = add_multiprocess_args(parser)

    args = parser.parse_args()

    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."
    main(args)
