"""
UMAP of per-layer activations
Age sweep 1–100 with three subject variants
"""
import math, torch, umap, matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig
from functools import partial
import numpy as np
import argparse
import random
import os
import socket


from dosageAct import DosageAct
from ageAct import AgeAct
from genderAct import GenderAct
from diseaseAct import DiseaseAct
from drugAct import DrugAct
from top100DrugAct import Top100DrugAct
from symptomAct import SymptomAct
from progressionAct import DiseaseProgressionAct
from maps import main_map, main_map_continuous, metrics_table_all, plot_selected_metrics_per_layer, plot_lesioning_metrics_per_layer
PRINT_REPLIES = True # set to True to print model replies
MIN_AGE = 1
MAX_AGE = 100
TEMPERATURE = 0.7

# ═══════════ Argparse to get model directory ══════════════════════════════════
parser = argparse.ArgumentParser(description="UMAP of per-layer activations for age sweep")
parser.add_argument(
    "--model-dir", required=True,
    help="Path to pretrained model directory"
)
parser.add_argument("--local_rank", type=int, default=4, help="Used for distributed training")

parser.add_argument("--deepspeed", action='store_true', default=False, help="Whether to use DeepSpeed (deprecated)")
parser.add_argument("--use_accelerate", action='store_true', default=True, help="Whether to use Accelerate + FSDP for distributed training")

# add umap and saliency arguments
parser.add_argument("--do_umap", action='store_true', help="Whether to run UMAP")
parser.add_argument("--do_saliency", action='store_true', help="Whether to run saliency")
parser.add_argument("--do_maps", action='store_true', help="Whether to plot LLM maps")
parser.add_argument("--do_lesioning", action='store_true', help="Whether to run layer lesioning analysis")
parser.add_argument("--do_lesioning_finegrained", action='store_true', help="Whether to run fine-grained layer lesioning analysis")
parser.add_argument("--do_activation_patching", action='store_true', help="Whether to run activation patching analysis")
parser.add_argument("--no_cache", action='store_true', help="Whether to not use the cached results. By default, the cached results are used.")
parser.add_argument("--do_activation_patching_finegrained", action='store_true', help="Whether to run fine-grained activation patching analysis")
parser.add_argument("--metrics_table", action='store_true', help="Generate LaTeX metrics table across models and exit")
# parser.add_argument("--do_metrics_table", action='store_true', help="Generate metrics grid PDF and LaTeX table across models and exit")
# add params for age, gender, disease, symptom, drug, progression, without do_
parser.add_argument("--age", action='store_true', help="Whether to run age analysis")
parser.add_argument("--gender", action='store_true', help="Whether to run gender analysis")
parser.add_argument("--disease", action='store_true', help="Whether to run disease analysis")
parser.add_argument("--symptom", action='store_true', help="Whether to run symptom analysis")
parser.add_argument("--drug", action='store_true', help="Whether to run drug analysis")
parser.add_argument("--prog", action='store_true', help="Whether to run progression analysis")
parser.add_argument("--dosage", action='store_true', help="Whether to run dosage analysis")

os.makedirs("results", exist_ok=True)

args = parser.parse_args()
MODEL_DIR = args.model_dir
MODEL_NAME = MODEL_DIR.rstrip("/").split("/")[-1]
do_umap = args.do_umap
do_saliency = args.do_saliency
do_maps = args.do_maps
do_lesioning = args.do_lesioning
do_lesioning_finegrained = args.do_lesioning_finegrained
do_activation_patching = args.do_activation_patching
do_activation_patching_finegrained = args.do_activation_patching_finegrained
load_cached = not args.no_cache
do_metrics_table = args.metrics_table
do_age = args.age
do_gender = args.gender
do_disease = args.disease
do_symptom = args.symptom
do_drug = args.drug
do_prog = args.prog
do_dosage = args.dosage
print('load_cached', load_cached)
if do_metrics_table:
    # Fast path: generate tables/plots from cached results without loading any model
    print("\nGenerating metrics table and selected per-layer metrics grid from caches…")
    # metrics_table_all()
    # plot_selected_metrics_per_layer()
    plot_lesioning_metrics_per_layer()
    import sys
    sys.exit(0)

# Use specific GPUs for all models
# Set CUDA_VISIBLE_DEVICES environment variable to specify which GPUs to use
# Example: CUDA_VISIBLE_DEVICES=0,1 python runActSomeGPUs.py --model-dir /path/to/model
import os
cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
gpu_list = [int(x.strip()) for x in cuda_visible.split(',')]
nr_gpu = len(gpu_list)
print(f"Using GPUs: {gpu_list}")

# Create memory allocation for visible GPUs
max_mem = {}
for i, gpu_id in enumerate(gpu_list):
    max_mem[i] = "20GB"  # Allocate 20GB per GPU

tp_size = nr_gpu

# ═══════════ 1. Load model ════════════════════════════════════════════════

if "Llama-4" in MODEL_NAME:
     # → use the new Chat-capable Llama-4 loader
     from transformers import AutoProcessor, Llama4ForConditionalGeneration

     processor = AutoProcessor.from_pretrained(MODEL_DIR)
     tokenizer = processor.tokenizer
     
     # Create explicit device map for better distribution
     from transformers import AutoConfig
     config = AutoConfig.from_pretrained(MODEL_DIR)
     
     # Handle different config attribute names for different model types
     if hasattr(config, 'num_hidden_layers'):
         num_layers = config.num_hidden_layers
     elif hasattr(config, 'num_layers'):
         num_layers = config.num_layers
     elif hasattr(config, 'n_layer'):
         num_layers = config.n_layer
     else:
         # Fallback: try to infer from the model structure
         print("Warning: Could not determine number of layers from config, using fallback")
         num_layers = 32  # Default fallback
     
     device_map = {
         "model.embed_tokens": 0,
         "model.norm": nr_gpu - 1,
         "lm_head": nr_gpu - 1,
     }
     
     # Distribute layers across available GPUs
     layers_per_gpu = max(1, num_layers // nr_gpu)
     for i in range(num_layers):
         gpu_id = min(i // layers_per_gpu, nr_gpu - 1)
         device_map[f"model.layers.{i}"] = gpu_id
     
     model = Llama4ForConditionalGeneration.from_pretrained(
         MODEL_DIR,
         attn_implementation="flex_attention",
         device_map=device_map,
         max_memory=max_mem,
         torch_dtype=torch.bfloat16,
         trust_remote_code=True, 
     ).eval()
     print(f"Loaded Llama-4 model: {MODEL_NAME} with explicit device map")
elif "Gemma-3" in MODEL_NAME:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    config = AutoConfig.from_pretrained(MODEL_DIR)

    # Determine number of text layers
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 62

    device_map = {}
    # Vision components on GPU 0, language components distributed
    device_map["model.vision_tower"] = 0
    device_map["model.multi_modal_projector"] = 0
    device_map["model.language_model.embed_tokens"] = 0
    device_map["lm_head"] = nr_gpu - 1
    device_map["model.language_model.norm"] = nr_gpu - 1

    # Distribute language layers across available GPUs
    layers_per_gpu = max(1, num_layers // nr_gpu)
    for i in range(num_layers):
        gpu_id = min(i // layers_per_gpu, nr_gpu - 1)
        device_map[f"model.language_model.layers.{i}"] = gpu_id

    print(f"Distributing {num_layers} Gemma-3 text layers across {nr_gpu} GPUs")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        offload_folder=None,
        low_cpu_mem_usage=False,
        trust_remote_code=True,
        max_memory=max_mem,
    )
    processor = None
    print(f"Loaded Gemma-3 model: {MODEL_NAME} with explicit device map (vision on GPU0, text sharded)")
elif "Gemma" in MODEL_NAME and "MedGemma" not in MODEL_NAME:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    config = AutoConfig.from_pretrained(MODEL_DIR)
        
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map="sequential",
        offload_folder=None,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    processor = None
    print(f"Loaded Gemma model: {MODEL_NAME} with auto device map")
elif "gpt-oss" in MODEL_NAME:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    
    # Create explicit device map to prevent disk offloading
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 36  # Default for GPT-OSS-120B

    device_map = {
        # Core components distributed
        "model.embed_tokens": 0,
        "lm_head": nr_gpu - 1,
        "model.norm": nr_gpu - 1,
    }

    # Distribute layers across available GPUs
    layers_per_gpu = max(1, num_layers // nr_gpu)
    for i in range(num_layers):
        gpu_id = min(i // layers_per_gpu, nr_gpu - 1)
        device_map[f"model.layers.{i}"] = gpu_id

    print(f"Distributing {num_layers} GPT-OSS layers across {nr_gpu} GPUs")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        trust_remote_code=True,
        offload_folder=None,  # Disable disk offloading
        low_cpu_mem_usage=False,  # Disable low CPU memory usage to prevent offloading
        max_memory=max_mem,
    ).eval()
    processor = None
    print(f"Loaded GPT-OSS model: {MODEL_NAME} with explicit device map and no disk offloading")
elif "70B" in MODEL_NAME: # Llama-3.3-70B-Instruct
     # → fall back to the Llama-3 / standard LM API
     from transformers import AutoModelForCausalLM, AutoTokenizer

     # (original fast-tokenizer logic)
     use_fast = not MODEL_NAME.startswith("Llama-4")
     tokenizer = AutoTokenizer.from_pretrained(
         MODEL_DIR, use_fast=use_fast, trust_remote_code=True
     )
     
     # Create explicit device map for other models
     from transformers import AutoConfig
     config = AutoConfig.from_pretrained(MODEL_DIR)
     
     # Handle different config attribute names for different model types
     if hasattr(config, 'num_hidden_layers'):
         num_layers = config.num_hidden_layers
     elif hasattr(config, 'num_layers'):
         num_layers = config.num_layers
     elif hasattr(config, 'n_layer'):
         num_layers = config.n_layer
     else:
         # Fallback: try to infer from the model structure
         print("Warning: Could not determine number of layers from config, using fallback")
         num_layers = 32  # Default fallback
     
     device_map = {
         "model.embed_tokens": 0,
         "model.norm": nr_gpu - 1,
         "lm_head": nr_gpu - 1,
     }
     
     # Distribute layers across available GPUs
     layers_per_gpu = max(1, num_layers // nr_gpu)
     for i in range(num_layers):
         gpu_id = min(i // layers_per_gpu, nr_gpu - 1)
         device_map[f"model.layers.{i}"] = gpu_id
     
     print(f"Distributing {num_layers} layers across {nr_gpu} GPUs")
     
     model = AutoModelForCausalLM.from_pretrained(
         MODEL_DIR,
         torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
         device_map=device_map,
         trust_remote_code=True,
         max_memory=max_mem, 
     ).eval()
     processor = None
     print(f"Loaded Model: {MODEL_NAME} with explicit device map")
elif "Qwen" in MODEL_NAME or "MedGemma" in MODEL_NAME:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True
    )
    
    # Create explicit device map to prevent disk offloading
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 32  # Default fallback
    
    device_map = {
        # All components on GPU 0
        "model.embed_tokens": 0,
        "lm_head": 0,
        "model.norm": 0,
    }
    
    # All layers on GPU 0
    for i in range(num_layers):
        device_map[f"model.layers.{i}"] = 0
    
    print(f"Loading {num_layers} layers on GPU 0")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,  # Use bfloat16 to save memory
        device_map=device_map,
        trust_remote_code=True,
        offload_folder=None,  # Disable disk offloading to prevent issues during lesioning
        low_cpu_mem_usage=False,  # Disable low CPU memory usage to prevent offloading
        max_memory=max_mem,
    ).eval()
    processor = None
    print(f"Loaded {MODEL_NAME} with explicit device map and no disk offloading")
elif "DeepSeek" in MODEL_NAME:  # DeepSeek models (including DeepSeek-R1)
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True
    )
    
    # Create explicit device map for DeepSeek models
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 61  # Default for DeepSeek-R1
    
    device_map = {
        "model.embed_tokens": 0,
        "model.norm": 0,
        "lm_head": 0,
    }
    
    # All layers on GPU 0
    for i in range(num_layers):
        device_map[f"model.layers.{i}"] = 0
    
    print(f"Loading {num_layers} layers on GPU 0")
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        trust_remote_code=True,
        max_memory=max_mem,
        offload_folder=None,  # Disable disk offloading to prevent issues during lesioning
        low_cpu_mem_usage=False,  # Disable low CPU memory usage to prevent offloading
    ).eval()
    processor = None
    print(f"Loaded DeepSeek model: {MODEL_NAME} with explicit device map")
elif "grok" in MODEL_NAME.lower():  # Grok models
    try:
        # Try loading with trust_remote_code and let transformers handle the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_DIR, trust_remote_code=True
        )
    except Exception as e:
        print(f"Failed to load tokenizer with AutoTokenizer: {e}")
        print("Trying alternative tokenizer loading...")
        try:
            # Try without specifying use_fast parameter
            tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
        except Exception as e2:
            print(f"Failed to load tokenizer with alternative method: {e2}")
            print("Using GPT-2 tokenizer as fallback for Grok model...")
            # Use GPT-2 tokenizer as fallback since Grok is similar to GPT architecture
            from transformers import GPT2Tokenizer
            tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            # Add padding token if not present
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
    
    # Create explicit device map for Grok models
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 64  # Default for Grok-2
    
    device_map = {
        "model.embed_tokens": 0,
        "model.norm": 0,
        "lm_head": 0,
    }
    
    # All layers on GPU 0
    for i in range(num_layers):
        device_map[f"model.layers.{i}"] = 0
    
    print(f"Loading {num_layers} layers on GPU 0")
    
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_DIR,
            torch_dtype=torch.bfloat16,
            device_map=device_map,
            trust_remote_code=True,
            max_memory=max_mem,
        ).eval()
    except OSError as e:
        print(f"Failed to load model with standard method: {e}")
        print("Trying alternative loading method for Grok model...")
        # Try loading without device_map first, then move to devices manually
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_DIR,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            max_memory=max_mem,
        ).eval()
        # Move model to devices manually if needed
        print("Model loaded successfully with alternative method")
    processor = None
    print(f"Loaded Grok model: {MODEL_NAME} with explicit device map")
elif "FP8" in MODEL_NAME:  # FP8 quantized models (like Llama-405B-FP8)
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True
    )
    
    # Create explicit device map for FP8 models
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(MODEL_DIR)
    
    # Handle different config attribute names for different model types
    if hasattr(config, 'num_hidden_layers'):
        num_layers = config.num_hidden_layers
    elif hasattr(config, 'num_layers'):
        num_layers = config.num_layers
    elif hasattr(config, 'n_layer'):
        num_layers = config.n_layer
    else:
        # Fallback: try to infer from the model structure
        print("Warning: Could not determine number of layers from config, using fallback")
        num_layers = 126  # Default for Llama-405B
    
    device_map = {
        "model.embed_tokens": 0,
        "model.norm": 0,
        "lm_head": 0,
    }
    
    # All layers on GPU 0
    for i in range(num_layers):
        device_map[f"model.layers.{i}"] = 0
    
    print(f"Loading {num_layers} layers on GPU 0")
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,  # FP8 models require bfloat16
        device_map=device_map,
        trust_remote_code=True,
        max_memory=max_mem,
        offload_folder=None,  # Disable disk offloading to prevent issues during lesioning
        low_cpu_mem_usage=False,  # Disable low CPU memory usage to prevent offloading
    ).eval()
    processor = None
    print(f"Loaded FP8 model: {MODEL_NAME} with explicit device map")
else:   
     tokenizer = AutoTokenizer.from_pretrained(
         MODEL_DIR, use_fast=True, trust_remote_code=True
     )
     
     # Create explicit device map for other models
     from transformers import AutoConfig
     config = AutoConfig.from_pretrained(MODEL_DIR)
     
     
     model = AutoModelForCausalLM.from_pretrained(
         MODEL_DIR,
         torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
         device_map="auto",
         trust_remote_code=True,
         max_memory=max_mem,
         offload_folder=None,  # Disable disk offloading to prevent issues during lesioning
     ).eval()
     processor = None


if args.use_accelerate:
    print(f"Using Accelerate for {nr_gpu} GPUs...")
    from accelerate import Accelerator
    import os
    
    # Initialize accelerator for multiple GPUs
    accelerator = Accelerator(
        mixed_precision="fp16",
        device_placement=True,
    )
    
    # Prepare model (will automatically handle device placement)
    model = accelerator.prepare(model)
    
    print(f"Model prepared with Accelerate on {accelerator.device}")
    
elif args.deepspeed:
    print(f"Using DeepSpeed for {nr_gpu} GPUs...")
    import deepspeed
    import os
    
    # Set environment variables for DeepSpeed to use available GPUs
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible
    os.environ["WORLD_SIZE"] = str(nr_gpu)
    os.environ["RANK"] = "0"
    os.environ["LOCAL_RANK"] = "0"
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    
    deepspeed.init_distributed()
    ds_inference_config = {
        "tensor_parallel": {
            "tp_size": nr_gpu  # Use available GPUs
        },
        "dtype": "float16",
        "replace_with_kernel_inject": False
    }

    model = deepspeed.init_inference(
        model,
        config=ds_inference_config
    )

print(model)
print('do_umap', do_umap)
print('do_saliency', do_saliency)
print('do_maps', do_maps)
print('do_lesioning', do_lesioning)

# Print device distribution information
print("\n=== Device Distribution ===")
for name, module in model.named_modules():
    if hasattr(module, 'weight') and module.weight is not None:
        print(f"{name}: {module.weight.device}")
        break  # Just show first few to avoid spam
print("=== End Device Distribution ===\n")

# Create analysis objects
dosageAct = DosageAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
top100DrugAct = Top100DrugAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
genderAct = GenderAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
ageAct = AgeAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, MIN_AGE, MAX_AGE, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
diseaseAct = DiseaseAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
symptomAct = SymptomAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
progressionAct = DiseaseProgressionAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)
drugAct = DrugAct(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=args.use_accelerate)

# if do_drug:
#     top100DrugAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, load_cached=load_cached)
if do_age:
    ageAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, load_cached=load_cached)
if do_disease:
    diseaseAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, load_cached=load_cached)
if do_symptom:
    symptomAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, load_cached=load_cached)
if do_dosage:
    dosageAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, load_cached=load_cached)
if do_gender:
    genderAct.run(do_umap, do_saliency, do_maps, load_cached=load_cached)
if do_prog:
    progressionAct.run(do_umap, do_saliency=False, do_maps=False, load_cached=load_cached)
if do_drug:
    drugAct.run(do_umap, do_saliency, do_maps, do_lesioning=do_lesioning, do_lesioning_finegrained=do_lesioning_finegrained, do_activation_patching=do_activation_patching, do_activation_patching_finegrained=do_activation_patching_finegrained, load_cached=load_cached)


# Create main integrated map
if do_maps:
    print("\n" + "="*50)
    print("Creating main integrated LLM map...")
    print("="*50)
    
    # Use any of the Act objects to call main_map (they all inherit from Act)
    main_map(ageAct=ageAct, diseaseAct=diseaseAct, 
                      top100DrugAct=top100DrugAct, symptomAct=symptomAct, dosageAct=dosageAct)
    # Also render continuous map variant with default params
    main_map_continuous(ageAct=ageAct, diseaseAct=diseaseAct, 
                        top100DrugAct=top100DrugAct, symptomAct=symptomAct, dosageAct=dosageAct)
    
