import os
import json
from tqdm import tqdm
import requests
import argparse


# === Argument Parsing ===
parser = argparse.ArgumentParser(description="Run inference with LLaMA-Factory API.")
parser.add_argument("--task", type=str, required=False, help="Task file name, e.g. task19.jsonl. If not specified, runs all tasks 1-19")
parser.add_argument("--model_name_or_path", type=str, required=True,
                    help="Model name used for API, e.g. meta-llama/Llama-3.1-8B-Instruct")
parser.add_argument('--dataset_folder', type=str, required=False,
                    default='../sci2pol_data', help='Dataset folder path')
parser.add_argument('--output_folder', type=str, required=False,
                    default='Inference_Results', help='Output folder path')
args = parser.parse_args()

# If no task specified, run all tasks
if args.task is None:
    tasks_to_run = [f"task{i}.jsonl" for i in range(1, 20)]
else:
    tasks_to_run = [f"{args.task}.jsonl"]

model_name_or_path = args.model_name_or_path

# === Configs ===
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Extract model name after the last "/"
output_dir = args.output_folder
os.makedirs(output_dir, exist_ok=True)

API_URL = "http://localhost:8000/v1/chat/completions"  # Replace with your actual API URL

# max_new_tokens config
MAX_NEW_TOKENS_MAPPING = {
    'task1': 16, 'task2': 16, 'task3': 16, 'task4': 16, 'task5': 16, 'task6': 16,
    'task7': 128, 'task8': 128, 'task9': 128, 'task10': 128,
    'task11': 1024, 'task12': 1024, 'task13': 1024, 'task14': 1024,
    'task15': 1024, 'task16': 16, 'task17': 16, 'task18': 16
}

# system prompt
SYSTEM_PROMPT_MAPPING = {
    # Autocompletion (scientific/policy)
    "task1": "You are a helpful assistant for scientific multiple-choice tasks. Your task is to choose the best option. Strictly reply with only a single uppercase letter: A, B, C, D, or E. No explanation, no extra words.",
    "task2": "You are a helpful assistant for policy multiple-choice tasks. Your task is to choose the best option. Strictly reply with only a single uppercase letter: A, B, C, D, or E. No explanation, no extra words.",
    
    # Sentence recombination
    "task3": "You are a helpful assistant for ordering scientific sentences. Select the most logical sequence of A, B, C. Reply only with the permutation (e.g., ABC, BAC). No explanation, no extra words.",
    "task4": "You are a helpful assistant for ordering policy sentences. Select the most logical sequence of A, B, C. Reply only with the permutation (e.g., ABC, BAC). No explanation, no extra words.",

    # Sentence classification
    "task5": "You are a helpful assistant for classifying scientific text. Reply with exactly one of: Policy Problem | Scientific Research Findings | Scientific Research Study Methods | Policy Implications | None. No explanation, no extra words.",

    # Scientific MCQ
    "task6": "You are a helpful assistant for answering scientific multiple-choice questions. Strictly reply with a single uppercase letter (A, B, C, etc). No explanation, no extra words, no extra punctuation.",

    # Summarization tasks (policy, findings, methods, implications)
    "task7": "You are a helpful assistant that summarizes policy problems using clear, policy-brief style language. Reply only with your summary sentence(s). No explanation, no extra words.",
    "task8": "You are a helpful assistant that summarizes research findings using clear, policy-brief style language. Reply only with your summary sentence(s). No explanation, no extra words.",
    "task9": "You are a helpful assistant that summarizes study methods using clear, policy-brief style language. Reply only with your summary sentence(s). No explanation, no extra words.",
    "task10": "You are a helpful assistant that summarizes policy implications using clear, policy-brief style language. Reply only with your summary sentence(s). No explanation, no extra words.",

    # Generation tasks
    "task11": "You are a helpful assistant tasked with writing the 'Policy Problem' section for a policy brief. Reply only with your short paragraph. No explanation, no extra words.",
    "task12": "You are a helpful assistant tasked with writing the 'Scientific Research Findings' section for a policy brief. Reply only with your short paragraph. No explanation, no extra words.",
    "task13": "You are a helpful assistant tasked with writing the 'Scientific Research Study Methods' section for a policy brief. Reply only with your short paragraph. No explanation, no extra words.",
    "task14": "You are a helpful assistant tasked with writing the 'Policy Implications' section for a policy brief. Reply only with your short paragraph. No explanation, no extra words.",
    "task15": "You are a helpful assistant tasked with generating a full policy brief, including title, Policy Problem, Scientific Research Findings, Study Methods, and Policy Implications. Reply only with the full policy brief. No explanation, no extra words.",

    # Claims verification
    "task16": "You are a helpful assistant tasked with verifying scientific claims. Reply SUPPORT or CONTRADICT. No explanation, no extra words.",
    "task17": "You are a helpful assistant tasked with verifying COVID-19 claims. Reply with a JSON object containing 'verdict' (SUPPORT/CONTRADICT) and 'evidence' (relevant abstract sentences). No explanation, no extra words.",
    "task18": "You are a helpful assistant tasked with verifying policy implications. Reply SUPPORT or CONTRADICT. No explanation, no extra words."
}

def process_single_task(task, model_name_or_path, output_dir):
    """Process a single task file"""
    max_new_tokens = MAX_NEW_TOKENS_MAPPING.get(task.replace(".jsonl", ""), 1024)
    system_prompt = SYSTEM_PROMPT_MAPPING.get(task.replace(".jsonl", ""), "You are a helpful assistant. Please respond appropriately.")

    # Load data from local file
    task_file_path = os.path.join(args.dataset_folder, task)
    if not os.path.exists(task_file_path):
        print(f"Error: Task file not found: {task_file_path}")
        return
    
    dataset = []
    with open(task_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            dataset.append(json.loads(line.strip()))
    
    requests_data = []
    for item in dataset:
        record = {
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": item["query"]}
            ],
            "id": item["id"],
            "expected_answer": item["answer"]
        }
        requests_data.append(record)

    print(f"Prepared {len(requests_data)} records for inference on {task}.")

    task_base = os.path.basename(task).replace(".jsonl", "")  # e.g., "task1"

    out_file = os.path.join(output_dir, f"{task_base}_response.jsonl")

    # Determine batch size based on task
    if task.replace(".jsonl", "") in ["task11", "task12", "task13", "task14", "task15"]:
        batch_size = 1
    else:
        batch_size = 10

    with open(out_file, "w") as f_out:
        for i in tqdm(range(0, len(requests_data), batch_size), desc=f"Processing {task}"):
            batch = requests_data[i:i+batch_size]
            for sample in batch:
                payload = {
                    "model": model_name_or_path,
                    "messages": sample["messages"],
                    "temperature": 0.95,
                    "top_p": 0.7,
                    "top_k": 50,
                    "num_beams": 1,
                    "repetition_penalty": 1.0,
                    "max_tokens": max_new_tokens
                }
                try:
                    response = requests.post(API_URL, json=payload)
                    response.raise_for_status()
                    result = response.json()
                    raw_answer = result["choices"][0]["message"]["content"].strip()
                    answer = raw_answer
                except Exception as e:
                    print(f"Error at idx {sample['id']}: {e}")
                    answer = None

                record = {
                    "idx": sample["id"],
                    "expected": sample["expected_answer"],
                    "response": answer
                }
                f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
            f_out.flush()

    print(f"Results saved to {out_file}")

# Run inference for all specified tasks
print(f"Running inference for {len(tasks_to_run)} task(s): {tasks_to_run}")
for task in tasks_to_run:
    process_single_task(task, model_name_or_path, output_dir)

print("All tasks completed!")