"""
UMAP of per-layer activations
Age sweep 1–100 with three subject variants
"""
import math, torch, umap, matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig
from functools import partial
import numpy as np
import argparse
import random
import os
import socket


from dosageAct import DosageAct
from ageAct import AgeAct
from genderAct import GenderAct
from diseaseAct import DiseaseAct
from drugAct import DrugAct
from top100DrugAct import Top100DrugAct
from symptomAct import SymptomAct
from progressionAct import DiseaseProgressionAct
from maps import main_map, main_map_continuous, metrics_table_all, plot_selected_metrics_per_layer, plot_lesioning_metrics_per_layer
PRINT_REPLIES = True # set to True to print model replies
MIN_AGE = 1
MAX_AGE = 100
TEMPERATURE = 0.7

# ═══════════ Argparse to get model directory ══════════════════════════════════
parser = argparse.ArgumentParser(description="UMAP of per-layer activations for age sweep")
parser.add_argument(
    "--model-dir", required=True,
    help="Path to pretrained model directory"
)
parser.add_argument("--local_rank", type=int, default=4, help="Used for distributed training")

parser.add_argument("--deepspeed", action='store_true', default=False, help="Whether to use DeepSpeed (deprecated)")
parser.add_argument("--use_accelerate", action='store_true', default=True, help="Whether to use Accelerate + FSDP for distributed training")

# add umap and saliency arguments
parser.add_argument("--do_umap", action='store_true', help="Whether to run UMAP")
parser.add_argument("--do_saliency", action='store_true', help="Whether to run saliency")
parser.add_argument("--do_maps", action='store_true', help="Whether to plot LLM maps")
parser.add_argument("--do_heatmap", action='store_true', help="Whether to plot heatmap")
parser.add_argument("--do_lesioning", action='store_true', help="Whether to run layer lesioning analysis")
parser.add_argument("--do_lesioning_finegrained", action='store_true', help="Whether to run fine-grained layer lesioning analysis")
parser.add_argument("--do_activation_patching", action='store_true', help="Whether to run activation patching analysis")
parser.add_argument("--no_cache", action='store_true', help="Whether to not use the cached results. By default, the cached results are used.")
parser.add_argument("--do_activation_patching_finegrained", action='store_true', help="Whether to run fine-grained activation patching analysis")
parser.add_argument("--metrics_table", action='store_true', help="Generate LaTeX metrics table across models and exit")
# parser.add_argument("--do_metrics_table", action='store_true', help="Generate metrics grid PDF and LaTeX table across models and exit")
# add params for age, gender, disease, symptom, drug, progression, without do_
parser.add_argument("--age", action='store_true', help="Whether to run age analysis")
parser.add_argument("--gender", action='store_true', help="Whether to run gender analysis")
parser.add_argument("--disease", action='store_true', help="Whether to run disease analysis")
parser.add_argument("--symptom", action='store_true', help="Whether to run symptom analysis")
parser.add_argument("--drug", action='store_true', help="Whether to run drug analysis")
parser.add_argument("--prog", action='store_true', help="Whether to run progression analysis")
parser.add_argument("--dosage", action='store_true', help="Whether to run dosage analysis")

os.makedirs("results", exist_ok=True)

args = parser.parse_args()
MODEL_DIR = args.model_dir
MODEL_NAME = MODEL_DIR.rstrip("/").split("/")[-1]
do_umap = args.do_umap
do_saliency = args.do_saliency
do_maps = args.do_maps
do_heatmap = args.do_heatmap
do_lesioning = args.do_lesioning
do_lesioning_finegrained = args.do_lesioning_finegrained
do_activation_patching = args.do_activation_patching
do_activation_patching_finegrained = args.do_activation_patching_finegrained
load_cached = not args.no_cache
do_metrics_table = args.metrics_table
do_age = args.age
do_gender = args.gender
do_disease = args.disease
do_symptom = args.symptom
do_drug = args.drug
do_prog = args.prog
do_dosage = args.dosage
print('load_cached', load_cached)
if do_metrics_table:
    # Fast path: generate tables/plots from cached results without loading any model
    print("\nGenerating metrics table and selected per-layer metrics grid from caches…")
    metrics_table_all()
    plot_selected_metrics_per_layer()
    plot_lesioning_metrics_per_layer()
    import sys
    sys.exit(0)

# if hostname is dl, only use 4 GPUs
if socket.gethostname() in ["dl", "a100-4gpu-east5"]:
    print("Using 4 GPUs")
    max_mem = {
        0: "8GB",  # GPU 0 - reduced from 10GB
        1: "8GB",  # GPU 1 - reduced from 10GB
        2: "8GB",  # GPU 2 - reduced from 10GB
        3: "8GB",  # GPU 3 - reduced from 10GB
    }
    tp_size = 4
    nr_gpu = 4
else:
    print("Using 8 GPUs")
    max_mem = {
        0: "20GB",  # GPU 0
        1: "20GB",  # GPU 1
        2: "20GB",  # GPU 2
        3: "20GB",  # GPU 3 
        4: "20GB",  # GPU 4
        5: "20GB",  # GPU 5
        6: "20GB",  # GPU 6
        7: "20GB",  # GPU 7
    }
    tp_size = 8
    nr_gpu = 8

# ═══════════ 1. Load model ════════════════════════════════════════════════

if "Llama-4" in MODEL_NAME:
     # → use the new Chat-capable Llama-4 loader
     from transformers import AutoProcessor, Llama4ForConditionalGeneration

     processor = AutoProcessor.from_pretrained(MODEL_DIR)
     tokenizer = processor.tokenizer
     
     # Create explicit device map for better distribution
     from transformers import AutoConfig
     config = AutoConfig.from_pretrained(MODEL_DIR)
     
     # Handle different config attribute names for different model types
     if hasattr(config, 'num_hidden_layers'):
         num_layers = config.num_hidden_layers
     elif hasattr(config, 'num_layers'):
         num_layers = config.num_layers
     elif hasattr(config, 'n_layer'):
         num_layers = config.n_layer
     else:
         # Fallback: try to infer from the model structure
         print("Warning: Could not determine number of layers from config, using fallback")
         num_layers = 32  # Default fallback
     
     layers_per_gpu = num_layers // 4
     
     device_map = {
         "model.embed_tokens": 0,
         "model.norm": 3,
         "lm_head": 3,
     }
     
     # Distribute layers across GPUs
     for i in range(num_layers):
         gpu_id = i // layers_per_gpu
         if gpu_id >= 4:  # Fallback to GPU 3 if we have more layers than expected
             gpu_id = 3
         device_map[f"model.layers.{i}"] = gpu_id
     
     model = Llama4ForConditionalGeneration.from_pretrained(
         MODEL_DIR,
         attn_implementation="flex_attention",
         device_map=device_map,
         max_memory=max_mem,
         torch_dtype=torch.bfloat16,
         trust_remote_code=True, 
     ).eval()
     print(f"Loaded Llama-4 model: {MODEL_NAME} with explicit device map")
elif "Gemma-3" in MODEL_NAME:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    config = AutoConfig.from_pretrained(MODEL_DIR)

    # Determine number of text layers
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 62

    import math
    # Avoid GPU 0 for language layers to keep room for vision tower
    effective_gpus_for_text = max(1, nr_gpu - 1)
    layers_per_gpu = max(1, math.ceil(num_layers / effective_gpus_for_text))

    device_map = {}
    # Vision stack on GPU 0 if present
    device_map["model.vision_tower"] = 0
    device_map["model.multi_modal_projector"] = 0
    # Tie word embeddings and lm_head on the same GPU to avoid cross-device weight tying
    device_map["model.language_model.embed_tokens"] = 1
    device_map["lm_head"] = 1
    # Final norm on last GPU
    device_map["model.language_model.norm"] = nr_gpu - 1

    # Distribute language layers across GPUs 1..nr_gpu-1
    for i in range(num_layers):
        gpu_id = 1 + (i // layers_per_gpu)
        if gpu_id >= nr_gpu:
            gpu_id = nr_gpu - 1
        device_map[f"model.language_model.layers.{i}"] = gpu_id

    print(f"Distributing {num_layers} Gemma-3 text layers across GPUs 1..{nr_gpu-1} ({layers_per_gpu} layers per GPU)")
    gpu_layer_counts = {}
    for key, gpu_id in device_map.items():
        if key.startswith("model.language_model.layers."):
            gpu_layer_counts[gpu_id] = gpu_layer_counts.get(gpu_id, 0) + 1
    for gpu_id in sorted(gpu_layer_counts.keys()):
        print(f"  GPU {gpu_id}: {gpu_layer_counts[gpu_id]} text layers")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        offload_folder=None,
        low_cpu_mem_usage=False,
        trust_remote_code=True,
        max_memory=max_mem,
    )
    processor = None
    print(f"Loaded Gemma-3 model: {MODEL_NAME} with explicit device map (vision on GPU0, text sharded)")
elif "Gemma" in MODEL_NAME and "MedGemma" not in MODEL_NAME:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    config = AutoConfig.from_pretrained(MODEL_DIR)
        
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map="sequential",
        offload_folder=None,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    processor = None
    print(f"Loaded Gemma model: {MODEL_NAME} with auto device map")
elif "gpt-oss" in MODEL_NAME:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    
    # Create explicit device map to prevent disk offloading
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 36  # Default for GPT-OSS-120B

    import math
    layers_per_gpu = max(1, math.ceil(num_layers / nr_gpu))

    device_map = {
        # Tie weights on the same GPU to avoid cross-device weight tying
        "model.embed_tokens": 0,
        "lm_head": 0,
        # Place final norm on the last GPU
        "model.norm": nr_gpu - 1,
    }

    # Distribute layers across all GPUs
    for i in range(num_layers):
        gpu_id = min(i // layers_per_gpu, nr_gpu - 1)
        device_map[f"model.layers.{i}"] = gpu_id

    # Debug: Print layer distribution
    print(f"Distributing {num_layers} GPT-OSS layers across {nr_gpu} GPUs ({layers_per_gpu} layers per GPU)")
    gpu_layer_counts = {}
    for key, gpu_id in device_map.items():
        if key.startswith("model.layers."):
            gpu_layer_counts[gpu_id] = gpu_layer_counts.get(gpu_id, 0) + 1
    for gpu_id in sorted(gpu_layer_counts.keys()):
        print(f"  GPU {gpu_id}: {gpu_layer_counts[gpu_id]} layers")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        trust_remote_code=True,
        offload_folder=None,  # Disable disk offloading
        low_cpu_mem_usage=False,  # Disable low CPU memory usage to prevent offloading
        max_memory=max_mem,
    ).eval()
    processor = None
    print(f"Loaded GPT-OSS model: {MODEL_NAME} with explicit device map and no disk offloading")
elif "70B" in MODEL_NAME: # Llama-3.3-70B-Instruct
     # → fall back to the Llama-3 / standard LM API
     from transformers import AutoModelForCausalLM, AutoTokenizer

     # (original fast-tokenizer logic)
     use_fast = not MODEL_NAME.startswith("Llama-4")
     tokenizer = AutoTokenizer.from_pretrained(
         MODEL_DIR, use_fast=use_fast, trust_remote_code=True
     )
     
     # Create explicit device map for other models
     from transformers import AutoConfig
     config = AutoConfig.from_pretrained(MODEL_DIR)
     
     # Handle different config attribute names for different model types
     if hasattr(config, 'num_hidden_layers'):
         num_layers = config.num_hidden_layers
     elif hasattr(config, 'num_layers'):
         num_layers = config.num_layers
     elif hasattr(config, 'n_layer'):
         num_layers = config.n_layer
     else:
         # Fallback: try to infer from the model structure
         print("Warning: Could not determine number of layers from config, using fallback")
         num_layers = 32  # Default fallback
     
     layers_per_gpu = num_layers // nr_gpu
     
     device_map = {
         "model.embed_tokens": 0,
         "model.norm": nr_gpu - 1,
         "lm_head": nr_gpu - 1,
     }
     
     # Distribute layers across GPUs
     for i in range(num_layers):
         gpu_id = i // layers_per_gpu
         if gpu_id >= nr_gpu:  # Fallback to last GPU if we have more layers than expected
             gpu_id = nr_gpu - 1
         device_map[f"model.layers.{i}"] = gpu_id
     
     # Debug: Print layer distribution
     print(f"Distributing {num_layers} layers across {nr_gpu} GPUs ({layers_per_gpu} layers per GPU)")
     gpu_layer_counts = {}
     for key, gpu_id in device_map.items():
         if key.startswith("model.layers."):
             gpu_layer_counts[gpu_id] = gpu_layer_counts.get(gpu_id, 0) + 1
     for gpu_id in sorted(gpu_layer_counts.keys()):
         print(f"  GPU {gpu_id}: {gpu_layer_counts[gpu_id]} layers")
     
     model = AutoModelForCausalLM.from_pretrained(
         MODEL_DIR,
         torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
         device_map=device_map,
         trust_remote_code=True,
         max_memory=max_mem, 
     ).eval()
     processor = None
     print(f"Loaded Model: {MODEL_NAME} with explicit device map")
elif "Qwen" in MODEL_NAME or "MedGemma" in MODEL_NAME:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True
    )
    
    # Create explicit device map to prevent disk offloading
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 32  # Default fallback
    
    import math
    layers_per_gpu = max(1, math.ceil(num_layers / nr_gpu))
    
    device_map = {
        # Keep tied weights on the same device to avoid auto-moves
        "model.embed_tokens": 0,
        "lm_head": 0,
        # Place final norm on the last GPU
        "model.norm": nr_gpu - 1,
    }
    
    # Distribute layers across GPUs
    for i in range(num_layers):
        gpu_id = min(i // layers_per_gpu, nr_gpu - 1)
        device_map[f"model.layers.{i}"] = gpu_id
    
    # Debug: Print layer distribution
    print(f"Distributing {num_layers} layers across {nr_gpu} GPUs ({layers_per_gpu} layers per GPU)")
    gpu_layer_counts = {}
    for key, gpu_id in device_map.items():
        if key.startswith("model.layers."):
            gpu_layer_counts[gpu_id] = gpu_layer_counts.get(gpu_id, 0) + 1
    for gpu_id in sorted(gpu_layer_counts.keys()):
        print(f"  GPU {gpu_id}: {gpu_layer_counts[gpu_id]} layers")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,  # Use bfloat16 to save memory
        device_map=device_map,
        trust_remote_code=True,
        offload_folder=None,  # Disable disk offloading to prevent issues during lesioning
        low_cpu_mem_usage=False,  # Disable low CPU memory usage to prevent offloading
        max_memory=max_mem,
    ).eval()
    processor = None
    print(f"Loaded {MODEL_NAME} with explicit device map and no disk offloading")
elif "DeepSeek" in MODEL_NAME:  # DeepSeek models (including DeepSeek-R1)
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True
    )
    
    # Create explicit device map for DeepSeek models
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 61  # Default for DeepSeek-R1
    
    layers_per_gpu = num_layers // nr_gpu
    
    device_map = {
        "model.embed_tokens": 0,
        "model.norm": nr_gpu - 1,
        "lm_head": nr_gpu - 1,
    }
    
    # Distribute layers across GPUs
    for i in range(num_layers):
        gpu_id = i // layers_per_gpu
        if gpu_id >= nr_gpu:  # Fallback to last GPU if we have more layers than expected
            gpu_id = nr_gpu - 1
        device_map[f"model.layers.{i}"] = gpu_id
    
    # Debug: Print layer distribution
    print(f"Distributing {num_layers} layers across {nr_gpu} GPUs ({layers_per_gpu} layers per GPU)")
    gpu_layer_counts = {}
    for key, gpu_id in device_map.items():
        if key.startswith("model.layers."):
            gpu_layer_counts[gpu_id] = gpu_layer_counts.get(gpu_id, 0) + 1
    for gpu_id in sorted(gpu_layer_counts.keys()):
        print(f"  GPU {gpu_id}: {gpu_layer_counts[gpu_id]} layers")
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        trust_remote_code=True,
        max_memory=max_mem,
        offload_folder=None,  # Disable disk offloading to prevent issues during lesioning
        low_cpu_mem_usage=False,  # Disable low CPU memory usage to prevent offloading
    ).eval()
    processor = None
    print(f"Loaded DeepSeek model: {MODEL_NAME} with explicit device map")
elif "grok" in MODEL_NAME.lower():  # Grok models
    try:
        # Try loading with trust_remote_code and let transformers handle the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_DIR, trust_remote_code=True
        )
    except Exception as e:
        print(f"Failed to load tokenizer with AutoTokenizer: {e}")
        print("Trying alternative tokenizer loading...")
        try:
            # Try without specifying use_fast parameter
            tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
        except Exception as e2:
            print(f"Failed to load tokenizer with alternative method: {e2}")
            print("Using GPT-2 tokenizer as fallback for Grok model...")
            # Use GPT-2 tokenizer as fallback since Grok is similar to GPT architecture
            from transformers import GPT2Tokenizer
            tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            # Add padding token if not present
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
    
    # Create explicit device map for Grok models
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 64  # Default for Grok-2
    
    layers_per_gpu = num_layers // nr_gpu
    
    device_map = {
        "model.embed_tokens": 0,
        "model.norm": nr_gpu - 1,
        "lm_head": nr_gpu - 1,
    }
    
    # Distribute layers across GPUs
    for i in range(num_layers):
        gpu_id = i // layers_per_gpu
        if gpu_id >= nr_gpu:  # Fallback to last GPU if we have more layers than expected
            gpu_id = nr_gpu - 1
        device_map[f"model.layers.{i}"] = gpu_id
    
    # Debug: Print layer distribution
    print(f"Distributing {num_layers} layers across {nr_gpu} GPUs ({layers_per_gpu} layers per GPU)")
    gpu_layer_counts = {}
    for key, gpu_id in device_map.items():
        if key.startswith("model.layers."):
            gpu_layer_counts[gpu_id] = gpu_layer_counts.get(gpu_id, 0) + 1
    for gpu_id in sorted(gpu_layer_counts.keys()):
        print(f"  GPU {gpu_id}: {gpu_layer_counts[gpu_id]} layers")
    
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_DIR,
            torch_dtype=torch.bfloat16,
            device_map=device_map,
            trust_remote_code=True,
            max_memory=max_mem,
        ).eval()
    except OSError as e:
        print(f"Failed to load model with standard method: {e}")
        print("Trying alternative loading method for Grok model...")
        # Try loading without device_map first, then move to devices manually
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_DIR,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            max_memory=max_mem,
        ).eval()
        # Move model to devices manually if needed
        print("Model loaded successfully with alternative method")
    processor = None
    print(f"Loaded Grok model: {MODEL_NAME} with explicit device map")
elif "FP8" in MODEL_NAME:  # FP8 quantized models (like Llama-405B-FP8)
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True
    )
    
    # Create explicit device map for FP8 models
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 126  # Default for Llama-405B
    
    layers_per_gpu = num_layers // nr_gpu
    
    device_map = {
        "model.embed_tokens": 0,
        "model.norm": nr_gpu - 1,
        "lm_head": nr_gpu - 1,
    }
    
    # Distribute layers across GPUs
    for i in range(num_layers):
        gpu_id = i // layers_per_gpu
        if gpu_id >= nr_gpu:  # Fallback to last GPU if we have more layers than expected
            gpu_id = nr_gpu - 1
        device_map[f"model.layers.{i}"] = gpu_id
    
    # Debug: Print layer distribution
    print(f"Distributing {num_layers} layers across {nr_gpu} GPUs ({layers_per_gpu} layers per GPU)")
    gpu_layer_counts = {}
    for key, gpu_id in device_map.items():
        if key.startswith("model.layers."):
            gpu_layer_counts[gpu_id] = gpu_layer_counts.get(gpu_id, 0) + 1
    for gpu_id in sorted(gpu_layer_counts.keys()):
        print(f"  GPU {gpu_id}: {gpu_layer_counts[gpu_id]} layers")
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,  # FP8 models require bfloat16
        device_map=device_map,
        trust_remote_code=True,
        max_memory=max_mem,
        offload_folder=None,  # Disable disk offloading to prevent issues during lesioning
        low_cpu_mem_usage=False,  # Disable low CPU memory usage to prevent offloading
    ).eval()
    processor = None
    print(f"Loaded FP8 model: {MODEL_NAME} with explicit device map")
else:   
     tokenizer = AutoTokenizer.from_pretrained(
         MODEL_DIR, use_fast=True, trust_remote_code=True
     )
     
     # Create explicit device map for other models
     from transformers import AutoConfig
     config = AutoConfig.from_pretrained(MODEL_DIR)
     
     
     model = AutoModelForCausalLM.from_pretrained(
         MODEL_DIR,
         torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
         device_map="balanced",
         trust_remote_code=True,
         max_memory=max_mem,
         offload_folder=None,  # Disable disk offloading to prevent issues during lesioning
     ).eval()
     processor = None


if args.use_accelerate:
    print("Using Accelerate for model parallelism...")
    from accelerate import Accelerator
    import os
    
    # For Gemma models, skip Accelerate to avoid offloading issues during lesioning
    if "Gemma" in MODEL_NAME or "Qwen" in MODEL_NAME or "gpt-oss" in MODEL_NAME:
        print("Skipping Accelerate for Gemma model to avoid offloading issues during lesioning")
        print("Model will use native device placement")
    else:
        # Initialize accelerator for single machine multi-GPU
        accelerator = Accelerator(
            mixed_precision="fp16",
            device_placement=True,
        )
        
        # Prepare model (will automatically handle device placement)
        model = accelerator.prepare(model)
        
        print(f"Model prepared with Accelerate on {accelerator.device}")
    
elif args.deepspeed:
    print("Using DeepSpeed ...")
    import deepspeed
    import os
    
    # Set environment variables for DeepSpeed to use all 4 GPUs
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
    os.environ["WORLD_SIZE"] = "4"
    os.environ["RANK"] = "0"
    os.environ["LOCAL_RANK"] = "0"
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    
    deepspeed.init_distributed()
    ds_inference_config = {
        "tensor_parallel": {
            "tp_size": tp_size  # Use 4 GPUs for this machine
        },
        "dtype": "float16",
        "replace_with_kernel_inject": False
    }

    model = deepspeed.init_inference(
        model,
        config=ds_inference_config
    )

print(model)
print('do_umap', do_umap)
print('do_saliency', do_saliency)
print('do_maps', do_maps)
print('do_lesioning', do_lesioning)

# Print device distribution information
print("\n=== Device Distribution ===")
for name, module in model.named_modules():
    if hasattr(module, 'weight') and module.weight is not None:
        print(f"{name}: {module.weight.device}")
        break  # Just show first few to avoid spam
print("=== End Device Distribution ===\n")

# Create analysis objects
dosageAct = DosageAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
top100DrugAct = Top100DrugAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
genderAct = GenderAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
ageAct = AgeAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, MIN_AGE, MAX_AGE, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
diseaseAct = DiseaseAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
symptomAct = SymptomAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
progressionAct = DiseaseProgressionAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
drugAct = DrugAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)

# if do_drug:
#     top100DrugAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, load_cached=load_cached)
if do_age:
    ageAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, do_heatmap=do_heatmap, load_cached=load_cached)
if do_disease:
    diseaseAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, do_heatmap=do_heatmap, load_cached=load_cached)
if do_symptom:
    symptomAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, do_heatmap=do_heatmap, load_cached=load_cached)
if do_dosage:
    dosageAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, do_heatmap=do_heatmap, load_cached=load_cached)
if do_gender:
    genderAct.run(do_umap, do_saliency, do_maps, load_cached=load_cached)
if do_prog:
    progressionAct.run(do_umap, do_saliency=False, do_maps=False, load_cached=load_cached)
if do_drug:
    drugAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, do_heatmap=do_heatmap, load_cached=load_cached)


# Create main integrated map
if do_maps:
    print("\n" + "="*50)
    print("Creating main integrated LLM map...")
    print("="*50)
    
    # Use any of the Act objects to call main_map (they all inherit from Act)
    main_map(ageAct=ageAct, diseaseAct=diseaseAct, 
                      top100DrugAct=top100DrugAct, symptomAct=symptomAct, dosageAct=dosageAct)
    # Also render continuous map variant with default params
    main_map_continuous(ageAct=ageAct, diseaseAct=diseaseAct, 
                        top100DrugAct=top100DrugAct, symptomAct=symptomAct, dosageAct=dosageAct)
    
