import argparse
import os
import logging
import random
import numpy as np
import torch
import torch.multiprocessing as mp
from typing import List, Tuple
import pickle

from datasets import load_dataset, load_from_disk

from bestofn.estimators.uhead import UHeadEstimator
from bestofn.stat_calculators.load_stat_calculators import load_relevant_stat_calculators
from lm_polygraph import WhiteboxModel
from lm_polygraph.estimators import MaximumSequenceProbability, MeanTokenEntropy, Perplexity
from bestofn.bestofn_utils import bestofn
from configs.load_qwen import load_model as load_qwen_model, load_tokenizer as load_qwen_tokenizer

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("bestofn_eval_multigpu")


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset-path", type=str, required=True,
                        help="Dataset to evaluate on (HuggingFace name or local path)")
    parser.add_argument("--dataset-split", type=str, default="train", help="Dataset split (e.g., test)")
    parser.add_argument('--prompt-file', type=str, required=True, help="Path to prompt template file")
    parser.add_argument("--save-path", type=str, required=True, help="Path to save the output .torch")
    parser.add_argument("--save-frequency", type=int, default=1, help="Save every n iterations")
    parser.add_argument("--hf-cache", type=str, default=None, help="Path to HuggingFace cache directory")
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen3-1.7B", help="Model name or path for generation")
    parser.add_argument("--uhead-path", type=str, default="user/uhead_Qwen3-1.7B_gsm8k", help="UHead HF path")
    parser.add_argument("--device", type=str, default="auto", 
                        help="Device to use. Options: 'auto' (use all GPUs), '0,1,2' (specific GPUs), 'cuda:0' (single GPU)")
    parser.add_argument("--n", type=int, default=10, help="Number of completions to generate per input")
    parser.add_argument("--temperature", type=float, default=1.0, help="Generation temperature")
    parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation max_new_tokens")
    parser.add_argument("--subset", type=int, default=None, help="Only process first N samples from dataset")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--batch-processing", action="store_true", default=True,
                        help="Enable batch processing with adaptive batch sizing (default: True)")
    parser.add_argument("--no-batch-processing", dest="batch_processing", action="store_false",
                        help="Disable batch processing, use single batch mode")
    return parser


def parse_device_arg(device_arg):
    """Parse device argument and return list of GPU IDs to use."""
    if device_arg == "auto":
        # Use all available GPUs
        num_gpus = torch.cuda.device_count()
        if num_gpus == 0:
            raise ValueError("No CUDA devices available. Please use CPU mode or check your GPU setup.")
        return list(range(num_gpus))
    elif device_arg.startswith("cuda:"):
        # Single GPU specified
        gpu_id = int(device_arg.split(":")[1])
        return [gpu_id]
    elif "," in device_arg:
        # Multiple GPUs specified
        return [int(x.strip()) for x in device_arg.split(",")]
    else:
        # Try to parse as single GPU ID
        try:
            gpu_id = int(device_arg)
            return [gpu_id]
        except ValueError:
            raise ValueError(f"Invalid device argument: {device_arg}. Use 'auto', 'cuda:0', '0,1,2', etc.")


def load_model(model_path, device):
    """Load model on specified device."""
    tokenizer = load_qwen_tokenizer(model_path)
    base_model = load_qwen_model(model_path, device)
    base_model.eval()
    return WhiteboxModel(base_model, tokenizer)


def split_dataset(dataset, num_splits):
    """Split dataset into roughly equal parts."""
    n_samples = len(dataset)
    samples_per_split = (n_samples + num_splits - 1) // num_splits
    
    splits = []
    for i in range(num_splits):
        start_idx = i * samples_per_split
        end_idx = min(start_idx + samples_per_split, n_samples)
        if start_idx < n_samples:
            split = dataset.select(range(start_idx, end_idx))
            splits.append(split)
    
    return splits


def worker_process(gpu_id: int, dataset_split, args, start_idx: int, result_queue: mp.Queue):
    """
    Worker process that runs the complete bestofn pipeline on a GPU.
    Each worker is completely independent.
    
    Note: gpu_id here is the logical GPU index after CUDA_VISIBLE_DEVICES remapping.
    """
    # Configure logging for this worker with GPU rank prefix
    class GPURankFormatter(logging.Formatter):
        def __init__(self, gpu_id, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.gpu_id = gpu_id
            
        def format(self, record):
            # Add GPU rank to the beginning of each log message
            original_msg = super().format(record)
            return f"[GPU {self.gpu_id}] {original_msg}"
    
    # Update all handlers to use GPU-aware formatter
    for handler in logging.root.handlers:
        original_formatter = handler.formatter
        if original_formatter:
            # Preserve the original format but add GPU rank
            gpu_formatter = GPURankFormatter(
                gpu_id,
                fmt=original_formatter._fmt if hasattr(original_formatter, '_fmt') else None,
                datefmt=original_formatter.datefmt
            )
        else:
            # Use default format with GPU rank
            gpu_formatter = GPURankFormatter(gpu_id)
        handler.setFormatter(gpu_formatter)
    
    try:
        # Set device for this process
        # gpu_id is already the logical index (0, 1, 2...) not the physical GPU ID
        device = f"cuda:{gpu_id}"
        torch.cuda.set_device(gpu_id)
        
        # Set random seeds with GPU offset
        random.seed(args.seed + gpu_id)
        np.random.seed(args.seed + gpu_id)
        torch.manual_seed(args.seed + gpu_id)
        torch.cuda.manual_seed(args.seed + gpu_id)
        
        log.info(f"Starting with {len(dataset_split)} samples (indices {start_idx}-{start_idx + len(dataset_split) - 1})")
        
        # Load model on this GPU
        log.info("Loading model")
        model = load_model(args.model_path, device)
        
        # Load estimators
        log.info("Loading estimators")
        estimators = [
            MaximumSequenceProbability(),
            MeanTokenEntropy(),
            Perplexity(),
            UHeadEstimator(reduction='min'),
        ]
        
        # Load stat calculators (regular single-GPU version)
        log.info("Loading stat calculators")
        stat_calculators = load_relevant_stat_calculators(
            estimators, model, args.uhead_path,
            prompt_path=args.prompt_file,
            hf_cache=args.hf_cache,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            batch_processing=args.batch_processing,
        )
        
        # Create temporary save path for this worker
        worker_save_path = args.save_path.replace('.torch', f'_gpu{gpu_id}_temp.torch')
        
        # Run bestofn on this GPU's dataset split
        log.info("Running bestofn evaluation")
        bestofn(
            dataset_split, model, estimators, stat_calculators,
            worker_save_path, args.save_frequency,
            args.n, args.max_new_tokens,
        )
        
        # Load the results and send back via queue
        if os.path.exists(worker_save_path):
            results = torch.load(worker_save_path, weights_only=False)
            result_queue.put((gpu_id, start_idx, results))
            os.remove(worker_save_path)  # Clean up temporary file
            log.info("Completed successfully")
        else:
            log.error("No results file found")
            result_queue.put((gpu_id, start_idx, None))
            
    except Exception as e:
        log.error(f"Error during processing: {e}")
        result_queue.put((gpu_id, start_idx, None))


def main(args):
    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    log.info(f"Set random seed to {args.seed}")
    
    # Parse GPU configuration
    physical_gpu_ids = parse_device_arg(args.device)
    log.info(f"Requested GPUs from command line: {physical_gpu_ids}")
    
    # When CUDA_VISIBLE_DEVICES is set, we need to use logical indices
    # The physical GPU IDs are handled by CUDA_VISIBLE_DEVICES
    if os.environ.get('CUDA_VISIBLE_DEVICES'):
        log.info(f"CUDA_VISIBLE_DEVICES is set to: {os.environ['CUDA_VISIBLE_DEVICES']}")
        # Use logical indices: 0, 1, 2, ... based on number of requested GPUs
        gpu_ids = list(range(len(physical_gpu_ids)))
        log.info(f"Using logical GPU indices: {gpu_ids}")
    else:
        # No CUDA_VISIBLE_DEVICES set, use physical GPU IDs directly
        gpu_ids = physical_gpu_ids
        log.info(f"Using physical GPU IDs: {gpu_ids}")
    
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)

    log.info(f"Loading dataset: {args.dataset_path} ({args.dataset_split})")
    
    # Check if it's a local dataset saved with save_to_disk() (e.g., from shard_dataset.py)
    if os.path.exists(args.dataset_path) and os.path.isdir(args.dataset_path):
        if os.path.exists(os.path.join(args.dataset_path, "dataset_info.json")):
            log.info("Loading from local directory (saved with save_to_disk)...")
            dataset = load_from_disk(args.dataset_path)
        else:
            # Directory but not a saved dataset
            dataset = load_dataset(args.dataset_path, split=args.dataset_split, cache_dir=args.hf_cache)
    else:
        # HuggingFace Hub dataset
        dataset = load_dataset(args.dataset_path, split=args.dataset_split, cache_dir=args.hf_cache)
    
    if args.subset is not None:
        log.info(f"Using subset: first {args.subset} samples")
        dataset = dataset.select(range(min(args.subset, len(dataset))))
    
    # Split dataset across GPUs
    dataset_splits = split_dataset(dataset, len(gpu_ids))
    log.info(f"Split {len(dataset)} samples across {len(gpu_ids)} GPUs")
    
    # Set up multiprocessing
    mp.set_start_method('spawn', force=True)
    result_queue = mp.Queue()
    
    # Start worker processes
    processes = []
    start_indices = []
    cumulative_idx = 0
    
    for i, (gpu_id, split) in enumerate(zip(gpu_ids, dataset_splits)):
        start_indices.append(cumulative_idx)
        p = mp.Process(
            target=worker_process,
            args=(gpu_id, split, args, cumulative_idx, result_queue)
        )
        p.start()
        processes.append(p)
        cumulative_idx += len(split)
    
    # Collect results from all workers
    results_by_gpu = {}
    for _ in range(len(processes)):
        gpu_id, start_idx, results = result_queue.get()
        results_by_gpu[gpu_id] = (start_idx, results)
    
    # Wait for all processes to complete
    for p in processes:
        p.join()
    
    # Merge results in correct order
    log.info("Merging results from all GPUs")
    all_results = []
    
    # Sort by start index to maintain order
    sorted_results = sorted(results_by_gpu.values(), key=lambda x: x[0])
    
    for start_idx, results in sorted_results:
        if results is not None:
            all_results.extend(results)
        else:
            log.error(f"Missing results from GPU with start_idx {start_idx}")
    
    # Save merged results
    if all_results:
        torch.save(all_results, args.save_path)
        log.info(f"Saved {len(all_results)} total results to {args.save_path}")
    else:
        log.error("No results to save!")
    
    log.info("Multi-GPU evaluation complete!")


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    main(args)