import os
import torch
import traceback
from typing import Tuple, Any, Optional

def initialize_model_on_gpu(gpu_id: int, model_name: str, dtype: str = "bfloat16") -> Tuple[Any, Any, Any, Any]:
    """
    Initialize a model on a specific GPU with robust verification.

    Args:
        gpu_id: The GPU ID to initialize the model on
        model_name: Name of the model to load
        dtype: Data type for the model

    Returns:
        Tuple of (model, tokenizer, runner, direct_runner)
    """
    try:
        # Import inside function to avoid circular imports
        from transformer_lens import HookedTransformer
        from custom_dreamy.runners import TlensTokenRunner, TlensTokenDiffRunner

        # Set the device to use for this model
        device = f"cuda:{gpu_id}"
        print(f"Initializing model on GPU {gpu_id} (device {device})")

        # First set the current device for all CUDA operations
        if torch.cuda.is_available():
            # Check if the requested GPU is valid
            device_count = torch.cuda.device_count()
            if gpu_id >= device_count:
                raise ValueError(f"GPU {gpu_id} not available. Only {device_count} GPUs detected.")

            # Set current device and check it worked
            torch.cuda.set_device(gpu_id)
            current_device = torch.cuda.current_device()
            if current_device != gpu_id:
                raise RuntimeError(f"Failed to set current device to {gpu_id}, got {current_device}")

            # Get properties of target GPU
            props = torch.cuda.get_device_properties(gpu_id)
            print(f"Using GPU {gpu_id}: {props.name} with {props.total_memory/1024**3:.2f} GB memory")

            # Clear CUDA cache
            torch.cuda.empty_cache()
        else:
            raise RuntimeError("CUDA is not available")

        # Load model with explicit device specification
        print(f"Loading model {model_name} on device {device}")
        model = HookedTransformer.from_pretrained(
            model_name,
            dtype=dtype,
            device=device
        )

        # Verify model is on the correct device
        model_device = next(model.parameters()).device
        print(f"Model loaded on device {model_device}")

        # If device doesn't match, move model
        if model_device.index != gpu_id:
            print(f"Moving model from {model_device} to {device}")
            model = model.to(device)
            model_device = next(model.parameters()).device
            if model_device.index != gpu_id:
                raise RuntimeError(f"Failed to move model to {device}, still on {model_device}")
            print(f"Model now on device {model_device}")

        # Get tokenizer
        tokenizer = model.tokenizer

        # Run a test forward pass to verify everything works
        with torch.no_grad():
            test_input = torch.ones((1, 10), dtype=torch.long, device=device)
            test_output = model(test_input)
            output_device = test_output.device

            # Check that output is on the correct device
            if output_device.index != gpu_id:
                raise RuntimeError(f"Model produced output on wrong device! Expected {device}, got {output_device}")

            print(f"Test forward pass successful on device {output_device}")
            del test_input, test_output

        # Initialize runners
        token_position = tokenizer.encode("1", add_special_tokens=False)[0]
        runner = TlensTokenRunner(model, tokenizer, token_pos=token_position)
        direct_runner = TlensTokenDiffRunner(
            model,
            tokenizer,
            token_pos_a=token_position,
            token_pos_b=tokenizer.encode("2", add_special_tokens=False)[0],
        )

        # Print memory usage after initialization
        allocated_memory = torch.cuda.memory_allocated(gpu_id) / 1024**3
        reserved_memory = torch.cuda.memory_reserved(gpu_id) / 1024**3
        print(f"GPU {gpu_id} memory after initialization: {allocated_memory:.2f} GB allocated, {reserved_memory:.2f} GB reserved")

        return model, tokenizer, runner, direct_runner

    except Exception as e:
        print(f"Error initializing model on GPU {gpu_id}: {e}")
        traceback.print_exc()
        raise