"""
Batch Generation Script for Hypothesis Composition (Teacher Distillation / Self-Sampling)

This script generates hypothesis outputs for subsequent filtering and SFT training.
Supports two modes:
1. Local model: Load model locally and generate with transformers
2. API mode: Call external API (e.g., SGLang, vLLM, or cloud APIs)

Key features:
- Batch processing with dynamic batching (local mode)
- Concurrent API calls (API mode)
- Support for multi-GPU with data parallelism
- Multiple samples per data point for rejection sampling

Usage:
    # Local model mode (if_use_api=0)
    python hypothesis_composition_sampling.py \
        --if_use_api 0 \
        --model_path /path/to/model \
        --output_dir /path/to/output \
        --batch_size 4 \
        --num_samples 8

    # API mode (if_use_api=1, e.g., SGLang with 32B teacher)
    python hypothesis_composition_sampling.py \
        --if_use_api 1 \
        --api_base_url http://localhost:1234/v1 \
        --api_model_name deepseek-r1-distill-qwen-32b \
        --output_dir /path/to/output \
        --num_samples 8 \
        --max_workers 32
"""

import os
import sys
import json
import argparse
from typing import List, Dict, Tuple, Optional
from tqdm import tqdm
from dataclasses import dataclass
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

# Add paths for imports
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parent_dir)
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
sys.path.insert(0, os.path.join(current_dir, 'Legacy'))

from prompt_store import instruction_prompts
from hypothesis_composition_reasoning_trace import sample_one_MDP_for_one_paper_from_hypothesis_components
from common_utils import extract_hypothesis_from_response

# Lazy imports for local mode (avoid importing torch when using API mode)
torch = None
DataLoader = None
AutoTokenizer = None
AutoModelForCausalLM = None

def _import_torch_deps():
    """Import torch dependencies lazily"""
    global torch, DataLoader, AutoTokenizer, AutoModelForCausalLM
    if torch is None:
        import torch as _torch
        from torch.utils.data import DataLoader as _DataLoader
        from transformers import AutoTokenizer as _AutoTokenizer, AutoModelForCausalLM as _AutoModelForCausalLM
        torch = _torch
        DataLoader = _DataLoader
        AutoTokenizer = _AutoTokenizer
        AutoModelForCausalLM = _AutoModelForCausalLM


@dataclass
class GenerationSample:
    """Single sample for generation"""
    file_name: str
    step_idx: int
    prompt: str
    gt_hypothesis: str  # Ground truth for later TC
    prev_hypothesis: Optional[str]
    inspiration_title: str
    inspiration_abstract: str
    # Bounded composition fields (optional)
    tier: Optional[str] = None  # 'hard', 'medium', 'easy' for bounded mode
    bounded_similarity: Optional[float] = None
    gt_inspiration_title: Optional[str] = None  # Original inspiration title


class HypothesisDataset:
    """Dataset for batch generation - directly from SFT QA data without MDP road dependency
    
    Note: This class implements __len__ and __getitem__ for compatibility with both
    torch DataLoader (local mode) and simple iteration (API mode).
    
    Supports two modes:
    1. Normal mode: Uses ground truth inspirations from sft_qa_data_dir
    2. Bounded mode: Uses bounded inspirations from bounded_selections_dir
    """
    
    def __init__(
        self,
        sft_qa_data_dir: str,
        prompts: List[str],
        file_list: Optional[List[str]] = None,
        max_samples: Optional[int] = None,
        # Bounded composition mode
        use_bounded: bool = False,
        bounded_selections_dir: Optional[str] = None,
        bounded_tiers: Optional[List[str]] = None  # ['hard', 'medium', 'easy']
    ):
        """
        Args:
            sft_qa_data_dir: Directory containing SFT QA data (research question, background survey, 
                             inspirations, hypothesis_components)
            prompts: List of prompt templates from instruction_prompts
            file_list: Optional list of specific files to process
            max_samples: Optional limit on number of samples
            use_bounded: If True, use bounded inspirations instead of GT
            bounded_selections_dir: Directory containing bounded inspiration selections
            bounded_tiers: Which tiers to include (default: all)
        """
        self.sft_qa_data_dir = sft_qa_data_dir
        self.prompts = prompts
        self.samples: List[GenerationSample] = []
        self.use_bounded = use_bounded
        
        if use_bounded:
            self._load_bounded_samples(
                bounded_selections_dir, 
                bounded_tiers or ['hard', 'medium', 'easy'],
                file_list, 
                max_samples
            )
        else:
            self._load_normal_samples(file_list, max_samples)
    
    def _load_normal_samples(
        self,
        file_list: Optional[List[str]] = None,
        max_samples: Optional[int] = None
    ):
        """Load samples using ground truth inspirations (normal mode)"""
        # Get file list from sft_qa_data_dir
        if file_list is None:
            qa_files = sorted([f for f in os.listdir(self.sft_qa_data_dir) if f.endswith('.json')])
        else:
            qa_files = file_list
        
        print(f"Loading samples from {len(qa_files)} files...")
        
        # Load all samples
        skipped_files = 0
        for file_name in tqdm(qa_files, desc="Loading data"):
            sft_qa_path = os.path.join(self.sft_qa_data_dir, file_name)
            if not os.path.exists(sft_qa_path):
                skipped_files += 1
                continue
                
            with open(sft_qa_path, 'r') as f:
                sft_qa_data = json.load(f)
            
            # Extract required fields
            research_question = sft_qa_data.get("research_question", "")
            background_survey = sft_qa_data.get("background_survey", "")
            inspirations = sft_qa_data.get("inspiration", [])
            hypothesis_components = sft_qa_data.get("hypothesis_components", {})
            
            # Skip if missing required fields
            if not inspirations or not hypothesis_components:
                skipped_files += 1
                continue
            
            # Build MDP road using the shared function (v2 format: sequential order)
            try:
                mdp_road = sample_one_MDP_for_one_paper_from_hypothesis_components(
                    inspirations, hypothesis_components, file_name
                )
            except AssertionError as e:
                print(f"Skipping {file_name}: {e}")
                skipped_files += 1
                continue
            
            # Process each step in the MDP road
            for step_idx, (insp_id, delta_hyp) in enumerate(mdp_road):
                cur_insp = inspirations[insp_id]
                title = cur_insp.get("found_title", "")
                abstract = cur_insp.get("found_abstract", "")
                gt_hypothesis = delta_hyp  # The delta hypothesis for this step
                
                if not gt_hypothesis:
                    continue
                
                # Build prev_hyp as cumulative (join all previous deltas)
                if step_idx > 0:
                    prev_deltas = [mdp_road[j][1] for j in range(step_idx)]
                    prev_hyp = "\n\n".join(prev_deltas)
                else:
                    prev_hyp = "No previous hypothesis."
                
                # Build prompt using the same format as training/evaluation
                prompt = (
                    self.prompts[0] + research_question +
                    self.prompts[1] + background_survey +
                    self.prompts[2] + prev_hyp +
                    self.prompts[3] + title +
                    self.prompts[4] + abstract +
                    self.prompts[5]
                )
                
                sample = GenerationSample(
                    file_name=file_name,
                    step_idx=step_idx,
                    prompt=prompt,
                    gt_hypothesis=gt_hypothesis,
                    prev_hypothesis=prev_hyp,
                    inspiration_title=title,
                    inspiration_abstract=abstract
                )
                self.samples.append(sample)
                
                if max_samples and len(self.samples) >= max_samples:
                    break
            
            if max_samples and len(self.samples) >= max_samples:
                break
        
        print(f"Loaded {len(self.samples)} samples from {len(qa_files) - skipped_files} files")
        if skipped_files > 0:
            print(f"Skipped {skipped_files} files (missing data or validation errors)")
    
    def _load_bounded_samples(
        self,
        bounded_selections_dir: str,
        tiers: List[str],
        file_list: Optional[List[str]] = None,
        max_samples: Optional[int] = None
    ):
        """Load samples using bounded inspirations (bounded composition mode)"""
        if file_list is None:
            sel_files = sorted([f for f in os.listdir(bounded_selections_dir) if f.endswith('.json')])
        else:
            sel_files = file_list
        
        print(f"[Bounded Mode] Loading samples from {len(sel_files)} files...")
        print(f"[Bounded Mode] Tiers: {tiers}")
        
        skipped_files = 0
        for file_name in tqdm(sel_files, desc="Loading bounded data"):
            sel_path = os.path.join(bounded_selections_dir, file_name)
            if not os.path.exists(sel_path):
                skipped_files += 1
                continue
            
            with open(sel_path, 'r') as f:
                sel_data = json.load(f)
            
            research_question = sel_data.get("research_question", "")
            background_survey = sel_data.get("background_survey", "")
            inspirations = sel_data.get("inspirations", [])
            
            if not inspirations:
                skipped_files += 1
                continue
            
            # Build prev_hyp tracking
            prev_hyp = "No previous hypothesis."
            
            for insp in inspirations:
                idx = insp.get('idx', 0)
                gt_hypothesis = insp.get('delta_hypothesis', '')
                gt_title = insp.get('gt_title', '')
                
                if not gt_hypothesis:
                    continue
                
                bounded_selections = insp.get('bounded_selections', {})
                
                # Process each tier
                for tier in tiers:
                    bounded = bounded_selections.get(tier)
                    if bounded is None:
                        continue
                    
                    bounded_title = bounded.get('title', '')
                    bounded_abstract = bounded.get('abstract', '')
                    similarity = bounded.get('similarity', 0)
                    
                    if not bounded_title or not bounded_abstract:
                        continue
                    
                    # Build prompt with bounded inspiration
                    prompt = (
                        self.prompts[0] + research_question +
                        self.prompts[1] + background_survey +
                        self.prompts[2] + prev_hyp +
                        self.prompts[3] + bounded_title +
                        self.prompts[4] + bounded_abstract +
                        self.prompts[5]
                    )
                    
                    sample = GenerationSample(
                        file_name=file_name,
                        step_idx=idx,
                        prompt=prompt,
                        gt_hypothesis=gt_hypothesis,
                        prev_hypothesis=prev_hyp,
                        inspiration_title=bounded_title,
                        inspiration_abstract=bounded_abstract,
                        tier=tier,
                        bounded_similarity=similarity,
                        gt_inspiration_title=gt_title
                    )
                    self.samples.append(sample)
                    
                    if max_samples and len(self.samples) >= max_samples:
                        break
                
                # Update prev_hyp for next inspiration (use GT, not generated)
                if gt_hypothesis:
                    if prev_hyp == "No previous hypothesis.":
                        prev_hyp = gt_hypothesis
                    else:
                        prev_hyp = prev_hyp + "\n\n" + gt_hypothesis
                
                if max_samples and len(self.samples) >= max_samples:
                    break
            
            if max_samples and len(self.samples) >= max_samples:
                break
        
        print(f"[Bounded Mode] Loaded {len(self.samples)} samples from {len(sel_files) - skipped_files} files")
        
        # Stats by tier
        tier_counts = {}
        for s in self.samples:
            if s.tier:
                tier_counts[s.tier] = tier_counts.get(s.tier, 0) + 1
        print(f"[Bounded Mode] Samples by tier: {tier_counts}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]


def collate_fn(batch: List[GenerationSample]) -> List[GenerationSample]:
    """Simple collate function that returns the batch as-is"""
    return batch


class BatchGenerator:
    """Efficient batch generator for hypothesis generation (local model mode)"""
    
    def __init__(
        self,
        model_path: str,
        device: str = "cuda",
        max_length: int = 16384,
        max_new_tokens: int = 4096,
        temperature: float = 0.6,
        top_p: float = 0.9,
        repetition_penalty: float = 1.2,
        num_samples: int = 1
    ):
        # Import torch dependencies
        _import_torch_deps()
        
        self.device = device
        self.max_length = max_length
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.repetition_penalty = repetition_penalty
        self.num_samples = num_samples
        
        print(f"Loading model from {model_path}...")
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True,
            use_fast=False
        )
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model with efficient settings
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        self.model.eval()
        
        print("Model loaded successfully")
    
    def format_prompt(self, prompt: str) -> str:
        """Format prompt with chat template"""
        messages = [{"role": "user", "content": prompt}]
        formatted = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
        formatted += "<｜Assistant｜>"
        return formatted
    
    def generate_batch(self, samples: List[GenerationSample]) -> List[Dict]:
        """Generate responses for a batch of samples (with multiple samples per data point if configured)"""
        
        # Format all prompts
        formatted_prompts = [self.format_prompt(s.prompt) for s in samples]
        
        # Tokenize with padding
        inputs = self.tokenizer(
            formatted_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        ).to(self.device)
        
        all_results = []
        
        # Generate num_samples times for each batch
        for sample_idx in range(self.num_samples):
            # Generate
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                    do_sample=True,
                    top_p=self.top_p,
                    repetition_penalty=self.repetition_penalty,
                    num_beams=1,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode responses
            for i, (sample, output) in enumerate(zip(samples, outputs)):
                # Get only generated tokens (use padded length for correct extraction)
                max_input_len = inputs['input_ids'].shape[1]
                generated_tokens = output[max_input_len:]
                response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                
                # Extract hypothesis and reasoning using shared utility
                generated_hypothesis, reasoning_trace = extract_hypothesis_from_response(response)
                generated_hypothesis = generated_hypothesis or ""
                reasoning_trace = reasoning_trace or ""
                
                result = {
                    'file_name': sample.file_name,
                    'step_idx': sample.step_idx,
                    'sample_idx': sample_idx,  # Which sample (0, 1, 2, ...)
                    'reasoning_trace': reasoning_trace,
                    'generated_hypothesis': generated_hypothesis,
                    'gt_hypothesis': sample.gt_hypothesis,
                    'inspiration_title': sample.inspiration_title,
                    'inspiration_abstract': sample.inspiration_abstract,
                    'raw_response': response,
                    'error': False
                }
                # Add bounded composition fields if present
                if sample.tier:
                    result['tier'] = sample.tier
                    result['bounded_similarity'] = sample.bounded_similarity
                    result['gt_inspiration_title'] = sample.gt_inspiration_title
                all_results.append(result)
        
        return all_results


class APIGenerator:
    """Generator using external API (e.g., SGLang, vLLM, OpenAI-compatible)"""
    
    def __init__(
        self,
        api_base_url: str,
        api_key: str = "EMPTY",
        model_name: str = "default",
        max_tokens: int = 4096,
        temperature: float = 0.6,
        top_p: float = 0.9,
        num_samples: int = 1,
        max_workers: int = 16
    ):
        from openai import OpenAI
        
        self.client = OpenAI(api_key=api_key, base_url=api_base_url)
        self.model_name = model_name
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.num_samples = num_samples
        self.max_workers = max_workers
        
        print(f"API Generator initialized:")
        print(f"  Base URL: {api_base_url}")
        print(f"  Model: {model_name}")
        print(f"  Max tokens: {max_tokens}")
        print(f"  Temperature: {temperature}")
        print(f"  Num samples per data point: {num_samples}")
        print(f"  Max concurrent workers: {max_workers}")
    
    def _generate_single_with_n(self, prompt: str, sample_info: dict) -> List[Dict]:
        """Generate n responses for a single prompt via API (using n parameter for efficiency)"""
        results = []
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "user", "content": prompt}
                ],
                max_tokens=self.max_tokens,
                temperature=self.temperature,
                top_p=self.top_p,
                n=self.num_samples,  # Let SGLang handle parallel generation internally
            )
            
            # Process all n responses
            for sample_idx, choice in enumerate(response.choices):
                raw_response = choice.message.content or ""
                
                # Extract hypothesis and reasoning using shared utility
                generated_hypothesis, reasoning_trace = extract_hypothesis_from_response(raw_response)
                generated_hypothesis = generated_hypothesis or ""
                reasoning_trace = reasoning_trace or ""
                
                result = {
                    'file_name': sample_info['file_name'],
                    'step_idx': sample_info['step_idx'],
                    'sample_idx': sample_idx,
                    'reasoning_trace': reasoning_trace,
                    'generated_hypothesis': generated_hypothesis,
                    'gt_hypothesis': sample_info['gt_hypothesis'],
                    'inspiration_title': sample_info['inspiration_title'],
                    'inspiration_abstract': sample_info.get('inspiration_abstract', ''),
                    'raw_response': raw_response,
                    'error': False
                }
                # Add bounded composition fields if present
                if sample_info.get('tier'):
                    result['tier'] = sample_info['tier']
                    result['bounded_similarity'] = sample_info.get('bounded_similarity')
                    result['gt_inspiration_title'] = sample_info.get('gt_inspiration_title')
                results.append(result)
            
        except Exception as e:
            # On error, create error entries for all expected samples
            for sample_idx in range(self.num_samples):
                result = {
                    'file_name': sample_info['file_name'],
                    'step_idx': sample_info['step_idx'],
                    'sample_idx': sample_idx,
                    'reasoning_trace': '',
                    'generated_hypothesis': '',
                    'gt_hypothesis': sample_info['gt_hypothesis'],
                    'inspiration_title': sample_info['inspiration_title'],
                    'inspiration_abstract': sample_info.get('inspiration_abstract', ''),
                    'raw_response': f'ERROR: {str(e)}',
                    'error': True
                }
                # Add bounded composition fields if present
                if sample_info.get('tier'):
                    result['tier'] = sample_info['tier']
                    result['bounded_similarity'] = sample_info.get('bounded_similarity')
                    result['gt_inspiration_title'] = sample_info.get('gt_inspiration_title')
                results.append(result)
        
        return results
    
    def generate_batch(self, samples: List['GenerationSample']) -> List[Dict]:
        """Generate responses for a batch of samples using concurrent API calls
        
        Uses n parameter for each request (SGLang handles parallel generation internally),
        then uses ThreadPoolExecutor to parallelize across different prompts.
        """
        all_results = []
        
        # Prepare sample info for each prompt
        tasks = []
        for sample in samples:
            sample_info = {
                'file_name': sample.file_name,
                'step_idx': sample.step_idx,
                'gt_hypothesis': sample.gt_hypothesis,
                'inspiration_title': sample.inspiration_title,
                'inspiration_abstract': sample.inspiration_abstract,
                # Bounded composition fields (may be None for normal mode)
                'tier': sample.tier,
                'bounded_similarity': sample.bounded_similarity,
                'gt_inspiration_title': sample.gt_inspiration_title
            }
            tasks.append((sample.prompt, sample_info))
        
        # Execute concurrently - each call generates num_samples responses
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = {
                executor.submit(self._generate_single_with_n, prompt, info): (prompt, info)
                for prompt, info in tasks
            }
            
            for future in as_completed(futures):
                results = future.result()
                all_results.extend(results)
        
        return all_results


def save_results(results: List[Dict], output_path: str, mode: str = 'a'):
    """Save results to JSONL file"""
    with open(output_path, mode) as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')


def main():
    parser = argparse.ArgumentParser(description='Batch generation for hypothesis composition')
    
    # Mode selection (0 = local model, 1 = API)
    parser.add_argument("--if_use_api", type=int, default=1, choices=[0, 1],
                       help="Use API mode (1) or local model (0)")
    
    # Local model configuration
    parser.add_argument("--model_path", type=str, 
                       default="/pfs/training-data/hf/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
                       help="Path to local model (ignored in API mode)")
    
    # API configuration
    parser.add_argument("--api_base_url", type=str, default="http://localhost:1234/v1",
                       help="Base URL for API endpoint (e.g., SGLang, vLLM)")
    parser.add_argument("--api_key", type=str, default="EMPTY",
                       help="API key (use 'EMPTY' for local deployments)")
    parser.add_argument("--api_model_name", type=str, default="default",
                       help="Model name for API calls")
    parser.add_argument("--max_workers", type=int, default=128,
                       help="Max concurrent workers for API calls (tested: 128 workers = 265 gen/s)")
    
    # Data paths
    parser.add_argument("--sft_qa_data_dir", type=str,
                       default="<YOUR_SFT_QA_DATA_DIR>",
                       help="Directory containing SFT QA data (train set)")
    
    # Output configuration
    parser.add_argument("--output_dir", type=str,
                       default="<YOUR_HC_GENERATION_DIR>",
                       help="Directory to save generation results")
    
    # Generation parameters
    parser.add_argument("--batch_size", type=int, default=128,
                       help="Batch size for generation (recommend same as max_workers for API mode)")
    parser.add_argument("--max_length", type=int, default=16384,
                       help="Maximum input sequence length (local mode)")
    parser.add_argument("--max_new_tokens", type=int, default=8192,
                       help="Maximum new tokens to generate (8192 for long reasoning)")
    parser.add_argument("--temperature", type=float, default=0.7,
                       help="Generation temperature (0.7 recommended for diversity in rejection sampling)")
    parser.add_argument("--top_p", type=float, default=0.9,
                       help="Top-p sampling parameter")
    parser.add_argument("--repetition_penalty", type=float, default=1.2,
                       help="Repetition penalty (local mode only)")
    
    # Processing options
    parser.add_argument("--max_samples", type=int, default=None,
                       help="Maximum number of samples to process (for testing)")
    parser.add_argument("--save_every", type=int, default=50,
                       help="Save results every N batches")
    parser.add_argument("--resume_from", type=str, default=None,
                       help="Path to existing output file to resume from (skip already processed samples)")
    parser.add_argument("--gpu_id", type=int, default=None,
                       help="Specific GPU to use (for local multi-GPU parallel runs)")
    parser.add_argument("--file_list", type=str, default=None,
                       help="JSON file containing list of specific files to process (for distributed runs)")
    parser.add_argument("--num_samples", type=int, default=8,
                       help="Number of samples to generate per data point (8 for rejection sampling)")
    
    # Bounded composition mode
    parser.add_argument("--use_bounded", type=int, default=0, choices=[0, 1],
                       help="Use bounded inspirations instead of GT (0=normal, 1=bounded)")
    parser.add_argument("--bounded_selections_dir", type=str,
                       default="<YOUR_BOUNDED_INSP_DIR>/selections",
                       help="Directory containing bounded inspiration selections (when use_bounded=1)")
    parser.add_argument("--bounded_tiers", type=str, default="hard,medium,easy",
                       help="Comma-separated tiers to include for bounded mode (hard,medium,easy)")
    
    args = parser.parse_args()
    
    # Set GPU (only for local mode)
    if not args.if_use_api and args.gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Output file path
    output_file = os.path.join(args.output_dir, "generations.jsonl")
    
    # Load prompts
    prompts = instruction_prompts("prepare_HC_sft_data_to_go_comprehensive_v2_delta")
    
    # Print mode info
    if args.if_use_api:
        print(f"\n{'='*60}")
        print("MODE: API (calling external service)")
        print(f"{'='*60}")
        print(f"API URL: {args.api_base_url}")
        print(f"Model: {args.api_model_name}")
    else:
        print(f"\n{'='*60}")
        print("MODE: Local Model")
        print(f"{'='*60}")
        print(f"Model: {args.model_path}")
    
    # Check for resume
    processed_keys = set()
    is_resuming = False
    
    if args.resume_from:
        # Explicit resume from specified file
        if not os.path.exists(args.resume_from):
            raise FileNotFoundError(f"Resume file not found: {args.resume_from}")
        
        print(f"Resuming from: {args.resume_from}")
        with open(args.resume_from, 'r') as f:
            for line in f:
                if line.strip():
                    data = json.loads(line)
                    # For bounded mode, include tier in key to allow different tiers
                    tier = data.get('tier')
                    if tier:
                        key = (data['file_name'], data['step_idx'], tier)
                    else:
                        key = (data['file_name'], data['step_idx'], None)
                    processed_keys.add(key)
        print(f"Found {len(processed_keys)} already processed samples")
        
        # Copy to output file if different
        if args.resume_from != output_file:
            import shutil
            shutil.copy(args.resume_from, output_file)
            print(f"Copied existing results to {output_file}")
        is_resuming = True
        
    elif os.path.exists(output_file):
        # Output file exists but no resume specified - raise error to prevent accidental overwrite
        raise FileExistsError(
            f"Output file already exists: {output_file}\n"
            f"To resume from it, use: --resume_from {output_file}\n"
            f"To overwrite, delete the file first or use a different --output_dir"
        )
    
    # Load file list if specified (for distributed runs)
    file_list = None
    if args.file_list and os.path.exists(args.file_list):
        with open(args.file_list, 'r') as f:
            file_list = json.load(f)
        print(f"Using file list with {len(file_list)} files from {args.file_list}")
    
    # Load dataset
    print("Loading dataset...")
    use_bounded = bool(args.use_bounded)
    bounded_tiers = [t.strip() for t in args.bounded_tiers.split(',')] if use_bounded else None
    
    if use_bounded:
        print(f"\n{'='*60}")
        print("BOUNDED COMPOSITION MODE")
        print(f"{'='*60}")
        print(f"Selections dir: {args.bounded_selections_dir}")
        print(f"Tiers: {bounded_tiers}")
    
    dataset = HypothesisDataset(
        sft_qa_data_dir=args.sft_qa_data_dir,
        prompts=prompts,
        file_list=file_list,
        max_samples=args.max_samples,
        use_bounded=use_bounded,
        bounded_selections_dir=args.bounded_selections_dir if use_bounded else None,
        bounded_tiers=bounded_tiers
    )
    
    # Filter out already processed samples if resuming
    if processed_keys:
        original_len = len(dataset.samples)
        # Key includes tier for bounded mode
        dataset.samples = [s for s in dataset.samples 
                          if (s.file_name, s.step_idx, s.tier) not in processed_keys]
        print(f"Filtered to {len(dataset.samples)} remaining samples (was {original_len})")
    
    if len(dataset) == 0:
        print("No samples to process!")
        return
    
    # Initialize generator and create dataloader based on mode
    if args.if_use_api:
        # API mode: use simple batch iterator
        generator = APIGenerator(
            api_base_url=args.api_base_url,
            api_key=args.api_key,
            model_name=args.api_model_name,
            max_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
            num_samples=args.num_samples,
            max_workers=args.max_workers
        )
        
        # Simple batch iterator for API mode (no torch dependency)
        def batch_iterator(samples, batch_size):
            for i in range(0, len(samples), batch_size):
                yield samples[i:i + batch_size]
        
        dataloader = list(batch_iterator(dataset.samples, args.batch_size))
    else:
        # Local mode: use torch DataLoader
        _import_torch_deps()
        from torch.utils.data import DataLoader as TorchDataLoader
        
        generator = BatchGenerator(
            model_path=args.model_path,
            max_length=args.max_length,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
            repetition_penalty=args.repetition_penalty,
            num_samples=args.num_samples
        )
        
        dataloader = TorchDataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=0,
            collate_fn=collate_fn
        )
    
    # Generate
    print(f"\nStarting generation for {len(dataset)} samples...")
    print(f"Batch size: {args.batch_size}")
    print(f"Output file: {output_file}")
    
    all_results = []
    total_time = 0
    first_save = True  # Track first save to decide write mode
    
    pbar = tqdm(dataloader, desc="Generating")
    for batch_idx, batch in enumerate(pbar):
        start_time = time.time()
        
        try:
            results = generator.generate_batch(batch)
            all_results.extend(results)
            
            batch_time = time.time() - start_time
            total_time += batch_time
            
            # Update progress bar
            samples_per_sec = len(batch) / batch_time
            pbar.set_postfix({
                'samples/s': f'{samples_per_sec:.2f}',
                'total': len(all_results)
            })
            
        except Exception as e:
            print(f"\nError in batch {batch_idx}: {e}")
            # Save individual samples that failed
            for sample in batch:
                result = {
                    'file_name': sample.file_name,
                    'step_idx': sample.step_idx,
                    'sample_idx': 0,
                    'reasoning_trace': '',
                    'generated_hypothesis': '',
                    'gt_hypothesis': sample.gt_hypothesis,
                    'inspiration_title': sample.inspiration_title,
                    'inspiration_abstract': sample.inspiration_abstract,
                    'raw_response': f'ERROR: {str(e)}',
                    'error': True
                }
                # Add bounded composition fields if present
                if sample.tier:
                    result['tier'] = sample.tier
                    result['bounded_similarity'] = sample.bounded_similarity
                    result['gt_inspiration_title'] = sample.gt_inspiration_title
                all_results.append(result)
        
        # Periodic save
        if (batch_idx + 1) % args.save_every == 0:
            # First save: use 'w' if not resuming, 'a' if resuming
            save_mode = 'a' if (is_resuming or not first_save) else 'w'
            save_results(all_results, output_file, mode=save_mode)
            print(f"\nSaved {len(all_results)} results to {output_file}")
            all_results = []
            first_save = False
    
    # Final save
    if all_results:
        save_mode = 'a' if (is_resuming or not first_save) else 'w'
        save_results(all_results, output_file, mode=save_mode)
        print(f"\nSaved final {len(all_results)} results")
    
    # Print summary
    print("\n" + "="*60)
    print("GENERATION COMPLETE")
    print("="*60)
    print(f"Total samples processed: {len(dataset)}")
    print(f"Total time: {total_time:.2f}s")
    print(f"Average time per sample: {total_time/len(dataset):.3f}s")
    print(f"Output saved to: {output_file}")


if __name__ == "__main__":
    main()

