import os
import logging
import json
from tqdm import tqdm
from math_verify import parse, verify
from pathlib import Path
import argparse

import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from utils import (
    read_json_objects,
    extract_content_from_tag, 
    write_data_to_jsonlines_file, 
    write_data_to_json_file, 
    create_parent_directory, 
    create_directory,
    get_parent_directory
)

import multiprocessing as mp


def math_verify(llm_res, ground_truth):
    llm_answer = parse(llm_res)
    ground_truth_answer = parse(ground_truth)
    correct = verify(llm_answer, ground_truth_answer)
    return correct

def load_tokenizer_and_vllm(config, eos_token=None):
    from transformers import AutoTokenizer
    from vllm import LLM

    model_path = config["model"]
    logging.info(f"Loading ckpt and tokenizer: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.padding_side = "left"
    if eos_token:
        eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
        logging.info(f"eos_token {eos_token} from user input")
    elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
        logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
        eos_token_id = tokenizer.eos_token_id
        eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
    else:
        raise ValueError("No available eos_token or eos_token_id.")
    try:
        tokenizer.eos_token = eos_token
        tokenizer.eos_token_id = eos_token_id
        tokenizer.pad_token = eos_token
        tokenizer.pad_token_id = eos_token_id
    except:
        logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
    logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
    logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")

    llm = LLM(
        model=model_path,
        tensor_parallel_size=config.get("tp_size", 1),
        enable_chunked_prefill=config.get("enable_chunked_prefill", True),
        gpu_memory_utilization=config.get("gpu_memory_utilization", 0.9),
        trust_remote_code=config.get("trust_remote_code", True),
        enforce_eager=config.get("enforce_eager", True),
        # dtype=torch.bfloat16,
        # max_model_len=config["inference"]["max_model_len"],
        # max_num_seqs=config["inference"].get("max_num_seqs", 64),
    )
    logging.info("vLLM model loaded successfully")
    return tokenizer, llm


def generate_model_response_batch(tokenizer, llm, data_list, config):
    from vllm import SamplingParams

    batch_size = config.get("batch_size", 8)
    generate_n = config.get("generate_n", 1)
    assert generate_n > 0
    outcomes = []
    batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
    for batch in tqdm(batches, desc="Generating responses"):
        new_batch = []
        for sample in batch:
            msg = sample["input"]
            message = [{"role": "user", "content": msg}]
            new_batch.append(message)
        model_outputs = llm.chat(
            messages=new_batch,
            sampling_params=SamplingParams(
                n=generate_n,
                # top_k=1,
                temperature=config.get("temperature", 0.2),
                seed=config.get("seed", 777),
                skip_special_tokens=False,
                ignore_eos=False,
                max_tokens=config.get("max_new_tokens", 1024)
            )
        )

        if generate_n == 1:
            model_responses = [output.outputs[0].text for output in model_outputs]
        else:
            model_responses = [[v.text for v in output.outputs] for output in model_outputs]
        gen_data = [{'input': batch[i], 'output': model_responses[i]} for i in range(len(batch))]
        outcomes = outcomes + gen_data
    return outcomes


def worker(config, data_list, dp_rank):
    dp_size = config.get("dp_size", 8)
    tp_size = config.get("tp_size", 1)
    # set devices for each dp_rank
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
        str(i) for i in range(dp_rank * tp_size, (dp_rank + 1) * tp_size)
    )
    logging.info(f"DP rank {dp_rank} uses device {os.environ['CUDA_VISIBLE_DEVICES']}")
    
    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
    floor = len(data_list) // dp_size
    remainder = len(data_list) % dp_size

    # Distribute prompts into even groups.
    def start(rank):
        return rank * floor + min(rank, remainder)

    data_list = data_list[start(dp_rank) : start(dp_rank + 1)]
    if len(data_list) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        data_list = ["Placeholder"]
    logging.info(f"DP rank {dp_rank} needs to process {len(data_list)} data samples.")

    # Load tokernizer and LLM
    tokenizer, llm = load_tokenizer_and_vllm(config)
    
    # Generate responses using LLM
    outputs = generate_model_response_batch(tokenizer, llm, data_list, config)

    # Write outputs to file
    output_path = config["output_path"]
    root, ext = os.path.splitext(output_path)
    output_path = root + f"_{dp_rank}" + ext
    write_data_to_json_file(outputs, output_path)


def measure_original_difficulty(config):
    dp_size = config["difficulty"]["dp_size"]
    temp_dir = Path(config['dataset']['temp_dir'])
    vllm_output_path = str(temp_dir / "original_difficulty_vllm_output.json")
    vllm_results_path = str(temp_dir / "original_difficulty_results.json")

    difficulty_config = config['difficulty']
    assert int(difficulty_config['generate_n']) > 0
    difficulty_config['output_path'] = vllm_output_path
    
    data_list = read_json_objects(config['dataset']['input_path'])
    prompt = config["difficulty"]["prompt"]
    for x in data_list:
        x['input'] = prompt + '\nQuestion:\n' + x['question']
 
    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(difficulty_config, data_list, dp_rank),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    root, ext = os.path.splitext(vllm_output_path)
    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            return 2
    write_data_to_json_file(all_outputs, vllm_output_path)
        
    def correct_count(responses, answer):
        correct = 0
        for res in responses:              
            if math_verify(res, answer):
                correct += 1
        return correct 
        
    if len(all_outputs) > 0:
        results = []
        for item in all_outputs:
            responses = item['output'] if isinstance(item['output'], list) else [item['output']]

            if 'gsm8k' in config['dataset']['input_path']:
                # for gsm8k, the answer is in the format of "#### answer"
                if '####' not in item['input']['answer']:
                    logging.error(f"Invalid answer format for gsm8k: {item['input']['answer']}")
                    continue
                # This is specific to openai/gsm8k
                new_item = {
                    'question': item['input']['question'],
                    'answer': item['input']['answer'].split('####')[1], # This is specific to openai/gsm8k
                    'solution': item['input']['answer'].split('####')[0], # This is specific to openai/gsm8k
                    'difficulty': 10 - correct_count(responses, item['input']['answer'])
                }
            else:
                # This is specific to orca-math
                new_item = {
                    'question': item['input']['question'],
                    'answer': str(parse(item['input']['solution'])[0]), # use math_verify's parse to extract the answer from the solution
                    'solution': item['input']['solution'],
                    'difficulty': 10 - correct_count(responses, item['input']['solution'])
                }
            results.append(new_item)
        write_data_to_json_file(results, vllm_results_path)
    else:
        logging.error("Gathered 0 responses.")
        return 3

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    return exit_code


def rewrite(config):
    dp_size = config["rewrite"]["dp_size"]
    # temp_dir = get_parent_directory(config["dataset"]["output_path"]) / "generate_temp"
    temp_dir = Path(config['dataset']['temp_dir'])
    vllm_output_path = str(temp_dir / "rewrite_vllm_output.json")
    vllm_results_path = str(temp_dir / "rewrite_results.json")

    rewrite_config = config['rewrite']
    rewrite_config['output_path'] = vllm_output_path
    
    # data_list = read_json_objects(config["dataset"]["input_path"])
    data_list = read_json_objects(str(temp_dir / "original_difficulty_results.json"))
    prompt = config["rewrite"]["prompt"]
    for x in data_list:
        x['input'] = prompt + '\nQuestion:\n' + x['question'] + '\nSolution:\n' + x['solution']
 
    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(rewrite_config, data_list, dp_rank),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    root, ext = os.path.splitext(vllm_output_path)
    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            return 2
    write_data_to_json_file(all_outputs, vllm_output_path)
    
    if len(all_outputs) > 0:
        results = []
        for item in all_outputs:
            questions = extract_content_from_tag('question', item['output'])
            # responses =  item['output']
            # for res in responses:
            for q in questions:
                new_item = {
                    'question': q,
                    'answer': item['input']['answer'],
                    'original': item['input']['question'],
                    'original_difficulty': item['input']['difficulty'],
                }
                results.append(new_item)
        write_data_to_json_file(results, vllm_results_path)
    else:
        logging.error("Gathered 0 responses.")
        return 3

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    return exit_code


def dedup(config, data_list):
    from sentence_transformers import SentenceTransformer, util
    import torch

    # --- 1. Your List of Questions ---
    # A mix of paraphrases, related but different questions, and unique ones.
    questions = [x['question'] for x in data_list]

    # --- 2. Load a Pre-trained Model ---
    # 'all-MiniLM-L6-v2' is a great all-rounder model: fast and high quality.
    # The model will be downloaded automatically on first run.
    print("Loading sentence transformer model...")
    model = SentenceTransformer('all-MiniLM-L6-v2')

    # Use GPU if available
    if torch.cuda.is_available():
        model = model.to('cuda')
        print("Model moved to GPU.")
    else:
        print("No GPU available, using CPU.")

    # --- 3. Generate Embeddings ---
    # This converts each question into a numerical vector.
    print("\nEncoding questions into embeddings...")
    embeddings = model.encode(questions, convert_to_tensor=True, show_progress_bar=True)
    print(f"Embeddings created with shape: {embeddings.shape}")

    # --- 4. Find and Cluster Similar Questions ---
    # util.community_detection is a high-performance clustering algorithm
    # based on cosine similarity.
    # - min_community_size: Minimum number of questions to form a cluster.
    # - threshold: The cosine similarity score to consider questions as similar.
    #   Tune this value based on your needs (0.75-0.90 is a good range).
    print("\nClustering similar questions...")
    clusters = util.community_detection(embeddings, min_community_size=1, threshold=0.75)

    print(f"Found {len(clusters)} unique question clusters.")

    # --- 5. Select One Representative Question from Each Cluster ---
    unique_questions = []
    if "print_groups" in config and config["print_groups"]:
        print("\n--- Unique Questions ---")
    
    for i, cluster in enumerate(clusters):
        # From each cluster, we select the first question as the representative.
        # The `cluster` variable contains the indices of the questions in that group.
        representative_idx = cluster[0]
        unique_questions.append(questions[representative_idx])
        
        # Optional: Print the cluster members to see the groupings
        if "print_groups" in config and config["print_groups"]:
            print(f"\nCluster {i+1}: Kept '{questions[representative_idx]}'")
            if len(cluster) > 1:
                print("  Discarded as duplicates:")
                for duplicate_idx in cluster[1:]:
                    print(f"    - '{questions[duplicate_idx]}'")

    print("\n----------------------------------")
    print(f"Original question count: {len(questions)}")
    print(f"Deduplicated question count: {len(unique_questions)}")
    print("----------------------------------")


def verify_solvable(config):
    dp_size = config["verify"]["dp_size"]
    temp_dir = Path(config['dataset']['temp_dir'])
    vllm_output_path = str(temp_dir / "verify_vllm_output.json")
    vllm_results_path = str(temp_dir / "verify_results.json")

    rewrite_config = config['verify']
    rewrite_config['output_path'] = vllm_output_path
    
    data_list = read_json_objects(str(temp_dir / "rewrite_results.json"))
    prompt = config["verify"]["prompt"]
    for x in data_list:
        x['input'] = prompt + '\nQuestion:\n' + x['question'] 
 
    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(rewrite_config, data_list, dp_rank),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    root, ext = os.path.splitext(vllm_output_path)
    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            return 2
    write_data_to_json_file(all_outputs, vllm_output_path)
    
    if len(all_outputs) > 0:
        results = []
        for item in all_outputs:
            try:
                solvable = extract_content_from_tag('solvable', item['output'])[0].strip()
            except IndexError as e:
                solvable = ''
            # try:
            #     rationale = extract_content_from_tag('rationale', item['output'])[0].strip()
            # except IndexError as e:
            #     rationale = 'N/A'
            if solvable == 'true':
                new_item = {
                    'question': item['input']['question'],
                    'answer': item['input']['answer'],
                    'original': item['input']['original'],
                    'original_difficulty': item['input']['original_difficulty'],
                }
                results.append(new_item)
        write_data_to_json_file(results, vllm_results_path)
    else:
        logging.error("Gathered 0 responses.")
        return 3

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    return exit_code


def measure_difficulty(config):
    dp_size = config["difficulty"]["dp_size"]
    # temp_dir = get_parent_directory(config["dataset"]["output_path"]) / "generate_temp"
    temp_dir = Path(config['dataset']['temp_dir'])
    vllm_output_path = str(temp_dir / "difficulty_vllm_output.json")
    vllm_results_path = str(temp_dir / "difficulty_results.json")

    difficulty_config = config['difficulty']
    assert int(difficulty_config['generate_n']) > 0
    difficulty_config['output_path'] = vllm_output_path
    
    data_list = read_json_objects(str(temp_dir / "verify_results.json"))
    prompt = config["difficulty"]["prompt"]
    for x in data_list:
        x['input'] = prompt + '\nQuestion:\n' + x['question']
 
    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(difficulty_config, data_list, dp_rank),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    root, ext = os.path.splitext(vllm_output_path)
    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            return 2
    write_data_to_json_file(all_outputs, vllm_output_path)
        
    def correct_count(responses, answer):
        correct = 0
        for res in responses:              
            if math_verify(res, answer):
                correct += 1
        return correct 
        
    if len(all_outputs) > 0:
        results = []
        for item in all_outputs:
            responses = item['output'] if isinstance(item['output'], list) else [item['output']]
            new_item = {
                'question': item['input']['question'],
                'answer': item['input']['answer'],
                'difficulty': 10 - correct_count(responses, item['input']['answer']),
                'original': item['input']['original'],
                'original_difficulty': item['input']['original_difficulty'],
            }
            results.append(new_item)
        write_data_to_json_file(results, vllm_results_path)
    else:
        logging.error("Gathered 0 responses.")
        return 3

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    return exit_code


def solve(config):
    dp_size = config["solve"]["dp_size"]
    # temp_dir = get_parent_directory(config["dataset"]["output_path"]) / "generate_temp"
    temp_dir = Path(config['dataset']['temp_dir'])
    vllm_output_path = str(temp_dir / "solve_vllm_output.json")
    vllm_results_path = str(temp_dir / "solve_results.json")

    solve_config = config['solve']
    solve_config['output_path'] = vllm_output_path
    
    data_list = read_json_objects(str(temp_dir / "combined_difficulty_results.json"))
    prompt = config["solve"]["prompt"]
    for x in data_list:
        x['input'] = prompt + '\nQuestion:\n' + x['question']
 
    procs = []
    for dp_rank in range(dp_size):
        proc = mp.Process(
            target=worker,
            args=(solve_config, data_list, dp_rank),
        )
        proc.start()
        procs.append(proc)
    exit_code = 0
    for proc in procs:
        proc.join(timeout=14400)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that didn't stop within 240 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
            exit_code = proc.exitcode

    # merge results into one file
    root, ext = os.path.splitext(vllm_output_path)
    all_outputs = []
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        if os.path.exists(rank_output_path):
            rank_outputs = read_json_objects(rank_output_path)
            all_outputs += rank_outputs
        else:
            logging.error(f"Missing output file from rank {rank}.")
            return 2
    write_data_to_json_file(all_outputs, vllm_output_path)

    if len(all_outputs) > 0:
        results = []
        for item in all_outputs:
            responses =  item['output'] if isinstance(item['output'], list) else [item['output']]
            for res in responses:
                new_item = {
                    'question': item['input']['question'],
                    'answer': item['input']['answer'],
                    'solution': res,
                    'difficulty': item['input']['difficulty'],
                }
                if 'original' in item['input']:
                    new_item['original'] = item['input']['original'],      
                if 'original_difficulty' in item['input']:
                    new_item['original_difficulty'] = item['input']['original_difficulty']
                results.append(new_item)
        write_data_to_json_file(results, vllm_results_path)
    else:
        logging.error("Gathered 0 responses.")
        return 3

    # remove files generated by each rank
    for rank in range(dp_size):
        rank_output_path = root + f"_{rank}" + ext
        try:
            os.remove(rank_output_path)
            print(f"File '{rank_output_path}' deleted successfully.")
        except FileNotFoundError:
            print(f"File '{rank_output_path}' not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    return exit_code


# -------------------- test each step works ------------------- #

def main():
    mp.set_start_method('spawn')

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config = json.load(open(args.config))

    temp_dir = get_parent_directory(config["dataset"]["output_path"]) / "generate_temp"
    create_directory(temp_dir)
    config['dataset']['temp_dir'] = str(temp_dir)

    # 1. measure original difficulty
    return_code = measure_original_difficulty(config)
    print(f"Measure original difficulty return code: {return_code}")

    # 2. rewrite
    return_code = rewrite(config)
    print(f"Rewrite return code: {return_code}")

    # 3. verify
    return_code = verify_solvable(config)
    print(f"Verify return code: {return_code}")

    # 4. difficulty
    return_code = measure_difficulty(config)
    print(f"Difficulty return code: {return_code}")

    # 5. combine original question and simplified questions
    original = read_json_objects(str(temp_dir / "original_difficulty_results.json"))
    simplified = read_json_objects(str(temp_dir / "difficulty_results.json"))
    combined = original + simplified
    write_data_to_json_file(combined, str(temp_dir / "combined_difficulty_results.json"))

    # write the results
    solve_results = read_json_objects(str(temp_dir / "combined_difficulty_results.json"))
    output = []
    for x in solve_results:
        new_item = {
            'question': x['question'],
            'answer': x['answer'],
            'difficulty': x['difficulty']
        }
        output.append(new_item)
    write_data_to_json_file(output, config['dataset']['output_path'])
    print('Success.')

    # # 6. solve
    # return_code = solve(config)
    # print(f"Solve return code: {return_code}")

    # # write the results
    # solve_results = read_json_objects(str(temp_dir / "solve_results.json"))
    # output = []
    # for x in solve_results:
    #     new_item = {
    #         'question': x['question'],
    #         'answer': x['answer'],
    #         'solution': x['solution'],
    #         'difficulty': x['difficulty']
    #     }
    #     output.append(new_item)
    # write_data_to_json_file(output, config['dataset']['output_path'])
    # print('Success.')

if __name__ == "__main__":
    main()

