import argparse
import os
import json
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from tqdm import tqdm


def prepare_input_data(data, include_reasoning=False):
    facts = "Facts:\n"
    for i, fact in enumerate(data['facts-tuned-nl']):
        facts += f"{i + 1}. {fact}\n"
    rules = "Rules:\n"
    for i, rule in enumerate(data['rules-tuned-nl']):
        rules += f"{i + 1}. {rule}\n"
    query_entity, query_attribute = data['query']
    query = f"Query:\nWhat is the value of {query_entity}'s {query_attribute}?\n"
    if not include_reasoning:
        return facts + rules + query
    intermediate_results = "After a detailed explanation, you would conclude as follows.\n"
    reasoning_process = "Reasoning:\n"
    reasoning_process += data["reasoning_process_nl"] + "\n"
    answer = data["values"][query_entity][query_attribute]
    answer = f"Answer: \\boxed{{{answer}}}\n"
    return facts + rules + query + intermediate_results + reasoning_process + answer


def load_data(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        return [json.loads(line.strip()) for line in f]


def load_prompt(data_name, example_num):
    with open("../prompt/instruction.txt", 'r', encoding='utf-8') as f:
        instruction = f.read()
    if example_num > 0:
        few_shot_data = load_data(f"../prompt/{data_name}.jsonl")
        assert len(
            few_shot_data) >= example_num, f"Not enough examples for {data_name}"
        few_shot_data = few_shot_data[:example_num]
        example_str = "Here are some examples:\n"
        example_str += "\n".join(list(map(lambda x: prepare_input_data(x,
                                 include_reasoning=True), few_shot_data)))
        example_str += "\nPlease follow the same format to conclude the answer at last:\n"
        instruction += example_str
    return instruction


def build_inputs(data, prompt):
    inputs = []
    for item in data:
        message = [
            {"role": "system", "content": prompt},
            {"role": "user", "content": prepare_input_data(
                item, include_reasoning=False)}
        ]
        inputs.append(message)
    return inputs


def generate_outputs(llm, tokenizer, inputs, enable_thinking=None):
    sampling_params = SamplingParams(
        temperature=0.3,
        top_p=0.8,
        top_k=20,
        max_tokens=8192,
    )
    if enable_thinking is None:
        text_inputs = tokenizer.apply_chat_template(
            inputs, tokenize=False, add_generation_prompt=True
        )
    else:
        text_inputs = tokenizer.apply_chat_template(
            inputs, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking
        )
    print(f"Example inputs: {text_inputs[:2]}")

    outputs = llm.generate(text_inputs, sampling_params)

    generated_texts = []
    for output in outputs:
        generated_text = output.outputs[0].text.strip()
        generated_texts.append(generated_text)

    return generated_texts


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-path', type=str,
                        required=True, help='model path')
    parser.add_argument(
        '--shot-num', type=int, required=True, help='number of shots for few-shot learning')
    parser.add_argument('--output-path', type=str,
                        required=True, help='output file path')
    parser.add_argument('--tensor-parallel-size', type=int, default=1,
                        help='number of GPUs to use for tensor parallelism')
    parser.add_argument('--gpu-memory-utilization', type=float, default=0.9,
                        help='GPU memory utilization ratio')

    args = parser.parse_args()
    if 'base' in args.model_path.lower():
        raise ValueError(
            "base model is not supported, please use a tuned model.")
    elif not any(keyword in args.model_path.lower() for keyword in ['chat', 'instruct', '-it', 'qwen3', 'phi']):
        raise ValueError(
            "Unsupported model type, please use a chat or instruct model.")

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        trust_remote_code=True,
        padding_side="left"
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(f"Loading model from {args.model_path}...")
    llm_kwargs = {
        'model': args.model_path,
        'tensor_parallel_size': args.tensor_parallel_size,
        'gpu_memory_utilization': args.gpu_memory_utilization,
        'trust_remote_code': True
    }

    llm = LLM(**llm_kwargs)
    print("Model loaded successfully!")

    for data_name in tqdm(['el-en', 'el-hn', 'hl-en', 'hl-hn'], desc="Processing datasets"):
        data_path = f"../data/{data_name}.jsonl"
        data = load_data(data_path)
        prompt = load_prompt(data_name, args.shot_num)

        if os.path.exists(args.output_path):
            with open(args.output_path, 'r', encoding='utf-8') as f:
                results = json.load(f)
        else:
            os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
            results = {}

        if results.get(data_name) is None:
            results[data_name] = []
            print(
                f"For dataset {data_name}: Processed 0 items, remaining {len(data)} items to process.")
        else:
            processed_data_id = [item['id'] for item in results[data_name]]
            data = [item for item in data if item['id']
                    not in processed_data_id]
            print(
                f"For dataset {data_name}: Processed {len(results[data_name])} items, remaining {len(data)} items to process.")

        if not data:
            print(f"Dataset {data_name} already fully processed, skipping...")
            continue

        inputs = build_inputs(data, prompt)

        print(
            f"Generating outputs for {len(inputs)} items in dataset {data_name}...")
        outputs = generate_outputs(llm, tokenizer, inputs)

        for item, output in zip(data, outputs):
            item['llm_output'] = output
            results[data_name].append(item)

        with open(args.output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=4)

        print(f"Completed processing dataset {data_name}")

    print("All datasets processed successfully!")


if __name__ == '__main__':
    main()
