#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Llama-2-7b Dense to MoE Conversion and Supervised Fine-Tuning

This script performs the following steps:
1. Converts a dense Llama-2-7b model to a Mixture-of-Experts (MoE) architecture
2. Fine-tunes the MoE model on a CSV dataset with instruction, input, and output columns
3. Saves checkpoints after each epoch
4. Provides inference functionality with the trained model
5. Integrates TensorBoard for monitoring training progress

Usage:
    python llama2_dense_to_moe_sft.py --data_path path/to/dataset.csv --output_dir path/to/output
"""
CSV_FILE_PATH = '/hy-tmp/code/data/mimic3/mimic3_full_sft_test.csv'
OUTPUT_FILE_PATH = '/hy-tmp/code/output/mimic4-llama7/evaluation.csv'
FINAL_CHECKPOINT_DIR = "/hy-tmp/code/output/mimic4-llama7/moe/checkpoint-final"


MED_TO_IDX_PATH = "/hy-tmp/code/data/mimic4/med_to_idx.pkl"
DDI_ADJ_PATH = '/hy-tmp/code/data/mimic4/ddi_A_final.pkl'


DEFAULT_TRAINING_STAGES = [1,2,3] 



import os
import gc
import math
import json
import logging
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    set_seed,
    HfArgumentParser,
    # QWenConfig,  
    # DeepseekConfig, 
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.tuners.lora import LoraLayer
import bitsandbytes as bnb
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List, Tuple, Union, Any
from tqdm import tqdm
import copy
import warnings
import argparse
from datetime import datetime
# Import custom modules for MoE conversion
from camelidae.configuration_camelidae import CamelidaeConfig
from camelidae.modeling_camelidae_path import LlamaForCausalLM

# Monkey patch transformers to handle MoE models
from transformers_utils import get_keys_to_not_convert, _load_pretrained_model
import transformers.integrations
import transformers.modeling_utils
import pandas as pd
import re
import numpy as np
import csv
from tqdm import tqdm
import copy
import warnings
import argparse
import pickle
import time
import GPUtil
import psutil
import os

transformers.integrations.get_keys_to_not_convert = get_keys_to_not_convert
transformers.modeling_utils.PreTrainedModel._load_pretrained_model = _load_pretrained_model
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers_utils", message="Some weights of the model checkpoint.*were not used")

# Set up logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

# Constants
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"

# Create custom Trainer class to solve dispatch_batches compatibility issue
class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        # Ensure setting deepspeed attribute
        self.is_deepspeed_enabled = False
        self.deepspeed = None
        
        # Call parent class's initialization method
        super().__init__(*args, **kwargs)
    
    def _wrap_model(self, model, training=True, dataloader=None):
        # Create an Accelerator without dispatch_batches parameter
        if self.accelerator is None:
            from accelerate import Accelerator
            self.accelerator = Accelerator()
        
        # Continue using parent class method
        return super()._wrap_model(model, training, dataloader)
    
    def _prepare_inputs(self, inputs):
        """
        Filter out special fields starting with _, these fields should not be passed to the model
        """
        prepared_inputs = {}
        
        # Explicitly filter out fields that could cause problems, but ensure including cluster_labels
        allowed_keys = ['input_ids', 'attention_mask', 'labels', 'past_key_values', 'cluster_labels']
        
        for k, v in inputs.items():
            if k in allowed_keys:  # Only keep fields needed by the model
                if isinstance(v, torch.Tensor):
                    prepared_inputs[k] = v.to(self.args.device)
                else:
                    prepared_inputs[k] = v
                    
        # Print available input fields for debugging
        if self.args.logging_first_step and self.state.global_step == 0:
            logger.info(f"Model inputs: {list(prepared_inputs.keys())}")
        
        return prepared_inputs

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
    """
    model_name_or_path: str = field(
        default="meta-llama/Llama-2-7b-hf",
        metadata={"help": "Path to pretrained model"}
        # metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    use_auth_token: bool = field(
        default=False,
        metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
    )


@dataclass
class DataArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    data_path: str = field(
        default=None, metadata={"help": "Path to the CSV training data."}
    )
    instruction_column: str = field(
        default="instruction", metadata={"help": "Column name for instructions."}
    )
    input_column: str = field(
        default="input", metadata={"help": "Column name for inputs."}
    )
    output_column: str = field(
        default="output", metadata={"help": "Column name for outputs."}
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={"help": "For debugging purposes, truncate the number of training examples."}
    )


@dataclass
class MoEArguments:
    """
    Arguments pertaining to MoE configuration.
    """
    num_experts: int = field(
        default=1, metadata={"help": "Number of experts in the MoE model."}
    )
    top_k: int = field(
        default=2, metadata={"help": "Number of experts to route to for each token."}
    )
    adapter_dim: int = field(
        default=64, metadata={"help": "Dimension of the adapter."}
    )
    lora_r: int = field(
        default=64, metadata={"help": "Rank of the LoRA adapter."}
    )
    lora_alpha: int = field(
        default=16, metadata={"help": "Alpha parameter for LoRA scaling."}
    )
    moe_scaling: float = field(
        default=1.0, metadata={"help": "Scaling factor for MoE outputs."}
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    """
    Arguments for training configuration.
    """
    report_to: str = field(default="tensorboard")
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(
        default="paged_adamw_32bit",
        metadata={"help": "Optimizer to use for training."}
    )
    lr_scheduler_type: str = field(
        default="constant_with_warmup",
        metadata={"help": "Learning rate scheduler type."}
    )
    model_max_length: int = field(
        default=2048,
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}
    )
    save_strategy: str = field(
        default="epoch",
        metadata={"help": "The checkpoint save strategy to use."}
    )
    save_total_limit: Optional[int] = field(
        default=3,
        metadata={"help": "Limit the total amount of checkpoints."}
    )
    per_device_train_batch_size: int = field(
        default=8,
        metadata={"help": "Batch size per GPU for training."}
    )
    gradient_accumulation_steps: int = field(
        default=4,
        metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}
    )
    # Add dispatch_batches attribute for compatibility with older versions of accelerate
    dispatch_batches: Optional[bool] = field(
        default=None,
        metadata={"help": "Whether to dispatch batches across devices in distributed training"}
    )
    # Add stage training control parameters
    training_stage: int = field(
        default=0, 
        metadata={"help": "Training stage: 0=all stages, 1=warmup, 2=extract features, 3=MoE training"}
    )
    warmup_subset_ratio: float = field(
        default=0.3, 
        metadata={"help": "Ratio of data to use during warmup stage (0-1)"}
    )
    features_dir: str = field(
        default="features",
        metadata={"help": "Directory to save extracted features"}
    )

from torch.nn.utils.rnn import pad_sequence
def cluster_features(feature_dir, num_experts: int) -> List[int]:
    """Process multiple feature files and perform dimensionality reduction clustering"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    # Get all feature files
    feature_files = sorted(
        [f for f in os.listdir(feature_dir) if f.endswith('.pt')],
        key=lambda x: x  # Sort by filename
    )
    
    if not feature_files:
        logger.error(f"No feature files found in {feature_dir}")
        # Return random labels to avoid program crash
        import random
        # Assume dataset has 30 samples
        random_labels = [random.randint(0, num_experts-1) for _ in range(30)]
        logger.warning("Returning random generated cluster labels")
        return random_labels
    
    logger.info(f"Found {len(feature_files)} feature files")
    
    # Batch processing parameters
    batch_size = 500  # Adjust based on memory
    pca_components = 32  # Dimension after each layer of dimensionality reduction
    
    # Initialize list of reduced features
    reduced_features = []
    
    # First pass: Load all features and reduce dimensionality
    for file_idx, f in enumerate(feature_files):
        try:
            # Load feature data
            file_path = os.path.join(feature_dir, f)
            logger.info(f"Processing file {file_idx+1}/{len(feature_files)}: {f}")
            
            # Adapt to new save format
            data = torch.load(file_path, map_location=device)
            if isinstance(data, dict) and "features" in data:
                features = data["features"]
            else:
                # Try directly loading
                features = data
            
            if not isinstance(features, torch.Tensor):
                logger.warning(f"File {f} does not contain valid features, skipping")
                continue
                
            # Ensure features are two-dimensional
            if len(features.shape) > 2:
                features = features.mean(dim=1)  # Average over sequence length
            elif len(features.shape) == 1:
                features = features.unsqueeze(0)  # Single feature to two-dimensional
                
            # Ensure conversion to float type
            features = features.float()
            
            # Standardization
            mean = features.mean(dim=0, keepdim=True)
            std = features.std(dim=0, keepdim=True) + 1e-6
            features = (features - mean) / std
            
            # Add to result list
            reduced_features.append(features.cpu())
            
        except Exception as e:
            logger.error(f"Error processing file {f}: {e}")
            continue
    
    # Check if there are valid features
    if not reduced_features:
        logger.error("No valid features available for clustering")
        # Return random labels
        import random
        random_labels = [random.randint(0, num_experts-1) for _ in range(30)]
        logger.warning("Returning random generated cluster labels")
        return random_labels
    
    # Merge all features
    try:
        combined = torch.cat(reduced_features, dim=0).to(device)
        logger.info(f"Merged feature shape: {combined.shape}")
    except Exception as e:
        logger.error(f"Error merging features: {e}")
        # Check feature shape
        shapes = [f.shape for f in reduced_features]
        logger.error(f"Feature shapes: {shapes}")
        # Return random labels
        import random
        random_labels = [random.randint(0, num_experts-1) for _ in range(30)]
        return random_labels
    
    # Perform PCA dimensionality reduction
    try:
        n_samples, n_features = combined.shape
        final_components = min(128, n_features)
        
        logger.info(f"Performing PCA dimensionality reduction to {final_components} dimensions")
        # Calculate covariance matrix
        cov = torch.mm(combined.T, combined) / (n_samples - 1)
        # Feature decomposition
        eigenvalues, eigenvectors = torch.linalg.eigh(cov)
        # Take eigenvectors corresponding to largest eigenvalues
        pca_components = eigenvectors[:, -final_components:]
        # Apply PCA transformation
        final_features = torch.mm(combined, pca_components)
        
        logger.info(f"PCA dimensionality reduction shape: {final_features.shape}")
    except Exception as e:
        logger.error(f"Error performing PCA dimensionality reduction: {e}")
        # Use original features
        final_features = combined
        logger.warning("Using original features for clustering")
    
    # Perform K-means clustering
    try:
        logger.info(f"Performing K-means clustering, expert count: {num_experts}")
        from sklearn.cluster import KMeans
        
        # Convert to numpy for clustering
        features_np = final_features.cpu().numpy()
        kmeans = KMeans(n_clusters=num_experts, random_state=42, n_init=10)
        labels = kmeans.fit_predict(features_np)
        
        logger.info(f"Clustering completed, label shape: {labels.shape}")
        
        # Save clustering results
        result_path = os.path.join(feature_dir, "cluster_results.pt")
        torch.save({
            'cluster_labels': labels,
            'cluster_centers': kmeans.cluster_centers_
        }, result_path)
        logger.info(f"Cluster results saved to {result_path}")
        
        return labels.tolist()
    except Exception as e:
        logger.error(f"Error performing K-means clustering: {e}")
        # Return random labels
        import random
        random_labels = [random.randint(0, num_experts-1) for _ in range(n_samples)]
        logger.warning(f"Returning random generated {len(random_labels)} cluster labels")
        return random_labels


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [
        _tokenize_fn(strings, tokenizer) for strings in (examples, sources)
    ]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)


class CSVDataset(Dataset):
    """Dataset for supervised fine-tuning from CSV file."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, 
                 instruction_column: str, input_column: str, output_column: str,
                 max_samples: Optional[int] = None,  cluster_labels_path=None):
        super(CSVDataset, self).__init__()
        
        logger.info(f"Loading data from {data_path}")
        df = pd.read_csv(data_path)
        # df = df.iloc[:int(len(df)*0.3)]
        if max_samples is not None:
            df = df.sample(min(max_samples, len(df)), random_state=42)
        
        logger.info("Processing data")
        self.tokenizer = tokenizer
        self.sources = []
        self.targets = []
        
        # Load cluster labels
        self.cluster_labels = None
        if cluster_labels_path:
            logger.info(f"Loading cluster labels from {cluster_labels_path}")
            if cluster_labels_path.endswith('.csv'):
                # Load cluster labels from CSV
                try:
                    cluster_df = pd.read_csv(cluster_labels_path)
                    if 'cluster_label' in cluster_df.columns:
                        self.cluster_labels = cluster_df['cluster_label'].tolist()
                        logger.info(f"Loaded {len(self.cluster_labels)} cluster labels from CSV")
                    else:
                        logger.warning(f"'cluster_label' column not found in {cluster_labels_path}")
                except Exception as e:
                    logger.error(f"Error loading cluster labels from CSV: {e}")
            
        
        # Ensure correct label count
        if self.cluster_labels is not None and len(self.cluster_labels) != len(df):
            logger.warning(f"Cluster labels count ({len(self.cluster_labels)}) doesn't match data count ({len(df)})")
            self.cluster_labels = None
        
        for i, row in tqdm(df.iterrows(), total=len(df)):
            instruction = row[instruction_column]
            input_text = row[input_column] if input_column in row else ""
            output_text = row[output_column]
            
            if not isinstance(instruction, str):
                instruction = ""
            if not isinstance(input_text, str):
                input_text = ""
            if not isinstance(output_text, str):
                output_text = ""

            source = f"### Human:\n{instruction}\n{input_text}\n### Assistant:\n"
            target = f"{output_text}"
            
            self.sources.append(source)
            self.targets.append(target)

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, i):
        source = [self.sources[i]]
        target = [self.targets[i]]
        data_dict = preprocess(source, target, self.tokenizer)
        
        input_ids = data_dict["input_ids"][0]
        labels = data_dict["labels"][0]
        item = {
            "input_ids": input_ids,
            "labels": labels
        }
        # Add cluster labels
        if self.cluster_labels is not None:
            item["cluster_labels"] = self.cluster_labels[i]
        
        return item

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""
    
    tokenizer: transformers.PreTrainedTokenizer
    
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple(
            [instance[key] for instance in instances] for key in ("input_ids", "labels")
        )
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        
        # Construct basic batch - Only include fields needed by the model
        batch = {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": input_ids.ne(self.tokenizer.pad_token_id),
        }
        
        # Add cluster labels
        if all("cluster_labels" in instance for instance in instances):
            cluster_labels = [instance["cluster_labels"] for instance in instances]
            batch["cluster_labels"] = torch.tensor(cluster_labels, dtype=torch.long)
        
        return batch

class SavePeftModelCallback(transformers.TrainerCallback):
    """Callback to save the PEFT model and MoE weights after each epoch."""
    
    def save_model(self, args, state, kwargs):
        if state.best_model_checkpoint is not None:
            checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model")
        else:
            checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        
        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        model = kwargs["model"]
        model.save_pretrained(peft_model_path)
        
        # Save MoE state
        moe_state = {}
        for param_tensor in model.state_dict():
            if "adapter" in param_tensor:
                moe_state.update({param_tensor: model.state_dict()[param_tensor]})
        
        moe_model_path = os.path.join(checkpoint_folder, "moe_model.bin")
        torch.save(moe_state, moe_model_path)
        
        # Remove pytorch_model.bin to save space
        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
    
    def on_save(self, args, state, control, **kwargs):
        self.save_model(args, state, kwargs)
        return control
    
    def on_train_end(self, args, state, control, **kwargs):
        def touch(fname, times=None):
            with open(fname, "a"):
                os.utime(fname, times)
        
        touch(os.path.join(args.output_dir, "completed"))
        self.save_model(args, state, kwargs)

class GradientCheckCallback(transformers.TrainerCallback):
    """Callback to check if gradients are correctly computed"""
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        # Check only on the first step
        if state.global_step == 1:
            # Check if parameters have gradients
            has_gradients = False
            grad_norm = 0.0
            param_count = 0
            
            for name, param in model.named_parameters():
                if param.requires_grad:
                    param_count += 1
                    if param.grad is not None:
                        has_gradients = True
                        grad_norm += param.grad.data.norm(2).item() ** 2
            
            if has_gradients:
                grad_norm = grad_norm ** 0.5
                logger.info(f"Gradients are being computed! Gradient norm: {grad_norm:.4f}")
                logger.info(f"Trainable parameters with gradients: {param_count}")
            else:
                logger.warning("NO GRADIENTS FOUND on trainable parameters! Check model configuration.")
                
                # Print missing gradient parameters
                missing_grad_params = []
                for name, param in model.named_parameters():
                    if param.requires_grad and param.grad is None:
                        missing_grad_params.append(name)
                
                if missing_grad_params:
                    logger.warning(f"Parameters missing gradients: {missing_grad_params[:5]}...")
                    if len(missing_grad_params) > 5:
                        logger.warning(f"... and {len(missing_grad_params) - 5} more")
        
        return control

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, cluster_labels_path=None, subset_ratio=None) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    logger.info(f"Creating dataset from {data_args.data_path}")
    train_dataset = CSVDataset(
        tokenizer=tokenizer,
        data_path=data_args.data_path,
        instruction_column=data_args.instruction_column,
        input_column=data_args.input_column,
        output_column=data_args.output_column,
        max_samples=data_args.max_train_samples,
        cluster_labels_path=cluster_labels_path
    )
    
    # If subset is needed, randomly select data from specified proportion
    if subset_ratio is not None and 0 < subset_ratio < 1:
        logger.info(f"Using {subset_ratio*100:.1f}% of data for training")
        total_samples = len(train_dataset)
        subset_size = int(total_samples * subset_ratio)
        indices = torch.randperm(total_samples)[:subset_size].tolist()
        train_dataset = torch.utils.data.Subset(train_dataset, indices)
    
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)


def find_all_linear_names(model, bits=4):
    """Find all linear layer names in the model."""
    cls = (
        bnb.nn.Linear4bit
        if bits == 4
        else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
    )
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    
    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)


def print_trainable_parameters(model):
    """Prints the number of trainable parameters in the model."""
    trainable_params = 0
    all_param = 0
    trainable_modules = set()
    
    # Collect trainable parameter information
    for name, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            # Extract module name (take last part after last dot)
            module_name = name.split(".")[-1]
            trainable_modules.add(module_name)
    
    # Output trainable parameter ratio
    logger.info(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}%"
    )
    
    # Output which modules are trainable
    logger.info(f"Trainable modules: {', '.join(sorted(trainable_modules))}")
    
    # Check if there are enough trainable parameters
    if trainable_params == 0:
        logger.warning("NO TRAINABLE PARAMETERS FOUND! Training will not work.")
    elif trainable_params / all_param < 0.01:
        logger.warning("Very small percentage of trainable parameters (<1%). This might cause training issues.")

import torch.nn as nn

import torch
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from sklearn.preprocessing import StandardScaler
import joblib

from torch.utils.data import Subset
def get_top50_subset(dataset):
    """Get random shuffled subset of first 70% data"""
    original_size = len(dataset)
    split_idx = int(0.5 * original_size)
    
    # Generate random indices (fixed seed for reproducibility)
    # torch.manual_seed(42)  # Can be modified or removed
    # shuffled_indices = torch.randperm(original_size).tolist()
    
    return Subset(dataset, dataset[:split_idx])

def extract_features(model, data_loader, output_dir, device):
    """Extract model features to specified directory, use smaller batch size and segment processing to reduce memory usage"""
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Extracting features to {output_dir}")
    
    # Set model to evaluation mode
    model.eval()
    
    # Check model type
    is_peft_model = hasattr(model, "base_model")
    if is_peft_model:
        logger.info("Detected PEFT model, will use base_model for feature extraction")
    
    features_list = []
    
    # Smaller processing batch size, reduce memory pressure
    max_features_batch = 50  # Maximum number of features to process per batch
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader, desc="Extracting features")):
            # Prepare input
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            
            # Segment processing for long sequences
            max_chunk_size = 512  # Maximum length of each segment
            seq_length = input_ids.size(1)
            
            # Process each sample
            for i in range(input_ids.size(0)):
                sample_features = []
                
                # Process segments of long sequences
                for chunk_start in range(0, seq_length, max_chunk_size):
                    chunk_end = min(chunk_start + max_chunk_size, seq_length)
                    
                    # Extract current segment
                    chunk_input_ids = input_ids[i:i+1, chunk_start:chunk_end]
                    chunk_attention_mask = attention_mask[i:i+1, chunk_start:chunk_end]
                    
                    # If segment length is 0, skip
                    if chunk_input_ids.size(1) == 0:
                        continue
                    
                    # Process current segment
                    try:
                        outputs = model(
                            input_ids=chunk_input_ids,
                            attention_mask=chunk_attention_mask,
                            output_hidden_states=True
                        )
                        
                        # Get features from last layer
                        last_hidden_state = outputs.hidden_states[-1]
                        
                        # Extract features from valid tokens (using attention mask)
                        valid_length = chunk_attention_mask.sum().item()
                        if valid_length > 0:
                            # Take average of valid tokens
                            chunk_feature = last_hidden_state[0, :valid_length].mean(dim=0)
                            sample_features.append(chunk_feature)
                    except Exception as e:
                        logger.error(f"Error extracting features: {e}")
                        logger.error(f"Batch: {batch_idx}, Sample: {i}, Segment: {chunk_start}-{chunk_end}")
                        continue
                
                # Merge all segment features
                if sample_features:
                    feature = torch.stack(sample_features).mean(dim=0)
                    features_list.append(feature)
                
                # Periodically clean up memory and save features
                if len(features_list) >= max_features_batch:
                    save_features_batch(features_list, output_dir, f"{batch_idx}_{i}")
                    features_list = []
                    torch.cuda.empty_cache()  # Release CUDA cache
            
            # Clean up cache after each batch
            torch.cuda.empty_cache()
            
            # Periodically save to prevent memory overflow
            if (batch_idx+1) % 10 == 0 and features_list:
                save_features_batch(features_list, output_dir, f"batch_{batch_idx}")
                features_list = []
    
    # Save remaining features
    if features_list:
        save_features_batch(features_list, output_dir, "final")
    
    logger.info("Feature extraction completed")

def save_features_batch(features_list, output_dir, batch_id):
    """Save feature batch"""
    features_tensor = torch.stack(features_list)
    output_file = os.path.join(output_dir, f"features_batch_{batch_id}.pt")
    torch.save({"features": features_tensor}, output_file)
    logger.info(f"Saved {len(features_list)} features to {output_file}")

def train_moe_model(model_args, data_args, moe_args, training_args):
    """Train a MoE model from a dense model."""
    set_seed(42)
    
    # Memory optimization configuration
    torch.cuda.empty_cache()
    
    # Ensure feature directory exists
    features_dir = os.path.join(training_args.output_dir, training_args.features_dir)
    os.makedirs(features_dir, exist_ok=True)
    
    # Set 3-stage training (if specific stage not specified)
    stages_to_run = [training_args.training_stage] if training_args.training_stage > 0 else DEFAULT_TRAINING_STAGES
    
    # ==== Stage 1: Warm-up Training ====
    if 1 in stages_to_run:
        logger.info("=== Stage 1: Warm-up Training ===")
        # Configure model (disable MoE)
        model_config = CamelidaeConfig.from_pretrained(model_args.model_name_or_path)
        model_config.pretraining_tp = 1
        model_config.use_moe = False  # Disable MoE modules
        
        # Update BitsAndBytes configuration to ensure compute_dtype is correct and set bnb_4bit_quant_type to "nf4"
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_storage=torch.bfloat16 if training_args.bf16 else torch.float16,
        )
        
        # Load model
        model = LlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=model_config,
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
            device_map="auto",
            token=model_args.use_auth_token,
            output_loading_info=False,
        )
        
        # Prepare model for quantization training
        model = prepare_model_for_kbit_training(model)
        
        # Enable gradient checkpointing to reduce memory usage
        if training_args.gradient_checkpointing:
            model.gradient_checkpointing_enable()
            model.config.use_cache = False
            
        # Ensure at least some parameters are trainable
        for param in model.lm_head.parameters():
            param.requires_grad = True
        
        # Set LoRA configuration
        # Use fixed target_modules list instead of dynamic lookup
        lora_target_modules = [
            "q_proj", "k_proj", "v_proj", "o_proj", 
            "gate_proj", "up_proj", "down_proj", 
            "lm_head"
        ]
        
        config = LoraConfig(
            r=moe_args.lora_r,
            lora_alpha=moe_args.lora_alpha,
            target_modules=lora_target_modules,
            lora_dropout=0.1,
            bias="none",
            task_type="CAUSAL_LM",
        )
        
        # Apply LoRA
        model = get_peft_model(model, config)
        
        # Ensure all LoRA parameters are trainable
        for name, param in model.named_parameters():
            if 'lora' in name:
                param.requires_grad = True
        
        model.config.use_cache = False
        print_trainable_parameters(model)
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
            token=model_args.use_auth_token,
            trust_remote_code=True,
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token_id = 0  # unk token
        
        # Use 30% of data for warmup training
        data_module = make_supervised_data_module(
            tokenizer=tokenizer, 
            data_args=data_args, 
            subset_ratio=training_args.warmup_subset_ratio
        )
        
        # Create trainer
        warmup_args = copy.deepcopy(training_args)
        warmup_args.output_dir = os.path.join(training_args.output_dir, "warmup")
        os.makedirs(warmup_args.output_dir, exist_ok=True)
        
        trainer = CustomTrainer(
            model=model,
            tokenizer=tokenizer,
            args=warmup_args,
            **data_module
        )
        
        trainer.add_callback(SavePeftModelCallback)
        trainer.add_callback(GradientCheckCallback)  # Add gradient check callback
        
        # Train model
        logger.info("Starting warmup training...")
        trainer.train()
        
        # Save warmup model
        warmup_model_path = os.path.join(warmup_args.output_dir, "final")
        os.makedirs(warmup_model_path, exist_ok=True)
        
        # Save PEFT/LoRA adapter
        model.save_pretrained(warmup_model_path)
        
        # Ensure original model configuration is saved
        model.config.save_pretrained(warmup_model_path)
        
        # Save tokenizer
        tokenizer.save_pretrained(warmup_model_path)
        
        logger.info(f"Warmup model saved to {warmup_model_path}")
        
        # Clean up memory
        del model, trainer
        torch.cuda.empty_cache()
    
    # ==== Stage 2: Feature Extraction and Clustering ====
    if 2 in stages_to_run:
        logger.info("=== Stage 2: Feature Extraction and Clustering ===")
        
        # Clean up GPU memory
        torch.cuda.empty_cache()
        
        # Load warmup model
        warmup_model_path = os.path.join(training_args.output_dir, "warmup/final")
        
        # Check if config file exists
        config_path = os.path.join(warmup_model_path, "config.json")
        if not os.path.exists(config_path):
            logger.warning(f"Config file not found: {config_path}, using original model configuration")
            # Load config directly from original model
            model_config = CamelidaeConfig.from_pretrained(model_args.model_name_or_path)
        else:
            model_config = CamelidaeConfig.from_pretrained(warmup_model_path)
            
        # Set model configuration
        model_config.use_moe = False  # Ensure MoE disabled
        model_config.output_hidden_states = True  # Ensure output hidden states
        
        # More efficient BitsAndBytes configuration
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
            bnb_4bit_use_double_quant=True,
        )
        
        # Check if PEFT model
        adapter_path = os.path.join(warmup_model_path, "adapter_model")
        peft_model = os.path.exists(adapter_path)
        
        # If warmup model cannot be found, use original model
        if not os.path.exists(warmup_model_path):
            logger.warning(f"Warmup model path does not exist: {warmup_model_path}, using original model")
            base_model_path = model_args.model_name_or_path
            peft_model = False
        else:
            # If PEFT model, need to load original model first, then PEFT model
            if peft_model:
                logger.info(f"Detected PEFT model, will load original model and adapter")
                base_model_path = model_args.model_name_or_path
            else:
                base_model_path = warmup_model_path
        
        # Load base model
        model = LlamaForCausalLM.from_pretrained(
            base_model_path,
            config=model_config,
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
            device_map="auto",
        )
        
        # If PEFT model, load adapter
        if peft_model:
            from peft import PeftModel
            logger.info(f"Loading PEFT adapter: {adapter_path}")
            try:
                model = PeftModel.from_pretrained(model, adapter_path)
                logger.info("PEFT adapter loaded successfully")
            except Exception as e:
                logger.error(f"PEFT adapter load failed: {e}")
        
        # Prepare model for feature extraction
        model.eval()  # Set to evaluation mode
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,  # Use tokenizer from original model
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token  # Use eos_token as pad_token
        
        # Create data loader (use full data)
        data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
        data_loader = torch.utils.data.DataLoader(
            data_module["train_dataset"],
            batch_size=8,  # Adjust batch size
            collate_fn=data_module["data_collator"],
            shuffle=False,  # Don't shuffle order
        )
        
        # Extract features
        extract_features(model, data_loader, features_dir, model.device)
        
        # Perform clustering
        logger.info(f"Clustering features into {moe_args.num_experts} experts")
        cluster_labels = cluster_features(features_dir, moe_args.num_experts)
        
        # Save clustering results
        cluster_labels_file = os.path.join(training_args.output_dir, "cluster_labels.csv")
        df = pd.read_csv(data_args.data_path)
        df['cluster_label'] = cluster_labels
        df.to_csv(cluster_labels_file, index=False)
        logger.info(f"Cluster labels saved to {cluster_labels_file}")
        
        # Clean up memory
        del model, data_loader
        torch.cuda.empty_cache()
    
    # ==== Stage 3: MoE Training ====
    if 3 in stages_to_run:
        logger.info("=== Stage 3: MoE Training ===")
        
        # Clean up GPU memory
        torch.cuda.empty_cache()
        
        # Check if clustering labels exist
        cluster_labels_file = os.path.join(training_args.output_dir, "cluster_labels.csv")
        if not os.path.exists(cluster_labels_file):
            logger.warning(f"Cluster labels not found at {cluster_labels_file}, proceeding without expert assignments")
            cluster_labels_path = None
        else:
            cluster_labels_path = cluster_labels_file
        
        # Configure MoE model
        model_config = CamelidaeConfig.from_pretrained(model_args.model_name_or_path)
        model_config.pretraining_tp = 1
        model_config.use_moe = False  # Enable MoE
        
        # Set MoE configuration
        model_config.moe_dtype = "bfloat16"
        model_config.lora_r = moe_args.lora_r
        model_config.lora_alpha = moe_args.lora_alpha
        model_config.adapter_dim = moe_args.adapter_dim
        model_config.topk = moe_args.top_k
        model_config.moe_scaling = moe_args.moe_scaling
        model_config.num_experts = moe_args.num_experts
        model_config.output_router_logits = True
        
        # More efficient BitsAndBytes configuration
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
            bnb_4bit_use_double_quant=True,
        )
        
        # Load model
        model = LlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=model_config,
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
            device_map="auto",
            token=model_args.use_auth_token,
            output_loading_info=False,
        )
        
        # Prepare model for quantization training
        model = prepare_model_for_kbit_training(model)
        
        # Enable gradient checkpointing to reduce memory usage
        if training_args.gradient_checkpointing:
            model.gradient_checkpointing_enable()
            model.config.use_cache = False
        
        # Set LoRA for router and adapter
        lora_target_modules = [
            # Original linear layers
            "q_proj", "k_proj", "v_proj", 
            "o_proj", "up_proj", "gate_proj", "down_proj",
            "lm_head",
            
            # Linear submodules in router
            "router.query",  # Query projection of AttentionRouter
            "router.key",    # Key projection of AttentionRouter
            "router.value",  # Value projection of AttentionRouter
            
            # Adapter layers
            "adapter_down",
            "adapter_up"
        ]
        
        config = LoraConfig(
            r=model_config.lora_r,
            lora_alpha=model_config.lora_alpha,
            target_modules=lora_target_modules,
            lora_dropout=0.1,
            bias="none",
            task_type="CAUSAL_LM",
        )
        
        model = get_peft_model(model, config)
        
        # Ensure all LoRA parameters and router/adapter parameters are trainable
        for name, param in model.named_parameters():
            if any(x in name for x in ['lora', 'router', 'adapter']):
                param.requires_grad = True
        
        model.config.use_cache = False
        print_trainable_parameters(model)
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
            token=model_args.use_auth_token,
            trust_remote_code=True,
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token_id = 0
        
        # Create dataset with clustering labels
        data_module = make_supervised_data_module(
            tokenizer=tokenizer, 
            data_args=data_args,
            cluster_labels_path=cluster_labels_path
        )
        
        # Create MoE trainer
        moe_args = copy.deepcopy(training_args)
        moe_args.output_dir = os.path.join(training_args.output_dir, "moe")
        os.makedirs(moe_args.output_dir, exist_ok=True)
        
        trainer = CustomTrainer(
            model=model,
            tokenizer=tokenizer,
            args=moe_args,
            **data_module
        )
        
        trainer.add_callback(SavePeftModelCallback)
        trainer.add_callback(GradientCheckCallback)  # Add gradient check callback
        
        # Train model
        logger.info("Starting MoE training...")
        trainer.train()
        
        # Save final model
        final_model_path = os.path.join(moe_args.output_dir, "final")
        os.makedirs(final_model_path, exist_ok=True)
        model.save_pretrained(final_model_path)
        tokenizer.save_pretrained(final_model_path)
        logger.info(f"Final MoE model saved to {final_model_path}")
    
    return model, tokenizer


def merge_moe_lora(model_path, peft_path, moe_path, save_path):
    """Merge LoRA weights with the base model."""
    print(f"Loading tokenizer from {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
    
    print(f"Loading model config from {model_path}")
    model_config = CamelidaeConfig.from_pretrained(model_path)
    model_config.pretraining_tp = 1  # without tensor parallelism rank
    
    # Set auto_map for loading the model
    model_config.auto_map = {
        "AutoConfig": "configuration_camelidae.CamelidaeConfig",
        "AutoModelForCausalLM": "modeling_camelidae.LlamaForCausalLM"
    }
    
    print(f"Loading base model from {model_path}")
    model = LlamaForCausalLM.from_pretrained(
        model_path,
        config=model_config,
        torch_dtype=torch.bfloat16,
        device_map={'': 'cpu'}
    )
    
    print(f"Loading MoE weights from {moe_path}")
    moe_weights = torch.load(moe_path, map_location=torch.device("cpu"))
    weights_dict = {}
    for k, v in moe_weights.items():
        new_k = k.replace("base_model.model.", "") if "base_model.model." in k else k
        weights_dict[new_k] = v
    
    if weights_dict:  # Only load if there are weights to load
        print(f"Loading {len(weights_dict)} MoE weight tensors")
        model.load_state_dict(weights_dict, strict=False)
    else:
        print("Warning: No MoE weights found, using base model weights only")
    
    # Load PEFT model
    print(f"Loading PEFT model from {peft_path}")
    from peft import PeftModel
    model = PeftModel.from_pretrained(
        model,
        peft_path,
        torch_dtype=torch.bfloat16,
        device_map={'': 'cpu'}
    )
    
    # Merge weights
    print("Merging weights...")
    model = model.merge_and_unload()
    
    # Save the merged model
    print(f"Saving merged model to {save_path}")
    os.makedirs(save_path, exist_ok=True)
    tokenizer.save_pretrained(save_path)
    model.save_pretrained(save_path)
    
    return model, tokenizer

def inference(model_path, prompt, max_new_tokens=512, temperature=0.7, top_p=0.9):
    """Run inference with the trained model."""
    start_time = time.time()
    start_gpu = get_gpu_usage()
    # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False)
    
    # Fix tokenizer configuration to avoid overflow
    # if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length > 1000000:
    #     tokenizer.model_max_length = 2048
    
    # model = AutoModelForCausalLM.from_pretrained(
    #     model_path, 
    #     device_map="auto", 
    #     trust_remote_code=True
    # ).eval()
    config_path = os.path.join(model_path, "config.json")
    if not os.path.exists(config_path):
        print(f"Warning: config.json not found at {config_path}, using default config")

        config = CamelidaeConfig()
    else:

        config = CamelidaeConfig.from_pretrained(model_path)
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_path, 
        use_fast=False,
        trust_remote_code=False
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = LlamaForCausalLM.from_pretrained(
        model_path, 
        config=config,
        device_map="auto", 
        torch_dtype=torch.bfloat16,
        trust_remote_code=False
    ).eval()
    
    # Format the prompt if it doesn't already have the expected format
    if not prompt.startswith("### Human:"):
        prompt = f"### Human:\n{prompt}\n### Assistant:\n"
        
    inference_start = time.time()
    inputs = tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True)
    inputs = inputs.to(model.device)
    
    # Generate response
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=temperature > 0,
    )
    inference_time = time.time() - inference_start
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract just the assistant's response
    if "### Assistant:" in response:
        response = response.split("### Assistant:")[-1].strip()
        
    end_gpu = get_gpu_usage()
    gpu_memory_used = None
    if start_gpu and end_gpu:
        gpu_memory_used = end_gpu["memory_used"] - start_gpu["memory_used"]

    logger.info(f"Inference time: {inference_time:.4f} seconds")
    
    if start_gpu and end_gpu:
        logger.info(f" usage of GPU: {end_gpu['gpu_name']}")
        logger.info(f" increase in memory usage: {gpu_memory_used:.2f} MB")
        logger.info(f" final memory usage rate: {end_gpu['memory_percent']:.2f}%")
        logger.info(f" final GPU load: {end_gpu['gpu_load']:.2f}%")
    return response

        
def evaluate_model(model_output, ground_truth, model_probs=None):
    # Convert model output and ground truth to sets
    model_set = set(model_output)
    ground_truth_set = set(ground_truth)

    # Calculate intersection
    intersection = model_set.intersection(ground_truth_set)

    # Calculate hit rate (number of intersection elements / number of ground truth elements)
    hit_rate = len(intersection) / len(ground_truth_set) if len(ground_truth_set) > 0 else 0

    # Calculate accuracy (number of intersection elements / number of model output elements)
    accuracy = len(intersection) / len(model_set) if len(model_set) > 0 else 0

    # Calculate Jaccard coefficient (number of intersection elements / number of union elements)
    union = model_set.union(ground_truth_set)
    jaccard = len(intersection) / len(union) if len(union) > 0 else 0

    # Calculate recall (number of intersection elements / number of ground truth elements)
    recall = len(intersection) / len(ground_truth_set) if len(ground_truth_set) > 0 else 0

    # Calculate precision (number of intersection elements / number of model output elements)
    precision = len(intersection) / len(model_set) if len(model_set) > 0 else 0

    # Calculate F1 score
    if precision + recall > 0:
        f1_score = 2 * (precision * recall) / (precision + recall)
    else:
        f1_score = 0

    metrics = {
        'Hit Rate': hit_rate,
        'Accuracy': accuracy,
        'Jaccard': jaccard,
        'Recall': recall,
        'Precision': precision,
        'F1 Score': f1_score
    }
    # print(metrics)

    # Calculate AUPRC (assuming prediction probabilities can be obtained)
    if model_probs is not None:
        all_items = sorted(set(model_output.split(', ')) | set(ground_truth.split(', ')))
        y_true = [1 if item in ground_truth_set else 0 for item in all_items]
        y_score = [model_probs.get(item, 0) for item in all_items]

        # Sort by probability
        sorted_indices = sorted(range(len(y_score)), key=lambda k: y_score[k], reverse=True)
        y_true = [y_true[i] for i in sorted_indices]
        y_score = [y_score[i] for i in sorted_indices]

        precisions = []
        recalls = []
        tp = 0
        fp = 0
        total_positives = sum(y_true)
        for i in range(len(y_true)):
            if y_true[i] == 1:
                tp += 1
            else:
                fp += 1
            if tp + fp > 0:
                precision = tp / (tp + fp)
                recall = tp / total_positives if total_positives > 0 else 0
                precisions.append(precision)
                recalls.append(recall)

        # Calculate AUPRC
        auprc = 0
        for i in range(1, len(precisions)):
            auprc += (recalls[i] - recalls[i - 1]) * precisions[i]
        metrics['AUPRC'] = auprc

    return metrics

def code(text):
    pattern = r"[A-Z]\d{2}[A-Z]"
    codes = re.findall(pattern, text)
    return codes

def read_pickle(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data
def load_ddi_adjacency_matrix(path=DDI_ADJ_PATH):
    try:
        ddi_adj = read_pickle(path)
        logger.info(f"Loaded DDI adjacency matrix from {path}")
        print(f"DDI adjacency matrix: {ddi_adj[:5,:5]}") 
        return ddi_adj
    except Exception as e:
        logger.error(f"Failed to load DDI adjacency matrix: {e}")
        logger.warning("Using empty DDI adjacency matrix for demonstration")
        return {}

def calculate_ddi_metrics(drug_list, ddi_adj):
    if not drug_list or len(drug_list) < 2:
        return 0.0, 0, 0
    
    all_cnt = 0  
    dd_cnt = 0   
    
    for i, drug_i in enumerate(drug_list):
        for j in range(i + 1, len(drug_list)):
            drug_j = drug_list[j]
            all_cnt += 1
            
            try:
                if isinstance(ddi_adj, dict):
                    has_interaction = (drug_i, drug_j) in ddi_adj or (drug_j, drug_i) in ddi_adj
                else:
                    has_interaction = ddi_adj[drug_i, drug_j] == 1 or ddi_adj[drug_j, drug_i] == 1
                
                if has_interaction:
                    dd_cnt += 1
            except (TypeError, IndexError, KeyError) as e:
                logger.warning(f"Error checking DDI for pair ({drug_i}, {drug_j}): {e}")
                continue
    
    ddi_rate = dd_cnt / all_cnt if all_cnt > 0 else 0.0
    print(ddi_rate)
    
    return ddi_rate, dd_cnt, all_cnt

def map_codes_to_indices(med_codes, med_to_idx):
    indices = []
    for code in med_codes:
        if code in med_to_idx:
            indices.append(med_to_idx[code])
        else:
            print(f"warning: the {code} is not in the list")
    return indices

def process_csv_and_save_results(csv_file_path, output_file_path):
    ddi_adj = load_ddi_adjacency_matrix()
    med_to_idx = read_pickle(MED_TO_IDX_PATH)
    total_metrics = {
        'Hit Rate': 0,
        'Accuracy': 0,
        'Jaccard': 0,
        'Recall': 0,
        'Precision': 0,
        'F1 Score': 0,
        'DDI Rate': 0,
        'DDI Count': 0,
        'Total Pairs': 0,
        'Model Drug Count': 0
    }
    row_count = 0

    with open(csv_file_path, 'r', newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        with open(output_file_path, 'w') as output_file:
            output_file.write("Row,Hit Rate,Accuracy,Jaccard,Recall,Precision,F1 Score,DDI Rate,DDI Count,Total Pairs,Model Drug Count")
            if 'model_probs' in reader.fieldnames:
                output_file.write(",AUPRC")
            output_file.write("\n")
            
            for row in reader:
                final_checkpoint_dir = FINAL_CHECKPOINT_DIR
                instruct = row.get('instruct', '')
                input_text = row.get('input', '')
                prompt = f"{instruct} {input_text}".strip()
                model_output = inference(final_checkpoint_dir, prompt)

                ground_truth = row.get('output', '')
                model_output = code(model_output)
                ground_truth = code(ground_truth)
                model_drug_count = len(model_output)
              
                model_probs = None
                if 'model_probs' in row:
                    try:
                        model_probs = eval(row['model_probs'])
                    except:
                        pass

                metrics = evaluate_model(model_output, ground_truth, model_probs)

                drug_pattern = r"[A-Z]\d{2}[A-Z]"
                drug_list = re.findall(drug_pattern, str(model_output))
                current_med_indices = map_codes_to_indices(drug_list, med_to_idx)

                ddi_rate, ddi_count, total_pairs = calculate_ddi_metrics(current_med_indices, ddi_adj)
                metrics['DDI Rate'] = ddi_rate
                metrics['DDI Count'] = ddi_count
                metrics['Total Pairs'] = total_pairs
                metrics['Model Drug Count'] = model_drug_count
                total_metrics['Model Drug Count'] += model_drug_count

                output_str = f"{row_count + 1},{metrics['Hit Rate']:.4f},{metrics['Accuracy']:.4f},{metrics['Jaccard']:.4f},{metrics['Recall']:.4f},{metrics['Precision']:.4f},{metrics['F1 Score']:.4f},{ddi_rate:.4f},{ddi_count},{total_pairs},{metrics['Model Drug Count']:.4f}"
                if 'AUPRC' in metrics:
                    output_str += f",{metrics['AUPRC']:.4f}"
                print(output_str)
                output_file.write(output_str + "\n")

                for metric in total_metrics:
                    if metric in metrics:
                        total_metrics[metric] += metrics[metric]
                row_count += 1

            if row_count > 0:
                for metric in total_metrics:
                    if metric != 'DDI Count' and metric != 'Total Pairs': 
                        total_metrics[metric] /= row_count

            output_str = f"Mean,{total_metrics['Hit Rate']:.4f},{total_metrics['Accuracy']:.4f},{total_metrics['Jaccard']:.4f},{total_metrics['Recall']:.4f},{total_metrics['Precision']:.4f},{total_metrics['F1 Score']:.4f},{total_metrics['DDI Rate']:.4f},{total_metrics['DDI Count']:.0f},{total_metrics['Total Pairs']:.0f},{total_metrics['Model Drug Count']/row_count:.2f}"
            if 'AUPRC' in total_metrics:
                output_str += f",{total_metrics['AUPRC']:.4f}"
            output_file.write(output_str + "\n")

def get_gpu_usage():
    gpus = GPUtil.getGPUs()
    if not gpus:
        return None
    gpu = gpus[0]  
    return {
        "gpu_id": gpu.id,
        "gpu_name": gpu.name,
        "memory_used": gpu.memoryUsed,  
        "memory_total": gpu.memoryTotal,  
        "memory_percent": gpu.memoryUtil * 100, 
        "gpu_load": gpu.load * 100 
    }
            
def main():
    # Configure PyTorch memory allocator, reduce memory fragmentation
    if torch.cuda.is_available():
        # Set larger memory allocation block, reduce fragmentation
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
        
        # Force release GPU memory
        torch.cuda.empty_cache()
        
        # Print available GPU memory
        device = torch.cuda.current_device()
        logger.info(f"Total GPU memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB")
        logger.info(f"Available GPU memory: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB reserved")
    
    parser = HfArgumentParser((ModelArguments, DataArguments, MoEArguments, TrainingArguments))
    model_args, data_args, moe_args, training_args = parser.parse_args_into_dataclasses()
    
    # Create output directory
    os.makedirs(training_args.output_dir, exist_ok=True)
    
    # Execute 3-stage training process
    # model, tokenizer = train_moe_model(model_args, data_args, moe_args, training_args)
    
    # Save results to final directory
    final_dir = os.path.join(training_args.output_dir, "final")
    os.makedirs(final_dir, exist_ok=True)
    
    # If final model exists, copy to final directory
    
    last_checkpoint = None
    available_checkpoints = []
    moe_path = f"{training_args.output_dir}/moe" 
    
    print("Scanning for checkpoints...")
    for checkpoint_dir in sorted(os.listdir(moe_path)):
        full_path = os.path.join(moe_path, checkpoint_dir)
        if checkpoint_dir.startswith(PREFIX_CHECKPOINT_DIR):
            available_checkpoints.append(checkpoint_dir)
            print(f"Found checkpoint: {checkpoint_dir}")
            
            adapter_model_path = os.path.join(full_path, "adapter_model")
            moe_model_path = os.path.join(full_path, "moe_model.bin")
            
            print(f"  - adapter_model exists: {os.path.exists(adapter_model_path)}")
            print(f"  - moe_model.bin exists: {os.path.exists(moe_model_path)}")
            
            if os.path.exists(adapter_model_path) and os.path.exists(moe_model_path):
                last_checkpoint = full_path
                print(f"  - Using this checkpoint")

    if last_checkpoint is None:
        raise FileNotFoundError("No valid checkpoint found in {}. Available checkpoints: {}".format(
            moe_path, available_checkpoints))

    peft_path = os.path.join(last_checkpoint, "adapter_model")
    moe = os.path.join(last_checkpoint, "moe_model.bin")
    final_checkpoint_dir = os.path.join(moe_path, "checkpoint-final")

    print(f"Using checkpoint: {last_checkpoint}")
    print(f"PEFT path: {peft_path}")
    print(f"MoE path: {moe}")
    print(f"Final save path: {final_checkpoint_dir}")

    # try:
    #      merge_moe_lora(
    #         model_path=model_args.model_name_or_path,
    #          peft_path=peft_path,
    #          moe_path=moe,
    #          save_path=final_checkpoint_dir
    #      )
    #      print(f"✅ Successfully merged model saved to {final_checkpoint_dir}")
    # except Exception as e:
    #      print(f"❌ Error during merging: {e}")
    #      import traceback
    #      traceback.print_exc()
    # logger.info(f"Training complete. Final model saved to {final_dir}")

    # Metrics
    csv_file_path = CSV_FILE_PATH
    output_file_path = OUTPUT_FILE_PATH

    process_csv_and_save_results(csv_file_path, output_file_path)

    print(f"Your model evaluation results have been saved in {output_file_path}")


if __name__ == "__main__":
    main() 

def load_features(feature_dir):
    """Load all .pt files in folder and merge features"""
    # Get all .pt files and sort by layer number
    feature_files = sorted(
        [f for f in os.listdir(feature_dir) if f.endswith('.pt')],
        key=lambda x: int(x.split('_')[1].split('.')[0])  # Assume file name is layer_0.pt, layer_1.pt...
    )
    
    # Load features layer by layer
    features = []
    for file in feature_files:
        path = os.path.join(feature_dir, file)
        layer_feature = torch.load(path).numpy()  # Convert to numpy array
        features.append(layer_feature)
    
    # Concatenate features along dimension (num_samples, total_features)
    return np.concatenate(features, axis=1) 