#!/usr/bin/env python3
"""
Ray-integrated VLLM Vector Parallel Inference with Dynamic Multi-LoRA Support

This script integrates Ray with VLLM for vector parallel inference using a single LoRA adapter
that is dynamically modified at inference time, allowing for efficient batching across different 
steering vectors in a single engine step without creating multiple LoRA files on disk.
"""

'''
Example usage:
python in_memory_multilora.py  --model_name "Qwen/Qwen1.5-7B-Chat" --vecs_name "../runs/train/scratch/_V.pt"  --vector_info_json "../runs/train/scratch/_vector_info.json" --output_db "../runs/inference/scratch/results_db"   --lora_dir "/dev/shm/" --dataset_name "../data/bomb"   --dataset_idx 0 --system_prompt "You are a helpful assistant." --field_name "prompt"   --R "100.0" --max_tokens 512 --signs "+"   --source_layer_idx 9 --max_loras 8 --max_cpu_loras 64 --lora_batch_size 8 --batch_size 64 --num_workers 1
'''

import functools
import os
import argparse
import torch
import shelve
import numpy as np
import time
import shutil
import uuid
from datetime import datetime, timedelta
import json
from typing import List, Dict, Any, Optional, Tuple
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# Set environment variable for vLLM v0
os.environ['VLLM_USE_V1'] = '0'

from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from vllm import SamplingParams, EngineArgs, LLMEngine, RequestOutput
from vllm.lora.request import LoRARequest
import ray
from ray import tune
from ray.air import session
from tqdm import tqdm
import threading
import fcntl
import tempfile

DTYPE = torch.bfloat16

torch.use_deterministic_algorithms(True)

def parse_args():
    parser = argparse.ArgumentParser(description="Run VLLM inference with dynamic multi-LoRA support via Ray")
    parser.add_argument("--model_name", type=str, required=True, help="Model to use")
    parser.add_argument("--tokenizer_name", type=str, default=None, help="Tokenizer to use")
    parser.add_argument("--vecs_name", type=str, default=None, help="Vectors filename")
    parser.add_argument("--vecs_dir", type=str, default=None, help="Directory containing per-observation vector files")
    parser.add_argument("--output_db", type=str, required=True, help="Output db name")
    parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name")
    parser.add_argument("--dataset_idx", type=int, default=None, help="Dataset idx to subset to")
    parser.add_argument("--field_name", type=str, required=True, help="Problem field name")
    parser.add_argument("--steering_multipliers", type=str, required=True, help="Comma-separated list of steering multipliers")
    parser.add_argument("--signs", type=str, default="+", help="Comma-separated signs for vectors (+/-)")
    parser.add_argument("--max_tokens", type=int, default=256, help="Max tokens")
    parser.add_argument("--system_prompt", type=str, default="You are a helpful and harmless assistant. You should think step-by-step.")
    parser.add_argument("--source_layer_idx", type=int, default=9, help="Source layer index")
    parser.add_argument("--batch_size", type=int, default=10, help="Number of vectors to process per task")
    parser.add_argument("--lora_batch_size", type=int, default=2, help="Number of LoRAs to process in one batch")
    parser.add_argument("--tensor_parallel_size", type=int, default=1, help="vLLM flag, default is 1 GPU per worker, so model has to fit")
    parser.add_argument("--num_gpus_per_worker", type=float, default=1, help="GPUs per worker")
    parser.add_argument("--num_workers", type=int, default=None, help="Number of workers (auto-determined if not specified)")
    parser.add_argument("--ray_address", type=str, default=None, help="Ray cluster address (None starts a local cluster)")
    parser.add_argument("--local", action="store_true", help="Force local Ray initialization (ignores ray_address)")
    parser.add_argument("--num_vecs", type=int, default=None, help="Limit inference to first N vectors")
    parser.add_argument("--lora_dir", type=str, default=None, help="Directory to store dummy LoRA adapter (uses temp dir if None)")
    parser.add_argument("--max_lora_rank", type=int, default=8, help="Maximum LoRA rank")
    parser.add_argument("--max_loras", type=int, default=2, help="Maximum number of LoRAs loaded in GPU memory")
    parser.add_argument("--max_cpu_loras", type=int, default=16, help="Maximum number of LoRAs to cache in CPU memory")
    parser.add_argument("--vector_info_json", type=str, default=None, help="Path to vector info JSON file generated by train_dct.py")
    parser.add_argument("--seed", type=int, default=None, help="Seed for sampling")
    parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
    parser.add_argument("--for_each", action="store_true", help="Use per-observation vector files")
    # Adding new parameters for data range
    parser.add_argument("--data_start", type=int, default=None, help="Start index for observations to process (only in for_each mode)")
    parser.add_argument("--data_end", type=int, default=None, help="End index for observations to process, exclusive (only in for_each mode)")

    return parser.parse_args()

# Parse args at the module level so they can inform `@ray.remote` below.
args = parse_args()

class DummyLoRAManager:
    """Creates a single dummy LoRA adapter on disk for dynamic modification at runtime."""
    
    def __init__(
        self, 
        model_name: str, 
        source_layer_idx: int,
        lora_dir: Optional[str] = None,
        vector_info: Optional[List[str]] = None
    ):
        self.model_name = model_name
        self.source_layer_idx = source_layer_idx
        self.vector_info = vector_info
        
        # Create a temporary directory if no lora_dir is provided
        if lora_dir is None:
            self.temp_dir = tempfile.mkdtemp(prefix="vllm_lora_")
            self.lora_dir = self.temp_dir
            print(f"Created temporary directory for dummy LoRA adapter: {self.lora_dir}")
            self.using_temp_dir = True
        else:
            self.temp_dir = tempfile.mkdtemp(prefix=lora_dir)
            self.lora_dir = self.temp_dir
            print(f"Created temporary directory for dummy LoRA adapter: {self.lora_dir}")
            self.using_temp_dir = True
        
        # Path for the dummy LoRA adapter
        self.dummy_lora_path = os.path.join(self.lora_dir, "dummy_lora")
    
    def create_dummy_lora_adapter(self):
        """Create a single dummy LoRA adapter that will be modified at runtime."""
        print(f"Creating dummy LoRA adapter...")
        
        # Load model with DTYPE to match the rest of the code
        print("Loading model to create dummy LoRA adapter...")
        model = AutoModelForCausalLM.from_pretrained(
            self.model_name, 
            torch_dtype=DTYPE,
            device_map="cpu",
            trust_remote_code=True
        )
        
        # Save the base model configuration to ensure model_type is properly preserved
        print("Saving base model config...")
        base_model_config_path = os.path.join(self.lora_dir, "base_model_config")
        os.makedirs(base_model_config_path, exist_ok=True)
        model.config.save_pretrained(base_model_config_path)
        
        # Configure the target module path
        target_module = f"model.layers.{self.source_layer_idx}.mlp.down_proj"
        
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=1,
            lora_alpha=1.0,
            lora_dropout=0.0,
            bias="none",
            lora_bias=True,  # Important to enable bias
            target_modules=[target_module],
            base_model_name_or_path=self.model_name
        )
        
        # Create a PEFT model
        peft_model = get_peft_model(model, peft_config)
        
        # Initialize with zero bias - will be modified at runtime
        with torch.no_grad():
            module_path_parts = target_module.split('.')
            layer_idx = int(module_path_parts[2])
            peft_model.model.model.layers[layer_idx].mlp.down_proj.lora_B["default"].bias.data = torch.zeros_like(
                peft_model.model.model.layers[layer_idx].mlp.down_proj.lora_B["default"].bias.data
            )
        
        # Save the dummy LoRA adapter
        peft_model.save_pretrained(self.dummy_lora_path)
        
        # Copy the base model configuration to ensure model_type is preserved
        src_config = os.path.join(base_model_config_path, "config.json")
        dst_config = os.path.join(self.dummy_lora_path, "config.json")
        if not os.path.exists(dst_config):
            shutil.copy(src_config, dst_config)
        
        print(f"Created dummy LoRA adapter at {self.dummy_lora_path}")
        
        # Free up memory
        del model
        del peft_model
        torch.cuda.empty_cache()
        
        return self.dummy_lora_path
    
    def get_dummy_lora_requests(self, num_requests: int) -> List[Tuple[int, LoRARequest]]:
        """Get dummy LoRA requests, each with a sequential ID starting from 1 but pointing to the same adapter."""
        requests = []
        for i in range(num_requests):
            # Use sequential adapter IDs starting from 1
            adapter_id = i + 1
            
            # Generate a unique ID for this request
            unique_id = uuid.uuid4().hex[:8]
            lora_id = f"dummy_lora_{adapter_id}_{unique_id}"
            
            # Create request pointing to the dummy adapter
            lora_request = LoRARequest(lora_id, adapter_id, self.dummy_lora_path)
            requests.append((adapter_id, lora_request))
        
        return requests
    
    def cleanup(self):
        """Clean up temporary directory if one was created."""
        if hasattr(self, 'using_temp_dir') and self.using_temp_dir and hasattr(self, 'temp_dir'):
            try:
                shutil.rmtree(self.temp_dir)
                print(f"Removed temporary directory: {self.temp_dir}")
            except Exception as e:
                print(f"Error removing temporary directory {self.temp_dir}: {e}")


class VectorManager:
    """Manages vector loading and distribution to Ray workers."""
    
    def __init__(
        self, 
        vectors_path: Optional[str] = None, 
        vectors_dir: Optional[str] = None,
        for_each: bool = False,
        num_vecs: Optional[int] = None
    ):
        self.vectors_path = vectors_path
        self.vectors_dir = vectors_dir
        self.for_each = for_each
        self.num_vecs = num_vecs
        self._vectors = None
        self._per_obs_vectors = {}
        
    def load_vectors(self, obs_idx: Optional[int] = None):
        """
        Load vectors based on the specified configuration.
        
        Args:
            obs_idx: If for_each=True, the observation index to load vectors for
                    
        Returns:
            Loaded vector tensor
        """
        # For per-observation vector files
        if self.for_each:
            if obs_idx is None:
                raise ValueError("obs_idx must be specified when for_each=True")
                
            # Check if we've already loaded these vectors
            if obs_idx in self._per_obs_vectors:
                return self._per_obs_vectors[obs_idx]
                
            # Construct path to this observation's vector file
            vector_file = os.path.join(self.vectors_dir, f"_{obs_idx}/_V.pt")
            
            if not os.path.exists(vector_file):
                raise FileNotFoundError(f"Vector file not found: {vector_file}")
                
            # Load the vectors for this observation
            vectors = torch.load(vector_file, map_location="cpu").to(DTYPE)
            
            # Apply num_vecs limit if specified
            if self.num_vecs is not None and self.num_vecs < vectors.shape[1]:
                vectors = vectors[:, :self.num_vecs]
                print(f"Loaded {self.num_vecs} vectors for observation {obs_idx} from {vector_file}")
            else:
                print(f"Loaded all {vectors.shape[1]} vectors for observation {obs_idx} from {vector_file}")
                
            # Cache for future use
            self._per_obs_vectors[obs_idx] = vectors
            return vectors
            
        # For shared vector file
        else:
            if self._vectors is None:
                if self.vectors_path is None:
                    raise ValueError("vectors_path must be specified when for_each=False")
                    
                full_vectors = torch.load(self.vectors_path, map_location="cpu").to(DTYPE)
                
                # If num_vecs is specified, only use the first num_vecs vectors
                if self.num_vecs is not None and self.num_vecs < full_vectors.shape[1]:
                    self._vectors = full_vectors[:, :self.num_vecs]
                    print(f"Loaded {self.num_vecs} vectors from matrix with shape {full_vectors.shape}")
                else:
                    self._vectors = full_vectors
                    print(f"Loaded all {full_vectors.shape[1]} vectors from matrix with shape {full_vectors.shape}")
                    
            return self._vectors
    
    def get_vector_batches(self, batch_size: int, obs_idx: Optional[int] = None) -> List[List[int]]:
        """
        Split vectors into batches for processing.
        
        Args:
            batch_size: Size of each batch
            obs_idx: If for_each=True, the observation index to load vectors for
        
        Returns:
            List of batches, where each batch is a list of vector indices
        """
        vectors = self.load_vectors(obs_idx=obs_idx if self.for_each else None)
        total_vecs = vectors.shape[1]
        
        # Create batches of vector indices
        all_indices = list(range(total_vecs))
        batches = [all_indices[i:i+batch_size] for i in range(0, total_vecs, batch_size)]
        
        return batches

def modify_lora_bias(engine, adapter_id: int, layer_idx: int, bias: torch.Tensor):
    """
    Modify the bias of a specific LoRA adapter in the vLLM engine.
    
    Args:
        engine: The vLLM LLMEngine instance
        adapter_id: ID of the adapter to modify
        layer_idx: Index of the layer to modify
        bias: The new bias tensor to apply
    """
    # Get the model from the engine
    model = engine.model
    
    # Find the specific LoRA layer to modify
    for name, module in model.named_modules():
        if isinstance(module, type) and hasattr(module, 'base_layer') and name.endswith(f'layers.{layer_idx}.mlp.down_proj'):
            if hasattr(module, 'lora_bias_stacked'):
                # Set the bias in the lora_bias_stacked tensor at the correct adapter_id index
                with torch.no_grad():
                    module.lora_bias_stacked[0][adapter_id, 0, :bias.shape[0]].copy_(bias.to(module.device), non_blocking=True)
                return True
    
    return False

@ray.remote(num_gpus=args.tensor_parallel_size)
class DynamicLoRAWorker:
    """Ray actor that runs VLLM inference with dynamic LoRA bias modification."""
    
    def __init__(
        self, 
        model_name: str,
        tokenizer_name: str,
        system_prompt: str,
        max_tokens: int,
        tensor_parallel_size: int,
        output_db_path: str,
        dummy_lora_path: str,
        source_layer_idx: int,
        max_loras: int = 2,
        max_lora_rank: int = 8,
        max_cpu_loras: int = 16,
        worker_id: int = 0
    ):
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name
        self.system_prompt = system_prompt
        self.max_tokens = max_tokens
        self.worker_id = worker_id
        self.output_db_path = output_db_path
        self.dummy_lora_path = dummy_lora_path
        self.source_layer_idx = source_layer_idx
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.tokenizer_name or self.model_name,
            trust_remote_code=True,
            padding_side="left",
            truncation_side="left"
        )
        
        # Set padding token if needed
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Initialize engine args with LoRA support
        engine_args = EngineArgs(
            model=self.model_name,
            dtype=DTYPE,
            tensor_parallel_size=tensor_parallel_size,
            max_model_len=self.max_tokens,
            enable_lora=True,
            max_loras=max_loras,
            max_lora_rank=max_lora_rank,
            max_cpu_loras=max_cpu_loras,
            enable_lora_bias=True,  # Critical for our approach
            gpu_memory_utilization=0.95,
            trust_remote_code=True,
            enable_chunked_prefill=True
        )
        
        # Initialize the LLMEngine
        print(f"Worker {worker_id}: Initializing LLMEngine...")
        self.engine = LLMEngine.from_engine_args(engine_args)

        # Create output DB for this worker
        print(f"Worker {worker_id}: Opening DB...")
        os.makedirs(os.path.dirname(self.output_db_path), exist_ok=True)
        with shelve.open(f"{self.output_db_path}_worker{self.worker_id}", flag="n") as db:
            pass
        
        print(f"Worker {self.worker_id} initialized with model {model_name} and dynamic LoRA support")
    
    def format_prompts(self, prompts: List[str]) -> List[str]:
        """Format prompts with system prompt and chat template."""
        if len(self.system_prompt) > 0:
            chat_init = [{'content': self.system_prompt, 'role': 'system'}]
        else:
            chat_init = []
            
        chats = [chat_init + [{'content': prompt, 'role': 'user'}] for prompt in prompts]
        formatted_prompts = self.tokenizer.apply_chat_template(
            chats,
            tokenize=False,
            add_generation_prompt=True,
            padding=False
        )
        
        return formatted_prompts

    def initialize_lora_pool(self, pool_size):
        """
        Initialize a fixed pool of LoRA adapters for reuse.
        
        This method creates a single dummy LoRA adapter on disk and then
        registers multiple adapter IDs that point to this same adapter file,
        allowing us to modify the biases in memory for each adapter separately.
        
        Args:
            pool_size: Number of LoRA adapters to initialize in the pool
        """
        print(f"Worker {self.worker_id}: Initializing LoRA pool with {pool_size} adapters")
        
        # Store the pool size and create tracking structures
        self.lora_pool_size = pool_size
        self.lora_pool = []
        self.adapter_id_to_idx = {}
        
        # Create dummy LoRA manager to get requests
        dummy_lora_manager = DummyLoRAManager(
            model_name=self.model_name,
            source_layer_idx=self.source_layer_idx,
            lora_dir=os.path.dirname(self.dummy_lora_path)
        )
        dummy_lora_manager.dummy_lora_path = self.dummy_lora_path
        
        # Get dummy LoRA requests for this pool
        dummy_requests = dummy_lora_manager.get_dummy_lora_requests(pool_size)
        
        # Add each adapter to the engine
        for adapter_id, lora_request in dummy_requests:
            # Pre-load all adapters to avoid loading them during inference
            self.engine.add_lora(lora_request)
            
            # Track adapter IDs and their pool indices
            self.lora_pool.append(adapter_id)
            self.adapter_id_to_idx[adapter_id] = len(self.lora_pool) - 1
        
        print(f"Worker {self.worker_id}: Successfully initialized LoRA pool with {len(self.lora_pool)} adapters")
        
        return True
    
    def _modify_lora_bias(self, model, adapter_id: int, bias: torch.Tensor):
        """
        Internal function to modify LoRA bias in the model.
        
        Args:
            model: Model instance
            adapter_id: ID of the adapter to modify
            bias: New bias tensor to apply
        """
        # Find the right layer
        target_module = None
        for name, module in model.named_modules():
            if f"layers.{self.source_layer_idx}.mlp.down_proj" in name:
                if hasattr(module, 'lora_bias_stacked'):
                    target_module = module
                    break
        
        if target_module is None:
            print(f"Warning: Could not find target module with lora_bias_stacked")
            return False
        # Prepare a dummy LoRA A and B
        lora_a = torch.zeros((model.config.intermediate_size, 1), dtype=DTYPE)
        lora_b = torch.zeros((1, model.config.hidden_size), dtype=DTYPE)
        
        # Use vLLM's existing method to update the weights
        target_module.set_lora(adapter_id-1, lora_a, lora_b, None, bias)
        return False
    
    def process_vector_batch(
        self,
        vector_batch: List[Dict[str, Any]],
        vectors: torch.Tensor,
        prompts: List[str],
        output_db_path: str,
        lora_batch_size: int = 2,
        temperature: float=0.0,
        seed: int=325
    ) -> Dict[str, Any]:
        """
        Process a batch of vectors using dynamic LoRA bias modification.
        
        Args:
            vector_batch: List of dicts with vector info (index, scale, sign)
            vectors: Full tensor of vectors
            prompts: List of prompts to use for generation
            output_db_path: Path to save results
            lora_batch_size: Number of vectors to process in parallel
            
        Returns:
            Dictionary mapping result keys to generated outputs
        """
        start_time = time.time()
        print(f"Worker {self.worker_id} starting batch with {len(vector_batch)} vectors")

        # Initialize LoRA pool if not already done
        if not hasattr(self, 'lora_pool'):
            self.initialize_lora_pool(lora_batch_size)

        
        # Process vectors in batches of lora_batch_size
        vector_batches = [vector_batch[i:i+lora_batch_size] for i in range(0, len(vector_batch), lora_batch_size)]
        
        # Define sampling parameters
        sampling_params = SamplingParams(
            temperature=temperature,
            max_tokens=self.max_tokens,
            repetition_penalty=1.05
        )
        
        # Create dummy LoRA manager to get requests
        dummy_lora_manager = DummyLoRAManager(
            model_name=self.model_name,
            source_layer_idx=self.source_layer_idx,
            lora_dir=os.path.dirname(self.dummy_lora_path)
        )
        dummy_lora_manager.dummy_lora_path = self.dummy_lora_path
        
        # Open local database for results
        with shelve.open(f"{output_db_path}_worker{self.worker_id}") as db:
            results = {}
            
            for batch_idx, batch in enumerate(vector_batches):
                batch_start_time = time.time()
                print(f"Worker {self.worker_id} processing vector batch {batch_idx+1}/{len(vector_batches)} with {len(batch)} vectors")
                
                # Clear any previous requests
                while self.engine.has_unfinished_requests():
                    self.engine.step()
                
                # Get dummy LoRA requests for this batch
                dummy_requests = dummy_lora_manager.get_dummy_lora_requests(len(batch))
                
                # Mapping of request IDs to their configurations for result tracking
                request_id_map = {}
                
                # Add each vector's steering request to the engine
                for i, config in enumerate(batch):
                    vec_idx = config["vector_idx"]
                    scale = config["scale"]
                    sign = config["sign"]
                    
                    # Apply sign to scale
                    sign_multiplier = 1.0 if sign == "+" else -1.0
                    scaled_vector = sign_multiplier * scale * vectors[:, vec_idx]
                    
                    # Get the adapter ID and request
                    adapter_id, lora_request = dummy_requests[i]
                    
                    # Create a partial function to modify the LoRA bias when the model is loaded
                    modify_fn = functools.partial(
                        self._modify_lora_bias,
                        adapter_id=adapter_id,
                        bias=scaled_vector
                    )
                    
                    # Apply the modification to the model
                    self.engine.model_executor.apply_model(modify_fn)
                    
                    # Process each prompt with this modified LoRA
                    for prompt_idx, prompt in enumerate(prompts):
                        # Create a unique request ID
                        request_id = f"req_{vec_idx}_{scale}_{sign}_{prompt_idx}_{uuid.uuid4().hex[:6]}"
                        
                        # Add request to engine
                        self.engine.add_request(
                            request_id,
                            prompt,
                            sampling_params,
                            lora_request=lora_request
                        )
                        
                        # Store mapping for later retrieval
                        request_id_map[request_id] = {
                            "vector_idx": vec_idx,
                            "scale": scale,
                            "sign": sign,
                            "prompt_idx": prompt_idx,
                            "config": config  # Store full config for result key generation
                        }
                
                # Process all requests in this batch
                batch_results = {}
                while self.engine.has_unfinished_requests():
                    # Process the next batch of outputs
                    request_outputs = self.engine.step()
                    
                    # Collect the outputs
                    for output in request_outputs:
                        if output.finished:
                            request_id = output.request_id
                            req_info = request_id_map.get(request_id)
                            
                            if req_info:
                                # Get the full config
                                config = req_info["config"]
                                prompt_idx = req_info["prompt_idx"]
                                
                                # Generate result key
                                if "vector_info" in config:
                                    # Use vector_info for the key
                                    result_key = f"{config['vector_info']}_sign{config['sign']}"
                                else:
                                    # Use original format
                                    vec_idx = config["vector_idx"]
                                    scale_idx = config.get("scale_idx", 0)  # Default to 0 if not available
                                    sign = config["sign"]
                                    result_key = f"vec{vec_idx}_scale{scale_idx}_sign{sign}"
                                
                                # Initialize lists if needed
                                if result_key not in batch_results:
                                    batch_results[result_key] = [None] * len(prompts)
                                
                                # Store result at the correct prompt index
                                batch_results[result_key][prompt_idx] = output.outputs[0].text
                
                # Update overall results with this batch's results
                for key, outputs in batch_results.items():
                    results[key] = outputs
                    db[key] = outputs
                
                batch_time = time.time() - batch_start_time
                print(f"Worker {self.worker_id} completed vector batch {batch_idx+1} in {batch_time:.2f} seconds")
            
        total_time = time.time() - start_time
        print(f"Worker {self.worker_id} completed all {len(vector_batch)} vector configurations in {total_time:.2f} seconds")
        
        return results

def run_inference_with_ray(config):
    """Main function to run distributed inference with Ray."""
    args = config["args"]
    vector_manager = config["vector_manager"]
    
    # Parse scales and signs
    scales = [float(scale) for scale in args.steering_multipliers.split(",")]
    signs = args.signs.split(",")
    print(f"Using scales: {scales} and signs: {signs}")
    
    # Initialize Ray if not already started
    if not ray.is_initialized():
        if args.ray_address == "auto" or args.ray_address is None:
            # Start a local Ray instance
            ray.init(include_dashboard=True)
            print("Started a new local Ray instance")
        else:
            # Connect to existing Ray cluster
            ray.init(address=args.ray_address)
            print(f"Connected to Ray cluster at {args.ray_address}")
    print(f"Ray cluster resources: {ray.cluster_resources()}")

    # Load dataset
    dataset = Dataset.load_from_disk(args.dataset_name)
    if args.dataset_idx is not None:
        dataset = dataset.select([args.dataset_idx])
    
    # Determine observation indices to process
    if args.for_each:
        # Use all observation indices by default
        all_obs_indices = list(range(len(dataset)))
        
        # Apply data_start and data_end limits if specified
        if args.data_start is not None or args.data_end is not None:
            data_start = args.data_start if args.data_start is not None else 0
            data_end = args.data_end if args.data_end is not None else len(dataset)
            
            # Validate range
            if data_start < 0:
                data_start = 0
            if data_end > len(dataset):
                data_end = len(dataset)
            if data_start >= data_end:
                raise ValueError(f"Invalid data range: data_start ({data_start}) must be less than data_end ({data_end})")
            
            # Filter observation indices to the specified range
            obs_indices = list(range(data_start, data_end))
            print(f"Processing observations from index {data_start} to {data_end-1} (total: {len(obs_indices)})")
        else:
            obs_indices = all_obs_indices
            print(f"Processing all {len(obs_indices)} observations")
    else:
        # Only one shared set of vectors (warn if data_start/data_end are set but ignored)
        if args.data_start is not None or args.data_end is not None:
            print("Warning: data_start and data_end are ignored in non-for_each mode")
        obs_indices = [None]

    # Determine number of workers if not specified
    num_workers = args.num_workers or min(
        int(ray.cluster_resources().get("GPU", 1)),
        len(obs_indices)  # At least one worker per observation in for_each mode
    )
    
    print(f"Using {num_workers} workers for {len(obs_indices)} observation(s)")
    
    # Create dummy LoRA adapter
    lora_manager = DummyLoRAManager(
        model_name=args.model_name,
        source_layer_idx=args.source_layer_idx,
        lora_dir=args.lora_dir,
        vector_info=None  # We'll handle vector_info separately for each observation
    )
    dummy_lora_path = lora_manager.create_dummy_lora_adapter()
    
    # Create workers
    workers = [
        DynamicLoRAWorker.remote(
            model_name=args.model_name,
            tokenizer_name=args.tokenizer_name or args.model_name,
            system_prompt=args.system_prompt,
            max_tokens=args.max_tokens,
            tensor_parallel_size=args.tensor_parallel_size,
            dummy_lora_path=dummy_lora_path,
            source_layer_idx=args.source_layer_idx,
            max_loras=args.max_loras,
            max_lora_rank=args.max_lora_rank,
            max_cpu_loras=args.max_cpu_loras,
            worker_id=i,
            output_db_path=args.output_db
        )
        for i in range(num_workers)
    ]
    
    # Process each observation
    for obs_idx in obs_indices:
        # Observation-specific data
        if obs_idx is not None:
            print(f"\nProcessing observation {obs_idx}")
            # Get prompt for this observation
            prompts = [dataset[args.field_name][obs_idx]]
            
            # Load vector info for this observation if available
            vector_info = None
            if args.vector_info_json and args.for_each:
                obs_vector_info_path = os.path.join(os.path.dirname(args.vector_info_json), 
                                                   f"_{obs_idx}/_vector_info.json")
                if os.path.exists(obs_vector_info_path):
                    try:
                        with open(obs_vector_info_path, 'r') as f:
                            vector_info = json.load(f)
                        print(f"Loaded vector info for observation {obs_idx} with {len(vector_info)} entries")
                    except Exception as e:
                        print(f"Error loading vector info for observation {obs_idx}: {e}")
        else:
            # Shared vectors mode
            prompts = dataset[args.field_name]
            
            # Load vector info if available
            vector_info = None
            if args.vector_info_json:
                try:
                    with open(args.vector_info_json, 'r') as f:
                        vector_info = json.load(f)
                    print(f"Loaded vector info with {len(vector_info)} entries")
                except Exception as e:
                    print(f"Error loading vector info: {e}")
        
        # Format prompts once (same for all workers)
        formatted_prompts_ref = workers[0].format_prompts.remote(prompts)
        formatted_prompts = ray.get(formatted_prompts_ref)
        
        # Load vectors for this observation
        vectors = vector_manager.load_vectors(obs_idx=obs_idx).to(torch.bfloat16)
        total_vectors = vectors.shape[1]
                
        # Prepare vector configurations
        vector_configs = []
        for vec_idx in range(total_vectors):
            for scale_idx, scale in enumerate(scales):
                for sign_idx, sign in enumerate(signs):
                    config = {
                        "vector_idx": vec_idx,
                        "scale": scale,
                        "scale_idx": scale_idx,
                        "sign": sign
                    }
                    
                    # Add observation index for for_each mode
                    if obs_idx is not None:
                        config["obs_idx"] = obs_idx
                    
                    # Add vector info if available
                    if vector_info and vec_idx < len(vector_info):
                        config["vector_info"] = vector_info[vec_idx]
                        
                    vector_configs.append(config)
                
        # Create balanced batches for workers
        worker_batches = [[] for _ in range(num_workers)]
        for i, config in enumerate(vector_configs):
            worker_idx = i % num_workers
            worker_batches[worker_idx].append(config)
        
        # Distribute batches to workers
        tasks = []
        for worker_idx, batch in enumerate(worker_batches):
            if not batch:  # Skip empty batches
                continue
            
            # Further split the batch based on batch_size
            for i in range(0, len(batch), args.batch_size):
                sub_batch = batch[i:i + args.batch_size]
                if sub_batch:  # Skip empty sub-batches
                    tasks.append(
                        workers[worker_idx].process_vector_batch.remote(
                            vector_batch=sub_batch,
                            vectors=vectors,
                            prompts=formatted_prompts,
                            output_db_path=args.output_db if obs_idx is None else f"{args.output_db}_obs{obs_idx}",
                            lora_batch_size=args.lora_batch_size,
                            temperature=args.temperature,
                            seed=args.seed
                        )
                    )
        
        # Track progress
        with tqdm(total=len(tasks)) as pbar:
            while tasks:
                # Use ray.wait to process results as they complete
                done_ids, tasks = ray.wait(tasks, num_returns=1)
                # Get the result and extract the vector indices
                result = ray.get(done_ids[0])
                
                pbar.update(1)
        
        # Combine results from worker databases for this observation
        if obs_idx is not None:
            output_db_path = f"{args.output_db}_obs{obs_idx}"
            print(f"Combining results for observation {obs_idx} from {num_workers} workers...")
        else:
            output_db_path = args.output_db
            print(f"Combining results from {num_workers} workers...")
            
        with shelve.open(output_db_path, flag="n") as combined_db:
            for worker_id in range(num_workers):
                worker_db_path = f"{output_db_path}_worker{worker_id}"
                if os.path.exists(f"{worker_db_path}.db"):
                    with shelve.open(worker_db_path) as worker_db:
                        for key, value in worker_db.items():
                            combined_db[key] = value
            print(f"Combined results saved to {output_db_path}")
    
    # Clean up LoRA adapters if using temporary directory
    lora_manager.cleanup()
    
    print("All inference tasks completed!")

def _load_completed_vectors_from_log(log_file):
    """Extract completed vector indices from an existing log file."""
    completed = set()
    
    if not os.path.exists(log_file):
        return completed
        
    try:
        with open(log_file, 'r') as f:
            for line in f:
                if "Completed vectors:" in line:
                    # Extract the vector indices from the log line
                    start_idx = line.find("[")
                    end_idx = line.find("]")
                    if start_idx >= 0 and end_idx >= 0:
                        indices_str = line[start_idx+1:end_idx]
                        for idx_str in indices_str.split(","):
                            try:
                                idx = int(idx_str.strip())
                                completed.add(idx)
                            except ValueError:
                                pass
    except Exception as e:
        print(f"Warning: Error reading completed vectors from log file: {e}")
        
    return completed

def main():
    # Validate arguments for for_each mode
    if args.for_each:
        if args.vecs_dir is None:
            raise ValueError("--vecs_dir must be specified when --for_each is enabled")
        
        # Check if the vectors directory exists
        if not os.path.exists(args.vecs_dir):
            raise ValueError(f"Vectors directory does not exist: {args.vecs_dir}")
    else:
        if args.vecs_name is None:
            raise ValueError("--vecs_name must be specified when --for_each is not enabled")
    
    # Create vector manager based on for_each mode
    if args.for_each:
        vector_manager = VectorManager(
            vectors_dir=args.vecs_dir,
            for_each=True,
            num_vecs=args.num_vecs
        )
    else:
        vector_manager = VectorManager(
            vectors_path=args.vecs_name,
            for_each=False,
            num_vecs=args.num_vecs
        )
    
    # Set up config
    config = {
        "args": args,
        "vector_manager": vector_manager
    }
    
    # Handle local mode option
    if args.local:
        print("Forcing local Ray initialization")
        ray.init(include_dashboard=True, ignore_reinit_error=True)

    # Save metadata
    output_dir = os.path.dirname(args.output_db)
    metadata = vars(args)
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, f"metadata.json"), 'w') as f:
        json.dump(metadata, f, indent=2)
    
    # Execute inference with Ray
    run_inference_with_ray(config)
    
    # Shutdown Ray when done
    if ray.is_initialized():
        ray.shutdown()
        print("Ray has been shut down")


if __name__ == "__main__":
    main()