import torch
import json
import os
from transformers import AutoModelForCausalLM, AutoTokenizer

import argparse

# --- 1. Configuration ---
class Config:
    # Model and Paths
    MODEL_NAME = "../../Qwen3-1.7B"  # Adjusted path relative to the script location
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    TEST_THEMES_FILE = "test_themes_qwen1.7b-50.json" # Assumes file is in the same directory
    OUTPUT_FILE = "Qwen3-1.7B_results/results_baseline_simple_top_p.json" # Changed output file name
    
    # Generation Parameters
    MAX_NEW_TOKENS = 4096  # Max length for the generated proposal
    TEMPERATURE = 0.7
    TOP_P = 0.9

# --- 2. Core Generation Function ---

def run_simple_top_p_sampling(theme_object: dict, model, tokenizer, device) -> str:
    """
    Generates a single candidate for a theme using simple top-p sampling.
    This is the true, simple baseline.
    """
    theme = theme_object['theme']
    elaboration = theme_object['elaboration']
    
    # This is the prompt we designed in our previous analysis
    prompt = f"""You are a research scientist. Based on the initial research idea below, write a complete and detailed research proposal. The proposal should be well-structured, clear, and scientifically plausible.

Initial Research Idea:
Theme: {theme}
Elaboration: {elaboration}

Your Detailed Research Proposal:
"""
    
    messages = [{"role": "user", "content": prompt}]
    text_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text_prompt], return_tensors="pt").to(device)

    print(f"Generating a single candidate for theme: '{theme[:80]}...'\n")
    
    # Simple, single generation call
    generated_ids = model.generate(
        model_inputs.input_ids,
        attention_mask=model_inputs.attention_mask,
        max_new_tokens=Config.MAX_NEW_TOKENS,
        do_sample=True,
        temperature=Config.TEMPERATURE,
        top_p=Config.TOP_P,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=1 # Explicitly generate only one sequence
    )
    
    # Decode the output
    output_text = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0]
    
    print("Generation complete.")
    return output_text

# --- 3. Main Execution ---

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='RUN baseline zero-shot prompting on test themes')
    parser.add_argument('--modelpath', type=str, required=True, help='Path of LLM model weights, e.g., ../../Qwen3-8B')
    parser.add_argument('--dbpath', type=str, required=True, help='Path of databases, e.g., ./Qwen3-db')
    parser.add_argument('--testfile', type=str, required=True, help='Test themes filepath, e.g., ./test_set/test_themes_qwen1.7b-50.json')
    parser.add_argument('--outfile', type=str, required=True, help='Path to save the result, e.g., ./result/results_cgmcts.json')
    args = parser.parse_args()

    cfg = Config()
    cfg.MODEL_NAME = args.modelpath
    cfg.TEST_THEMES_FILE = args.testfile
    cfg.OUTPUT_FILE = args.outfile

    os.makedirs(os.path.dirname(args.outfile), exists_ok=True)

    # Load Model and Tokenizer
    print(f"Loading model: {cfg.MODEL_NAME} on {cfg.DEVICE}...")
    tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        cfg.MODEL_NAME,
        dtype="auto",
        device_map="auto"
    ).eval()
    print("Model loaded successfully.")
    
    # Load Test Data
    print(f"Loading test themes from: {cfg.TEST_THEMES_FILE}")
    try:
        with open(cfg.TEST_THEMES_FILE, 'r', encoding='utf-8') as f:
            test_themes = json.load(f)
    except FileNotFoundError:
        print(f"Error: Test themes file not found at {cfg.TEST_THEMES_FILE}. Please ensure the file exists.")
        exit()
        
    all_results = []
    
    print(f"\n--- Starting SIMPLE Baseline Experiment: Top-p Sampling (1-of-1) ---")
    
    for i, theme_obj in enumerate(test_themes):
        print(f"\n--- Processing Theme {i+1}/{len(test_themes)} (ID: {theme_obj.get('id', 'N/A')}) ---")
        
        # Run the simple generation
        output = run_simple_top_p_sampling(theme_obj, model, tokenizer, model.device)
        
        # Store the result
        result = {
            "id": theme_obj.get('id'),
            "theme": theme_obj.get('theme'),
            "elaboration": theme_obj.get('elaboration'),
            "baseline_method": "simple_top_p_sampling",
            "output": output
        }
        all_results.append(result)
        
        # Save incrementally to avoid data loss on long runs
        print(f"Saving intermediate results to {cfg.OUTPUT_FILE}...")
        with open(cfg.OUTPUT_FILE, 'w', encoding='utf-8') as f_out:
            json.dump(all_results, f_out, indent=2, ensure_ascii=False)
            
    print("\n--- Experiment finished successfully! ---")
    print(f"All results saved to {cfg.OUTPUT_FILE}")
