
import json
import argparse
import logging
import os
import re
from tqdm import tqdm
from transformers import AutoTokenizer

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from utils import (
    read_json_objects,
    write_data_to_json_file, 
    create_parent_directory
)


def load_tokenizer_and_vllm(config, eos_token=None):
    from transformers import AutoTokenizer
    from vllm import LLM

    model_path = config["models"]["teacher"]
    logging.info(f"Loading ckpt and tokenizer: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.padding_side = "left"
    if eos_token:
        eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
        logging.info(f"eos_token {eos_token} from user input")
    elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
        logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
        eos_token_id = tokenizer.eos_token_id
        eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
    else:
        raise ValueError("No available eos_token or eos_token_id.")
    try:
        tokenizer.eos_token = eos_token
        tokenizer.eos_token_id = eos_token_id
        tokenizer.pad_token = eos_token
        tokenizer.pad_token_id = eos_token_id
    except:
        logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
    logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
    logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")

    llm = LLM(
        model=model_path,
        tensor_parallel_size=config["inference"]["tp_size"],
        enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
        gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
        trust_remote_code=config["inference"]["trust_remote_code"],
        # dtype=torch.bfloat16,
        enforce_eager=config["inference"]["enforce_eager"],
        # max_model_len=config["inference"]["max_model_len"],
        # max_num_seqs=config["inference"].get("max_num_seqs", 64),
    )
    logging.info("vLLM model loaded successfully")
    return tokenizer, llm


def generate_model_response_batch(tokenizer, llm, data_list, config):
    from vllm import SamplingParams

    prompt = config["inference"].get("prompt", "")
    batch_size = config["inference"].get("batch_size", 8)
    generate_n = config["inference"].get("generate_n", 1)
    assert generate_n > 0
    outcomes = []
    batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
    for batch in tqdm(batches, desc="Generating responses"):
        new_batch = []
        for sample in batch:
            if prompt:
                msg = prompt + '\n' + sample["question"]
            else:
                msg = sample["question"]
            message = [{"role": "user", "content": msg}]
            new_batch.append(message)
        model_outputs = llm.chat(
            messages=new_batch,
            sampling_params=SamplingParams(
                n=generate_n,
                # top_k=1,
                temperature=config["inference"]["temperature"],
                seed=config["inference"]["seed"],
                skip_special_tokens=False,
                ignore_eos=False,
                max_tokens=config["inference"]["max_new_tokens"]
            )
        )

        # if generate_n == 1:
        #     model_responses = [output.outputs[0].text for output in model_outputs]
        # else:

        model_responses = [[v.text for v in output.outputs] for output in model_outputs]
        gen_data = [{'input': batch[i], 'output': model_responses[i]} for i in range(len(batch))]
        outcomes = outcomes + gen_data
    return outcomes


def worker(config, dp_rank):
    dp_size = config["inference"]["dp_size"]
    tp_size = config["inference"]["tp_size"]
    # set devices for each dp_rank
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
        str(i) for i in range(dp_rank * tp_size, (dp_rank + 1) * tp_size)
    )
    logging.info(f"DP rank {dp_rank} uses device {os.environ['CUDA_VISIBLE_DEVICES']}")
    
    data_list = read_json_objects(config["dataset"]["instruction_path"])

    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
    floor = len(data_list) // dp_size
    remainder = len(data_list) % dp_size

    # Distribute prompts into even groups.
    def start(rank):
        return rank * floor + min(rank, remainder)

    data_list = data_list[start(dp_rank) : start(dp_rank + 1)]
    if len(data_list) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        data_list = ["Placeholder"]
    logging.info(f"DP rank {dp_rank} needs to process {len(data_list)} data samples.")

    # Load tokernizer and LLM
    tokenizer, llm = load_tokenizer_and_vllm(config)
    
    # Generate responses using LLM
    outputs = generate_model_response_batch(tokenizer, llm, data_list, config)

    # Write outputs to file
    output_path = config["dataset"]["output_path"]
    root, ext = os.path.splitext(output_path)
    output_path = root + f"_{dp_rank}" + ext
    write_data_to_json_file(outputs, output_path)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config = json.load(open(args.config))
    
    import multiprocessing as mp
    mp.set_start_method('spawn')
   
    dp_size = config["inference"]["dp_size"]
    tp_size = config["inference"]["tp_size"]

    if 'output_path' not in config['dataset']:
        if 'labeled_path' in config['dataset']:
            create_parent_directory(config['dataset']['labeled_path'])
            root, ext = os.path.splitext(config['dataset']['labeled_path'])
            config['dataset']['output_path'] = root + '_vllm_output.json'
        elif 'trainset_path' in config['dataset'] and 'valset_path' in config['dataset']:
            create_parent_directory(config['dataset']['trainset_path'])
            root, ext = os.path.splitext(config['dataset']['trainset_path'])
            config['dataset']['output_path'] = root + '_vllm_output.json'
        else:
            logging.error('')
    else:
        create_parent_directory(config['dataset']['output_path'])

    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(
                config,
                dp_rank,
            ),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    output_path = config["dataset"]["output_path"]
    root, ext = os.path.splitext(output_path)

    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            exit(2)
    
    if len(all_outputs) > 0:
        write_data_to_json_file(all_outputs, output_path)
    else:
        logging.error("Gathered 0 responses.")
        exit(3)

    def flatten_and_convert_to_prompt_completion(prompt, data_list):
        results = []
        for item in data_list:
            for res in item['output']:
                new_item = {
                    'prompt': prompt + '\n' + item['input']['question'],
                    'completion': res
                }
                if 'difficulty' in item['input']:
                    new_item['difficulty'] = item['input']['difficulty']
            results.append(new_item)
        return results

    if 'labeled_path' in config['dataset']:
        prompt = config['inference']['prompt']
        flattened = flatten_and_convert_to_prompt_completion(prompt, all_outputs)
        write_data_to_json_file(flattened, config['dataset']['labeled_path'])
    
    elif 'trainset_path' in config['dataset'] and 'valset_path' in config['dataset']:
        # for curriculum learning, split questions into train and val set
        # sample 100 questions from each bucket to create the val set
        from collections import defaultdict
        import random
        train_questions = []
        val_map = defaultdict(list)
        random.seed(config['dataset']['seed'])
        random.shuffle(all_outputs)
        for item in all_outputs:
            d = item['input']['difficulty']
            if len(val_map[d]) < 100:
                val_map[d].append(item)
            else:
                train_questions.append(item)
        val_questions = []
        for v in val_map.values():
            val_questions += v
        
        # convert both train and val set to prompt completion format
        prompt = config['inference']['prompt']
        train_flattened = flatten_and_convert_to_prompt_completion(prompt, train_questions)
        val_flattened = flatten_and_convert_to_prompt_completion(prompt, val_questions)
        write_data_to_json_file(train_flattened, config['dataset']['trainset_path'])
        write_data_to_json_file(val_flattened, config['dataset']['valset_path'])
    else:
        logging.error('No labeled_path or trainset_path found.')

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    exit(exit_code)


if __name__ == "__main__":
    main()
