import os
import json
import random
import math
from pathlib import Path
from collections import Counter
from peft import PeftModel
import string
import re
import time
import csv
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    GenerationConfig, BitsAndBytesConfig
)
import transformers
import argparse
import random
from collections import defaultdict
import types
import uuid


# ---------------------- Runtime Configs ---------------------- #

JOB_ID = os.getenv("SLURM_JOB_ID") or "manual_run"
MAX_NEW_TOKENS = 1
MAX_PROMPT_LENGTH = 1 # only applicable to perplexity
MAX_PERPLEXITY_LENGTH = 1 # only applicable to perplexity
MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B" # "deepseek-ai/deepseek-moe-16b-chat" # "deepseek-ai/deepseek-moe-16b-base" # "mistralai/Mixtral-8x7B-Instruct-v0.1"
MODEL_TYPE = "chat" # "chat" or "base"
SEED = 42
MODE = "perplexity"  # "qa" or "perplexity"
SELECTION_MODE = "gini"  # "gini" or "threshold"
THRESHOLD_FACTOR =[0.9]

BETA = None
SAMPLE_BEFORE_LOAD = False
VECTORIZED = False
SUM_THRESHOLD = [-1]

### Model config per model family
# Place holders
DEFAULT_MIN_K = None
DEFAULT_MAX_K = None 
DEFAULT_NUM_EXPERTS = None # top k, should equal to number of experts per token
# qwen
MODEL_CFG_QWEN_2_54b = {
    "default_min_k": 8,
    "default_max_k": 64,
    "default_num_experts": 8,
}
# deepseek
MODEL_CFG_DS = {
    "default_min_k": 6,
    "default_max_k": 64,
    "default_num_experts": 6,
    "layer_ranges": [ 
        (0, 10),
        (11, 20),
        (21, 27), 
    ],
    "layer_ranges_threshold": [ 
        (0, 10),
        (11, 20),
        (21, 27), 
    ],
}
# mixtral
MODEL_CFG_MIXTRAL_7b = {
    "default_min_k": 2,
    "default_max_k": 8,
    "default_num_experts": 2,
    "layer_ranges": [ # layer index ranges that is used when multiple sum_threshold is inputted. ranges inclusive on both sides, 0-indexed. all layers included (even non-moe layers)
        (0, 10),
        (11, 20),
        (21, 31), 
    ],
    "layer_ranges_threshold": [ # similar above but for t_fixed
        (0, 10),
        (11, 20),
        (21, 31), 
    ],
}

MODEL_CFG_MAP = {
    "qwen_54b": MODEL_CFG_QWEN_2_54b,
    "deepseek": MODEL_CFG_DS,
    "mixtral_7b": MODEL_CFG_MIXTRAL_7b,
}


### Dataset config 
DATASET_CFG = None # placeholder
DATASET_CFG_WIKI = {
    "path": "wikitext",
    "name": "wikitext-103-raw-v1",
    "split": "validation",
    "shuffle": True
}
DATASET_CFG_SQUAD = {
    "path": "rajpurkar/squad",
    "split": "validation",
    "shuffle": True
}
DATASET_CFG_ARC_EASY = {
    "path": "ai2_arc",
    "name": "ARC-Easy", #ARC-Easy
    "split": "test",
    "shuffle": True,
}
DATASET_CFG_ARC_CHALLENGE = {
    "path": "ai2_arc",
    "name": "ARC-Challenge", #ARC-Challenge
    "split": "test",
    "shuffle": True,
}
DATASET_CFG_GSM8K = {
    "path": "openai/gsm8k",
    "name": "main",
    "split": "test",
    "shuffle": True,
}
DATASET_CFG_MMLU = {
    "path": "cais/mmlu",
    "name": "all",
    "split": "test",
    "shuffle": True,
}
DATASET_CFG_MAP = {
    "wiki": DATASET_CFG_WIKI,
    "squad": DATASET_CFG_SQUAD,
    "arc_easy": DATASET_CFG_ARC_EASY,
    "arc_challenge": DATASET_CFG_ARC_CHALLENGE,
    "gsm8k": DATASET_CFG_GSM8K,
    "mmlu": DATASET_CFG_MMLU,
}


### Output dir
OUTPUTS = Path("logs") / JOB_ID
OUTPUT_DIR = "OUTPUT/"

### Defaults
BATCH_SIZE = 32
SAMPLE_SIZE = 160



# ---------------------- Routing Utilities ---------------------- #
def compute_gini_coefficient(scores):
    """Compute the Gini coefficient of a distribution.
    The Gini coefficient ranges from 0 (perfect equality) to 1 (perfect inequality).
    """
    sorted_scores = torch.sort(scores)[0]
    n = len(sorted_scores)
    index = torch.arange(1, n + 1, dtype=scores.dtype, device=scores.device)
    weighted_sum = torch.sum(index * sorted_scores)
    total = torch.sum(sorted_scores)
    gini = (2 * weighted_sum) / (n * total) - (n + 1) / n
    return gini

def compute_dynamic_k(scores, min_k=DEFAULT_MIN_K, max_k=DEFAULT_MAX_K):
    """Compute dynamic k based on score distribution skewness.
    Higher skewness (Gini coefficient) means smaller k.
    """
    gini = compute_gini_coefficient(scores)
    # If Gini is effectively zero, return max_k
    if abs(gini) < 1e-15:
        return max_k

    beta = BETA # >= 1
    k_float = min_k + (max_k - min_k) * (1 - gini) ** beta
    # Ensure k_float is strictly less than max_k
    # e.g., if floating math or a tiny gini might push us above
    if k_float >= (max_k - 1e-9):
        k_float = max_k - 1e-9
    
    # Convert to integer
    k_val = int(k_float)
    # Enforce minimum k
    if k_val < min_k:
        k_val = min_k
    return k_val

def compute_dynamic_k_threshold(scores, t=0.9, min_k=DEFAULT_MIN_K, max_k=DEFAULT_MAX_K):
    """
    Compute dynamic k based on a threshold relative to the maximum score.
    Only experts with a score at least t * max_score are considered.
    """
    max_score = torch.max(scores)
    threshold = t * max_score
    mask = scores >= threshold
    k_val = int(mask.sum().item())
    if k_val < min_k:
        k_val = min_k
    if k_val > max_k:
        k_val = max_k
    return k_val

# ----------------------- Imbalanced Routing Utilities ----------------------- #

# --- Per-batch imbalance helpers -------------------------------------------

def _trimmed_mean(x: np.ndarray, proportion_to_cut: float = 0.05) -> float:
    if x.size == 0:
        return float('nan')
    a = np.sort(x)
    k = int(proportion_to_cut * len(a))
    if k == 0 or 2 * k >= len(a):
        return float(a.mean())
    return float(a[k:-k].mean())

def init_per_batch_imbalance(moe_layer_indices):
    """Create storage for per-batch per-layer imbalance records."""
    return {L: [] for L in moe_layer_indices}

def record_per_batch_imbalance(moe_gates, moe_layer_indices, per_batch_store, id_key="id"):
    """
    Compute per-batch, per-layer imbalance grouped by `id_key`.
    Appends dicts {'id': <id>, 'I':..., 'MV':..., 'T':...} into per_batch_store[L].
    """
    # Flatten once
    assignments = [a for gate in moe_gates for a in getattr(gate, "token_assignments", [])]

    # Precompute routed experts per layer
    nE_by_layer = {}
    for L in moe_layer_indices:
        try:
            nE_by_layer[L] = next(g.n_routed_experts for g in moe_gates if g.layer_idx == L)
        except StopIteration:
            raise ValueError(f"No gate found for layer {L} to determine n_routed_experts.")

    # Group assignments by (layer, id)
    grouped = defaultdict(list)
    for a in assignments:
        L = a.get("layer_idx")
        if L not in nE_by_layer:   # skip layers we don't track
            continue
        gid = a.get(id_key, None)  # None groups all "missing id" together
        grouped[(L, gid)].append(a)

    # Compute stats for each (layer, id) group
    for (L, gid), group in grouped.items():
        nE = nE_by_layer[L]
        counts = np.zeros(nE, dtype=np.int64)

        for a in group:
            for e in a.get("assigned_experts", []):
                if 0 <= e < nE:
                    counts[e] += 1

        T = int(counts.sum())
        if T == 0:
            rec = {"id": gid, "I": 1.0, "MV": 0.0, "T": 0}
        else:
            barN = T / float(nE)
            maxN = int(counts.max())
            MV = (maxN - barN) / (barN + 1e-12)
            I  = maxN / (barN + 1e-12)
            rec = {"id": gid, "I": float(I), "MV": float(MV), "T": T}

        per_batch_store[L].append(rec)

def summarize_and_save_per_batch_imbalance(per_batch_store, out_dir, experiment_name, layer_weights=None):
    """
    Summarize I_{b,L} across batches, and build a step-level aggregate I_b^{agg}.
    Saves JSON with per-layer stats and aggregate stats.
    """
    os.makedirs(os.path.join(out_dir, experiment_name), exist_ok=True)

    # Per-layer summaries
    per_layer_stats = {}
    for L, recs in per_batch_store.items():
        Iv = np.array([r['I'] for r in recs], dtype=float)
        Tv = np.array([r['T'] for r in recs], dtype=float)
        if Iv.size == 0:
            continue
        per_layer_stats[L] = {
            'num_steps': int(Iv.size),
            'mean_I':     float(Iv.mean()),
            'median_I':   float(np.median(Iv)),
            'p95_I':      float(np.quantile(Iv, 0.95)),
            'p99_I':      float(np.quantile(Iv, 0.99)),
            'trimmed_mean_I_5pct': _trimmed_mean(Iv, 0.05),
            'token_weighted_I':    float((Iv * Tv).sum() / max(Tv.sum(), 1.0)),
        }

    # Build aggregate per-batch I_b^{agg} = sum_L w_L I_{b,L}
    layers = sorted(per_batch_store.keys())
    if not layers:
        agg_stats = {}
    else:
        num_steps = min(len(per_batch_store[L]) for L in layers)

        # Default equal weights if none provided
        if layer_weights is None:
            w = np.ones(len(layers), dtype=float) / len(layers)
        else:
            # Normalize provided weights over the selected layers
            w = np.array([layer_weights.get(L, 0.0) for L in layers], dtype=float)
            if w.sum() <= 0:
                w = np.ones(len(layers), dtype=float) / len(layers)
            else:
                w = w / w.sum()

        I_agg = np.zeros(num_steps, dtype=float)
        for b in range(num_steps):
            for j, L in enumerate(layers):
                I_agg[b] += w[j] * float(per_batch_store[L][b]['I'])

        agg_stats = {
            'num_steps': int(num_steps),
            'mean_I_agg':   float(I_agg.mean()),
            'median_I_agg': float(np.median(I_agg)),
            'p95_I_agg':    float(np.quantile(I_agg, 0.95)),
            'p99_I_agg':    float(np.quantile(I_agg, 0.99)),
            'weights':      {str(L): float(wj) for wj, L in zip(w, layers)},
        }

    # Save to JSON
    out_path = os.path.join(out_dir, experiment_name, "per_batch_imbalance_summary.json")
    with open(out_path, "w") as f:
        json.dump({'per_layer': per_layer_stats, 'aggregate': agg_stats}, f, indent=2)
    print(f"[per-batch imbalance] wrote {out_path}")
    return per_layer_stats, agg_stats

# ---------------------- Dataset and Evaluation Utilities ---------------------- #

def set_random_seed(seed: int) -> None:
    """Sets the random seed for reproducible results."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def format_prompt(cfg, example):
    if cfg["path"] == "ai2_arc":
        return format_arc_prompt(example)
    elif cfg["path"] == "openai/gsm8k":
        return format_gsm8k_prompt(example)
    elif cfg["path"] == "cais/mmlu":
        return format_mmlu_prompt(example)
    raise Exception(f"{cfg['path']} not yet supported")

def format_arc_prompt(example):
    """Formats a multiple-choice question into a chat-friendly prompt."""

    question = example["question"].strip()
    choices = example["choices"]["text"]
    labels = example["choices"]["label"]
     
    choice_lines = [f"{l}. {c}" for l, c in zip(labels, choices)]

    if MODEL_TYPE == "chat":
        prompt = (
            f"You are a highly knowledgeable assistant. Please choose the best answer to the following question. "
            f"Respond only with the letter corresponding to the correct choice in the format: ```Answer: <letter>```.\n\n"
            f"Question: {question}\n" + "\n".join(choice_lines) + "\n\n"
        )
    elif MODEL_TYPE == "base":
        raise Exception("not suppoterted yet")
    return prompt

def format_gsm8k_prompt(example):
    question = example["question"].strip()
    if MODEL_TYPE == "chat":
        prompt = (
            "You are a helpful assistant. Solve the following problem step-by-step and you must provide the final answer in the end strictly following the format: ```Answer: <number>```.\n\n"
            f"Question: {question}\n\n"
            f"Reminder: only include the final numeric answer on the last line in the format: ```Answer: <number>```."
        )
    elif MODEL_TYPE == "base":
        raise Exception("not suppoterted yet")
    return prompt

def format_mmlu_prompt(example):
    question = example["question"].strip()
    choices = example["choices"]
    letters = ["A", "B", "C", "D"]
    choice_lines = [f"{l}. {c}" for l, c in zip(letters, choices)]

    if MODEL_TYPE == "chat":
        prompt = (
            "You are a highly knowledgeable assistant. Please choose the best answer to the following question. "
            "Respond only with the letter corresponding to the correct choice in the format: ```Answer: <letter>```.\n\n"
            f"Question: {question}\n" + "\n".join(choice_lines) + "\n\n"
        )
    elif MODEL_TYPE == "base":
        raise Exception("not suppoterted yet")
    return prompt

def format_response(cfg, rsp):
    """Extracing answer from complete response"""
    if cfg["path"] == "ai2_arc":
        return extract_arc_answer(rsp)
    elif cfg["path"] == "openai/gsm8k":
        return extract_gsm8k_answer(rsp)
    elif cfg["path"] == "cais/mmlu":
        return extract_mmlu_answer(rsp)
    raise Exception(f"{cfg['path']} not yet supported")

def extract_arc_answer(text):
    """Extracts the most likely answer letter (A/B/C/D) from the model's output."""
    text = text.strip().upper()
    m = re.findall(r"ANSWER[:\s]*([A-D])\b", text)
    if m:
        return m[-1]
    return "INVALID"

def extract_gsm8k_answer(text: str) -> str:
    # Look for "ANSWER:" followed by optional whitespace and number-like tokens
    matches = re.findall(r"(?i)ANSWER[:\s]*([$\s]*-?\d[\d,]*\.?\d*)", text)
    if matches:
        ans = matches[-1]  # take the last occurrence
        # Remove $ and commas
        ans = ans.replace("$", "").replace(",", "").strip()
        # Convert floats that are really ints (e.g. 42.0 -> 42)
        if re.fullmatch(r"-?\d+\.0+", ans):
            ans = ans.split(".")[0]
        return ans
    return "INVALID"

def extract_mmlu_answer(text):
    text = text.strip().upper()
    match = re.findall(r"ANSWER[:\s]*([A-D])\b", text)
    if match:
        return match[-1]
    return "INVALID"


def evaluate(cfg, predictions, dataset):
    """ Assuming predictions is parsed already """
    if cfg["path"] == "ai2_arc":
        return evaluate_arc(predictions, dataset)
    elif cfg["path"] == "openai/gsm8k":
        return evaluate_gsm8k(predictions, dataset)
    elif cfg["path"] == "cais/mmlu":
        return evaluate_mmlu(predictions, dataset)
    raise Exception(f"{cfg['path']} not yet supported")

def evaluate_arc(predictions, dataset):
    ''' evaluate parsed prediction to score it'''
    correct = 0
    total = len(predictions)
    for pred, ex in zip(predictions, dataset):
        gt_ans = ex["answerKey"]
        correct += int(pred == gt_ans)
    acc = correct / total
    return {"accuracy": acc}

def evaluate_gsm8k(predictions, dataset):
    correct = 0
    for pred, ex in zip(predictions, dataset):
        gold = ex["answer"]
        pred_ans = pred
        # Common format for final answer in GSM8K is: "#### <number>"
        match = re.search(r"####\s*(-?\d+)", gold)
        gold_ans = match.group(1)
        
        try:
            gold_ans = float(gold_ans)
        except:
            raise Exception(f"this should not happen as gold is always a number {gold_ans}")
        try:
            pred_ans = float(pred_ans)
        except:
            continue

        correct += int(pred_ans == gold_ans)
    return {"exact_match": correct / len(dataset)}

def evaluate_mmlu(predictions, dataset):
    correct = 0
    CHOICE_LETTERS = ["A", "B", "C", "D"]
    for pred, ex in zip(predictions, dataset):
        gold = CHOICE_LETTERS[ex["answer"]]
        pred_ans = pred
        correct += int(pred_ans == gold)
    return {"accuracy": correct / len(dataset)}



def load_dataset_from_config(cfg, sample_size):
    if 'name' in cfg:
        ds = load_dataset(cfg["path"], cfg.get("name"), split=cfg["split"])
    else:
        ds = load_dataset(cfg["path"], split=cfg["split"])
    if cfg.get("shuffle"):
        ds = ds.shuffle(seed=SEED)
    if sample_size == -1:
        return ds.select(range(len(ds)))
    else:
        return ds.select(range(sample_size))

# ---------------------- Testing Related Utilities ---------------------- #

def compute_perplexity(model: AutoModelForCausalLM, 
                       tokenizer: AutoTokenizer, 
                       texts: list[str], 
                       batch_size: int, 
                       sequence_length: int) -> dict:
    """Computes the perplexity of a language model over a list of input texts."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    for batch_start in tqdm(range(0, len(texts), batch_size), desc="Evaluating"):
        batch = texts[batch_start: batch_start + batch_size]
        tokenizer.padding_side = "left"
        inputs = tokenizer(
             batch,
             return_tensors="pt",
             padding=True,
             truncation=True,
             max_length=sequence_length,
         )
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        # Mask labels where padding is present
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100

        inputs["labels"] = labels
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)
            loss = outputs.loss
            num_tokens = (labels != -100).sum().item()

        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens

    avg_loss = total_loss / total_tokens

    return {
        'perplexity': math.exp(avg_loss),
    }

# ---------------------- Runner ---------------------- #
def reset_gate_counters(gates):
    for gate in gates:
        gate.token_assignments = []

def run_experiment(
    num_experts, num_choices, sample_size, experiment_name,
    selection_method, threshold_factor, mode,
    min_dynamic_k, max_dynamic_k, lora_path,
):  
    if "qwen" in MODEL_NAME.lower():
        return run_experiment_qwen(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                selection_method=selection_method, threshold_factor=threshold_factor, mode=mode,
                min_dynamic_k=min_dynamic_k, max_dynamic_k=max_dynamic_k, lora_path=lora_path,
            )
    elif "deepseek" in MODEL_NAME.lower():
        return run_experiment_ds(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                selection_method=selection_method, threshold_factor=threshold_factor, mode=mode,
                min_dynamic_k=min_dynamic_k, max_dynamic_k=max_dynamic_k, lora_path=lora_path,
            )
    elif "mixtral" in MODEL_NAME.lower():
        return run_experiment_mixtral(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                selection_method=selection_method, threshold_factor=threshold_factor, mode=mode,
                min_dynamic_k=min_dynamic_k, max_dynamic_k=max_dynamic_k, lora_path=lora_path,
            )
    raise Exception(f"{MODEL_NAME} not have runner defined")

# deepseek
def set_ds_router(model, num_experts, num_choices, min_dynamic_k, max_dynamic_k, selection_method, threshold_factor, beta=BETA, sample_before_load=SAMPLE_BEFORE_LOAD, vectorized=VECTORIZED, collect_stats=False, track_assignments=True, training=False, sum_threshold=SUM_THRESHOLD):
    import copy
    # Set TwoChoicesGates as the gating module
    moe_gates = []
    moe_layer_indices = []  # Store the actual layer indices of MoE layers
    moe_blocks = []
    DeepseekDecoderLayer = type(model.model.layers[0])
    DeepseekMoE = type(model.model.layers[1].mlp)
    MoEGate = type(model.model.layers[1].mlp.gate)
    if len(sum_threshold) != 1:
        ranges = MODEL_CFG_DS["layer_ranges"]
        # expand sum_threshold such that each layer has 1 sum_threshold
        layer_num = len(model.model.layers)
        layer_sum_threshold = [-1] * layer_num
        assert len(ranges) == len(sum_threshold), f"Number of sum_threshold must match number of ranges, {len(ranges)}, {len(sum_threshold)}"
        for i, (lower, upper) in enumerate(ranges):
            for idx in range(lower, upper+1):
                layer_sum_threshold[idx] = sum_threshold[i]
        sum_threshold = layer_sum_threshold
        print(f"New layer sum thresholds are {sum_threshold}")

    if len(threshold_factor) != 1:
        ranges = MODEL_CFG_DS["layer_ranges_threshold"]
        layer_num = len(model.model.layers)
        layer_threshold_t = [-1] * layer_num
        assert len(ranges) == len(threshold_factor), f"Number of threshold_factor must match number of ranges, {len(ranges)}, {len(threshold_factor)}"
        for i, (lower, upper) in enumerate(ranges):
            for idx in range(lower, upper+1):
                layer_threshold_t[idx] = threshold_factor[i]
        threshold_factor = layer_threshold_t
        print(f"New layer sum thresholds are {threshold_factor}")

    for layer_idx, layer in enumerate(model.model.layers):
        if isinstance(layer, DeepseekDecoderLayer):
            if hasattr(layer, "mlp") and isinstance(layer.mlp, DeepseekMoE):
                if hasattr(layer.mlp, "gate") and isinstance(layer.mlp.gate, MoEGate):
                    old_gate = layer.mlp.gate
                    # Materialize old_gate if it's on meta device
                    if any(p.device.type == "meta" for p in old_gate.parameters()):
                        print("Materializing old_gate before replacing it")
                        old_gate._apply(lambda t: torch.empty_like(t, device=model.device) if t.device.type == "meta" else t)
                    device = old_gate.weight.device
                    config = copy.deepcopy(old_gate.config)
                    config.layer_idx = layer_idx + 1  # Store 1-based layer index
                    moe_layer_indices.append(layer_idx + 1)  # Store actual layer index
                    
                    # Set gating strategy and dynamic k parameters
                    config.gating_strategy = selection_method
                    
                    # first layer uses baseline as test
                    config.min_dynamic_k = min_dynamic_k
                    config.max_dynamic_k = max_dynamic_k
                    if selection_method == "threshold":
                        if len(threshold_factor) == 1:
                            config.threshold_t = threshold_factor[0]
                        else:
                            config.threshold_t = threshold_factor[layer_idx]
  
                    config.sample_before_load = sample_before_load
                    config.default_beta = beta
                    if len(sum_threshold) == 1:
                        config.sum_threshold = sum_threshold[0]
                    else:
                        config.sum_threshold = sum_threshold[layer_idx]

                    layer.mlp.moe_infer = types.MethodType(forward_ds, layer.mlp)
                    moe_blocks.append(layer.mlp)
                    ###################
                    
                    new_gate = DynamicKGate(config, num_experts, num_choices, vectorized=vectorized, collect_stats=collect_stats, track_assignments=track_assignments, training=training)
                    # new_gate.load_state_dict(old_gate.state_dict(), strict=True)
                    with torch.no_grad():
                        new_gate.weight.copy_(old_gate.weight)
                    new_gate.to(device)
                    layer.mlp.gate = new_gate
                    moe_gates.append(new_gate)
    return model, moe_gates, moe_layer_indices, moe_blocks

def run_experiment_ds(
    num_experts, num_choices, sample_size, experiment_name,
    selection_method, threshold_factor, mode,
    min_dynamic_k, max_dynamic_k, lora_path,
):
    """Run a single experiment with fixed parameters."""

    assert num_choices >= num_experts
    assert min_dynamic_k >= num_experts
    assert max_dynamic_k >= min_dynamic_k
    set_random_seed(SEED)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.padding_side = "left"

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16, 
        device_map="auto",
        trust_remote_code=True,
    )


    model.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
    model.generation_config.pad_token_id = model.generation_config.eos_token_id
    model.config.num_experts_per_tok = num_experts


    model, moe_gates, moe_layer_indices, moe_blocks = set_ds_router(model, 
                                                                    num_experts,
                                                                    num_choices,
                                                                    min_dynamic_k, 
                                                                    max_dynamic_k, 
                                                                    selection_method=selection_method, 
                                                                    threshold_factor=threshold_factor, 
                                                                    beta=BETA, 
                                                                    sample_before_load=SAMPLE_BEFORE_LOAD,
                                                                    vectorized=VECTORIZED,
                                                                    sum_threshold=SUM_THRESHOLD)
    if lora_path:
        # must happen after setting routers
        print(f"Adding LoRA weights from {lora_path}")
        model = PeftModel.from_pretrained(model, lora_path)
    config = {
        "min_dynamic_k": min_dynamic_k,
        "max_dynamic_k": max_dynamic_k,
    }

    print(f"Using strategy: {selection_method}, dynamic_k: [{min_dynamic_k}, {max_dynamic_k}], layers: {moe_layer_indices}")
    if selection_method == "threshold":
        print(f"Threshold factor: {threshold_factor}")

    try:
        if mode == "perplexity":
            return _run_experiment_with_perplexity_score(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                moe_gates=moe_gates, moe_layer_indices=moe_layer_indices, tokenizer=tokenizer, model=model, config=config, moe_blocks=moe_blocks,
                selection_method=selection_method, threshold_factor=threshold_factor,
            )
        elif mode == "qa":
            return _run_experiment_with_qa_score(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                moe_gates=moe_gates, moe_layer_indices=moe_layer_indices, tokenizer=tokenizer, model=model, config=config, moe_blocks=moe_blocks,
                selection_method=selection_method, threshold_factor=threshold_factor,
            )
        else:
            raise ValueError(f"Unsupported mode: {mode}")
    finally:
        # force reinitialization
        del model
        torch.cuda.empty_cache()
        import gc; gc.collect()
        from accelerate.utils import release_memory
        release_memory()

#################################
# Qwen version
class DummyConfig:
    def __init__(self, d):
        self.__dict__.update(d)


def set_qwen_router(model, num_experts, num_choices, min_dynamic_k, max_dynamic_k, selection_method, threshold_factor, beta=BETA, sample_before_load=SAMPLE_BEFORE_LOAD, vectorized=VECTORIZED, collect_stats=False, track_assignments=True, training=False, sum_threshold=SUM_THRESHOLD):
    import copy

    # Set TwoChoicesGates as the gating module
    moe_gates = []
    moe_layer_indices = []  # Store the actual layer indices of MoE layers
    moe_blocks = []
    QwenDecoderLayer = type(model.model.layers[0])
    QwenMoE = type(model.model.layers[1].mlp)
    MoEGate = type(model.model.layers[1].mlp.gate)

    if len(sum_threshold) != 1:
        raise Exception("sum_threshold layer ranges for qwen not configured yet")
    
    if len(threshold_factor) != 1:
        raise Exception("threshold_factor layer ranges for qwen not configured yet")
    
    for layer_idx, layer in enumerate(model.model.layers):
        if isinstance(layer, QwenDecoderLayer):
            if hasattr(layer, "mlp") and isinstance(layer.mlp, QwenMoE):
                if hasattr(layer.mlp, "gate") and isinstance(layer.mlp.gate, MoEGate):
                    moe_layer_indices.append(layer_idx + 1)  # Store actual layer index
                    moe_block = layer.mlp

                    config = DummyConfig({})
                    # Set gating strategy and dynamic k parameters
                    config.gating_strategy = selection_method
                    config.min_dynamic_k = min_dynamic_k
                    config.max_dynamic_k = max_dynamic_k
                    if selection_method == "threshold":
                        if len(threshold_factor) == 1:
                            config.threshold_t = threshold_factor[0]
                        else:
                            config.threshold_t = threshold_factor[layer_idx]
  
                    config.sample_before_load = sample_before_load
                    config.default_beta = beta
                    config.vectorized = vectorized
                    if len(sum_threshold) == 1:
                        config.sum_threshold = sum_threshold[0]
                    else:
                        config.sum_threshold = sum_threshold[layer_idx]
                    
                    moe_block.config = config
                    moe_block.n_routed_experts = 64 # default 64 experts total

                    try:
                        dev = next(moe_block.parameters()).device
                    except StopIteration:
                        dev = torch.device("cpu")
                    # Pre-compute static tensors for efficiency
                    moe_block.register_buffer('gini_indices', torch.arange(1, moe_block.n_routed_experts + 1, dtype=torch.float32, device=dev))
                    # Pre-compute mask indices for top-k operations, default is 64
                    moe_block.register_buffer('mask_indices', torch.arange(moe_block.n_routed_experts, dtype=torch.long, device=dev))

                    # counters
                    moe_block.layer_idx = layer_idx + 1
                    moe_block.token_assignments = []
                    moe_block.routing_stats = []
                    moe_block.collect_stats = collect_stats
                    moe_block.track_assignments = track_assignments 
                    moe_block.training = training
                    moe_block.vectorized=vectorized

                    moe_block.num_choices = num_choices
                    # moe_block.num_experts = num_experts # cannot set this since conflict with original
                    moe_block.config.num_experts = num_experts

                    moe_block.log_dir = os.path.join(os.getcwd(), "logs")
                    os.makedirs(moe_block.log_dir, exist_ok=True)
                    moe_block.gating_dim = moe_block.gate.in_features

                    moe_block.forward = types.MethodType(forward_qwen, moe_block)

                    moe_blocks.append(moe_block)
                    moe_gates.append(moe_block)

    print(f"Using strategy: {selection_method}, dynamic_k: [{min_dynamic_k}, {max_dynamic_k}], layers: {moe_layer_indices}")
    if selection_method == "threshold":
        print(f"Threshold factor: {threshold_factor}")

    return model, moe_gates, moe_layer_indices, moe_blocks


def run_experiment_qwen(
    num_experts, num_choices, sample_size, experiment_name,
    selection_method, threshold_factor, mode,
    min_dynamic_k, max_dynamic_k, lora_path,
):
    """Run a single experiment with fixed parameters."""
    assert num_choices >= num_experts
    assert min_dynamic_k >= num_experts
    assert max_dynamic_k >= min_dynamic_k
    set_random_seed(SEED)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16, 
        device_map="auto",
        trust_remote_code=True,
    )

    model.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
    model.generation_config.pad_token_id = model.generation_config.eos_token_id
    model.config.num_experts_per_tok = num_experts

    model, moe_gates, moe_layer_indices, moe_blocks = set_qwen_router(model, 
                                                                    num_experts,
                                                                    num_choices,
                                                                    min_dynamic_k, 
                                                                    max_dynamic_k, 
                                                                    selection_method=selection_method, 
                                                                    threshold_factor=threshold_factor, 
                                                                    beta=BETA, 
                                                                    sample_before_load=SAMPLE_BEFORE_LOAD,
                                                                    vectorized=VECTORIZED,
                                                                    track_assignments=True,
                                                                    collect_stats=False,
                                                                    training=False,
                                                                    sum_threshold=SUM_THRESHOLD)
    
    if lora_path:
        # must happen after setting routers
        print(f"Adding LoRA weights from {lora_path}")
        model = PeftModel.from_pretrained(model, lora_path)
    config = {
        "min_dynamic_k": min_dynamic_k,
        "max_dynamic_k": max_dynamic_k,
    }
    print(f"Using strategy: {selection_method}, dynamic_k: [{min_dynamic_k}, {max_dynamic_k}], layers: {moe_layer_indices}")
    if selection_method == "threshold":
        print(f"Threshold factor: {threshold_factor}")
    try:
        if mode == "perplexity":
            return _run_experiment_with_perplexity_score(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                moe_gates=moe_gates, moe_layer_indices=moe_layer_indices, tokenizer=tokenizer, model=model, config=config, moe_blocks=moe_blocks, 
                selection_method=selection_method, threshold_factor=threshold_factor,
            )
        elif mode == "qa":
            return _run_experiment_with_qa_score(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                moe_gates=moe_gates, moe_layer_indices=moe_layer_indices, tokenizer=tokenizer, model=model, config=config, moe_blocks=moe_blocks, 
                selection_method=selection_method, threshold_factor=threshold_factor,
            )
        else:
            raise ValueError(f"Unsupported mode: {mode}")
    finally:
        del model
        torch.cuda.empty_cache()
        import gc; gc.collect()
        from accelerate.utils import release_memory
        release_memory()

################
#  Mixtral version


def set_mixtral_router(model, num_experts, num_choices, min_dynamic_k, max_dynamic_k, selection_method, threshold_factor, beta=BETA, sample_before_load=SAMPLE_BEFORE_LOAD, vectorized=VECTORIZED, collect_stats=False, track_assignments=True, training=False, sum_threshold=SUM_THRESHOLD):
    import copy

    # Set TwoChoicesGates as the gating module
    moe_gates = []
    moe_layer_indices = []  # Store the actual layer indices of MoE layers
    moe_blocks = []
    MixtralDecoderLayer = type(model.model.layers[0])
    MixtralMoE = type(model.model.layers[0].block_sparse_moe)
    MoEGate = type(model.model.layers[0].block_sparse_moe.gate)

    if len(sum_threshold) != 1:
        ranges = MODEL_CFG_MIXTRAL_7b["layer_ranges"]
        # expand sum_threshold such that each layer has 1 sum_threshold
        layer_num = len(model.model.layers)
        layer_sum_threshold = [-1] * layer_num
        assert len(ranges) == len(sum_threshold), f"Number of sum_threshold must match number of ranges, {len(ranges)}, {len(sum_threshold)}"
        for i, (lower, upper) in enumerate(ranges):
            for idx in range(lower, upper+1):
                layer_sum_threshold[idx] = sum_threshold[i]
        sum_threshold = layer_sum_threshold
        print(f"New layer sum thresholds are {sum_threshold}")

    if len(threshold_factor) != 1:
        ranges = MODEL_CFG_MIXTRAL_7b["layer_ranges_threshold"]
        layer_num = len(model.model.layers)
        layer_threshold_t = [-1] * layer_num
        assert len(ranges) == len(threshold_factor), f"Number of threshold_factor must match number of ranges, {len(ranges)}, {len(threshold_factor)}"
        for i, (lower, upper) in enumerate(ranges):
            for idx in range(lower, upper+1):
                layer_threshold_t[idx] = threshold_factor[i]
        threshold_factor = layer_threshold_t
        print(f"New layer sum thresholds are {threshold_factor}")

    for layer_idx, layer in enumerate(model.model.layers):
        if isinstance(layer, MixtralDecoderLayer):
            if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, MixtralMoE):
                if hasattr(layer.block_sparse_moe, "gate") and isinstance(layer.block_sparse_moe.gate, MoEGate):
                    old_gate = layer.block_sparse_moe.gate
                    # Materialize old_gate if it's on meta device
                    if any(p.device.type == "meta" for p in old_gate.parameters()):
                        print("Materializing old_gate before replacing it")
                        old_gate._apply(lambda t: torch.empty_like(t, device=model.device) if t.device.type == "meta" else t)
                        
                    moe_layer_indices.append(layer_idx + 1)  # Store actual layer index
                    moe_block = layer.block_sparse_moe

                    config = DummyConfig({})
                    # Set gating strategy and dynamic k parameters
                    config.gating_strategy = selection_method
                    config.min_dynamic_k = min_dynamic_k
                    config.max_dynamic_k = max_dynamic_k
                    if selection_method == "threshold":
                        if len(threshold_factor) == 1:
                            config.threshold_t = threshold_factor[0]
                        else:
                            config.threshold_t = threshold_factor[layer_idx]

                    config.sample_before_load = sample_before_load
                    config.default_beta = beta
                    config.vectorized = vectorized

                    if len(sum_threshold) == 1:
                        config.sum_threshold = sum_threshold[0]
                    else:
                        config.sum_threshold = sum_threshold[layer_idx]

                    moe_block.config = config
                    moe_block.n_routed_experts = 8 # default 8 experts total

                    try:
                        dev = next(moe_block.parameters()).device
                    except StopIteration:
                        dev = torch.device("cpu")
                    # Pre-compute static tensors for efficiency
                    moe_block.register_buffer('gini_indices', torch.arange(1, moe_block.n_routed_experts + 1, dtype=torch.float32, device=dev))
                    # Pre-compute mask indices for top-k operations, default is 8
                    moe_block.register_buffer('mask_indices', torch.arange(moe_block.n_routed_experts, dtype=torch.long, device=dev))

                    # counters
                    moe_block.layer_idx = layer_idx + 1
                    moe_block.token_assignments = []
                    moe_block.routing_stats = []
                    moe_block.collect_stats = collect_stats
                    moe_block.track_assignments = track_assignments 
                    moe_block.training = training
                    moe_block.vectorized=vectorized

                    moe_block.num_choices = num_choices
                    # moe_block.num_experts = num_experts # cannot set this since conflict with original
                    moe_block.config.num_experts = num_experts

                    moe_block.log_dir = os.path.join(os.getcwd(), "logs")
                    os.makedirs(moe_block.log_dir, exist_ok=True)
                    moe_block.gating_dim = moe_block.gate.in_features

                    moe_block.forward = types.MethodType(forward_mixtral, moe_block)

                    moe_blocks.append(moe_block)
                    moe_gates.append(moe_block)

    print(f"Using strategy: {selection_method}, dynamic_k: [{min_dynamic_k}, {max_dynamic_k}], layers: {moe_layer_indices}")
    if selection_method == "threshold":
        print(f"Threshold factor: {threshold_factor}")

    return model, moe_gates, moe_layer_indices, moe_blocks


def run_experiment_mixtral(
    num_experts, num_choices, sample_size, experiment_name,
    selection_method, threshold_factor, mode,
    min_dynamic_k, max_dynamic_k, lora_path,
):
    """Run a single experiment with fixed parameters."""
    assert num_choices >= num_experts
    assert min_dynamic_k >= num_experts
    assert max_dynamic_k >= min_dynamic_k
    set_random_seed(SEED)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token          
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16, 
        device_map="auto",
        trust_remote_code=True,
    )

    model.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
    model.generation_config.pad_token_id = model.generation_config.eos_token_id
    model.config.num_experts_per_tok = num_experts

    model, moe_gates, moe_layer_indices, moe_blocks = set_mixtral_router(model, 
                                                                    num_experts,
                                                                    num_choices,
                                                                    min_dynamic_k, 
                                                                    max_dynamic_k, 
                                                                    selection_method=selection_method, 
                                                                    threshold_factor=threshold_factor, 
                                                                    beta=BETA, 
                                                                    sample_before_load=SAMPLE_BEFORE_LOAD,
                                                                    vectorized=VECTORIZED,
                                                                    track_assignments=True,
                                                                    collect_stats=False,
                                                                    training=False,
                                                                    sum_threshold=SUM_THRESHOLD)
    
    if lora_path:
        # must happen after setting routers
        print(f"Adding LoRA weights from {lora_path}")
        model = PeftModel.from_pretrained(model, lora_path)
    config = {
        "min_dynamic_k": min_dynamic_k,
        "max_dynamic_k": max_dynamic_k,
    }
    print(f"Using strategy: {selection_method}, dynamic_k: [{min_dynamic_k}, {max_dynamic_k}], layers: {moe_layer_indices}")
    if selection_method == "threshold":
        print(f"Threshold factor: {threshold_factor}")
    try:
        if mode == "perplexity":
            return _run_experiment_with_perplexity_score(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                moe_gates=moe_gates, moe_layer_indices=moe_layer_indices, tokenizer=tokenizer, model=model, config=config, moe_blocks=moe_blocks, 
                selection_method=selection_method, threshold_factor=threshold_factor,
            )
        elif mode == "qa":
            return _run_experiment_with_qa_score(
                num_experts=num_experts, num_choices=num_choices, sample_size=sample_size, experiment_name=experiment_name,
                moe_gates=moe_gates, moe_layer_indices=moe_layer_indices, tokenizer=tokenizer, model=model, config=config, moe_blocks=moe_blocks, 
                selection_method=selection_method, threshold_factor=threshold_factor,
            )
        else:
            raise ValueError(f"Unsupported mode: {mode}")
    finally:
        del model
        torch.cuda.empty_cache()
        import gc; gc.collect()
        from accelerate.utils import release_memory
        release_memory()


def _run_experiment_with_qa_score(
    num_experts, num_choices, sample_size, experiment_name,
    moe_gates, moe_layer_indices, tokenizer, model, config, moe_blocks,
    selection_method, threshold_factor,
):
    dataset = load_dataset_from_config(DATASET_CFG, sample_size)

    expert_counts = {l: {e: 0 for e in range(DEFAULT_MAX_K)} for l in moe_layer_indices} # layer to expert to count

    expert_times = {l: {e: 0 for e in range(DEFAULT_MAX_K)} for l in moe_layer_indices} # layer to expert to time

    predictions = []
    prompts = []

    total_time, total_tokens = 0, 0
    per_batch_store = init_per_batch_imbalance(moe_layer_indices)

    for i in tqdm(range(0, len(dataset), BATCH_SIZE)):
        reset_gate_counters(moe_gates)
        batch = dataset.select(range(i, min(i + BATCH_SIZE, len(dataset))))

        batch_prompts = [format_prompt(DATASET_CFG, x) for x in batch]
        tokenizer.padding_side = "left"

        if MODEL_TYPE == "chat":
            formatted_batch_prompts = [tokenizer.apply_chat_template([{"role": "user", "content": q}], add_generation_prompt=True, tokenize=False) for q in batch_prompts]
            prompts.extend(formatted_batch_prompts)
            inputs = tokenizer(
                formatted_batch_prompts,
                return_tensors="pt", padding="longest", truncation=False
            ).to(model.device)
        elif MODEL_TYPE == "base":
            raise Exception("not supported yet")
        
        t0 = time.perf_counter()
        model.eval()
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
            record_per_batch_imbalance(moe_gates, moe_layer_indices, per_batch_store)

        t1 = time.perf_counter()
        total_time += (t1 - t0)
        
        # Record results
        sequences = outputs         
        input_length = inputs['input_ids'].shape[1]

        for j in range(sequences.size(0)):
            # ALL sequences have the same input length due to padding
            # Generation always starts after position `input_length`
            gen = sequences[j, input_length:]
            text = tokenizer.decode(gen, skip_special_tokens=True).strip()
            predictions.append(text)
            total_tokens += len(gen)


        assignments = [a for gate in moe_gates for a in gate.token_assignments]
        for a in assignments:
            for e in a['assigned_experts']:
                expert_counts[a['layer_idx']][e] += 1

    per_layer_stats, agg_stats = summarize_and_save_per_batch_imbalance(
                                    per_batch_store,
                                    out_dir=OUTPUT_DIR,
                                    experiment_name=experiment_name,
                                    layer_weights=None)

    latency = {
        "total_forward_time_sec": total_time,
        "total_generated_tokens": total_tokens,
        "avg_latency_per_token_sec": total_time / total_tokens,
    }
    # qa_scores = evaluate_qa(predictions, ground_truths)
    parsed_preds = [format_response(DATASET_CFG, p) for p in predictions]
    print(f"Parsed Predictions: {parsed_preds}")
    print(f"Predictions: {predictions}")
    latency["predictions"] = predictions
    latency["parsed_predictions"] = parsed_preds
    latency["prompts"] = prompts
    
    qa_scores = evaluate(DATASET_CFG, parsed_preds, dataset)
    all_layer_metrics = _save_layer_metrics(expert_counts, moe_layer_indices, sample_size, num_experts, num_choices, config, selection_method, threshold_factor, experiment_name)
    print(f"QA score {qa_scores}")
    return expert_counts, all_layer_metrics, latency, qa_scores, per_layer_stats, agg_stats

def _run_experiment_with_perplexity_score(
    num_experts, num_choices, sample_size, experiment_name,
    moe_gates, moe_layer_indices, tokenizer, model, config, moe_blocks,
    selection_method, threshold_factor,
):
    assert MODEL_TYPE == "base", f"{MODEL_TYPE} should be base for perplexity experiments"
    
    dataset = load_dataset_from_config(DATASET_CFG, sample_size)
    texts = [x["text"] for x in dataset]
    expert_counts = {l: {e: 0 for e in range(DEFAULT_MAX_K)} for l in moe_layer_indices} # per layer, per expert
    token_stats = {l: [] for l in moe_layer_indices}

    total_time, total_tokens = 0, 0
    for i in tqdm(range(0, len(dataset), BATCH_SIZE), desc="Generating"):
        reset_gate_counters(moe_gates)
        batch_texts = texts[i: i + BATCH_SIZE]
        tokenizer.padding_side = "left"
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_PROMPT_LENGTH).to(model.device)
        
        t0 = time.perf_counter()
        outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
        t1 = time.perf_counter()
        total_time += (t1 - t0)
        total_tokens += (outputs.shape[1] - inputs.input_ids.shape[1]) * outputs.shape[0]

        assignments = [a for gate in moe_gates for a in gate.token_assignments]

        for a in assignments:
            for e in a["assigned_experts"]:
                expert_counts[a["layer_idx"]][e] += 1

        for gate in moe_gates:
            for stat in gate.routing_stats:
                token_stats[stat['layer_idx']].append(stat)

    latency = {
        "total_forward_time_sec": total_time,
        "total_generated_tokens": total_tokens,
        "avg_latency_per_token_sec": total_time / total_tokens,
    }
    perplexity = compute_perplexity(model, tokenizer, texts, batch_size=BATCH_SIZE, sequence_length=MAX_PERPLEXITY_LENGTH)
    all_layer_metrics = _save_layer_metrics(expert_counts, moe_layer_indices, sample_size, num_experts, num_choices, config, selection_method, threshold_factor, experiment_name, token_stats)
    print(f"Perplexity score {perplexity}")
    return expert_counts, all_layer_metrics, latency, perplexity


def _save_layer_metrics(expert_counts, layer_indices, sample_size, num_experts, num_choices, config, selection_method, threshold_factor, experiment_name, token_stats=None):
    metrics = []
    for layer in layer_indices:
        counts = np.array([v for _, v in sorted(expert_counts[layer].items())]).T
        utilization = (counts / counts.sum()) * 100
        expected = counts.sum() / DEFAULT_MAX_K
        max_violation = (counts.max() - expected) / expected

        data = {
            # "num_experts": num_experts,
            # "num_choices": num_choices,
            # "sample_size": sample_size,
            # "expert_counts": counts.tolist(),
            "expert_utilization": utilization.tolist(),
            "max_violation": max_violation,
            # "selection_method": selection_method,
            # "threshold_factor": threshold_factor if selection_method == "threshold" else None,
            # "min_dynamic_k": config.min_dynamic_k,
            # "max_dynamic_k": config.max_dynamic_k,
        }
        # # DEBUG ONLY
        # if token_stats:
        #     data["layer_token_counts"] = token_stats[layer]

        out_file = f"{OUTPUT_DIR}/{experiment_name}/{experiment_name}_layer_{layer}_metrics.json"
        os.makedirs(os.path.dirname(out_file), exist_ok=True)
        with open(out_file, "w") as f:
            json.dump(data, f, indent=2)
        print(f"Saved metrics for layer {layer} to {out_file}")
        metrics.append(data)
    return metrics


# ---------------------- Gating Module ---------------------- #

@torch.no_grad()
def forward_ds(self, x, flat_expert_indices, flat_expert_weights):

    expert_cache = torch.zeros_like(x)
    idxs = flat_expert_indices.argsort()
    tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
    token_idxs = idxs // self.num_experts_per_tok
    for i, end_idx in enumerate(tokens_per_expert):
        start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
        expert = self.experts[i]
        exp_token_idx = token_idxs[start_idx:end_idx]
        expert_tokens = x[exp_token_idx]
        expert_out = expert(expert_tokens) 
        expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
        expert_cache.scatter_reduce_(
            0,
            exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
            expert_out,
            reduce='sum'
        )

    return expert_cache
    
class DynamicKGate(nn.Module):
    """
    Overrides Deepseek's MoEGate class to choose experts using the Power of Two
    Choices, with dynamic k selection based on a threshold.
    """

    def __init__(self, config, num_experts, num_choices, log_dir=None, vectorized=False, collect_stats=False, track_assignments=True, training=False):
        super().__init__()
        self.config = config
        self.top_k = config.num_experts_per_tok
        self.n_routed_experts = config.n_routed_experts
        self.num_experts = num_experts
        self.num_choices = num_choices
        self.log_dir = log_dir or os.path.join(os.getcwd(), "logs")
        os.makedirs(self.log_dir, exist_ok=True)

        self.scoring_func = config.scoring_func
        # self.alpha = config.aux_loss_alpha
        self.alpha = 0.01 # from original paper page 9
        self.seq_aux = config.seq_aux

        # topk selection algorithm
        self.norm_topk_prob = config.norm_topk_prob
        self.gating_dim = config.hidden_size
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
        self.reset_parameters()

        # Pre-compute static tensors for efficiency
        self.register_buffer('gini_indices', torch.arange(1, self.n_routed_experts + 1, dtype=torch.float32))
        # Pre-compute mask indices for top-k operations, default is 64
        self.register_buffer('mask_indices', torch.arange(64, dtype=torch.long))

        # counters
        self.layer_idx = config.layer_idx
        self.token_assignments = []
        self.routing_stats = []

        self.collect_stats = collect_stats
        self.track_assignments = track_assignments 
        self.training = training

        self.vectorized=vectorized

    def reset_parameters(self) -> None:
        import torch.nn.init as init
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, hidden_states):
        bsz, seq_len, h = hidden_states.shape
        hidden_states = hidden_states.view(-1, h)  # [B*T, H]
        logits = F.linear(hidden_states, self.weight.to(hidden_states.dtype), None)  # [B*T, E]
        if self.scoring_func == 'softmax':
            scores = logits.softmax(dim=-1)
        else:
            raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')

        B, T, E = bsz, seq_len, self.n_routed_experts
        assert E == scores.size(-1)
        BT = B * T

        strategy = getattr(self.config, "gating_strategy", "gini")
        beta = getattr(self.config, "default_beta", BETA)
        min_k = getattr(self.config, "min_dynamic_k", DEFAULT_MIN_K)
        max_k = getattr(self.config, "max_dynamic_k", DEFAULT_MAX_K)

        id = str(uuid.uuid4())
        # === LOAD-ONLY ROUTING (ignore scores) ===
        if strategy == "load_only":
            E = self.n_routed_experts
            expert_loads = torch.zeros(E, dtype=torch.long, device=scores.device)
            final_topk_idx = torch.empty((BT, self.top_k), dtype=torch.long, device=scores.device)
            # use ones; will be normalized if self.norm_topk_prob is True
            final_topk_weight = torch.ones((BT, self.top_k), dtype=hidden_states.dtype, device=scores.device)

            for i in range(BT):
                # pick the k least-loaded experts *right now*
                least = torch.topk(-expert_loads, k=self.top_k).indices
                final_topk_idx[i] = least
                expert_loads[least] += 1

                if self.track_assignments:

                    batch_idx = i // T
                    token_idx = i % T
                    self.token_assignments.append({
                        "batch_idx": batch_idx,
                        "layer_idx": self.layer_idx,
                        "token_idx": token_idx,
                        "id": id,
                        "assigned_experts": least.tolist(),
                    })

            # normalize weights later if requested; no aux loss here
            aux_loss = None
            return final_topk_idx, final_topk_weight, aux_loss

        if strategy == "baseline":
            dynamic_k = torch.full((BT,), self.num_experts, device=scores.device)
            topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        else:
            # Shared top-k logic
            topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
            if strategy == "threshold":
                t = getattr(self.config, "threshold_t", 0.9)
                # topk_scores, _ = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
                strict_topk_min, _ = topk_scores.min(dim=-1, keepdim=True) 
                relative_thresh = t * strict_topk_min
                dynamic_k = torch.sum(scores > relative_thresh, dim=-1).clamp(min=min_k, max=max_k).long()
            else:
                # Gini coefficient routing (optimized)
                sorted_scores, _ = torch.sort(scores, dim=-1)
                # Use pre-computed indices and combine operations for efficiency
                index = self.gini_indices.unsqueeze(0)
                total = sorted_scores.sum(dim=1, keepdim=True) + 1e-8  # keepdim for broadcasting
                weighted = torch.sum(index * sorted_scores, dim=1)
                gini = (2 * weighted) / (E * total.squeeze(1)) - (E + 1) / E
                gini = torch.clamp(gini, min=0.0, max=1.0)  # Ensure valid range
                gini = torch.where(torch.isfinite(gini), gini, torch.zeros_like(gini))  # Handle NaN/inf

                k_float = min_k + (max_k - min_k) * (1 - gini) ** beta
                k_float = torch.where(torch.isfinite(k_float), k_float, min_k)
                dynamic_k = torch.clamp(k_float.floor(), min=min_k, max=max_k).long()
                
            # Shared top-k logic
            sum_threshold = getattr(self.config, "sum_threshold", -1)
            if sum_threshold >= 0:
                # sum across the k strict top-k scores (pre-mask)
                topk_mass = topk_scores.sum(dim=-1)  # [BT]
                force_min = topk_mass >= float(sum_threshold)
                if force_min.any():
                    dynamic_k = torch.where(
                        force_min,
                        torch.full_like(dynamic_k, min_k),
                        dynamic_k
                    )

            # create top of max_k, ie sort both scores and indexes. But maintain top-k exactly as before
            ################## implementation that keeps strict top k as before
            # Mask out topk
            rest_mask = torch.ones_like(scores, dtype=torch.bool)
            rest_mask.scatter_(1, topk_indices, False)

            # Get the scores and their global indices for the "rest"
            rest_scores = scores.masked_select(rest_mask).view(scores.size(0), -1)
            rest_indices = torch.arange(scores.size(1), device=scores.device).expand_as(scores).masked_select(rest_mask).view(scores.size(0), -1)

            # Sort the rest by score and concat
            rest_scores_sorted, sort_order = torch.sort(rest_scores, dim=-1, descending=True)
            rest_indices_sorted = torch.gather(rest_indices, 1, sort_order)

            all_topk_scores = torch.cat([topk_scores, rest_scores_sorted], dim=-1)
            all_topk_indices = torch.cat([topk_indices, rest_indices_sorted], dim=-1)
            ####################
            mask = self.mask_indices[:max_k].unsqueeze(0) >= dynamic_k.unsqueeze(1)
            all_topk_scores = all_topk_scores.masked_fill(mask, float('-inf'))
            all_topk_indices = all_topk_indices.masked_fill(mask, -1)

            
        expert_loads = torch.zeros(E, dtype=torch.long, device=scores.device)
        final_topk_idx = torch.empty_like(topk_indices)
        final_topk_weight = torch.empty_like(topk_scores)

        if strategy != "baseline":
            if not self.vectorized:
                for i in range(BT):
                    candidates = all_topk_indices[i]
                    weights = all_topk_scores[i]
                    valid = candidates != -1

                    candidates = candidates[valid]
                    weights = weights[valid]

                    if candidates.size(0) <= self.num_choices:
                        choices = candidates
                    else:
                        if getattr(self.config, "sample_before_load", SAMPLE_BEFORE_LOAD):
                            rand_idx = torch.randperm(candidates.size(0), device=scores.device)[:self.num_choices]
                            choices = candidates[rand_idx]
                            weights = weights[rand_idx]
                        else:
                            choices = candidates[:self.num_choices]
                            weights = weights[:self.num_choices]

                    loads = expert_loads[choices]
                    k_select = min(self.num_experts, choices.size(0))
                    _, least_loaded_idx = torch.topk(-loads, k=k_select)
                    selected = choices[least_loaded_idx]
                    selected_scores = weights[least_loaded_idx]
                    
                    assert selected.size(0) == self.top_k, f"Selected {selected.size(0)} experts, expected {self.top_k}"

                    final_topk_idx[i, :] = selected
                    final_topk_weight[i, :] = selected_scores
                    expert_loads[selected] += 1
            else:
                # Valid positions after dynamic_k masking
                valid_mask = all_topk_indices.ne(-1)   # [BT, max_k]
                selected_mask = torch.zeros(BT, E, dtype=torch.bool, device=scores.device)

                C = self.num_choices
                device = scores.device

                # Choose candidate *positions* in the top-k table
                if getattr(self.config, "sample_before_load", SAMPLE_BEFORE_LOAD):
                    # Uniform sample among valid via random sort
                    rnd = torch.rand(BT, max_k, device=device)
                    rnd = rnd.masked_fill(~valid_mask, float('inf'))                 # invalids → +inf so they sort last
                    cand_pos = torch.argsort(rnd, dim=-1)[:, :C]                     # [BT, C] positions (0..max_k-1)
                    choices = torch.gather(all_topk_indices, 1, cand_pos)                # [BT, C] expert IDs
                    choice_valid = choices.ne(-1)                                    # [BT, C]
                else:
                    # Since topk_indices is sorted desc, just slice columns
                    choices = all_topk_indices[:, :C].contiguous()                       # [BT, C] expert IDs
                    choice_valid = choices.ne(-1)                                    # [BT, C]

                # Select k experts in k rounds (round-wise balancing)
                for j in range(self.top_k):
                    # Get current loads for each candidate expert (invalids temporarily clamped to 0)
                    gathered = expert_loads[choices.clamp_min(0)]            # [BT, C]
                    gathered = gathered.to(torch.float32)
                    # Mask invalid candidate positions to +inf so they never win argmin
                    gathered = torch.where(choice_valid, gathered, torch.full_like(gathered, float('inf')))

                    # Pick the least-loaded candidate per token
                    best_pos = gathered.argmin(dim=1)                        # [BT]
                    sel = torch.gather(choices, 1, best_pos.unsqueeze(1)).squeeze(1)  # [BT], expert ids (may be -1 if all invalid)

                    # Fallback if a row had no valid candidate this round
                    bad = sel.eq(-1)
                    if bad.any():
                        # For only the 'bad' rows, pick the best unselected expert by score
                        # (respect 'selected_mask' so we don't reuse experts for that token)
                        masked_scores = scores[bad].masked_fill(selected_mask[bad], float("-inf"))  # [num_bad, E]
                        alt = masked_scores.argmax(dim=1)                                           # [num_bad]
                        sel = sel.clone()
                        sel[bad] = alt

                    # Write selection + weights
                    final_topk_idx[:, j]    = sel
                    final_topk_weight[:, j] = scores[torch.arange(BT, device=device), sel]

                    # Update global loads and forbid reusing selected experts for the same token
                    expert_loads.scatter_add_(0, sel, torch.ones_like(sel, dtype=expert_loads.dtype))
                    selected_mask[torch.arange(BT, device=device), sel] = True

                    # Don’t pick the same expert again for this token in later rounds
                    same = choices.eq(sel.unsqueeze(1))
                    choice_valid = choice_valid & (~same)
        else:
            # baseline case
            final_topk_weight, final_topk_idx = topk_scores, topk_indices

            
        # Stats collection (only if needed)
        if self.collect_stats:
            strict_topk_all = torch.topk(scores, k=self.num_experts, dim=-1, largest=True, sorted=True)
            for i in range(min(BT, 100)):  # Limit to avoid memory issues
                selected = final_topk_idx[i]
                strict_expert_idx = strict_topk_all.indices[i].cpu().tolist()
                strict_scores = strict_topk_all.values[i].cpu().tolist()
                final_scores = [scores[i][eid].item() for eid in selected]
                overlap = len(set(strict_expert_idx) & set(selected.cpu().tolist())) / len(selected)

                self.routing_stats.append({
                    "token_index": i,
                    "layer_idx": self.layer_idx,
                    "strict_topk_experts": strict_expert_idx,
                    "strict_topk_scores": strict_scores,
                    "chosen_experts": selected.cpu().tolist(),
                    "chosen_scores": final_scores,
                    "score_gap": sum(strict_scores) - sum(final_scores),
                    "overlap_ratio": overlap,
                })

        if self.track_assignments:
            # Vectorized assignment tracking
            for i in range(BT):
                batch_idx = i // T
                token_idx = i % T
                experts = final_topk_idx[i].cpu().tolist()
                self.token_assignments.append({
                    "batch_idx": batch_idx,
                    "layer_idx": self.layer_idx,
                    "token_idx": token_idx,
                    "assigned_experts": experts,
                    "id": id,
                })
        

        # Normalize gate weights
        if self.top_k > 1 and self.norm_topk_prob:
            denom = final_topk_weight.sum(dim=-1, keepdim=True) + 1e-8
            final_topk_weight = final_topk_weight / denom

        # Auxiliary loss (optional)
        if self.training and self.alpha > 0.0:
            aux_topk_idx = final_topk_idx.view(B, -1)
            scores_for_aux = scores.view(B, T, -1)
            ce = torch.zeros(B, E, device=scores.device)
            ce.scatter_add_(1, aux_topk_idx, torch.ones_like(aux_topk_idx, dtype=torch.float, device=scores.device))
            ce = ce.div(T * self.top_k / E)
            aux_loss = (ce * scores_for_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
        else:
            aux_loss = torch.tensor(0.0, device=hidden_states.device)

        return final_topk_idx, final_topk_weight, aux_loss

##################################
def forward_qwen(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        #####custom gating logic ###########
        # router_logits: (batch * sequence_length, n_experts)
        # router_logits = self.gate(hidden_states)
        # routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
         # routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        selected_experts, routing_weights, router_logits  = forward_qwen_gating(self, hidden_states)
        selected_experts = selected_experts.long()
        ######################################

        if self.norm_topk_prob:
            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
        for expert_idx in expert_hitted:
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        shared_expert_output = self.shared_expert(hidden_states)
        shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

        final_hidden_states = final_hidden_states + shared_expert_output

        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

def forward_qwen_gating(self, hidden_states):
    if hidden_states.dim() == 2:
        # already flattened: (bsz * seq_len, hidden_dim)
        bsz = 1
        seq_len = hidden_states.size(0)
        h = hidden_states.size(1)
    elif hidden_states.dim() == 3:
        bsz, seq_len, h = hidden_states.shape
    else:
        raise ValueError(f"Unexpected hidden_states shape: {hidden_states.shape}")
    
    hidden_states = hidden_states.view(-1, h)  # [B*T, H]
    logits = self.gate(hidden_states)

    scores = F.softmax(logits, dim=1, dtype=torch.float) # [B*T, E]

    B, T, E = bsz, seq_len, self.n_routed_experts
    BT = B * T

    strategy = getattr(self.config, "gating_strategy", "gini")
    beta = getattr(self.config, "default_beta", BETA)
    min_k = getattr(self.config, "min_dynamic_k", DEFAULT_MIN_K)
    max_k = getattr(self.config, "max_dynamic_k", DEFAULT_MAX_K)
    vectorized = getattr(self.config, "vectorized", VECTORIZED) 
    id = str(uuid.uuid4())
    # === LOAD-ONLY ROUTING (ignore scores) ===
    if strategy == "load_only":
        expert_loads = torch.zeros(E, dtype=torch.long, device=scores.device)
        final_topk_idx = torch.empty((BT, self.top_k), dtype=torch.long, device=scores.device)
        final_topk_weight = torch.ones((BT, self.top_k), dtype=scores.dtype, device=scores.device)
        for i in range(BT):
            least = torch.topk(-expert_loads, k=self.top_k).indices
            final_topk_idx[i] = least
            expert_loads[least] += 1
            if self.track_assignments:
                batch_idx = i // T
                token_idx = i % T
                self.token_assignments.append({
                    "batch_idx": batch_idx,
                    "layer_idx": self.layer_idx,
                    "token_idx": token_idx,
                    "assigned_experts": least.tolist(),
                    "id": id,
                })
        # keep returning logits for shape parity with the usual path
        return final_topk_idx, final_topk_weight, logits

    if strategy == "baseline":
        dynamic_k = torch.full((BT,), self.config.num_experts, device=scores.device)
        topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

    else:
        # Shared top-k logic
        topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        if strategy == "threshold":
            t = getattr(self.config, "threshold_t", 0.9)
            strict_topk_min, _ = topk_scores.min(dim=-1, keepdim=True) 
            relative_thresh = t * strict_topk_min
            dynamic_k = torch.sum(scores >= relative_thresh, dim=-1).clamp(min=min_k, max=max_k).long()
        else:
            # Gini coefficient routing (optimized)
            sorted_scores, _ = torch.sort(scores, dim=-1)
            # Use pre-computed indices and combine operations for efficiency
            index = self.gini_indices.unsqueeze(0)
            total = sorted_scores.sum(dim=1, keepdim=True) + 1e-8  # keepdim for broadcasting
            weighted = torch.sum(index * sorted_scores, dim=1)
            gini = (2 * weighted) / (E * total.squeeze(1)) - (E + 1) / E
            gini = torch.clamp(gini, min=0.0, max=1.0)  # Ensure valid range
            gini = torch.where(torch.isfinite(gini), gini, torch.zeros_like(gini))  # Handle NaN/inf

            k_float = min_k + (max_k - min_k) * (1 - gini) ** beta
            k_float = torch.where(torch.isfinite(k_float), k_float, min_k)
            dynamic_k = torch.clamp(k_float.floor(), min=min_k, max=max_k).long()
            
        # Shared top-k logic
        sum_threshold = getattr(self.config, "sum_threshold", -1)
        if sum_threshold >= 0:
            # sum across the k strict top-k scores (pre-mask)
            topk_mass = topk_scores.sum(dim=-1)  # [BT]
            force_min = topk_mass >= float(sum_threshold)
            if force_min.any():
                dynamic_k = torch.where(
                    force_min,
                    torch.full_like(dynamic_k, min_k),
                    dynamic_k
                )

        # create top of max_k, ie sort both scores and indexes. But maintain top-k exactly as before
        ################## implementation that keeps strict top k as before
        # Mask out topk
        rest_mask = torch.ones_like(scores, dtype=torch.bool)
        rest_mask.scatter_(1, topk_indices, False)

        # Get the scores and their global indices for the "rest"
        rest_scores = scores.masked_select(rest_mask).view(scores.size(0), -1)
        rest_indices = torch.arange(scores.size(1), device=scores.device).expand_as(scores).masked_select(rest_mask).view(scores.size(0), -1)

        # Sort the rest by score and concat
        rest_scores_sorted, sort_order = torch.sort(rest_scores, dim=-1, descending=True)
        rest_indices_sorted = torch.gather(rest_indices, 1, sort_order)

        all_topk_scores = torch.cat([topk_scores, rest_scores_sorted], dim=-1)
        all_topk_indices = torch.cat([topk_indices, rest_indices_sorted], dim=-1)
        ####################
        mask = self.mask_indices[:max_k].unsqueeze(0) >= dynamic_k.unsqueeze(1)
        all_topk_scores = all_topk_scores.masked_fill(mask, float('-inf'))
        all_topk_indices = all_topk_indices.masked_fill(mask, -1)

    expert_loads = torch.zeros(E, dtype=torch.long, device=scores.device)
    final_topk_idx = torch.empty_like(topk_indices)
    final_topk_weight = torch.empty_like(topk_scores)

    if strategy != "baseline":
        if not vectorized:
            for i in range(BT):
                candidates = all_topk_indices[i]
                weights = all_topk_scores[i]
                valid = candidates != -1

                candidates = candidates[valid]
                weights = weights[valid]

                if candidates.size(0) <= self.num_choices:
                    choices = candidates
                else:
                    if getattr(self.config, "sample_before_load", SAMPLE_BEFORE_LOAD):
                        rand_idx = torch.randperm(candidates.size(0), device=scores.device)[:self.num_choices]
                        choices = candidates[rand_idx]
                        weights = weights[rand_idx]
                    else:
                        choices = candidates[:self.num_choices]
                        weights = weights[:self.num_choices]

                loads = expert_loads[choices]
                k_select = min(self.config.num_experts, choices.size(0))
                _, least_loaded_idx = torch.topk(-loads, k=k_select)
                selected = choices[least_loaded_idx]
                selected_scores = weights[least_loaded_idx]
                
                assert selected.size(0) == self.top_k, f"Selected {selected.size(0)} experts, expected {self.top_k}"

                final_topk_idx[i, :] = selected
                final_topk_weight[i, :] = selected_scores[i, selected]
                expert_loads[selected] += 1
        else:
            raise Exception("Vectorized version not supported yet")
    else:
        # baseline case
        final_topk_weight, final_topk_idx = topk_scores, topk_indices

        
    # Stats collection (only if needed)
    if self.collect_stats:
        strict_topk_all = torch.topk(scores, k=self.config.num_experts, dim=-1, largest=True, sorted=True)
        for i in range(min(BT, 100)):  # Limit to avoid memory issues
            selected = final_topk_idx[i]
            strict_expert_idx = strict_topk_all.indices[i].cpu().tolist()
            strict_scores = strict_topk_all.values[i].cpu().tolist()
            final_scores = [scores[i][eid].item() for eid in selected]
            overlap = len(set(strict_expert_idx) & set(selected.cpu().tolist())) / len(selected)

            self.routing_stats.append({
                "token_index": i,
                "layer_idx": self.layer_idx,
                "strict_topk_experts": strict_expert_idx,
                "strict_topk_scores": strict_scores,
                "chosen_experts": selected.cpu().tolist(),
                "chosen_scores": final_scores,
                "score_gap": sum(strict_scores) - sum(final_scores),
                "overlap_ratio": overlap,
            })

    if self.track_assignments:
        # Vectorized assignment tracking
        for i in range(BT):
            batch_idx = i // T
            token_idx = i % T
            experts = final_topk_idx[i].cpu().tolist()
            self.token_assignments.append({
                "batch_idx": batch_idx,
                "layer_idx": self.layer_idx,
                "token_idx": token_idx,
                "assigned_experts": experts,
                "id": id,
            })
    

    # # Normalize gate weights
    # if self.top_k > 1 and self.norm_topk_prob:
    #     denom = final_topk_weight.sum(dim=-1, keepdim=True) + 1e-8
    #     final_topk_weight = final_topk_weight / denom

    return final_topk_idx, final_topk_weight, logits




##################################
def forward_mixtral(self, hidden_states: torch.Tensor) -> torch.Tensor:
    """ """
    batch_size, sequence_length, hidden_dim = hidden_states.shape
    if self.training and self.jitter_noise > 0:
        hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
    hidden_states = hidden_states.view(-1, hidden_dim)

    ###################
    # # router_logits: (batch * sequence_length, n_experts)
    # router_logits = self.gate(hidden_states)

    # routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    # routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
    
    # CUSTOM IMPLEMENTATION
    # logits, top-k weights, top-k experts indexes
    router_logits, routing_weights, selected_experts = forward_mixtral_gating(self, hidden_states=hidden_states)
    selected_experts = selected_experts.long()
    ######################
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
    # we cast back to the input dtype
    routing_weights = routing_weights.to(hidden_states.dtype)

    final_hidden_states = torch.zeros(
        (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
    )

    # One hot encode the selected experts to create an expert mask
    # this will be used to easily index which expert is going to be sollicitated
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)


    expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
    for expert_idx in expert_hit:
        expert_layer = self.experts[expert_idx]
        idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
        # Index the correct hidden states and compute the expert hidden state for
        # the current expert. We need to make sure to multiply the output hidden
        # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
        current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
        current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

        # However `index_add_` only support torch tensors for indexing so we'll use
        # the `top_x` tensor here.
        final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
    return final_hidden_states, router_logits

def forward_mixtral_gating(self, hidden_states):
    if hidden_states.dim() == 2:
        # already flattened: (bsz * seq_len, hidden_dim)
        bsz = 1
        seq_len = hidden_states.size(0)
        h = hidden_states.size(1)
    elif hidden_states.dim() == 3:
        bsz, seq_len, h = hidden_states.shape
    else:
        raise ValueError(f"Unexpected hidden_states shape: {hidden_states.shape}")

    logits = self.gate(hidden_states)

    scores = F.softmax(logits, dim=1, dtype=torch.float) # [B*T, E]

    B, T, E = bsz, seq_len, self.n_routed_experts
    BT = B * T

    strategy = getattr(self.config, "gating_strategy", "gini")
    beta = getattr(self.config, "default_beta", BETA)
    min_k = getattr(self.config, "min_dynamic_k", DEFAULT_MIN_K)
    max_k = getattr(self.config, "max_dynamic_k", DEFAULT_MAX_K)
    vectorized = getattr(self.config, "vectorized", VECTORIZED) 
    id = str(uuid.uuid4())
    # === LOAD-ONLY ROUTING (ignore scores) ===
    if strategy == "load_only":
        expert_loads = torch.zeros(E, dtype=torch.long, device=scores.device)
        final_topk_idx = torch.empty((BT, self.top_k), dtype=torch.long, device=scores.device)
        final_topk_weight = torch.ones((BT, self.top_k), dtype=scores.dtype, device=scores.device)
        for i in range(BT):
            least = torch.topk(-expert_loads, k=self.top_k).indices
            final_topk_idx[i] = least
            expert_loads[least] += 1
            if self.track_assignments:
                batch_idx = i // T
                token_idx = i % T
                self.token_assignments.append({
                    "batch_idx": batch_idx,
                    "layer_idx": self.layer_idx,
                    "token_idx": token_idx,
                    "assigned_experts": least.tolist(),
                    "id": id,
                })
        # keep returning logits for shape parity with the usual path
        return logits, final_topk_weight, final_topk_idx

    if strategy == "baseline":
        topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

    else:
        # Shared top-k logic
        topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
        
        if strategy == "threshold":
            t = getattr(self.config, "threshold_t", 0.9)
            # topk_scores, _ = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
            strict_topk_min, _ = topk_scores.min(dim=-1, keepdim=True) 
            relative_thresh = t * strict_topk_min
            dynamic_k = torch.sum(scores >= relative_thresh, dim=-1).clamp(min=min_k, max=max_k).long()
        else:
            # Gini coefficient routing (optimized)
            sorted_scores, _ = torch.sort(scores, dim=-1)
            # Use pre-computed indices and combine operations for efficiency
            index = self.gini_indices.unsqueeze(0)
            total = sorted_scores.sum(dim=1, keepdim=True) + 1e-8  # keepdim for broadcasting
            weighted = torch.sum(index * sorted_scores, dim=1)
            gini = (2 * weighted) / (E * total.squeeze(1)) - (E + 1) / E
            gini = torch.clamp(gini, min=0.0, max=1.0)  # Ensure valid range
            gini = torch.where(torch.isfinite(gini), gini, torch.zeros_like(gini))  # Handle NaN/inf

            k_float = min_k + (max_k - min_k) * (1 - gini) ** beta
            k_float = torch.where(torch.isfinite(k_float), k_float, min_k)
            dynamic_k = torch.clamp(k_float.floor(), min=min_k, max=max_k).long()
        # Shared top-k logic
        sum_threshold = getattr(self.config, "sum_threshold", -1)
        if sum_threshold >= 0:
            # sum across the k strict top-k scores (pre-mask)
            topk_mass = topk_scores.sum(dim=-1)  # [BT]
            force_min = topk_mass >= float(sum_threshold)
            if force_min.any():
                dynamic_k = torch.where(
                    force_min,
                    torch.full_like(dynamic_k, min_k),
                    dynamic_k
                ).long()

        # create top of max_k, ie sort both scores and indexes. But maintain top-k exactly as before
        ################## implementation that keeps strict top k as before
        # Mask out topk
        rest_mask = torch.ones_like(scores, dtype=torch.bool)
        rest_mask.scatter_(1, topk_indices, False)

        # Get the scores and their global indices for the "rest"
        rest_scores = scores.masked_select(rest_mask).view(scores.size(0), -1)
        rest_indices = torch.arange(scores.size(1), device=scores.device).expand_as(scores).masked_select(rest_mask).view(scores.size(0), -1)

        # Sort the rest by score and concat
        rest_scores_sorted, sort_order = torch.sort(rest_scores, dim=-1, descending=True)
        rest_indices_sorted = torch.gather(rest_indices, 1, sort_order)

        all_topk_scores = torch.cat([topk_scores, rest_scores_sorted], dim=-1)
        all_topk_indices = torch.cat([topk_indices, rest_indices_sorted], dim=-1)
        ####################

        mask = self.mask_indices.unsqueeze(0) >= dynamic_k.unsqueeze(1)
        all_topk_scores = all_topk_scores.masked_fill(mask, float('-inf'))
        all_topk_indices = all_topk_indices.masked_fill(mask, -1)


    expert_loads = torch.zeros(E, dtype=torch.long, device=scores.device)
    final_topk_idx = torch.empty_like(topk_indices)
    final_topk_weight = torch.empty_like(topk_scores)

    if strategy != "baseline":

        if not vectorized:
            for i in range(BT):

                candidates = all_topk_indices[i]
                weights = all_topk_scores[i]
                valid = candidates != -1
                candidates = candidates[valid]
                weights = weights[valid]


                if candidates.size(0) <= self.num_choices:
                    choices = candidates
                else:
                    if getattr(self.config, "sample_before_load", SAMPLE_BEFORE_LOAD):
                        rand_idx = torch.randperm(candidates.size(0), device=scores.device)[:self.num_choices]
                        choices = candidates[rand_idx]
                        weights = weights[rand_idx]
                    else:
                        choices = candidates[:self.num_choices]
                        weights = weights[:self.num_choices]

                
                loads = expert_loads[choices]
                least_loaded_idx = torch.argsort(loads)[:self.top_k] 
                selected = choices[least_loaded_idx]
                selected_scores = weights[least_loaded_idx]
                
                assert selected.size(0) == self.top_k, f"Selected {selected.size(0)} experts, expected {self.top_k}"

                final_topk_idx[i, :] = selected
                final_topk_weight[i, :] = selected_scores
                expert_loads[selected] += 1
        else:
            raise Exception("Vectorized not yet supported")
    else:
        # baseline case
        final_topk_weight, final_topk_idx = topk_scores, topk_indices

        
    # Stats collection (only if needed)
    if self.collect_stats:
        strict_topk_all = torch.topk(scores, k=self.config.num_experts, dim=-1, largest=True, sorted=True)
        for i in range(min(BT, 100)):  # Limit to avoid memory issues
            selected = final_topk_idx[i]
            strict_expert_idx = strict_topk_all.indices[i].cpu().tolist()
            strict_scores = strict_topk_all.values[i].cpu().tolist()
            final_scores = [scores[i][eid].item() for eid in selected]
            overlap = len(set(strict_expert_idx) & set(selected.cpu().tolist())) / len(selected)

            self.routing_stats.append({
                "token_index": i,
                "layer_idx": self.layer_idx,
                "strict_topk_experts": strict_expert_idx,
                "strict_topk_scores": strict_scores,
                "chosen_experts": selected.cpu().tolist(),
                "chosen_scores": final_scores,
                "score_gap": sum(strict_scores) - sum(final_scores),
                "overlap_ratio": overlap,
            })

    if self.track_assignments:
        # Vectorized assignment tracking
        for i in range(BT):
            batch_idx = i // T
            token_idx = i % T
            experts = final_topk_idx[i].cpu().tolist()
            self.token_assignments.append({
                "batch_idx": batch_idx,
                "layer_idx": self.layer_idx,
                "token_idx": token_idx,
                "assigned_experts": experts,
                "id": id,
            })
    

    return logits, final_topk_weight, final_topk_idx
# ---------------------- Runner Entry ---------------------- #

def parse_args():
    parser = argparse.ArgumentParser()

    # General config
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--lora_path", type=str, required=False, default="")
    parser.add_argument("--model_family", type=str, required=True)
    parser.add_argument("--mode", choices=["qa", "perplexity"], default="perplexity")
    parser.add_argument("--model_type", type=str, default="chat")
    parser.add_argument("--selection_method", choices=["gini", "threshold", "baseline", "load_only"], default="gini")
    parser.add_argument(
        "--threshold_factor", 
        type=float, 
        nargs="+",   
        default=[0.9],
        help="One or more floats (space separated)"
    ) # by default not used
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--beta", type=int, default=1)
    parser.add_argument(
        "--sum_threshold", 
        type=float, 
        nargs="+",   
        default=[-1],
        help="One or more floats (space separated)"
    ) # by default not used
    parser.add_argument("--sample_before_load", action="store_true", help="If we sample prior to computing and selecting based on load")
    parser.add_argument("--vectorized", action="store_true")

    # Experiment parameters
    parser.add_argument("--num_choices", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--sample_size", type=int, default=SAMPLE_SIZE)
    parser.add_argument("--experiment_name", type=str, required=True)

    # Generation/perplexity limits
    parser.add_argument("--max_new_tokens", type=int, default=1)
    parser.add_argument("--max_prompt_length", type=int, default=1)
    parser.add_argument("--max_perplexity_length", type=int, default=1)

    # Dataset
    parser.add_argument("--dataset_name", type=str, default="wiki")

    # Output
    parser.add_argument("--output_dir", type=str, default="")
    parser.add_argument("--result_csv", type=str, default="", help="Append a one-line summary here (created if missing)")
    parser.add_argument("--artifact_json", type=str, default="", help="Write full detailed results here (optional)")


    return parser.parse_args()

def _ensure_parent_dir(path: str):
    d = os.path.dirname(path)
    if d:
        os.makedirs(d, exist_ok=True)

def atomic_dump_json(obj, path):
    _ensure_parent_dir(path)
    fd, tmp = tempfile.mkstemp(dir=os.path.dirname(path) or ".", prefix=".tmp_", suffix=".json")
    with os.fdopen(fd, "w") as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)

def append_row_to_csv(csv_path, row: dict):
    _ensure_parent_dir(csv_path)
    write_header = not os.path.exists(csv_path) or os.path.getsize(csv_path) == 0
    with open(csv_path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=list(row.keys()))
        if write_header:
            w.writeheader()
        w.writerow(row)

if __name__ == "__main__":
    args = parse_args()
    # Model
    MODEL_NAME = args.model_name
    model_family = args.model_family
    MODEL_TYPE = args.model_type
    MODE = args.mode
    SELECTION_MODE = args.selection_method
    THRESHOLD_FACTOR = args.threshold_factor
    SEED = args.seed

    model_family_cfg = MODEL_CFG_MAP[model_family]
    DEFAULT_MIN_K = model_family_cfg["default_min_k"]
    DEFAULT_MAX_K = model_family_cfg["default_max_k"]
    DEFAULT_NUM_EXPERTS = model_family_cfg["default_num_experts"]

    # Experiment setup
    BETA = args.beta
    SUM_THRESHOLD = args.sum_threshold

    SAMPLE_BEFORE_LOAD = args.sample_before_load
    VECTORIZED = args.vectorized

    BATCH_SIZE = args.batch_size
    SAMPLE_SIZE = args.sample_size
    num_experts = DEFAULT_NUM_EXPERTS
    num_choices = args.num_choices
    if args.num_choices == -1:
        print(f"Number of choices unspecified, using the number of experts {DEFAULT_NUM_EXPERTS}")
        num_choices = DEFAULT_NUM_EXPERTS
    sample_size = SAMPLE_SIZE
    experiment_name = args.experiment_name
    selection_method = args.selection_method
    threshold_factor = args.threshold_factor
    mode = args.mode
    SEED = args.seed

    # Generation/perplexity limits
    MAX_NEW_TOKENS = args.max_new_tokens
    MAX_PROMPT_LENGTH = args.max_prompt_length
    MAX_PERPLEXITY_LENGTH = args.max_perplexity_length

    # Dataset
    dataset = args.dataset_name
    DATASET_CFG = DATASET_CFG_MAP[dataset]
    
    # Output
    OUTPUT_DIR = args.output_dir


    print("Script path:", os.path.abspath(__file__))
    print("Parameters:")
    print({
        "NUM_CHOICES": num_choices,
        "EXPERIMENT_NAME": experiment_name,
        "BATCH_SIZE": BATCH_SIZE,
        "JOB_ID": JOB_ID,
        "MAX_NEW_TOKENS": MAX_NEW_TOKENS,
        "MODEL_NAME": MODEL_NAME,
        "DEFAULT_NUM_EXPERTS": DEFAULT_NUM_EXPERTS,
        "SAMPLE_SIZE": SAMPLE_SIZE,
        "SEED": SEED,
        "OUTPUTS": str(OUTPUTS),
        "VECTORIZED": VECTORIZED,
    })
    set_random_seed(SEED)
    
    expert_counts, all_layer_metrics, latency, metrics, per_layer_stats, agg_stats = run_experiment(
        num_experts=num_experts, num_choices=num_choices, sample_size=SAMPLE_SIZE, experiment_name=experiment_name,
        selection_method=selection_method, threshold_factor=threshold_factor, mode=mode,
        min_dynamic_k=DEFAULT_MIN_K, max_dynamic_k=DEFAULT_MAX_K, lora_path=args.lora_path
    )

    # Build a compact, flat summary row (fits in CSV)
    row = {
        "experiment_name": experiment_name,
        "model_name": MODEL_NAME,
        "model_family": model_family,
        "dataset": dataset,
        "mode": mode,
        "selection_method": selection_method,
        "num_experts": DEFAULT_NUM_EXPERTS,     
        "num_choices": num_choices,             
        "sample_size": sample_size,
        "batch_size": BATCH_SIZE,

        # latency
        "avg_latency_per_token_sec": latency.get("avg_latency_per_token_sec", None),
        "total_forward_time_sec": latency.get("total_forward_time_sec", None),
        "total_generated_tokens": latency.get("total_generated_tokens", None),

        # main metric 
        "perplexity": metrics.get("perplexity") if isinstance(metrics, dict) and "perplexity" in metrics else None,
        "qa_accuracy": metrics.get("accuracy") if isinstance(metrics, dict) and "accuracy" in metrics else None,
        "qa_exact_match": metrics.get("exact_match") if isinstance(metrics, dict) and "exact_match" in metrics else None,
    }

    # Optional: write the full artifact (nested details)
    if args.artifact_json:
        artifact = {
            "summary_row": row,                         # mirrors CSV
            "latency": latency,
            "metrics": metrics,
            "expert_counts": expert_counts,
            "per_batch_imbalance": {
                "per_layer": per_layer_stats,
                "aggregate": agg_stats,
            },
            "all_layer_metrics": all_layer_metrics,
            "config": {
                "DEFAULT_MIN_K": DEFAULT_MIN_K,
                "DEFAULT_MAX_K": DEFAULT_MAX_K,
                "DEFAULT_NUM_EXPERTS": DEFAULT_NUM_EXPERTS,
                "SAMPLE_BEFORE_LOAD": SAMPLE_BEFORE_LOAD,
                "THRESHOLD_FACTOR": THRESHOLD_FACTOR,
                "VECTORIZED": VECTORIZED,
                "BETA": BETA,
                "SUM_THRESHOLD": SUM_THRESHOLD,
            },
        }
        atomic_dump_json(artifact, args.artifact_json)

    # Append one line to CSV (create header if missing)
    if args.result_csv:
        append_row_to_csv(args.result_csv, row)