import os
import json
import threading
import hashlib
import time
from pathlib import Path
from datetime import datetime

from vllm import LLM, SamplingParams
import torch
from torch.utils.data import DataLoader
from transformers import AutoProcessor
try:
    from transformers import Qwen3VLForConditionalGeneration
except ImportError:
    from transformers import Qwen2VLForConditionalGeneration as Qwen3VLForConditionalGeneration
from peft import PeftModel
from tqdm import tqdm

from utils import parse_args, parse_llm_args, back_envs, set_seed, save_data_to_cache
from dataset import dataset_dict, collate_fn


def merge_lora_if_needed(model_path, lora_path):
    """Merge LoRA weights into base model and cache the result"""
    if lora_path is None:
        return model_path
    
    # Create cache directory
    root_dir = torch.hub.get_dir()  # default: ~/.cache/torch/hub
    lora_filename = os.path.splitext(os.path.basename(lora_path))[0]
    lora_hash = hashlib.md5(lora_path.encode()).hexdigest()[:8]
    lora_identifier = f"{lora_filename}_{lora_hash}"
    cache_dir = os.path.join(root_dir, "MMRB2", f"{os.path.basename(model_path)}_merged_lora_{lora_identifier}")
    
    if not os.path.exists(cache_dir):
        print(f"Merging LoRA from {lora_path} to {model_path}...")
        print(f"Saving merged model to {cache_dir}")
        start_time = time.time()
        
        # Load base model
        model = Qwen3VLForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cpu"
        )
        
        # Load and merge LoRA
        model = PeftModel.from_pretrained(model, lora_path)
        model = model.merge_and_unload()
        model.save_pretrained(cache_dir)
        
        # Save processor
        processor = AutoProcessor.from_pretrained(model_path)
        processor.save_pretrained(cache_dir)
        
        print(f"LoRA merging completed in {time.time() - start_time:.2f} seconds")
    else:
        print(f"Using cached merged model from {cache_dir}")
    
    return cache_dir


def initialize_vllm(args):
    # Merge LoRA if provided
    model_path = merge_lora_if_needed(args.model, args.lora_path)
    
    # Update args with merged model path
    args.model = model_path
    
    llm_kwargs = parse_llm_args(args)
    model = LLM(**llm_kwargs)
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        max_tokens=args.max_tokens,
    )
    return model, sampling_params


def set_cuda_visible_devices(local_rank, tensor_parallel_size, offset=0):
    if os.environ.get('CUDA_VISIBLE_DEVICES', None) is not None:
        offset = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0])
    os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(
        [str(i) for i in range(local_rank * tensor_parallel_size + offset, (local_rank + 1) * tensor_parallel_size + offset)]
    )
    print(f"local rank {local_rank}, CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")


def main():
    args = parse_args()

    rank, locals_rank, world_size = os.getenv('RANK', '0'), os.getenv('LOCAL_RANK', '0'), os.getenv('WORLD_SIZE', '1')
    
    set_seed(args.seed)
    
    # Determine output path
    current_output_path = args.output_path
    
    # Add timestamp if requested
    if args.add_timestamp:
        # Insert timestamp before .json extension
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if current_output_path.endswith('.json'):
            base_path = current_output_path[:-5]  # Remove .json
            current_output_path = f"{base_path}_{timestamp}.json"
        else:
            current_output_path = f"{current_output_path}_{timestamp}"
        print(f"\n{'='*60}")
        print(f"Using timestamped output: {current_output_path}")
        print(f"{'='*60}\n")
    
    processor = AutoProcessor.from_pretrained(args.model, max_pixels=args.max_pixels, min_pixels=args.min_pixels)
    dataset = dataset_dict["mmrb2_edit"](
        args.data_path,
        current_output_path,
        int(rank),
        int(world_size),
        processor=processor,
        with_region=args.with_region,
        interleaved=args.interleaved,
        score_aggregation=args.score_aggregation,
        weighted_power_params=args.weighted_power_params,
        single_image_only=args.single_image_only,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_fn,
    )
    
    # Initialize VLLM
    set_cuda_visible_devices(int(locals_rank), args.tensor_parallel_size)
    torchrun_vars_bak = back_envs()
    llm, sampling_params = initialize_vllm(args)
    
    # Update sampling params with seed
    sampling_params.seed = args.seed
    
    cache_path = current_output_path + f"_rank{rank}_cache.json"
    data_lock = threading.Lock()
    data_dict = dataset.cache_dict
    save_thread = None
    
    # Save interval
    save_interval = 10
    
    for step, batch_data in enumerate(tqdm(dataloader, disable=bool(int(rank)), desc="MMRB2 Inference")):
        # batch_data is (prompts, metadata) from collate_fn
        prompts, metadata = batch_data
        
        outputs = llm.generate(
            prompts,
            sampling_params=sampling_params,
            use_tqdm=False,
        )
        
        with data_lock:
            data_dict = dataset.post_process(metadata, outputs, data_dict)
        
        # Periodic saves
        if step % save_interval == 0 and step > 0:
            if save_thread is not None and save_thread.is_alive():
                save_thread.join()
            
            save_thread = threading.Thread(
                target=save_data_to_cache,
                args=(data_dict, cache_path, data_lock)
            )
            save_thread.start()
    
    if save_thread is not None and save_thread.is_alive():
        save_thread.join()
    save_data_to_cache(data_dict, cache_path, data_lock)

    # Merge results from all ranks
    if int(rank) == 0:
        print(f"\nMerging results from {world_size} ranks...")
        tot_data = data_dict
        
        cache_path_obj = Path(current_output_path)
        rank_cache_files = []
        
        for r in range(1, int(world_size)):
            rank_cache_path = str(cache_path_obj) + f"_rank{r}_cache.json"
            if os.path.exists(rank_cache_path):
                rank_cache_files.append(rank_cache_path)
        
        # Load rank files in parallel if multiple ranks
        if rank_cache_files:
            from concurrent.futures import ThreadPoolExecutor
            def load_rank_file(filepath):
                with open(filepath, 'r', encoding='utf-8') as f:
                    return json.load(f)
            
            with ThreadPoolExecutor(max_workers=min(8, len(rank_cache_files))) as executor:
                futures = {executor.submit(load_rank_file, f): f for f in rank_cache_files}
                for future in futures:
                    rank_data = future.result()
                    tot_data.update(rank_data)
        
        # Save final merged results
        with open(current_output_path, 'w', encoding='utf-8') as f:
            json.dump(tot_data, f, ensure_ascii=False, indent=2)
        
        # Clean up cache files
        for r in range(0, int(world_size)):
            rank_cache_path = str(cache_path_obj) + f"_rank{r}_cache.json"
            if os.path.exists(rank_cache_path):
                os.remove(rank_cache_path)
        
        print(f"\nInference completed. Results saved to {current_output_path}")


if __name__ == "__main__":
    main()

