from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from tqdm import tqdm
import sys
import os.path
import json
import os
import argparse

MODEL_PATH = os.environ.get("MODEL_PATH", "path/to/your/model")
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0,1")

with open("./nl_tuning_prompt.json", "r") as f:
    prompt_data = json.load(f)

fact_examples = ''.join(
    f"formal representation: {item['formal']}\ntemplate representation: {item['template']}\noutput: {item['nl']}\n" for item in prompt_data['fact_examples'])
rule_examples = ''.join(
    f"formal representation: {item['formal']}\ntemplate representation: {item['template']}\noutput: {item['nl']}\n" for item in prompt_data['rule_examples'])
instruction = prompt_data['prompt']

# Initialize tokenizer for handling chat template
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH, trust_remote_code=True, padding_side="left")

# Initialize vLLM model
llm = LLM(
    model=MODEL_PATH,
    trust_remote_code=True,
    tensor_parallel_size=len(gpus.split(",")),
    gpu_memory_utilization=0.8,
    max_model_len=8192,  # Adjust as needed
)

enable_thinking = False


def extract_thinking_and_content(output_text):
    """
    Extract thinking content and final answer from Qwen3 output
    Qwen3 thinking is usually between special tokens
    """
    # Qwen3 thinking token IDs
    thinking_start_token = "<|thinking|>"
    thinking_end_token = "<|/thinking|>"

    thinking_content = ""
    final_content = output_text

    if thinking_start_token in output_text and thinking_end_token in output_text:
        # Extract thinking content
        start_idx = output_text.find(thinking_start_token)
        end_idx = output_text.find(thinking_end_token)

        if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
            thinking_content = output_text[start_idx +
                                           len(thinking_start_token):end_idx].strip()
            # Extract final content (part after thinking)
            final_content = output_text[end_idx +
                                        len(thinking_end_token):].strip()

    return thinking_content, final_content


def post_batch(question_list, fact_prompt, debug_print=False):
    if fact_prompt:
        prompt = instruction.format(examples=fact_examples)
    else:
        prompt = instruction.format(examples=fact_examples + rule_examples)

    # Prepare all messages
    messages_list = []
    for nl, formal in question_list:
        message = [
            {"role": "system", "content": prompt},
            {"role": "user", "content": f"formal representation: {formal}\ntemplate representation: {nl}\noutput: "}
        ]
        messages_list.append(message)

    # Use tokenizer's chat template to generate input text
    input_texts = []
    for messages in messages_list:
        input_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking
        )
        input_texts.append(input_text)

    # Set sampling parameters
    if not enable_thinking:
        sampling_params = SamplingParams(
            max_tokens=512,
            temperature=0.7,
            top_p=0.8,
            top_k=20,
            stop=["<|endoftext|>", "<|im_end|>"]
        )
    else:
        sampling_params = SamplingParams(
            max_tokens=2048,
            temperature=0.6,
            top_p=0.95,
            top_k=20,
            stop=["<|endoftext|>", "<|im_end|>"]
        )

    # Batch generation
    outputs = llm.generate(input_texts, sampling_params)

    answers = []
    for idx, output in enumerate(outputs):
        generated_text = output.outputs[0].text

        if enable_thinking:
            thinking_content, content = extract_thinking_and_content(
                generated_text)
        else:
            thinking_content = ""
            content = generated_text.strip()

        answers.append(content)

        if debug_print:
            print("\n================ LLM CALL ================")
            print(f"[Prompt]:\n{messages_list[idx][0]['content']}")
            print(f"[User Input]:\n{messages_list[idx][1]['content']}")
            if enable_thinking:
                print(f"[Thinking Content]:\n{thinking_content}")
            print(f"[LLM Output]:\n{content}")
            print("==========================================\n")

    return answers


def process_real_data(data_list, batch_size=100000):
    # vLLM can usually handle larger batches
    if enable_thinking:
        batch_size = batch_size // 2  # Reduce batch size in thinking mode

    rules_results = process_batch_items(data_list, 'rules', batch_size // 2)
    facts_results = process_batch_items(data_list, 'facts', batch_size)

    for data_idx, data in enumerate(data_list):
        data['facts-tuned-nl'] = facts_results.get(data_idx)
        assert data['facts-tuned-nl'] is not None, \
            f"Facts tuning failed for data index {data_idx}"
        assert len(data['facts-tuned-nl']) == len(data['facts-str']), \
            f"Facts tuning length mismatch for data index {data_idx}"
        data['rules-tuned-nl'] = rules_results.get(data_idx)
        assert data['rules-tuned-nl'] is not None, \
            f"Rules tuning failed for data index {data_idx}"
        assert len(data['rules-tuned-nl']) == len(data['rules-str']), \
            f"Rules tuning length mismatch for data index {data_idx}"

    return data_list


def process_batch_items(data_list, item_type, batch_size):
    id_to_item = {}
    questions = []
    unique_ids = []

    for data_idx, data in enumerate(data_list):
        formal_key = f"{item_type}-str"
        nl_key = f"{item_type}-nl"

        if formal_key not in data or nl_key not in data:
            continue

        formal_items = data.get(formal_key, [])
        nl_items = data.get(nl_key, [])

        assert len(nl_items) == len(formal_items), \
            f"{item_type.capitalize()} and formal {item_type} length mismatch in data {data_idx}"

        for item_idx, (nl, formal) in enumerate(zip(nl_items, formal_items)):
            unique_id = f"data{data_idx}_{item_type}{item_idx}"
            id_to_item[unique_id] = (data_idx, item_idx)
            questions.append((nl, formal))
            unique_ids.append(unique_id)

    fact_prompt = item_type == 'facts'
    all_results = []
    for i in tqdm(range(0, len(questions), batch_size), desc=f"Processing {item_type.capitalize()}"):
        batch = questions[i:i+batch_size]
        all_results.extend(post_batch(batch, fact_prompt))

    results_by_data = {}
    for unique_id, answer in zip(unique_ids, all_results):
        data_idx, item_idx = id_to_item[unique_id]
        if data_idx not in results_by_data:
            results_by_data[data_idx] = []
        while len(results_by_data[data_idx]) <= item_idx:
            results_by_data[data_idx].append("")
        results_by_data[data_idx][item_idx] = answer

    return results_by_data


def read_jsonl(path):
    data_list = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            data_list.append(data)
    return data_list


def write_jsonl(path, data_list):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        for data in data_list:
            line = json.dumps(data, ensure_ascii=False)
            f.write(line + '\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--debug', action='store_true',
                        help='Run in debug mode with example facts and rules')
    parser.add_argument('--enable-thinking', action='store_true',
                        help='Enable LLM thinking content')
    parser.add_argument('--batch-size', type=int, default=1000000,
                        help='Batch size for processing')
    args = parser.parse_args()

    if args.enable_thinking:
        enable_thinking = True
    else:
        enable_thinking = False

    if not os.path.exists(MODEL_PATH):
        print(
            f"Model path {MODEL_PATH} does not exist. Please set the MODEL_PATH environment variable correctly.")
        sys.exit(1)

    if args.debug:
        print("[DEBUG MODE] Only processing part of fact/rule examples from raw-data file and printing detailed input/output\n")
        debug_data_path = "../raw-data/el-en.jsonl"
        with open(debug_data_path, 'r', encoding='utf-8') as f:
            first_line = f.readline()
            data = json.loads(first_line)
        fact_formals = data['facts-str'][:2]
        fact_templates = data['facts-nl'][:2]
        fact_questions = list(zip(fact_templates, fact_formals))
        rule_formals = data['rules-str'][:2]
        rule_templates = data['rules-nl'][:2]
        rule_questions = list(zip(rule_templates, rule_formals))
        print("\n[FACT EXAMPLES FROM FILE]")
        fact_outputs = post_batch(
            fact_questions, fact_prompt=True, debug_print=True)
        print("\n[RULE EXAMPLES FROM FILE]")
        rule_outputs = post_batch(
            rule_questions, fact_prompt=False, debug_print=True)
        print("\n[DEBUG MODE END]")
        sys.exit(0)

    data_names = [f"{logicai}-{numerical}.jsonl" for logicai in ['el', 'hl']
                  for numerical in ['en', 'hn']]
    data_names.extend(['train-el.jsonl', 'train-en.jsonl', 'depth-7.jsonl',
                      'depth-8.jsonl', 'depth-9.jsonl', 'depth-10.jsonl'])
    with tqdm(total=len(data_names), desc="Total Loop") as outer_bar:
        for data_name in data_names:
            path = f"../raw-data/{data_name}"
            data_list = read_jsonl(path)
            data_list = process_real_data(data_list)
            write_jsonl(f"../data/{data_name}", data_list)
            outer_bar.update(1)
