
import json
import argparse
import logging
import os
from tqdm import tqdm
from math_verify import parse, verify
   
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def read_json_objects(filename, field_names=None):
    file_extension = os.path.splitext(filename)[1]
    if file_extension == '.jsonl':
        try:
            with open(filename, 'r') as file:
                lines = file.readlines()
            items = []
            for line in lines:
                item = json.loads(line)
                if field_names is not None and isinstance(field_names, list):
                    new_item = {}
                    for field_name in item:
                        new_item[field_name] = item[field_name]
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSONL file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    elif file_extension == '.json':
        try:
            with open(filename, 'r') as file:
                data = json.load(file)
            items = []
            for item in data:
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSON file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    else:
        logging.error(f"Unknown file extension {file_extension}")
        return []


def write_data_to_json_file(data, file_path):
    try:
        with open(file_path, 'w') as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        logging.info(f"Data successfully written to {file_path}")
    except Exception as e:
        logging.error(f"An error occurred: {e}")


def create_parent_directory(file_path):
    """
    Creates the parent directories of a given file path if they do not exist.

    Args:
        file_path (str or Path): The path to the file.
    """
    from pathlib import Path
    file_path = Path(file_path)  # Ensure it's a Path object
    parent_directory = file_path.parent

    # Create parent directories recursively if they don't exist, and ignore if they already exist
    parent_directory.mkdir(parents=True, exist_ok=True)
    print(f"Parent directory '{parent_directory}' ensured to exist.")


def load_tokenizer_and_vllm(config, eos_token=None):
    from vllm import LLM
    from transformers import AutoTokenizer

    model_path = config["inference"]["model"]
    logging.info(f"Loading ckpt and tokenizer: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.padding_side = "left"
    if eos_token:
        eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
        logging.info(f"eos_token {eos_token} from user input")
    elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
        logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
        eos_token_id = tokenizer.eos_token_id
        eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
    else:
        raise ValueError("No available eos_token or eos_token_id.")
    try:
        tokenizer.eos_token = eos_token
        tokenizer.eos_token_id = eos_token_id
        tokenizer.pad_token = eos_token
        tokenizer.pad_token_id = eos_token_id
    except:
        logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
    logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
    logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")

    llm = LLM(
        model=model_path,
        tensor_parallel_size=config["inference"]["tp_size"],
        enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
        gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
        trust_remote_code=config["inference"]["trust_remote_code"],
        # dtype=torch.bfloat16,
        enforce_eager=config["inference"]["enforce_eager"],
        # max_model_len=config["inference"]["max_model_len"],
        # max_num_seqs=config["inference"].get("max_num_seqs", 64),
    )
    logging.info("vLLM model loaded successfully")
    return tokenizer, llm


def generate_model_response_batch(tokenizer, llm, data_list, config):
    from vllm import SamplingParams

    prompt = config["inference"].get("prompt", "")
    batch_size = config["inference"].get("batch_size", 8)
    generate_n = config["inference"].get("generate_n", 1)
    assert generate_n > 0
    outcomes = []
    batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
    for batch in tqdm(batches, desc="Generating responses"):
        batch_prompts = []
        for sample in batch:
            if prompt:
                msg = prompt + ' ' + sample["question"]
            else:
                msg = sample["prompt"]
            batch_prompts.append(msg)
        model_outputs = llm.generate(
            prompts=batch_prompts,
            sampling_params=SamplingParams(
                n=generate_n,
                # top_k=1,
                temperature=config["inference"]["temperature"],
                seed=config["inference"]["seed"],
                skip_special_tokens=False,
                ignore_eos=False,
                max_tokens=config["inference"]["max_new_tokens"]
            )
        )

        if generate_n == 1:
            model_responses = [output.outputs[0].text for output in model_outputs]
        else:
            model_responses = [[v.text for v in output.outputs] for output in model_outputs]
        gen_data = [{'prompt': batch_prompts[i], 'answer': batch[i]['answer'], 'output': model_responses[i]} for i in range(len(batch))]
        outcomes = outcomes + gen_data
    return outcomes


def math_verify_fn(llm_res, ground_truth):
    llm_answer = parse(llm_res)
    ground_truth_answer = parse(ground_truth)
    correct = verify(llm_answer, ground_truth_answer)
    return correct


def worker(config, dp_rank):
    dp_size = config["inference"]["dp_size"]
    tp_size = config["inference"]["tp_size"]

    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
        str(i) for i in range(dp_rank * tp_size, (dp_rank + 1) * tp_size)
    )
    logging.info(f"DP rank {dp_rank} uses device {os.environ['CUDA_VISIBLE_DEVICES']}")
    
    data_list = read_json_objects(config["dataset"]["input_path"])

    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
    floor = len(data_list) // dp_size
    remainder = len(data_list) % dp_size

    # Distribute prompts into even groups.
    def start(rank):
        return rank * floor + min(rank, remainder)

    data_list = data_list[start(dp_rank) : start(dp_rank + 1)]
    if len(data_list) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        data_list = ["Placeholder"]
    logging.info(f"DP rank {dp_rank} needs to process {len(data_list)} data samples.")

    # Load tokernizer and LLM
    tokenizer, llm = load_tokenizer_and_vllm(config)
    
    # Generate responses using LLM
    outputs = generate_model_response_batch(tokenizer, llm, data_list, config)

    # Write outputs to file
    output_path = config["dataset"]["output_path"]
    root, ext = os.path.splitext(output_path)
    output_path = root + f"_{dp_rank}" + ext
    write_data_to_json_file(outputs, output_path)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, required=True, help='path to the model directory')
    parser.add_argument('--input_path', type=str, default='data/gsm8k/test.jsonl', help='path to the input file')
    parser.add_argument('--output_path', type=str, required=True, help='path to the ouptut file')
    args = parser.parse_args()

    config_template = {
        "job_type": "eval_model_vllm",
        "dataset": {
            "input_path": args.input_path,
            "output_path": args.output_path,
            "seed": 42
        },
        "inference": {
            "model": args.model,
            "prompt": "Solve the following question step by step and put your final answer in \\boxed{{}}.\n",
            "batch_size": 96,
            "generate_n": 1,
            "dp_size": 8,
            "tp_size": 1,
            "enable_chunked_prefill": True,
            "seed": 777,
            "gpu_memory_utilization": 0.8,
            "temperature": 0.2,
            "trust_remote_code": True,
            "enforce_eager": True,
            "max_new_tokens": 1024
        }
    }

    config = config_template
    
    import multiprocessing as mp
    mp.set_start_method('spawn')

    dp_size = config["inference"]["dp_size"]
    tp_size = config["inference"]["tp_size"]

    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(
                config,
                dp_rank,
                # dp_master_ip,
                # dp_master_port,
            ),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    output_path = config["dataset"]["output_path"]
    root, ext = os.path.splitext(output_path)
    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            exit(2)
    
    if len(all_outputs) > 0:
        write_data_to_json_file(all_outputs, output_path)
    else:
        logging.error("Gathered 0 responses.")
        exit(3)

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    if not os.path.exists(args.output_path):
        logging.info(f"Output file {args.output_path} does not exist")
        exit()
    
    data_list = read_json_objects(args.output_path)
    total = 0
    correct = 0
    for item in data_list:
        if math_verify_fn(item['output'], item['answer']):
            correct += 1
        total += 1
    logging.info(f"Total: {total}, correct: {correct}, Math-verify accuracy: {float(correct) / total * 100} %")
    
    exit(exit_code)


if __name__ == "__main__":
    main()
