import os
import json
import pickle
import torch
from vllm import LLM, SamplingParams
from tqdm import tqdm
import time
import argparse
from threading import Lock
from typing import Dict, List, Any, Union

# Configuration
NUM_RESPONSES = 20
SAVE_EVERY_N_QUESTIONS = 5  # Save checkpoint every N questions
BATCH_SIZE_QUESTIONS = 2

# Save lock for thread safety
save_lock = Lock()

# Prompt templates
PROMPT_TEMPLATE = """Please structure your respond to the last INPUT_OBJECT based on the CONTEXT with OUTPUT_OBJECT according to OUTPUT_TYPE.

INSTRUCTIONS:
- Only output an object like `ReasoningGraph(...)`. No reasoning, explanation or thinking.
- Do NOT define or repeat any class or function.
- ONLY produce an OUTPUT_OBJECT that instantiates the OUTPUT_TYPE.
- The output must be valid Python using the given type names.
- Do NOT generate code, explanation, or helper variables.

INPUT_OBJECT:
  1 + 1 =

OUTPUT_TYPE:
  Answer

  ```python
  class Answer:
    final_answer: int
  ```

OUTPUT_OBJECT:
  ```python
  Answer(
    final_answer=2
  )
  ```

INPUT_OBJECT:
{question}

CONTEXT:
{context}

OUTPUT_TYPE:
ReasoningGraph

```python
class ReasoningNode:
  id: int
  description: str
  output: Union[int, float, str]
  depends_on: list[int]

class ReasoningGraph:
  nodes: list[ReasoningNode]
  final_answer: Union[int, float, str]

OUTPUT_OBJECT:
"""

PROMPT_TEMPLATE_SCALE = """
{question}
"""

class BaseModel:
    """Base class for all models"""
    def __init__(self, model_name: str):
        self.model_name = model_name
        print(f"Loading model: {model_name}")
        print("This may take a few minutes...")
        
    def generate_with_logprobs(self, prompts, temperature=0.8, max_tokens=1024, top_p=0.95):
        raise NotImplementedError("Must be implemented by subclass")

class LlamaModel(BaseModel):
    """Llama 3.1 8B Model"""
    def __init__(self, model_name="meta-llama/Llama-3.1-8B-Instruct"):
        super().__init__(model_name)
        
        # vLLM configuration for Llama
        self.llm = LLM(
            model=model_name,
            trust_remote_code=True,
            dtype="float16",
            #tensor_parallel_size=torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1,
            max_model_len=2048,
            gpu_memory_utilization=0.95,
            disable_log_stats=True
        )
        
        print("Model loaded successfully!")
    
    def generate_with_logprobs(self, prompts, temperature=0.8, max_tokens=1024, top_p=0.95):
        """Generate responses with logprobs for Llama"""
        if isinstance(prompts, str):
            prompts = [prompts]
        
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            top_k=50,
            max_tokens=max_tokens,
            stop=["<|eot_id|>", "<|end_of_text|>"],
            logprobs=1
        )
        
        outputs = self.llm.generate(prompts, sampling_params)
        
        results = []
        for output in outputs:
            output_obj = output.outputs[0]
            
            logprobs = []
            if hasattr(output_obj, 'logprobs') and output_obj.logprobs:
                for logprob_data in output_obj.logprobs:
                    if isinstance(logprob_data, dict):
                        token_logprob = list(logprob_data.values())[0] if logprob_data else 0.0
                        logprobs.append(token_logprob)
                    else:
                        logprobs.append(0.0)
            
            results.append({
                "response": output_obj.text,
                "token_ids": output_obj.token_ids if hasattr(output_obj, 'token_ids') else [],
                "logprobs": logprobs
            })
        
        return results

class DeepSeekModel(BaseModel):
    """DeepSeek R1 Distill Model"""
    def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"):
        super().__init__(model_name)
        
        # vLLM configuration for DeepSeek
        self.llm = LLM(
            model=model_name,
            trust_remote_code=True,
            dtype="float16",
            #tensor_parallel_size=torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1,
            max_model_len=2048,
            gpu_memory_utilization=0.98,
            disable_log_stats=True
        )
        
        print("Model loaded successfully!")
    
    def generate_with_logprobs(self, prompts, temperature=0.8, max_tokens=1024, top_p=0.95):
        """Generate responses with logprobs for DeepSeek"""
        if isinstance(prompts, str):
            prompts = [prompts]
        
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            top_k=50,
            max_tokens=max_tokens,
            stop=["</s>", "Human:", "User:"],
            logprobs=1
        )
        
        outputs = self.llm.generate(prompts, sampling_params)
        
        results = []
        for output in outputs:
            output_obj = output.outputs[0]
            
            logprobs = []
            if hasattr(output_obj, 'logprobs') and output_obj.logprobs:
                for logprob_data in output_obj.logprobs:
                    if isinstance(logprob_data, dict):
                        token_logprob = list(logprob_data.values())[0] if logprob_data else 0.0
                        logprobs.append(token_logprob)
                    else:
                        logprobs.append(0.0)
            
            results.append({
                "response": output_obj.text,
                "token_ids": output_obj.token_ids if hasattr(output_obj, 'token_ids') else [],
                "logprobs": logprobs
            })
        
        return results

class Phi4Model(BaseModel):
    """Phi-4 Reasoning Model"""
    def __init__(self, model_name="microsoft/Phi-4-reasoning"):
        super().__init__(model_name)
        
        # vLLM configuration for Phi-4
        self.llm = LLM(
            model=model_name,
            trust_remote_code=True,
            dtype="float16",
            #tensor_parallel_size=torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1,
            max_model_len=2048, 
            gpu_memory_utilization=0.98
        )
        
        print("Model loaded successfully!")
    
    def generate_with_logprobs(self, prompts, temperature=0.8, max_tokens=2048, top_p=0.95):
        """Generate responses with logprobs for Phi-4"""
        if isinstance(prompts, str):
            prompts = [prompts]
        
        # Apply Phi-4 specific prompt template
        formatted_prompts = []
        for prompt in prompts:
            # Phi-4 uses ChatML format with special system prompt
            system_prompt = """<|im_start|>system<|im_sep|> You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions.<|im_end|>"""
            
            formatted_prompt = f"{system_prompt}\n<|im_start|>user<|im_sep|>{prompt}<|im_end|>\n<|im_start|>assistant<|im_sep|>"
            formatted_prompts.append(formatted_prompt)
        
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            top_k=50,
            max_tokens=max_tokens,
            stop=["<|im_end|>"],
            logprobs=1
        )
        #print(formatted_prompts)
        outputs = self.llm.generate(formatted_prompts, sampling_params)
        
        results = []
        for output in outputs:
            output_obj = output.outputs[0]
            
            logprobs = []
            if hasattr(output_obj, 'logprobs') and output_obj.logprobs:
                for logprob_data in output_obj.logprobs:
                    if isinstance(logprob_data, dict):
                        token_logprob = list(logprob_data.values())[0] if logprob_data else 0.0
                        logprobs.append(token_logprob)
                    else:
                        logprobs.append(0.0)
            
            results.append({
                "response": output_obj.text,
                "token_ids": output_obj.token_ids if hasattr(output_obj, 'token_ids') else [],
                "logprobs": logprobs
            })
        
        return results

def check_gpu_availability():
    """Check and print GPU configuration"""
    print(f"GPU Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
            print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")

def save_results_sync(results: Dict, path: str):
    """Synchronously save results to pickle file"""
    with save_lock:
        # Ensure directory exists
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'wb') as f:
            pickle.dump(results, f)

def save_progress_tracker(processed_ids: set, path: str):
    """Save the set of processed IDs for resuming"""
    with save_lock:
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'wb') as f:
            pickle.dump(processed_ids, f)

def load_progress_tracker(path: str) -> set:
    """Load the set of processed IDs"""
    if os.path.exists(path):
        with open(path, 'rb') as f:
            return pickle.load(f)
    return set()

def load_existing_results(model_type: str) -> tuple[Dict, set]:
    """Check and load existing results for the specified model"""
    output_dir = 
    temp_dir = os.path.join(output_dir, 'temp/')
    
    # Try to load the progress tracker
    progress_file = os.path.join(temp_dir, f'{model_type}_processed_ids.pkl')
    processed_ids = load_progress_tracker(progress_file)
    
    if not os.path.exists(temp_dir):
        return {}, processed_ids
    
    # Find latest checkpoint file
    checkpoint_files = [f for f in os.listdir(temp_dir) 
                       if f.startswith(f'{model_type}_morehopqa_responses_checkpoint_')
                       and f.endswith('.pkl')]
    
    if not checkpoint_files:
        return {}, processed_ids
    
    # Find highest checkpoint number
    checkpoint_numbers = []
    for f in checkpoint_files:
        try:
            num = int(f.split('_checkpoint_')[1].split('.pkl')[0])
            checkpoint_numbers.append(num)
        except:
            continue
    
    if not checkpoint_numbers:
        return {}, processed_ids
    
    latest_checkpoint = max(checkpoint_numbers)
    
    # Load latest results
    latest_file = f'{model_type}_morehopqa_responses_checkpoint_{latest_checkpoint}.pkl'
    checkpoint_path = os.path.join(temp_dir, latest_file)
    
    print(f"Loading checkpoint from: {checkpoint_path}")
    with open(checkpoint_path, 'rb') as f:
        results = pickle.load(f)
    
    # Update processed_ids from results
    processed_ids = set(results.keys())
    
    print(f"Found existing results: {len(processed_ids)} questions completed")
    print(f"Processed IDs sample: {list(processed_ids)[:5]}")
    
    return results, processed_ids

def process_dataset(data: List[Dict], model: BaseModel, model_type: str, use_scale_prompt: bool = False):
    """Process the MoreHopQA dataset in batches using dataset IDs"""
    
    # Load existing results if any
    existing_results, processed_ids = load_existing_results(model_type)
    results = existing_results.copy()
    
    # Filter out already processed data
    remaining_data = [item for item in data if item['_id'] not in processed_ids]
    
    print(f"Total questions: {len(data)}")
    print(f"Already processed: {len(processed_ids)}")
    print(f"Remaining to process: {len(remaining_data)}")
    
    if len(remaining_data) == 0:
        print("All questions have been processed!")
        return results
    
    # Create output directory
    output_dir = 
    temp_dir = os.path.join(output_dir, 'temp/')
    os.makedirs(temp_dir, exist_ok=True)
    
    # Process in batches
    total_batches = (len(remaining_data) + BATCH_SIZE_QUESTIONS - 1) // BATCH_SIZE_QUESTIONS
    
    # Track total processed count for checkpointing
    total_processed = len(processed_ids)
    
    for batch_idx in tqdm(range(total_batches), desc="Processing batches"):
        start_idx = batch_idx * BATCH_SIZE_QUESTIONS
        end_idx = min(start_idx + BATCH_SIZE_QUESTIONS, len(remaining_data))
        batch_data = remaining_data[start_idx:end_idx]
        
        # Prepare batch prompts
        batch_prompts = []
        batch_info = []  # (data_id, example, response_idx)
        
        for example in batch_data:
            data_id = example['_id']
            
            # # Choose prompt template
            # if use_scale_prompt:
            #     prompt = PROMPT_TEMPLATE_SCALE.format(question=example['question'])
            # else:
            #     # Convert context to string if it's a list
            context_str = str(example['context']) if isinstance(example['context'], list) else example['context']
            prompt = PROMPT_TEMPLATE.format(
                question=example['question'],
                context=context_str
            )
            
            # Generate NUM_RESPONSES prompts for each question
            for resp_idx in range(NUM_RESPONSES):
                batch_prompts.append(prompt)
                batch_info.append((data_id, example, resp_idx))
        
        try:
            # Batch generation
            batch_outputs = model.generate_with_logprobs(batch_prompts)
            
            # Organize results
            for i, (data_id, example, resp_idx) in enumerate(batch_info):
                # Use the dataset's _id as the key
                key = data_id
                
                # Initialize result dictionary
                if key not in results:
                    results[key] = {
                        'id': data_id,
                        'question': example['question'],
                        'context': example['context'],
                        'answer': example.get('answer', 'N/A'),
                        'response_ids': [],
                        'responses': [],
                        'token_ids': [],
                        'logprobs': []
                    }
                    processed_ids.add(data_id)
                
                # Add response
                output = batch_outputs[i]
                results[key]['responses'].append(output['response'])
                results[key]['token_ids'].append(output['token_ids'])
                results[key]['logprobs'].append(output['logprobs'])
                results[key]['response_ids'].append(resp_idx)
            
            total_processed += len(batch_data)
            
            # Save checkpoint periodically
            if (batch_idx + 1) % 5 == 0 or (batch_idx + 1) == total_batches:
                # Save results checkpoint
                checkpoint_path = os.path.join(
                    temp_dir, 
                    f'{model_type}_morehopqa_responses_checkpoint_{total_processed}.pkl'
                )
                save_results_sync(results, checkpoint_path)
                
                # Save progress tracker
                progress_file = os.path.join(temp_dir, f'{model_type}_processed_ids.pkl')
                save_progress_tracker(processed_ids, progress_file)
                
                print(f"\nSaved checkpoint: {total_processed} total questions processed")
                print(f"Checkpoint saved to: {checkpoint_path}")
                
        except Exception as e:
            print(f"\nError processing batch {batch_idx}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    return results

def main():
    """Main function to run the processing pipeline"""
    parser = argparse.ArgumentParser(description='Process MoreHopQA dataset with multiple models')
    parser.add_argument(
        '--model', 
        type=str, 
        choices=['llama', 'deepseek', 'phi4'],
        required=True,
        help='Choose which model to use: llama (Llama 3.1 8B), deepseek (DeepSeek R1), or phi4 (Phi-4 Reasoning)'
    )
    parser.add_argument(
        '--scale-prompt',
        action='store_true',
        help='Use scale prompt template (question only, no context)'
    )
    parser.add_argument(
        '--num-responses',
        type=int,
        default=20,
        help='Number of responses per question (default: 20)'
    )
    parser.add_argument(
        '--batch-size',
        type=int,
        default=2,
        help='Number of questions to process per batch (default: 2)'
    )
    parser.add_argument(
        '--dataset-path',
        type=str,
        default='',
        help='Path to the MoreHopQA dataset JSON file'
    )
    
    args = parser.parse_args()
    
    # Update global configs based on arguments
    global NUM_RESPONSES, BATCH_SIZE_QUESTIONS
    NUM_RESPONSES = args.num_responses
    BATCH_SIZE_QUESTIONS = args.batch_size
    
    # Check GPU availability
    check_gpu_availability()
    
    # Load MoreHopQA dataset
    print(f"Loading MoreHopQA dataset from: {args.dataset_path}")
    with open(args.dataset_path, 'r') as f:
        source_data = json.load(f)
    
    print(f"Loaded {len(source_data)} questions from MoreHopQA dataset")
    
    # Verify that dataset has _id field
    if source_data and '_id' not in source_data[0]:
        print("WARNING: Dataset doesn't have '_id' field. Using index as ID.")
        for idx, item in enumerate(source_data):
            item['_id'] = f"morehopqa_{idx}"
    
    # Show sample IDs
    sample_ids = [item['_id'] for item in source_data[:5]]
    print(f"Sample IDs from dataset: {sample_ids}")
    
    # Initialize selected model
    if args.model == 'llama':
        model = LlamaModel()
        model_type = 'llama31_8b'
    elif args.model == 'deepseek':
        model = DeepSeekModel()
        model_type = 'deepseek_r1'
    elif args.model == 'phi4':
        model = Phi4Model()
        model_type = 'phi4'
    
    print(f"Processing dataset with {NUM_RESPONSES} responses per question")
    print(f"Using {model_type} model with vLLM")
    print(f"Prompt type: {'Scale (question only)' if args.scale_prompt else 'Full (with context)'}")
    print(f"Batch size: {BATCH_SIZE_QUESTIONS} questions per batch")
    
    start_time = time.time()
    
    # Process dataset
    results = process_dataset(source_data, model, model_type, args.scale_prompt)
    
    # Save final results
    output_dir = f''
    os.makedirs(output_dir, exist_ok=True)
    
    prompt_suffix = '_scale' if args.scale_prompt else '_full'
    final_output_path = os.path.join(output_dir, f'{model_type}_morehopqa_responses{prompt_suffix}.pkl')
    save_results_sync(results, final_output_path)
    
    end_time = time.time()
    total_time = end_time - start_time
    
    print(f"\nProcessing completed in {total_time:.2f} seconds")
    if len(results) > 0:
        print(f"Average time per question: {total_time/len(results):.2f} seconds")
    print(f"Saved all results to {final_output_path}")
    
    # Print statistics
    total_responses = sum(len(r['responses']) for r in results.values())
    print(f"Total responses generated: {total_responses}")
    print(f"Questions processed: {len(results)}")
    
    # Show sample of processed IDs
    processed_ids = list(results.keys())[:10]
    print(f"Sample of processed IDs: {processed_ids}")

if __name__ == "__main__":
    main()