import argparse
import os
import re
import json
import random
import torch
import evaluate

import time
from pathlib import Path
import jsonlines

from eval.utils import (
    generate_completions,
    load_hf_lm_and_tokenizer,
    query_openai_chat_model,
    dynamic_import_function,
)
from eval.gsm.examplars import EXAMPLARS as GSM_EXAMPLARS

## WizardMath
wizard_math_cot_prompt = (
    "Below is an instruction that describes a task. "
    "Write a response that appropriately completes the request.\n\n"
    "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)

def create_prompts_wiz_math(args, test_examples):
    prompts = []
    prompt_prefix = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
    )

    if args.n_shot:
        demonstrations = []
        for example in GSM_EXAMPLARS:
            answer = example["short_answer"] if args.no_cot else example["cot_answer"]
            response_prefix = "### Response: " if args.no_cot else "### Response: Let's think step by step. "
            demonstrations.append(
                "### Instruction:\n" + example["question"].strip() + "\n\n" + response_prefix + answer 
            )
        prompt_prefix = prompt_prefix + "\n\n".join(demonstrations) + "\n\n"

    for example in test_examples:
        response_prefix = "### Response: " if args.no_cot else "### Response: Let's think step by step. "
        prompt = prompt_prefix + "### Instruction:\n" + example["question"].strip() + "\n\n" + response_prefix
        prompts.append({"id": example['id'], 'prompt': prompt})
        
    return prompts

def create_prompts_open_instruct(args, test_examples):
    prompts = []
    if args.n_shot:
        demonstrations = []
        for example in GSM_EXAMPLARS:
            if args.no_cot:
                demonstrations.append(
                    "Quesion: " + example["question"] + "\n" + "Answer: " + example["short_answer"]
                )
            else:
                demonstrations.append(
                    "Question: " + example["question"] + "\n" + "Answer: " + example["cot_answer"]
                )
        prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
    else:
        prompt_prefix = "Answer the following question.\n\n"

    prompts = []
    chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
    for example in test_examples:
        prompt = prompt_prefix + "Question: " + example["question"].strip()
        if args.use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, add_bos=False)
            if prompt[-1] in ["\n", " "]:
                prompt += "Answer:"
            else:
                prompt += " Answer:"
        else:
            prompt += "\nAnswer:"

        prompts.append({"id": example['id'], 'prompt': prompt})
        
    return prompts 


exact_match = evaluate.load("exact_match")

def is_multiprocess(args):
    if args.world_size > 1:
        assert args.process_idx < args.world_size and args.process_idx >= 0
        assert args.run_id is not None
        return True
    else:
        return False

def get_predictions_file_path(args, task_name, limit, process_idx=None):
    process_idx = process_idx or args.process_idx

    limit_str = f"_limit{limit}" if limit is not None else ""
    if is_multiprocess(args):
        file_name = f"{task_name}_predictions{limit_str}_{args.run_id}_{process_idx}.jsonl"
    else:
        file_name = f"{task_name}_predictions{limit_str}.jsonl"
    return Path(args.save_dir) / file_name

def get_existing_predcition_files(args, task_name, limit):
    all_file_paths = []
    if is_multiprocess(args):
        for p_id in range(args.world_size):
            curr_file_path = get_predictions_file_path(args, task_name, process_idx=p_id, limit=args.max_num_examples)
            if not curr_file_path.exists():
                return None

            all_file_paths.append(curr_file_path)
    else:
        curr_file_path = get_predictions_file_path(args, task_name, process_idx=p_id, limit=args.max_num_examples)
        if curr_file_path.exists():
            return None
        all_file_paths = [curr_file_path]

    return all_file_paths



def add_multiprocess_args(parser):
    parser.add_argument("--process_idx", type=int, default=0)
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument('--run_id', type=str, default=None)
    parser.add_argument('--max_wait_retry', type=int, default=10)
    parser.add_argument('--use_vllm', action="store_true")
    parser.add_argument('--use_wizmath_prompt', action="store_true")
    parser.add_argument('--debug', action='store_true')

    return parser


def read_predictions_files(existing_predcition_files, expected_len):
    all_test_predictions = []
    for existing_predcition_file in existing_predcition_files:
        curr_cnt = 0
        with jsonlines.open(existing_predcition_file) as fout:
            curr_cnt = 0
            for idx, obj in enumerate(fout):
                all_test_predictions.append(obj)
                curr_cnt += 1
            p_print(args, f"Read {curr_cnt} from {existing_predcition_file}")

    if len(all_test_predictions) != expected_len:
        raise ValueError('Read test prediction len missmatch')

    return all_test_predictions

def p_print(args, msg):
    if is_multiprocess(args):
        print(f"Process {args.process_idx}: {msg}")
    else:
        print(msg)

def get_limit_str(args):
    if args.max_num_examples is None:
        return ""
    else:
        return f"_limit{args.max_num_examples}"

def load_data(args):
    print("Loading data...")
    test_data = []

    test_data_file = Path(args.data_dir) / f"processed_test_data{get_limit_str(args)}.jsonl"
    if is_multiprocess(args) and args.process_idx != 0:
        while not test_data_file.exists():
            time.sleep(5)
            p_print(args, "Waiting for data creation by main process")
        time.sleep(3)

    if test_data_file.exists():
        p_print(args, f"Loading cached data {test_data_file}")
        with jsonlines.open(test_data_file) as f:
            test_data = [o for o in f] 
        return test_data
    else:
        p_print(args, "Generating test data")
        with open(os.path.join(args.data_dir, f"test.jsonl")) as fin:
            for idx, line in enumerate(fin):
                example = json.loads(line)
                test_data.append({
                    "id": f"gsm_{idx}",
                    "question": example["question"],
                    "answer": example["answer"].split("####")[1].strip()
                })
            
        # some numbers are in the `x,xxx` format, and we want to remove the comma
        for example in test_data:
            example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
            assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"

        if args.max_num_examples and len(test_data) > args.max_num_examples:
            test_data = random.sample(test_data, args.max_num_examples)

        # Write to disk
        with jsonlines.open(test_data_file, 'w') as f:
            f.write_all(test_data)

    return test_data

def shard_data(args, test_data):
    assert is_multiprocess(args)

    process_data_size = len(test_data) // args.world_size
    start_idx = args.process_idx * process_data_size
    end_idx = len(test_data) if args.process_idx == args.world_size - 1 else start_idx + process_data_size
    shard = test_data[start_idx: end_idx]
    print(f'Process {args.process_idx}: [{start_idx}, {end_idx}[ -> Len={len(shard)}')
    return shard 

def compute_scores(args, predictions):
    p_print(args, "Calculating accuracy...")
    targets, pred_list = [], []
    for p in predictions:
        targets.append(p['answer'])
        pred_list.append(p['prediction'])

    em_score = exact_match.compute(predictions=pred_list, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    p_print(args, f"Exact match : {em_score}")
    with open(os.path.join(args.save_dir, f"metrics_{args.run_id}.json"), "w") as fout:
        json.dump({
            "exact_match": em_score
        }, fout, indent=4)

def batch_data(data_list, batch_size=1):
    n = len(data_list) // batch_size
    batch_data = []
    for i in range(n-1):
        start = i * batch_size
        end = (i+1)*batch_size
        batch_data.append(data_list[start:end])

    last_start = (n-1) * batch_size
    batch_data.append(data_list[last_start:])
    return batch_data

def predict_outputs(args, prompt_list, prompt_ids):
    outputs = []
    if args.model_name_or_path:
        if args.debug:
            p_print(args, '(((((((((((((((((((((DEBUG)))))))))))))))))))))')
            model, tokenizer = None, None
            outputs = ['debug'] * len(prompt_list)
        else:
            max_new_tokens = 512
            if args.use_vllm:
                p_print(args, "using VLLM")
                from vllm import LLM, SamplingParams

                stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
                sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=max_new_tokens, stop=stop_tokens)
                llm = LLM(model=args.model_name_or_path,tensor_parallel_size=1, dtype='bfloat16')
                batched_data = batch_data(prompt_list, batch_size=args.eval_batch_size)
                for batch in batched_data:
                    completions = llm.generate(batch, sampling_params)
                    for output in completions:
                        generated_text = output.outputs[0].text
                        outputs.append(generated_text)
            else:
                p_print(args, "Loading model and tokenizer...")
                model, tokenizer = load_hf_lm_and_tokenizer(
                    model_name_or_path=args.model_name_or_path, 
                    tokenizer_name_or_path=args.tokenizer_name_or_path, 
                    load_in_8bit=args.load_in_8bit, 
                    device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
                    gptq_model=args.gptq
                )
                new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
                outputs = generate_completions(
                    model=model,
                    tokenizer=tokenizer,
                    prompts=prompt_list,
                    max_new_tokens=max_new_tokens,
                    batch_size=args.eval_batch_size,
                    stop_id_sequences=[[new_line_token]]
                )
    else:
        instances = prompt_list
        results = query_openai_chat_model(
            engine=args.openai_engine,
            instances=instances,
            batch_size=args.eval_batch_size if args.eval_batch_size else 10,
            output_path=os.path.join(args.save_dir, f"openai_results.jsonl"),
        )
        outputs = [result["output"] for result in results]
        
    return outputs

def main(args):
    random.seed(42)
    task_name = 'gsm'

    test_data = load_data(args)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir, exist_ok=True)

    global GSM_EXAMPLARS
    if len(GSM_EXAMPLARS) > args.n_shot:
        GSM_EXAMPLARS = random.sample(GSM_EXAMPLARS, args.n_shot)


    if is_multiprocess(args): 
        test_data_shard = shard_data(args, test_data)
    else:
        test_data_shard = test_data


    # Compute score if the files exists (case of crash or job preempt)
    if is_multiprocess(args): 
        existing_predcition_files = get_existing_predcition_files(args, task_name, limit=args.max_num_examples)
        if existing_predcition_files is not None:
            if args.process_idx == 0:
                p_print(args, 'Prediction files exists. Computing scores')
                test_predictions = read_predictions_files(existing_predcition_files, len(test_data))
                compute_scores(args, test_predictions)
                return
            else:
                p_print(args, 'Prediction files exists leaving first process handle it. Done!')
                return
            
    prompts = create_prompts_wiz_math(args, test_data_shard) if args.use_wizmath_prompt else create_prompts_open_instruct(args, test_data_shard)

    prompt_list = [p['prompt'] for p in prompts]
    prompt_ids = [p['id'] for p in prompts]
    
    outputs = predict_outputs(args, prompt_list, prompt_ids)


    predictions = []
    test_data_dict = {e['id']: e for e in test_data_shard}
    for prompt_id, output in zip(prompt_ids, outputs):
        # replace numbers like `x,xxx` with `xxxx`
        output = re.sub(r"(\d),(\d)", r"\1\2", output)
        numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
        pred = numbers[-1] if numbers else output
        predictions.append({'id': prompt_id, "prediction": pred, 'answer': test_data_dict[prompt_id]["answer"], 'question': test_data_dict[prompt_id]['question']})
        

    predictions_file  = get_predictions_file_path(args, task_name, limit=args.max_num_examples)
    p_print(args, f'Writing predtions to {predictions_file}')
    with jsonlines.open(predictions_file, 'w') as fout:
        fout.write_all(predictions)

    if is_multiprocess(args):
        if args.process_idx != 0:
            p_print(args, 'Done creating predictions. Leave scrore for main process')
            return
        else:

            retry = 0
            existing_predcition_files = get_existing_predcition_files(args, task_name, limit=args.max_num_examples)
            p_print(args, "Start waiting for other process")
            while existing_predcition_files is None:
                existing_predcition_files = get_existing_predcition_files(args, task_name, limit=args.max_num_examples)
                retry += 1
                p_print(args, f'Still waiting {retry}/{args.max_wait_retry}')
                time.sleep(10)
                 
                if retry > args.max_wait_retry:
                    raise RuntimeError(f"{args.process_idx} Reached max retry {retry} while waiting for other processes")
            p_print(args, f'Found all files len={len(existing_predcition_files)} word_size={args.world_size}')

            predictions = read_predictions_files(existing_predcition_files, len(test_data))

    compute_scores(args, predictions)


    print('Done!!')

def process_args(args):
    ws_path = Path(os.environ.get("WS_PATH", "/mnt/workspace"))
    for key in ["data_dir", "model_name_or_path", "tokenizer_name_or_path"]:
        if getattr(args, key) is not None:
            setattr(args, key, (ws_path / getattr(args, key)).as_posix())
            
    if 'wizard' in args.model_name_or_path:
        # To use the same way Wizard models are tested
        args.use_vllm = True

    return args

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir", 
        type=str, 
        default="data/gsm"
    )
    parser.add_argument(
        "--max_num_examples", 
        type=int, 
        default=None, 
        help="maximum number of examples to evaluate."
    )
    parser.add_argument(
        "--save_dir", 
        type=str, 
        default="results/gsm"
    )
    parser.add_argument(
        "--model_name_or_path", 
        type=str, 
        default=None, 
        help="if specified, we will load the model to generate the predictions."
    )
    parser.add_argument(
        "--tokenizer_name_or_path", 
        type=str, 
        default=None, 
        help="if specified, we will load the tokenizer from here."
    )
    parser.add_argument(
        "--openai_engine", 
        type=str, 
        default=None, help="if specified, we will use the OpenAI API to generate the predictions."
    )
    parser.add_argument(
        "--n_shot", 
        type=int, 
        default=8, 
        help="max number of examples to use for demonstration."
    )
    parser.add_argument(
        "--no_cot", 
        action="store_true", 
        help="If given, we're evaluating a model without chain-of-thought."
    )
    parser.add_argument(
        "--eval_batch_size", 
        type=int, 
        default=1, 
        help="batch size for evaluation."
    )
    parser.add_argument(
        "--load_in_8bit", 
        action="store_true", 
        help="load model in 8bit mode, which will reduce memory and speed up inference."
    )
    parser.add_argument(
        "--gptq", 
        action="store_true", 
        help="If given, we're evaluating a 4-bit quantized GPTQ model."
    )
    parser.add_argument(
        "--use_chat_format", 
        action="store_true", 
        help="If given, we will use the chat format for the prompts."
    )
    parser.add_argument(
        "--chat_formatting_function", 
        type=str, 
        default="eval.templates.create_prompt_with_tulu_chat_format", 
        help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`."
    )

    parser = add_multiprocess_args(parser)

    args = parser.parse_args()
    
    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."

    args = process_args(args)
    main(args)
