import os
import json
import threading
import hashlib
import time
from pathlib import Path
from datetime import datetime

from vllm import LLM, SamplingParams
import torch
from torch.utils.data import DataLoader
from transformers import AutoProcessor
try:
    from transformers import Qwen3VLForConditionalGeneration
except ImportError:
    from transformers import Qwen2VLForConditionalGeneration as Qwen3VLForConditionalGeneration
from peft import PeftModel
from tqdm import tqdm

from utils import parse_args, parse_llm_args, back_envs, set_seed, save_data_to_cache
from dataset import dataset_dict, collate_fn


def merge_lora_if_needed(model_path, lora_path):
    """Merge LoRA weights into base model and cache the result"""
    if lora_path is None:
        return model_path
    
    # Create cache directory
    root_dir = torch.hub.get_dir()  # default: ~/.cache/torch/hub
    lora_filename = os.path.splitext(os.path.basename(lora_path))[0]
    lora_hash = hashlib.md5(lora_path.encode()).hexdigest()[:8]
    lora_identifier = f"{lora_filename}_{lora_hash}"
    cache_dir = os.path.join(root_dir, "EditScore", f"{os.path.basename(model_path)}_merged_lora_{lora_identifier}")
    
    if not os.path.exists(cache_dir):
        print(f"Merging LoRA from {lora_path} to {model_path}...")
        print(f"Saving merged model to {cache_dir}")
        start_time = time.time()
        
        # Load base model
        model = Qwen3VLForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cpu"
        )
        
        # Load and merge LoRA
        model = PeftModel.from_pretrained(model, lora_path)
        model = model.merge_and_unload()
        model.save_pretrained(cache_dir)
        
        # Save processor
        processor = AutoProcessor.from_pretrained(model_path)
        processor.save_pretrained(cache_dir)
        
        print(f"LoRA merging completed in {time.time() - start_time:.2f} seconds")
    else:
        print(f"Using cached merged model from {cache_dir}")
    
    return cache_dir


def initialize_vllm(args):
    # Merge LoRA if provided
    model_path = merge_lora_if_needed(args.model, args.lora_path)
    
    # Update args with merged model path
    args.model = model_path
    
    llm_kwargs = parse_llm_args(args)
    model = LLM(**llm_kwargs)
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        max_tokens=args.max_tokens,
    )
    return model, sampling_params


def set_cuda_visible_devices(local_rank, tensor_parallel_size, offset=0):
    if os.environ.get('CUDA_VISIBLE_DEVICES', None) is not None:
        offset = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0])
    os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(
        [str(i) for i in range(local_rank * tensor_parallel_size + offset, (local_rank + 1) * tensor_parallel_size + offset)]
    )
    print(f"local rank {local_rank}, CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")


def main():
    args = parse_args()

    rank, locals_rank, world_size = os.getenv('RANK', '0'), os.getenv('LOCAL_RANK', '0'), os.getenv('WORLD_SIZE', '1')
    
    # Loop through multiple inference passes if num_pass > 1
    for pass_id in range(args.num_pass):
        current_seed = args.seed + pass_id
        set_seed(current_seed)
        
        # Determine output path for this pass
        if args.num_pass > 1:
            base_path = args.output_path.replace('.json', '')
            current_output_path = f"{base_path}_pass{pass_id + 1}.json"
            print(f"\n{'='*60}")
            print(f"Starting inference pass {pass_id + 1}/{args.num_pass} (seed={current_seed})")
            print(f"Output: {current_output_path}")
            print(f"{'='*60}\n")
        else:
            current_output_path = args.output_path
            
            # Add timestamp if requested
            if args.add_timestamp:
                # Insert timestamp before .json extension
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                if current_output_path.endswith('.json'):
                    base_path = current_output_path[:-5]  # Remove .json
                    current_output_path = f"{base_path}_{timestamp}.json"
                else:
                    current_output_path = f"{current_output_path}_{timestamp}"
                print(f"\n{'='*60}")
                print(f"Using timestamped output: {current_output_path}")
                print(f"{'='*60}\n")
        
        processor = AutoProcessor.from_pretrained(args.model, max_pixels=args.max_pixels, min_pixels=args.min_pixels)
        dataset = dataset_dict[args.dataset_type](
            args.data_path,
            current_output_path,
            int(rank),
            int(world_size),
            processor=processor,
            with_region=args.with_region,
            interleaved=args.interleaved,
            score_aggregation=args.score_aggregation,
            weighted_power_params=args.weighted_power_params,
        )
        dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=collate_fn,
        )
        
        # Only initialize VLLM once
        if pass_id == 0:
            set_cuda_visible_devices(int(locals_rank), args.tensor_parallel_size)
            torchrun_vars_bak = back_envs()
            llm, sampling_params = initialize_vllm(args)
        
        # Update sampling params with current seed
        sampling_params.seed = current_seed
        
        cache_path = current_output_path + f"_rank{rank}_cache.json"
        data_lock = threading.Lock()
        data_dict = dataset.cache_dict
        save_thread = None
        
        # Optimized: reduce save frequency and use larger batches
        save_interval = 10  # Save every 10 steps instead of 5
        
        for step, batch_data in enumerate(tqdm(dataloader, disable=bool(int(rank)), desc=f"Pass {pass_id+1}/{args.num_pass}")):
            # batch_data is (prompts, metadata) from collate_fn
            prompts, metadata = batch_data
            
            outputs = llm.generate(
                prompts,
                sampling_params=sampling_params,
                use_tqdm=False,
            )
            
            with data_lock:
                data_dict = dataset.post_process(metadata, outputs, data_dict)
            
            # Optimized: less frequent saves
            if step % save_interval == 0 and step > 0:
                if save_thread is not None and save_thread.is_alive():
                    save_thread.join()
                
                save_thread = threading.Thread(
                    target=save_data_to_cache,
                    args=(data_dict, cache_path, data_lock)
                )
                save_thread.start()
        
        if save_thread is not None and save_thread.is_alive():
            save_thread.join()
        save_data_to_cache(data_dict, cache_path, data_lock)

        # Optimized: parallel merge of rank results
        if int(rank) == 0:
            print(f"\nMerging results from {world_size} ranks...")
            tot_data = data_dict
            
            # Use Path for cleaner file operations
            cache_path_obj = Path(current_output_path)
            rank_cache_files = []
            
            for r in range(1, int(world_size)):
                rank_cache_path = str(cache_path_obj) + f"_rank{r}_cache.json"
                if os.path.exists(rank_cache_path):
                    rank_cache_files.append(rank_cache_path)
            
            # Load rank files in parallel if multiple ranks
            if rank_cache_files:
                from concurrent.futures import ThreadPoolExecutor
                def load_rank_file(filepath):
                    with open(filepath, 'r', encoding='utf-8') as f:
                        return json.load(f)
                
                with ThreadPoolExecutor(max_workers=min(8, len(rank_cache_files))) as executor:
                    futures = {executor.submit(load_rank_file, f): f for f in rank_cache_files}
                    for future in futures:
                        rank_data = future.result()
                        tot_data.update(rank_data)
            
            # Save final merged results
            with open(current_output_path, 'w', encoding='utf-8') as f:
                json.dump(tot_data, f, ensure_ascii=False, indent=2)
            
            # Clean up cache files
            for r in range(0, int(world_size)):
                rank_cache_path = str(cache_path_obj) + f"_rank{r}_cache.json"
                if os.path.exists(rank_cache_path):
                    os.remove(rank_cache_path)
            
            print(f"\nPass {pass_id + 1} completed. Results saved to {current_output_path}")
    
    if int(rank) == 0 and args.num_pass > 1:
        print(f"\n{'='*60}")
        print(f"All {args.num_pass} inference passes completed!")
        print(f"Result files:")
        base_path = args.output_path.replace('.json', '')
        for i in range(args.num_pass):
            print(f"  - {base_path}_pass{i + 1}.json")
        print(f"\nTo calculate avg{args.num_pass} statistics, run:")
        print(f"python calculate_statistics.py \\")
        print(f"  --result_files {base_path}_pass{{1..{args.num_pass}}}.json \\")
        print(f"  --avg_n {args.num_pass} \\")
        print(f"  --benchmark_dir {args.data_path}")
        print(f"{'='*60}")


if __name__ == "__main__":
    main()

