#!/usr/bin/env python3
"""
Ray-integrated VLLM Inference with Multiple Sampling

This script integrates Ray with VLLM to generate multiple samples at different temperatures
for a given dataset.
"""

import os
import glob
import re
import sys
import argparse
import torch
import shelve
import time
import uuid
import json
from typing import List, Dict, Any, Optional
from datasets import Dataset
from transformers import AutoTokenizer

# Set environment variable for vLLM v0
os.environ['VLLM_USE_V1'] = '0'

from vllm import SamplingParams, EngineArgs, LLMEngine, RequestOutput
import ray
from tqdm import tqdm

DTYPE = torch.bfloat16

def parse_args():
    parser = argparse.ArgumentParser(description="Run VLLM inference with multiple sampling via Ray")
    parser.add_argument("--model_name", type=str, required=True, help="Model to use")
    parser.add_argument("--tokenizer_name", type=str, default=None, help="Tokenizer to use")
    parser.add_argument("--output_db", type=str, required=True, help="Output db name")
    parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name")
    parser.add_argument("--dataset_idx", type=int, default=None, help="Dataset idx to subset to")
    parser.add_argument("--field_name", type=str, required=True, help="Problem field name")
    parser.add_argument("--temperatures", type=str, required=True, help="Comma-separated list of temperatures")
    parser.add_argument("--samples_per_temp", type=int, default=1, help="Number of samples to generate per temperature")
    parser.add_argument("--max_tokens", type=int, default=256, help="Max tokens")
    parser.add_argument("--system_prompt", type=str, default="You are a helpful and harmless assistant. You should think step-by-step.")
    parser.add_argument("--batch_size", type=int, default=10, help="Number of requests to process per batch")
    parser.add_argument("--tensor_parallel_size", type=int, default=1, help="vLLM flag, default is 1 GPU per worker, so model has to fit")
    parser.add_argument("--num_gpus_per_worker", type=float, default=1, help="GPUs per worker")
    parser.add_argument("--num_workers", type=int, default=None, help="Number of workers (auto-determined if not specified)")
    parser.add_argument("--ray_address", type=str, default=None, help="Ray cluster address (None starts a local cluster)")
    parser.add_argument("--local", action="store_true", help="Force local Ray initialization (ignores ray_address)")
    parser.add_argument("--seed", type=int, default=42, help="Base seed for sampling (each sample will have a different seed)")

    return parser.parse_args()

@ray.remote(num_gpus=1)
class SamplingWorker:
    """Ray actor that runs VLLM inference with multiple sampling configurations."""
    
    def __init__(
        self, 
        model_name: str,
        tokenizer_name: str,
        system_prompt: str,
        max_tokens: int,
        tensor_parallel_size: int,
        output_db_path: str,
        worker_id: int = 0
    ):
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name
        self.system_prompt = system_prompt
        self.max_tokens = max_tokens
        self.worker_id = worker_id
        self.output_db_path = output_db_path
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.tokenizer_name or self.model_name,
            trust_remote_code=True,
            padding_side="left",
            truncation_side="left"
        )
        
        # Set padding token if needed
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Initialize engine args
        engine_args = EngineArgs(
            model=self.model_name,
            dtype=DTYPE,
            tensor_parallel_size=tensor_parallel_size,
            max_model_len=self.max_tokens,
            gpu_memory_utilization=0.95,
            trust_remote_code=True,
            enable_chunked_prefill=True
        )
        
        # Initialize the LLMEngine
        print(f"Worker {worker_id}: Initializing LLMEngine...")
        self.engine = LLMEngine.from_engine_args(engine_args)

        # Create output DB for this worker
        print(f"Worker {worker_id}: Opening DB...")
        os.makedirs(os.path.dirname(self.output_db_path), exist_ok=True)
        with shelve.open(f"{self.output_db_path}_worker{self.worker_id}", flag="n") as db:
            pass
        
        print(f"Worker {self.worker_id} initialized with model {model_name}")
    
    def format_prompts(self, prompts: List[str]) -> List[str]:
        """Format prompts with system prompt and chat template."""
        if len(self.system_prompt) > 0:
            chat_init = [{'content': self.system_prompt, 'role': 'system'}]
        else:
            chat_init = []
            
        chats = [chat_init + [{'content': prompt, 'role': 'user'}] for prompt in prompts]
        formatted_prompts = self.tokenizer.apply_chat_template(
            chats,
            tokenize=False,
            add_generation_prompt=True,
            padding=False
        )
        
        return formatted_prompts
    
    def process_sampling_batch(
        self,
        sampling_batch: List[Dict[str, Any]],
        prompts: List[str],
        output_db_path: str,
        batch_size: int = 10,
    ) -> Dict[str, Any]:
        """
        Process a batch of sampling configurations.
        
        Args:
            sampling_batch: List of dicts with sampling info (temperature, sample_idx)
            prompts: List of prompts to use for generation
            output_db_path: Path to save results
            batch_size: Number of requests to process in parallel
            
        Returns:
            Dictionary mapping result keys to generated outputs
        """
        start_time = time.time()
        print(f"Worker {self.worker_id} starting batch with {len(sampling_batch)} sampling configurations")
        
        # Process sampling configurations in batches
        config_batches = [sampling_batch[i:i+batch_size] for i in range(0, len(sampling_batch), batch_size)]
        
        # Open local database for results
        with shelve.open(f"{output_db_path}_worker{self.worker_id}") as db:
            results = {}
            
            for batch_idx, batch in enumerate(config_batches):
                batch_start_time = time.time()
                print(f"Worker {self.worker_id} processing batch {batch_idx+1}/{len(config_batches)} with {len(batch)} configs")
                
                # Clear any previous requests
                while self.engine.has_unfinished_requests():
                    self.engine.step()
                
                # Mapping of request IDs to their configurations for result tracking
                request_id_map = {}
                
                # Add each sampling configuration request to the engine
                for config in batch:
                    temperature = config["temperature"]
                    sample_idx = config["sample_idx"]
                    seed = config.get("seed", 42)
                    
                    # Create sampling parameters with the specified temperature
                    sampling_params = SamplingParams(
                        temperature=temperature,
                        top_p=.95,
                        max_tokens=self.max_tokens,
                        seed=seed + sample_idx  # Unique seed for each sample
                    )
                    
                    # Process each prompt with this sampling configuration
                    for prompt_idx, prompt in enumerate(prompts):
                        # Create a unique request ID
                        request_id = f"temp{temperature}_sample{sample_idx}_prompt{prompt_idx}_{uuid.uuid4().hex[:6]}"
                        
                        # Add request to engine
                        self.engine.add_request(
                            request_id,
                            prompt,
                            sampling_params
                        )
                        
                        # Store mapping for later retrieval
                        request_id_map[request_id] = {
                            "temperature": temperature,
                            "sample_idx": sample_idx,
                            "prompt_idx": prompt_idx,
                            "config": config  # Store full config for result key generation
                        }
                
                # Process all requests in this batch
                batch_results = {}
                while self.engine.has_unfinished_requests():
                    # Process the next batch of outputs
                    request_outputs = self.engine.step()
                    
                    # Collect the outputs
                    for output in request_outputs:
                        if output.finished:
                            request_id = output.request_id
                            req_info = request_id_map.get(request_id)
                            
                            if req_info:
                                # Get the configuration info
                                temperature = req_info["temperature"]
                                sample_idx = req_info["sample_idx"]
                                prompt_idx = req_info["prompt_idx"]
                                
                                # Generate result key
                                result_key = f"temp_{temperature}_sample_{sample_idx}"
                                
                                # Initialize lists if needed
                                if result_key not in batch_results:
                                    batch_results[result_key] = [None] * len(prompts)
                                
                                # Store result at the correct prompt index
                                batch_results[result_key][prompt_idx] = output.outputs[0].text
                
                # Update overall results with this batch's results
                for key, outputs in batch_results.items():
                    results[key] = outputs
                    db[key] = outputs
                
                batch_time = time.time() - batch_start_time
                print(f"Worker {self.worker_id} completed batch {batch_idx+1} in {batch_time:.2f} seconds")
            
        total_time = time.time() - start_time
        print(f"Worker {self.worker_id} completed all {len(sampling_batch)} sampling configurations in {total_time:.2f} seconds")
        
        return results

def run_inference_with_ray(config):
    """Main function to run distributed inference with Ray."""
    args = config["args"]
    
    # Parse temperatures
    temperatures = [float(temp) for temp in args.temperatures.split(",")]
    samples_per_temp = args.samples_per_temp
    print(f"Using temperatures: {temperatures} with {samples_per_temp} samples per temperature")
    
    # Initialize Ray if not already started
    if not ray.is_initialized():
        if args.ray_address == "auto" or args.ray_address is None:
            # Start a local Ray instance
            ray.init(include_dashboard=True)
            print("Started a new local Ray instance")
        else:
            # Connect to existing Ray cluster
            ray.init(address=args.ray_address)
            print(f"Connected to Ray cluster at {args.ray_address}")
    print(f"Ray cluster resources: {ray.cluster_resources()}")

    # Load dataset
    dataset = Dataset.load_from_disk(args.dataset_name)
    if args.dataset_idx is not None:
        dataset = dataset.select([args.dataset_idx])
    
    # Determine observation indices to process
    obs_indices = list(range(len(dataset)))

    # Determine number of workers if not specified
    num_workers = args.num_workers or min(
        int(ray.cluster_resources().get("GPU", 1)),
        len(obs_indices)
    )
    
    print(f"Using {num_workers} workers for {len(obs_indices)} observation(s)")
    
    # Create workers
    workers = [
        SamplingWorker.remote(
            model_name=args.model_name,
            tokenizer_name=args.tokenizer_name or args.model_name,
            system_prompt=args.system_prompt,
            max_tokens=args.max_tokens,
            tensor_parallel_size=args.tensor_parallel_size,
            worker_id=i,
            output_db_path=args.output_db
        )
        for i in range(num_workers)
    ]
    
    # Process each observation
    for obs_idx in obs_indices:
        print(f"\nProcessing observation {obs_idx}")
        
        # Get prompt for this observation
        prompts = [dataset[args.field_name][obs_idx]]
        
        # Format prompts once (same for all workers)
        formatted_prompts_ref = workers[0].format_prompts.remote(prompts)
        formatted_prompts = ray.get(formatted_prompts_ref)
        
        # Prepare sampling configurations
        sampling_configs = []
        for temp_idx, temperature in enumerate(temperatures):
            for sample_idx in range(samples_per_temp):
                config = {
                    "temperature": temperature,
                    "sample_idx": sample_idx,
                    "seed": args.seed + (temp_idx * samples_per_temp)
                }
                sampling_configs.append(config)
        
        # Create balanced batches for workers
        worker_batches = [[] for _ in range(num_workers)]
        for i, config in enumerate(sampling_configs):
            worker_idx = i % num_workers
            worker_batches[worker_idx].append(config)
        
        # Distribute batches to workers
        tasks = []
        for worker_idx, batch in enumerate(worker_batches):
            if not batch:  # Skip empty batches
                continue
            
            # Further split the batch based on batch_size
            for i in range(0, len(batch), args.batch_size):
                sub_batch = batch[i:i + args.batch_size]
                if sub_batch:  # Skip empty sub-batches
                    tasks.append(
                        workers[worker_idx].process_sampling_batch.remote(
                            sampling_batch=sub_batch,
                            prompts=formatted_prompts,
                            output_db_path=f"{args.output_db}_obs{obs_idx}",
                            batch_size=args.batch_size
                        )
                    )
        
        # Track progress
        with tqdm(total=len(tasks)) as pbar:
            while tasks:
                # Use ray.wait to process results as they complete
                done_ids, tasks = ray.wait(tasks, num_returns=1)
                # Get the result
                result = ray.get(done_ids[0])
                pbar.update(1)
        
        # Combine results from worker databases for this observation
        output_db_path = f"{args.output_db}_obs{obs_idx}"
        print(f"Combining results for observation {obs_idx} from {num_workers} workers...")
            
        with shelve.open(output_db_path, flag="n") as combined_db:
            for worker_id in range(num_workers):
                worker_db_path = f"{output_db_path}_worker{worker_id}"
                if os.path.exists(f"{worker_db_path}.db"):
                    with shelve.open(worker_db_path) as worker_db:
                        for key, value in worker_db.items():
                            combined_db[key] = value
            print(f"Combined results saved to {output_db_path}")
    
    print("All inference tasks completed!")

def concatenate_shelve_dbs(path=".", output_db_name="results_db", base_pattern="results_db_obs"):
    # Construct the full pattern with the provided path
    full_base_pattern = os.path.join(path, base_pattern)
    
    # Find all shelve db files matching the exact pattern
    base_files = set()
    
    # Look for files with various possible extensions
    for extension in ['', '.db', '.dat', '.bak', '.dir']:
        pattern = f"{full_base_pattern}[0-9]*{extension}"
        for file in glob.glob(pattern):
            # Extract the base name without extension
            base_name = file.rsplit('.', 1)[0] if '.' in file else file
            
            # Only add if it matches our exact pattern (base_pattern followed by digits only)
            # We need to use the basename (without the path) for the regex check
            file_basename = os.path.basename(base_name)
            pattern_basename = os.path.basename(base_pattern)
            if re.match(f"{pattern_basename}[0-9]+$", file_basename):
                base_files.add(base_name)
    
    # Sort the base files by their index number
    base_files = sorted(base_files, key=lambda x: int(re.search(f"{pattern_basename}([0-9]+)", os.path.basename(x)).group(1)))
    
    print(f"Found {len(base_files)} database files to process:")
    for file in base_files:
        print(f"  - {file}")
    
    # Create the output db in the specified path
    output_db_path = os.path.join(path, output_db_name)
    
    # First pass: collect all keys and their values from all dbs
    consolidated_data = {}
    
    for base_file in base_files:
        print(f"Reading {base_file}...")
        try:
            with shelve.open(base_file, flag='r') as db:
                # Debug: Print some keys and values from each db
                print(f"Database {base_file} has {len(db.keys())} keys")
                for i, key in enumerate(db.keys()):
                    if i < 2:  # Print details for first 2 keys for debugging
                        print(f"  Sample key: {key}, Value type: {type(db[key])}, Length: {len(db[key]) if isinstance(db[key], list) else 'not a list'}")
                    
                    # Initialize this key in our consolidated data if not already there
                    if key not in consolidated_data:
                        consolidated_data[key] = []
                    
                    # Get the value and explicitly make a copy to avoid reference issues
                    value = db[key]
                    if isinstance(value, list):
                        consolidated_data[key].extend(value[:])  # Make a copy of the list
                    else:
                        print(f"Warning: Value for key '{key}' in {base_file} is not a list. Type: {type(value)}")
                        # Try to handle non-list values by wrapping them
                        consolidated_data[key].append(value)
        except Exception as e:
            print(f"Error processing {base_file}: {str(e)}")
    
    # Now write the consolidated data to the output db
    print(f"Writing consolidated data to {output_db_path}...")
    try:
        with shelve.open(output_db_path, 'c') as results_db:
            for key, values in consolidated_data.items():
                results_db[key] = values
                
        # Verify the results
        with shelve.open(output_db_path, flag='r') as results_db:
            key_count = len(results_db.keys())
            print(f"Total keys in consolidated db: {key_count}")
            sample_keys = list(results_db.keys())[:3]  # Show first 3 keys
            for key in sample_keys:
                value = results_db[key]
                print(f"Key: {key}, Type: {type(value)}, Number of responses: {len(value) if isinstance(value, list) else 'not a list'}")
                if isinstance(value, list) and len(value) > 0:
                    print(f"  First response type: {type(value[0])}")
    except Exception as e:
        print(f"Error writing or verifying output database: {str(e)}")
    
    print(f"Consolidation complete. Saved to '{output_db_path}'")


def main():
    # Parse command-line arguments
    args = parse_args()
    
    # Set up config
    config = {
        "args": args
    }
    
    # Handle local mode option
    if args.local:
        print("Forcing local Ray initialization")
        ray.init(include_dashboard=True, ignore_reinit_error=True)

    # Save metadata
    output_dir = os.path.dirname(args.output_db)
    metadata = vars(args)
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, f"metadata.json"), 'w') as f:
        json.dump(metadata, f, indent=2)
    
    # Execute inference with Ray
    run_inference_with_ray(config)

    # create one single db for all results
    concatenate_shelve_dbs(path=output_dir)
    
    # Shutdown Ray when done
    if ray.is_initialized():
        ray.shutdown()
        print("Ray has been shut down")


if __name__ == "__main__":
    main()
