import os
import logging
import json
import shutil
from tqdm import tqdm
from math_verify import parse, verify
from pathlib import Path
import argparse
import re
import string

import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from utils import (
    read_json_objects,
    extract_content_from_tag, 
    write_data_to_jsonlines_file, 
    write_data_to_json_file, 
    create_parent_directory, 
    create_directory,
    get_parent_directory
)

import multiprocessing as mp


# ------------ answer verification functions ------------ #

def math_verify(llm_res, ground_truth):
    llm_answer = parse(llm_res)
    ground_truth_answer = parse(ground_truth)
    correct = verify(llm_answer, ground_truth_answer)
    return correct

def is_correct_math(llm_res, ground_truth):
    llm_answer = parse(llm_res)
    ground_truth_answer = parse(ground_truth)
    correct = verify(llm_answer, ground_truth_answer)
    return correct

def normalize(text: str) -> str:
    text = text.lower().strip()
    # remove punctuation
    text = text.translate(str.maketrans('', '', string.punctuation))
    # normalize whitespace
    text = re.sub(r'\s+', ' ', text)
    # remove articles (optional, SQuAD-style)
    text = re.sub(r'\b(a|an|the)\b', ' ', text)
    return re.sub(r'\s+', ' ', text).strip()

def exact_match(pred, gold):
    return int(normalize(pred) == normalize(gold))

def f1_score(pred, gold):
    p_tokens = normalize(pred).split()
    g_tokens = normalize(gold).split()
    if len(p_tokens) == 0 and len(g_tokens) == 0:
        return 1.0
    common = {}
    for t in p_tokens:
        common[t] = min(p_tokens.count(t), g_tokens.count(t))
    num_same = sum(common.values())
    if num_same == 0:
        return 0.0
    precision = num_same / len(p_tokens)
    recall = num_same / len(g_tokens)
    return 2 * precision * recall / (precision + recall)

def is_semantically_equivalent(pred, gold, nli_model=None, emb_model=None):
    if nli_model:
        e1 = nli_model.entails(pred, gold)
        e2 = nli_model.entails(gold, pred)
        return e1 and e2
    if emb_model:
        v1 = emb_model.encode(pred)
        v2 = emb_model.encode(gold)
        cos = (v1 @ v2) / (np.linalg.norm(v1)*np.linalg.norm(v2))
        return cos >= 0.88  # tune this
    return False

def is_correct_nl(pred, gold, nli_model=None, emb_model=None):
    if exact_match(pred, gold):
        return True
    if f1_score(pred, gold) >= 0.8:
        return True
    if is_semantically_equivalent(pred, gold, nli_model, emb_model):
        return True
    return False

# -------------------- vLLM functions ------------------- #


def load_tokenizer_and_vllm(config, eos_token=None):
    from transformers import AutoTokenizer
    from vllm import LLM

    model_path = config["model"]
    logging.info(f"Loading ckpt and tokenizer: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.padding_side = "left"
    if eos_token:
        eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
        logging.info(f"eos_token {eos_token} from user input")
    elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
        logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
        eos_token_id = tokenizer.eos_token_id
        eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
    else:
        raise ValueError("No available eos_token or eos_token_id.")
    try:
        tokenizer.eos_token = eos_token
        tokenizer.eos_token_id = eos_token_id
        tokenizer.pad_token = eos_token
        tokenizer.pad_token_id = eos_token_id
    except:
        logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
    logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
    logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")

    llm = LLM(
        model=model_path,
        tensor_parallel_size=config.get("tp_size", 1),
        enable_chunked_prefill=config.get("enable_chunked_prefill", True),
        gpu_memory_utilization=config.get("gpu_memory_utilization", 0.9),
        trust_remote_code=config.get("trust_remote_code", True),
        enforce_eager=config.get("enforce_eager", True),
        # dtype=torch.bfloat16,
        # max_model_len=config["inference"]["max_model_len"],
        # max_num_seqs=config["inference"].get("max_num_seqs", 64),
    )
    logging.info("vLLM model loaded successfully")
    return tokenizer, llm

def generate_model_response_batch(tokenizer, llm, data_list, config):
    from vllm import SamplingParams

    batch_size = config.get("batch_size", 8)
    generate_n = config.get("generate_n", 1)
    assert generate_n > 0
    outcomes = []
    batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
    for batch in tqdm(batches, desc="Generating responses"):
        new_batch = []
        for sample in batch:
            msg = sample["input"]
            message = [{"role": "user", "content": msg}]
            new_batch.append(message)
        model_outputs = llm.chat(
            messages=new_batch,
            sampling_params=SamplingParams(
                n=generate_n,
                # top_k=1,
                temperature=config.get("temperature", 0.2),
                seed=config.get("seed", 777),
                skip_special_tokens=False,
                ignore_eos=False,
                max_tokens=config.get("max_new_tokens", 1024)
            )
        )

        if generate_n == 1:
            model_responses = [output.outputs[0].text for output in model_outputs]
        else:
            model_responses = [[v.text for v in output.outputs] for output in model_outputs]
        gen_data = [{'input': batch[i], 'output': model_responses[i]} for i in range(len(batch))]
        outcomes = outcomes + gen_data
    return outcomes


def worker(config, data_list, dp_rank):
    dp_size = config.get("dp_size", 8)
    tp_size = config.get("tp_size", 1)
    # set devices for each dp_rank
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
        str(i) for i in range(dp_rank * tp_size, (dp_rank + 1) * tp_size)
    )
    logging.info(f"DP rank {dp_rank} uses device {os.environ['CUDA_VISIBLE_DEVICES']}")
    
    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
    floor = len(data_list) // dp_size
    remainder = len(data_list) % dp_size

    # Distribute prompts into even groups.
    def start(rank):
        return rank * floor + min(rank, remainder)

    data_list = data_list[start(dp_rank) : start(dp_rank + 1)]
    if len(data_list) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        data_list = ["Placeholder"]
    logging.info(f"DP rank {dp_rank} needs to process {len(data_list)} data samples.")

    # Load tokernizer and LLM
    tokenizer, llm = load_tokenizer_and_vllm(config)
    
    # Generate responses using LLM
    outputs = generate_model_response_batch(tokenizer, llm, data_list, config)

    # Write outputs to file
    output_path = config["output_path"]
    root, ext = os.path.splitext(output_path)
    output_path = root + f"_{dp_rank}" + ext
    write_data_to_json_file(outputs, output_path)



# -------------------- difficulty measure functions ------------------- #


def measure_original_difficulty(config):
    dp_size = config["difficulty"]["dp_size"]
    temp_dir = Path(config['dataset']['temp_dir'])
    vllm_output_path = str(temp_dir / "original_difficulty_vllm_output.json")
    vllm_results_path = str(temp_dir / "original_difficulty_results.json")

    difficulty_config = config['difficulty']
    assert int(difficulty_config['generate_n']) > 0
    difficulty_config['output_path'] = vllm_output_path
    
    data_list = read_json_objects(config['dataset']['input_path'])
    prompt = config["difficulty"]["prompt"]
    for x in data_list:
        x['input'] = prompt + '\nQuestion:\n' + x['question']
 
    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(difficulty_config, data_list, dp_rank),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    root, ext = os.path.splitext(vllm_output_path)
    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            return 2
    write_data_to_json_file(all_outputs, vllm_output_path)
        
    def correct_count(responses, answer):
        correct = 0
        for res in responses:              
            if math_verify(res, answer):
                correct += 1
        return correct 
        
    if len(all_outputs) > 0:
        results = []
        for item in all_outputs:
            responses = item['output'] if isinstance(item['output'], list) else [item['output']]

            if 'gsm8k' in config['dataset']['input_path']:
                # for gsm8k, the answer is in the format of "#### answer"
                if '####' not in item['input']['answer']:
                    logging.error(f"Invalid answer format for gsm8k: {item['input']['answer']}")
                    continue
                # This is specific to openai/gsm8k
                new_item = {
                    'question': item['input']['question'],
                    'answer': item['input']['answer'].split('####')[1], # This is specific to openai/gsm8k
                    'solution': item['input']['answer'].split('####')[0], # This is specific to openai/gsm8k
                    'difficulty': 10 - correct_count(responses, item['input']['answer'])
                }
            else:
                # This is specific to orca-math
                new_item = {
                    'question': item['input']['question'],
                    'answer': str(parse(item['input']['solution'])[0]), # use math_verify's parse to extract the answer from the solution
                    'solution': item['input']['solution'],
                    'difficulty': 10 - correct_count(responses, item['input']['solution'])
                }
            results.append(new_item)
        write_data_to_json_file(results, vllm_results_path)
    else:
        logging.error("Gathered 0 responses.")
        return 3

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    return exit_code


def measure_difficulty(config):
    dp_size = config["difficulty"]["dp_size"]
    temp_dir = Path(config['dataset']['temp_dir'])
    vllm_output_path = str(temp_dir / "vllm_output.json")
    vllm_results_path = str(temp_dir / "results.json")

    difficulty_config = config['difficulty']
    assert int(difficulty_config['generate_n']) > 0
    difficulty_config['output_path'] = vllm_output_path
    
    data_list = read_json_objects(config['dataset']['input_path'])
    prompt = config["difficulty"]["prompt"]
    question_prefix = config["difficulty"].get("question_prefix", "")
    answer_prefix = config["difficulty"].get("answer_prefix", "")
    for x in data_list:
        x['input'] = prompt + question_prefix + x['question'] + answer_prefix

    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(difficulty_config, data_list, dp_rank),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    root, ext = os.path.splitext(vllm_output_path)
    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            return 2
    write_data_to_json_file(all_outputs, vllm_output_path)

    def correct_count(responses, answer, type='math'):
        correct = 0
        for res in responses:
            if type == 'math':
                if is_correct_math(res, answer):
                    correct += 1
            elif type == 'nli':
                if is_correct_nl(res, answer):
                    correct += 1
        return correct
        
    verify_type = 'math' if ('gsm8k' in config['dataset'].get('name', '') or 'gsm8k' in config['dataset']['input_path']) else 'nl'
    if len(all_outputs) > 0:
        results = []
        for item in all_outputs:
            responses = item['output'] if isinstance(item['output'], list) else [item['output']]
            new_item = {
                'question': item['input']['question'],
                'answer': item['input']['answer'],
                'difficulty': 10 - correct_count(responses, item['input']['answer'], type=verify_type),
            }
            if 'original_id' in item['input']:
                new_item['original_id'] = item['input']['original_id']
            results.append(new_item)
        write_data_to_json_file(results, vllm_results_path)
    else:
        logging.error("Gathered 0 responses.")
        return 3

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    return exit_code


# -------------------- main function ------------------- #

def main():
    mp.set_start_method('spawn')

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config = json.load(open(args.config))

    temp_dir = get_parent_directory(config["dataset"]["output_path"]) / "measure_difficulty_temp"
    create_directory(temp_dir)
    config['dataset']['temp_dir'] = str(temp_dir)

    # measure difficulty
    return_code = measure_difficulty(config)
    print(f"Difficulty measurement return code: {return_code}")
 
    # copy the results.json file to output path
    shutil.copy(str(temp_dir / "results.json"), config['dataset']['output_path'])


if __name__ == "__main__":
    main()
