import torch
import numpy as np
from scipy.stats import entropy
import torch.nn.functional as F

def compute_msp_score(logits):
    """Calculate the Maximum Softmax Probability score"""
    probs = torch.softmax(logits, dim=1)
    max_probs, _ = probs.max(dim=1)
    return max_probs.cpu().numpy()  # higher is more likely to be ID

def compute_energy_score(logits, temperature=1.0):
    """Calculate the Energy score"""
    energy = torch.logsumexp(logits / temperature, dim=1)
    return -energy.cpu().numpy()  # lower is more likely to be ID

def compute_entropy_score(logits):
    """Calculate the Entropy score"""
    probs = torch.softmax(logits, dim=1).cpu().numpy()
    return np.array([entropy(p) for p in probs])  # lower is more likely to be ID

def initialize_vim(model, dataset, train_features, train_logits):
    """
    Initialization function for ViM 
    
    Args:
        model: model in use
        dataset: dataset name
        train_features: train features
        train_logits: train logits
        
    Returns:
        dict: parameters for ViM calculation
    """
    import numpy as np
    from sklearn.covariance import EmpiricalCovariance
    from numpy.linalg import norm, pinv
    
    # convert tensor to numpy
    if torch.is_tensor(train_features):
        train_features_np = train_features.cpu().detach().numpy()
    else:
        train_features_np = train_features
    
    if torch.is_tensor(train_logits):
        train_logits_np = train_logits.cpu().detach().numpy()
    else:
        train_logits_np = train_logits
    
    # determine the dimension of principal component analysis based on the feature dimension
    if train_features_np.shape[-1] >= 1500:
        DIM = 1000
    elif train_features_np.shape[-1] >= 768:
        DIM = 512
    else:
        DIM = train_features_np.shape[-1] // 2
    
    # extract FC layer parameters
    w = None
    b = None
    if 'cifar' in dataset:
        if hasattr(model, 'fc'):
            w = model.fc.weight.cpu().detach().numpy()
            b = model.fc.bias.cpu().detach().numpy()
        elif hasattr(model, 'classifier'):
            w = model.classifier.weight.cpu().detach().numpy()
            b = model.classifier.bias.cpu().detach().numpy()
        else:
            raise ValueError("No FC layer found in the model")
    elif 'tiny_imagenet' in dataset:
        w = model.model.fc.weight.cpu().detach().numpy()
        b = model.model.fc.bias.cpu().detach().numpy()
    
    # calculate the center of the null space
    u = -np.matmul(pinv(w), b)
    
    # calculate the empirical covariance
    ec = EmpiricalCovariance(assume_centered=True)
    ec.fit(train_features_np - u)
    eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_)
    
    # calculate the null space
    NS = np.ascontiguousarray((eigen_vectors.T[np.argsort(eig_vals * -1)[DIM:]]).T)
    
    # calculate the alpha (scaling factor)
    vlogit_train = norm(np.matmul(train_features_np - u, NS), axis=-1)
    alpha = train_logits_np.max(axis=-1).mean() / vlogit_train.mean()
    
    return {
        'u': u,
        'NS': NS,
        'alpha': alpha,
        'DIM': DIM
    }

def compute_vim_score(features, logits, vim_params):
    """
    Calculate the ViM score using the initialized ViM parameters
    
    Args:
        features: test features
        logits: test logits
        vim_params: parameters returned from initialize_vim
        
    Returns:
        ViM score array (higher is more likely to be ID)
    """
    import numpy as np
    from scipy.special import logsumexp
    from numpy.linalg import norm
    
    # convert tensor to numpy
    if torch.is_tensor(features):
        features_np = features.cpu().detach().numpy()
    else:
        features_np = features
    
    if torch.is_tensor(logits):
        logits_np = logits.cpu().detach().numpy()
    else:
        logits_np = logits
    
    u = vim_params['u']
    NS = vim_params['NS']
    alpha = vim_params['alpha']
    
    # calculate the ViM score: energy - vlogit
    vlogit = norm(np.matmul(features_np - u, NS), axis=-1) * alpha
    energy = logsumexp(logits_np, axis=-1)
    vim_scores = energy - vlogit
    
    return vim_scores

def normalize_features(x):
    """Feature normalization function - used for KNN and other distance-based methods"""
    if torch.is_tensor(x):
        norm = torch.norm(x, p=2, dim=-1, keepdim=True)  # Use dim=-1 instead of dim=1
        return x / (norm + 1e-10)
    else:
        norm = np.linalg.norm(x, ord=2, axis=-1, keepdims=True)  # Use axis=-1 instead of axis=1
        return x / (norm + 1e-10)

def initialize_knn(train_features):
    """
    KNN initialization function - creates index based on training data
    
    Args:
        train_features: training data feature vectors
        
    Returns:
        dict: initialized parameters needed for KNN calculation
    """
    import numpy as np
    import faiss
    
    # convert tensor to numpy (faiss usage)
    if torch.is_tensor(train_features):
        # apply pooling if necessary
        if len(train_features.shape) > 2:  # if spatial dimensions exist
            train_features_np = F.adaptive_avg_pool2d(train_features, 1).squeeze()
        else:
            train_features_np = train_features
        train_features_np = train_features_np.cpu().detach().numpy()
    else:
        train_features_np = train_features
        
    # feature normalization
    train_features_np = normalize_features(train_features_np)
    
    # batch features in contiguous memory (faiss requirement)
    train_features_np = np.ascontiguousarray(train_features_np.astype(np.float32))
    
    # create faiss index and add train features
    index = faiss.IndexFlatL2(train_features_np.shape[1])
    index.add(train_features_np)
    
    # return KNN parameters
    return {
        'index': index
    }

def compute_knn_score(test_features, knn_params, k=50):
    """
    Calculate the score using the initialized KNN parameters
    
    Args:
        test_features: test features
        knn_params: parameters returned from initialize_knn
        k: number of nearest neighbors
         
    Returns:
        KNN score array (higher is more likely to be ID)
    """
    import numpy as np
    
    index = knn_params['index']

    # convert tensor to numpy (faiss usage)
    if torch.is_tensor(test_features):
        # apply pooling if necessary
        if len(test_features.shape) > 2:  # if spatial dimension exists
            test_features_np = F.adaptive_avg_pool2d(test_features, 1).squeeze()
        else:
            test_features_np = test_features
        test_features_np = test_features_np.cpu().detach().numpy()
    else:
        test_features_np = test_features
    
    # feature normalization (use normalize_features)
    test_features_np = normalize_features(test_features_np)
    # batch features in contiguous memory (faiss requirement)
    test_features_np = np.ascontiguousarray(test_features_np.astype(np.float32))
    # search k-nearest neighbors
    distances, _ = index.search(test_features_np, k)
    
    # KNN score: negative distance to the farthest neighbor
    # higher score means closer to ID (farther neighbor is closer to ID)
    knn_scores = -distances[:, -1]
    return knn_scores

def compute_gradnorm_score(model, dataset, features, temperature=1.0):
    """
    GradNorm score calculation to exactly match the original implementation
    
    Args:
        model: Model to use
        dataset: Dataset name (for determining FC layer)
        features: Feature vectors
        temperature: Temperature scaling parameter (default 1.0)
        
    Returns:
        GradNorm scores (numpy array, higher is more likely to be ID)
    """
    import torch.nn as nn

    # Set model to evaluation mode
    model.eval()

    # Find FC layer
    fc_layer = None

    if 'cifar' in dataset:
        if hasattr(model, 'fc'):
            fc_layer = model.fc
        elif hasattr(model, 'classifier'):
            fc_layer = model.classifier
        else:
            raise ValueError("No FC layer found in the model")
    elif 'tiny_imagenet' in dataset:
        if hasattr(model, 'model.fc'):
            fc_layer = model.model.fc
        else:
            raise ValueError("No FC layer found in the model")

    gradnorm_scores = []

    # Process each sample individually
    logsoftmax = nn.LogSoftmax(dim=1).cuda()
    for feat in features:
        # Reset gradients
        model.zero_grad()

        # Prepare feature
        feat = feat.unsqueeze(0)  # Add batch dimension
        feat.requires_grad_(True)

        # Forward pass through FC layer
        logits = fc_layer(feat.cuda())

        # Apply temperature scaling
        logits = logits / temperature

        # Create unnormalized targets (all ones)
        targets = torch.ones_like(logits)

        # Calculate loss (mean of summed cross entropy losses)
        loss = torch.mean(torch.sum(-targets * logsoftmax(logits), dim=1))

        # Backpropagate to get gradients
        loss.backward()

        # Calculate gradient norm (L1 norm)
        grad_norm = torch.sum(torch.abs(fc_layer.weight.grad.data)).cpu().numpy()
        gradnorm_scores.append(grad_norm)

        # Clean up
        fc_layer.zero_grad()
        feat.grad = None

    return np.array(gradnorm_scores)

# ODIN method implementation
def compute_odin_score(model, inputs, dataset=None, temperature=1000):
    """
    Calculate the ODIN OOD detection score (efficient batch processing)
    
    Args:
        model: neural network model
        inputs: input data tensor
        dataset: dataset name (used to select epsilon value)
        temperature: temperature parameter
        
    Returns:
        ODIN score (higher is more likely to be ID)
    """
    model.eval()
    criterion = torch.nn.CrossEntropyLoss().cuda()
    scores = []
    
    # set epsilon value
    epsilon = 0.002
    
    # batch processing (for performance improvement)
    for i in range(0, inputs.size(0), 10):
        # prepare batch
        end_idx = min(i + 10, inputs.size(0))
        batch = inputs[i:end_idx].clone().detach().cuda()
        batch.requires_grad = True
        
        # first forward pass
        outputs = model(batch)
        if hasattr(outputs, 'logits'):
            outputs = outputs.logits
        
        # calculate softmax for original outputs (before temperature scaling)
        nnOutputs = outputs.data.cpu().numpy()
        
        # calculate softmax for each sample
        softmax_outputs = []
        for j in range(nnOutputs.shape[0]):
            current_output = nnOutputs[j]
            current_output = current_output - np.max(current_output)
            current_output = np.exp(current_output)/np.sum(np.exp(current_output))
            softmax_outputs.append(current_output)
        
        # get predicted class
        max_indices = np.argmax(softmax_outputs, axis=1)
        labels = torch.LongTensor(max_indices).cuda()
        
        # apply temperature scaling
        outputs = outputs / temperature
        
        # calculate gradient
        loss = criterion(outputs, labels)
        loss.backward()
        
        # normalize gradient (same as official implementation)
        gradient = torch.ge(batch.grad.data, 0)
        gradient = (gradient.float() - 0.5) * 2
        
        # normalize by channel
        for j in range(gradient.size(0)):
            gradient[j, 0] = gradient[j, 0] / (63.0/255.0)
            gradient[j, 1] = gradient[j, 1] / (62.1/255.0)
            gradient[j, 2] = gradient[j, 2] / (66.7/255.0)
        
        # apply adversarial perturbation
        tempInputs = torch.add(batch.data, -epsilon, gradient)
        
        # second forward pass
        with torch.no_grad():
            perturbed_outputs = model(tempInputs)
            if hasattr(perturbed_outputs, 'logits'):
                perturbed_outputs = perturbed_outputs.logits
            perturbed_outputs = perturbed_outputs / temperature
            
            # calculate softmax probability for each sample
            perturbed_np = perturbed_outputs.data.cpu().numpy()
            for j in range(perturbed_np.shape[0]):
                output = perturbed_np[j]
                output = output - np.max(output)
                output = np.exp(output)/np.sum(np.exp(output))
                scores.append(np.max(output))
        
        # initialize gradient
        batch.grad = None
    
    return np.array(scores)

# single layer Mahalanobis initialization function
def initialize_mahalanobis(model, dataset, train_features, train_labels, num_classes=None):
    """
    Initialize parameters for single layer Mahalanobis distance calculation
    
    Args:
        model: neural network model
        dataset: dataset name
        train_features: train features tensor
        train_labels: train labels tensor
        num_classes: number of classes (None if inferred from dataset)
        
    Returns:
        dict: Mahalanobis parameters
    """
    # determine number of classes
    if num_classes is None:
        if 'cifar100' in dataset:
            num_classes = 100
        elif 'tiny_imagenet' in dataset:
            num_classes = 200
        else:
            num_classes = int(torch.max(train_labels).item() + 1)
    
    # convert features and labels to numpy
    if torch.is_tensor(train_features):
        train_features = train_features.cpu().numpy()
    if torch.is_tensor(train_labels):
        train_labels = train_labels.cpu().numpy()
    
    # calculate class-wise mean
    class_means = []
    for c in range(num_classes):
        idx = train_labels == c
        if np.any(idx):
            class_means.append(np.mean(train_features[idx], axis=0))
        else:
            # if no samples for a class
            class_means.append(np.zeros(train_features.shape[1]))
    
    # calculate covariance matrix
    centered_features = []
    for c in range(num_classes):
        idx = train_labels == c
        if np.any(idx):
            class_features = train_features[idx]
            centered_features.append(class_features - class_means[c])
    
    if centered_features:
        centered_features = np.vstack(centered_features)
        cov_matrix = np.cov(centered_features, rowvar=False)
        
        # normalize for numerical stability
        cov_matrix += np.eye(cov_matrix.shape[0]) * 1e-6
        
        # calculate inverse covariance matrix
        precision_matrix = np.linalg.inv(cov_matrix)
    else:
        # if no data
        precision_matrix = None
    
    return {
        'class_means': class_means,
        'precision_matrix': precision_matrix,
        'num_classes': num_classes
    }

# layer weight learning function
def learn_mahalanobis_layer_weights(id_features_by_layer, ood_features_by_layer, mahalanobis_params, C=1.0, max_iter=1000):
    """
    Use logistic regression to learn optimal weights for each layer Mahalanobis score
    
    Args:
        id_features_by_layer: ID data layer-wise features {layer name: features} dictionary
        ood_features_by_layer: OOD data layer-wise features {layer name: features} dictionary
        mahalanobis_params: parameters returned from initialize_mahalanobis_ensemble
        C: logistic regression regularization strength (larger means weaker regularization)
        max_iter: maximum number of logistic regression iterations
        
    Returns:
        dict: learned layer-wise weights
    """
    from sklearn.linear_model import LogisticRegressionCV
    import numpy as np
    
    class_means = mahalanobis_params['class_means']
    precision_matrices = mahalanobis_params['precision_matrices']
    num_classes = mahalanobis_params['num_classes']
    layer_names = mahalanobis_params['layer_names']
    # check common layers (layers present in both ID/OOD and have precision matrix)
    common_layers = []
    
    for layer in layer_names:
        if (layer in id_features_by_layer and layer in ood_features_by_layer and 
            precision_matrices[layer] is not None):
            common_layers.append(layer)
    
    if not common_layers:
        print("No common layers found. Using uniform weights.")
        layer_weights = {layer: 1.0/len(layer_names) for layer in layer_names}
        return layer_weights
    
    print(f"Layers used for learning: {common_layers}")
    
    # calculate Mahalanobis score for each layer
    X_train = []
    y_train = []
    
    # calculate Mahalanobis score for ID data
    num_id_samples = len(next(iter(id_features_by_layer.values())))
    for layer in common_layers:
        scores = compute_mahalanobis_score(
            id_features_by_layer[layer], 
            class_means[layer], 
            precision_matrices[layer],
            num_classes
        )
        X_train.append(scores)
    
    # ID label: 1
    y_train = np.ones(num_id_samples)
    
    # calculate Mahalanobis score for OOD data
    num_ood_samples = len(next(iter(ood_features_by_layer.values())))
    for i, layer in enumerate(common_layers):
        scores = compute_mahalanobis_score(
            ood_features_by_layer[layer], 
            class_means[layer], 
            precision_matrices[layer],
            num_classes
        )
        X_train[i] = np.concatenate([X_train[i], scores])
    
    # OOD label: 0
    y_train = np.concatenate([y_train, np.zeros(num_ood_samples)])
    
    # transform training data shape [samples, features]
    X_train = np.column_stack(X_train)
    
    print(f"Training data shape: {X_train.shape}, label distribution: ID={np.sum(y_train)}, OOD={len(y_train)-np.sum(y_train)}")
    
    # train logistic regression model
    logreg = LogisticRegressionCV(Cs=[C], max_iter=max_iter, n_jobs=-1, cv=5)
    logreg.fit(X_train, y_train)
    
    # extract learned coefficients and convert to weights
    coefficients = logreg.coef_[0]
    
    # handle negative weights - use absolute value
    coefficients = np.abs(coefficients)
    coefficients = np.maximum(coefficients, 1e-5)
    
    # normalize weights
    total_weight = np.sum(coefficients)
    if total_weight > 0:
        coefficients = coefficients / total_weight
    else:
        coefficients = np.ones(len(common_layers)) / len(common_layers)
    
    # map layer names and weights
    layer_weights = dict(zip(common_layers, coefficients))
    
    print("learned layer weights:")
    for layer, weight in layer_weights.items():
        print(f"  {layer}: {weight:.4f}")
    
    return layer_weights

# initialize Mahalanobis for multi-layer
def initialize_mahalanobis_ensemble(model, dataset, train_features_by_layer, train_labels, num_classes=None, learn_weights=True):
    """
    Initialize parameters for multi-layer Mahalanobis ensemble calculation (mainly for id features)
    
    Args:
        model: neural network model
        dataset: dataset name
        train_features_by_layer: {layer name: features tensor} dictionary
        train_labels: training labels tensor
        num_classes: number of classes (None if inferred from dataset)
        learn_weights: whether to learn layer weights through logistic regression
        
    Returns:
        dict: layer-wise Mahalanobis parameters
    """
    # determine number of classes
    if num_classes is None:
        if 'cifar100' in dataset:
            num_classes = 100
        elif 'tiny_imagenet' in dataset:
            num_classes = 200
        else:
            num_classes = int(torch.max(train_labels).item() + 1)
    
    # convert labels to numpy
    if torch.is_tensor(train_labels):
        train_labels = train_labels.cpu().numpy()
    
    # initialize Mahalanobis parameters for each layer
    class_means_by_layer = {}
    precision_matrices_by_layer = {}
    
    for layer_name, train_features in train_features_by_layer.items():
        # convert layer features to numpy
        if torch.is_tensor(train_features):
            train_features = train_features.cpu().numpy()
        
        # calculate class-wise mean
        class_means = []
        for c in range(num_classes):
            idx = train_labels == c
            if np.any(idx):
                class_means.append(np.mean(train_features[idx], axis=0))
            else:
                class_means.append(np.zeros(train_features.shape[1]))
        
        # calculate covariance matrix
        centered_features = []
        for c in range(num_classes):
            idx = train_labels == c
            if np.any(idx):
                class_features = train_features[idx]
                centered_features.append(class_features - class_means[c])
        
        if centered_features:
            centered_features = np.vstack(centered_features)
            cov_matrix = np.cov(centered_features, rowvar=False)
            
            # normalize for numerical stability
            cov_matrix += np.eye(cov_matrix.shape[0]) * 1e-6
            
            # calculate inverse covariance matrix
            precision_matrix = np.linalg.inv(cov_matrix)
        else:
            # if no data
            precision_matrix = None
        
        class_means_by_layer[layer_name] = class_means
        precision_matrices_by_layer[layer_name] = precision_matrix
    
    result = {
        'class_means': class_means_by_layer,
        'precision_matrices': precision_matrices_by_layer,
        'num_classes': num_classes,
        'layer_names': list(train_features_by_layer.keys())
    }
    
    return result

# single layer Mahalanobis score calculation function
def compute_mahalanobis_score(features, class_means, precision_matrix, num_classes):
    """
    Calculate Mahalanobis distance-based OOD score
    
    Args:
        features: feature vector
        class_means: list of class-wise mean vectors
        precision_matrix: inverse covariance matrix
        num_classes: number of classes
        
    Returns:
        Mahalanobis score (higher is closer to ID)
    """
    # convert tensor to numpy
    if torch.is_tensor(features):
        features = features.cpu().numpy()
    
    # calculate Mahalanobis distance for each class
    scores = np.zeros((features.shape[0], num_classes))
    
    for i in range(num_classes):
        # calculate distance to class mean  
        centered_features = features - class_means[i]
        scores[:, i] = np.sum((centered_features @ precision_matrix) * centered_features, axis=1)
    
    # return minimum distance among all classes (convert to negative for higher ID)
    return -np.min(scores, axis=1)

# Mahalanobis distance-based adversarial perturbation generation function
def generate_perturbed_features(model, inputs, mahalanobis_params, layer_names, magnitude=0.001, dataset=None):
    """
    Generate synthetic OOD features through Mahalanobis distance-based adversarial perturbation
    
    Args:
        model: original model
        inputs: input data tensor (batch)
        mahalanobis_params: Mahalanobis parameters
        layer_names: list of layer names to extract features
        magnitude: perturbation magnitude
        dataset: dataset name (used for normalization)
        
    Returns:
        dict: layer-wise perturbed features {layer name: features} dictionary
    """
    import torch
    import torch.nn as nn
    import numpy as np
    import sys
    from metrics.online_rds.feature_extraction import extract_features_by_layer
    
    # set model to evaluation mode
    model.eval()
    
    # get Mahalanobis parameters
    class_means = mahalanobis_params['class_means']
    precision_matrices = mahalanobis_params['precision_matrices']
    num_classes = mahalanobis_params['num_classes']
    
    # initialize result
    perturbed_features_by_layer = {}
    # process batch
    for batch_idx in range(0, inputs.size(0), 100):  # use batch processing for memory
        end_idx = min(batch_idx + 100, inputs.size(0))
        batch = inputs[batch_idx:end_idx].clone().detach().cuda()
        batch.requires_grad = True
        # calculate Mahalanobis distance and extract gradient for each layer
        for layer_name in layer_names:
            # get output values through forward pass
            # extract features (keep source change minimum using extract_features_by_layer)
            # here, use gradient of just performed forward pass to keep gradient
            out_features, _ = extract_features_by_layer(model, batch, layer_name, batch_size=batch.size(0), dataset=dataset)
            out_features = out_features.cuda()
            
            # find nearest class
            gaussian_score = 0
            for i in range(num_classes):
                batch_sample_mean = torch.from_numpy(class_means[layer_name][i]).cuda().float()
                zero_f = out_features.data - batch_sample_mean
                precision = torch.from_numpy(precision_matrices[layer_name]).cuda().float()
                term_gau = -0.5*torch.mm(torch.mm(zero_f, precision), zero_f.t()).diag()
                if i == 0:
                    gaussian_score = term_gau.view(-1,1)
                else:
                    gaussian_score = torch.cat((gaussian_score, term_gau.view(-1,1)), 1)
            
            # select nearest class
            sample_pred = gaussian_score.max(1)[1]
            
            # Convert class means to numpy array first
            class_means_array = np.array(class_means[layer_name])
            # Then index with the predictions
            indices = sample_pred.cpu().numpy().astype(int)
            batch_sample_mean = torch.from_numpy(class_means_array[indices]).cuda().float()
            
            # calculate Mahalanobis distance to nearest class
            zero_f = out_features - batch_sample_mean  # Remove .data to keep gradients
            precision = torch.from_numpy(precision_matrices[layer_name]).cuda().float()
            pure_gau = -0.5*torch.mm(torch.mm(zero_f, precision), zero_f.t()).diag()
            
            # minimize loss (maximize distance)
            loss = torch.mean(-pure_gau)
            
            # get gradient by performing new forward pass on input
            batch.retain_grad()  # keep gradient
            
            # forward pass and backward pass
            new_output = model(batch)
            if hasattr(new_output, 'logits'):
                new_output = new_output.logits
            
            # use original target and handle it in a way consistent with CrossEntropyLoss
            criterion = nn.CrossEntropyLoss()
            target_class = sample_pred  # use nearest class
            loss = criterion(new_output, target_class)
            loss.backward()
            
            # after backward pass, use gradient
            # Normalizing the gradient to binary in {0, 1}
            gradient = torch.ge(batch.grad.data, 0)
            gradient = (gradient.float() - 0.5) * 2
            
            # normalize by channel for dataset
            if dataset == 'tiny_imagenet' or 'imagenet' in dataset:
                # ImageNet normalization
                gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (0.2023))
                gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (0.1994))
                gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (0.2010))
            elif 'cifar' in dataset:
                # CIFAR normalization
                if 'densenet' in str(type(model)).lower():
                    gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (63.0/255.0))
                    gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (62.1/255.0))
                    gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (66.7/255.0))
                else:
                    gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (0.2023))
                    gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (0.1994))
                    gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (0.2010))
            else:
                # default normalization (divide by standard deviation)
                gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (0.2023))
                gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (0.1994))
                gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (0.2010))
            
            # apply adversarial perturbation
            perturbed_batch = torch.add(batch.data, -magnitude, gradient)
            
            # extract features from perturbed input (use extract_features_by_layer)
            with torch.no_grad():
                perturbed_features, _ = extract_features_by_layer(model, perturbed_batch, layer_name, batch_size=perturbed_batch.size(0), dataset=dataset)
            
            # save result
            if layer_name in perturbed_features_by_layer:
                perturbed_features_by_layer[layer_name] = torch.cat([
                    perturbed_features_by_layer[layer_name],
                    perturbed_features
                ])
            else:
                perturbed_features_by_layer[layer_name] = perturbed_features
            
            # initialize gradient
            batch.grad = None
    
    # convert to numpy
    for layer_name in perturbed_features_by_layer:
        perturbed_features_by_layer[layer_name] = perturbed_features_by_layer[layer_name].numpy()
    
    print("completed adversarial perturbation for OOD features")
    return perturbed_features_by_layer

# multi-layer Mahalanobis ensemble score calculation function
def compute_mahalanobis_ensemble_score(features_by_layer, mahalanobis_params, layer_weights=None):
    """
    calculate Mahalanobis distance for multi-layer ensemble
    
    Args:
        features_by_layer: layer-wise features {layer name: features} dictionary
        mahalanobis_params: parameters returned from initialize_mahalanobis_ensemble
        layer_weights: weights for each layer (None if uniform)
        
    Returns:
        Mahalanobis ensemble score (higher is closer to ID)
    """
    class_means = mahalanobis_params['class_means']
    precision_matrices = mahalanobis_params['precision_matrices']
    num_classes = mahalanobis_params['num_classes']
    layer_names = mahalanobis_params['layer_names']
    
    # find common layers (layers present in both parameters and features)
    common_layers = []
    for layer in layer_names:
        if layer in features_by_layer and precision_matrices[layer] is not None:
            common_layers.append(layer)
    if not common_layers:
        # no available layers
        return np.zeros(next(iter(features_by_layer.values())).shape[0])
    
    # if weights are not specified, set uniformly
    if layer_weights is None:
        layer_weights = {layer: 1.0 for layer in common_layers}
    else:
        # use only specified weights for common layers
        layer_weights = {layer: layer_weights.get(layer, 1.0) for layer in common_layers}
    
    # normalize weights
    total_weight = sum(layer_weights.values())
    norm_weights = {layer: layer_weights[layer] / total_weight for layer in common_layers}
    
    # initialize ensemble scores
    first_layer = common_layers[0]
    if torch.is_tensor(features_by_layer[first_layer]):
        num_samples = features_by_layer[first_layer].size(0)
    else:
        num_samples = features_by_layer[first_layer].shape[0]
    
    ensemble_scores = np.zeros(num_samples)
    
    # calculate scores for each layer and combine
    for layer in common_layers:
        layer_scores = compute_mahalanobis_score(
            features_by_layer[layer], 
            class_means[layer], 
            precision_matrices[layer],
            num_classes
        )
        
        # sum weighted scores
        ensemble_scores += norm_weights[layer] * layer_scores
    
    return ensemble_scores

