"""
Function to evaluate classification performance of adapted model and OOD detection performance of internal OnlineRDS
"""

import os
import math
import numpy as np
import torch
from sklearn import metrics
import matplotlib.pyplot as plt

from .core import OnlineRDS 
from .feature_extraction import extract_features_by_layer, extract_features 
from .thresholds import find_threshold_otsu 

from metrics.baseline_ood import initialize_vim, initialize_knn
from metrics.baseline_ood import compute_vim_score, compute_energy_score, compute_knn_score, compute_gradnorm_score
from metrics.baseline_ood import initialize_mahalanobis, initialize_mahalanobis_ensemble, learn_mahalanobis_layer_weights
from metrics.baseline_ood import compute_mahalanobis_score, compute_mahalanobis_ensemble_score, compute_odin_score
from data_tinyimagenet import load_tiny_imagenet_train
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

transform_none = transforms.Compose([
    transforms.ToTensor(),
])

def compute_energy_score(logits, temperature=1.0):
    """
    Calculate Energy-based OOD score
    
    Args:
        logits (torch.Tensor): model output logits
        temperature (float): temperature parameter
        
    Returns:
        np.ndarray: energy score (higher means ID, lower means OOD)
    """
    # calculate negative logsumexp scaled by temperature
    energy = torch.logsumexp(logits / temperature, dim=1)
    # invert sign - higher value means ID (matching AUROC calculation direction)
    return energy.detach().cpu().numpy()

def get_oscr(score_ind, score_ood, pred, y_ind):
    """
    Calculate OSCR (Open Set Classification Rate)
    
    Args:
        score_ind (np.ndarray): Scores of ID samples
        score_ood (np.ndarray): Scores of OOD samples
        pred (np.ndarray): Predicted classes for ID samples
        y_ind (np.ndarray): True class labels of ID samples
        
    Returns:
        float: OSCR score
    """
    score = np.concatenate((score_ind, score_ood), axis=0)
    def get_fpr(t):
        return (score_ood >= t).sum() / len(score_ood)
    def get_ccr(t):
        return ((score_ind > t) & (pred == y_ind)).sum() / len(score_ind)
    fpr = [0.0]
    ccr = [0.0]
    for s in -np.sort(-score):
        fpr.append(get_fpr(s))
        ccr.append(get_ccr(s))
    fpr.append(1.0)
    ccr.append(1.0)
    roc = sorted(zip(fpr, ccr), reverse=True)
    oscr = 0.0
    for i in range(len(score)):
        if i + 1 < len(roc):
            oscr += (roc[i][0] - roc[i + 1][0]) * (roc[i][1] + roc[i + 1][1]) / 2.0
    return oscr

def evaluate_ood_scores(args, model, x_ind_all, x_ood_all, y_ind_all=None, dataset='cifar100', 
                        rds_confidence_threshold=0.9, rds_iqr_factor=1.5, rds_ema_alpha=0.9,
                        auto_correction=False, init_methods=['msp'], target_methods = ['RDS', 'Energy', 'MSP', 'Max_logit','Entropy', 'GradNorm', 'ViM', 'KNN', 'Mahalanobis_single', 'Mahalanobis_ensemble', 'ODIN'], 
                        layer_list=None, temperature=5.0, flip_weight=1.2, batch_size=100, num_batches=10, save_dir='./ood_scores_results'):
    """
    Function to calculate and compare multiple OOD detection scores (RDS, Energy, MSP, Entropy, GradNorm, ViM, KNN)
    
    Args:
        model: model to evaluate
        x_ind_all (torch.Tensor): all ID data
        x_ood_all (torch.Tensor): all OOD data
        y_ind_all (torch.Tensor, optional): labels for ID data (for classification performance evaluation)
        dataset (str): dataset name ('cifar100', 'tiny_imagenet')
        confidence_threshold (float): confidence threshold for RDS
        iqr_factor (float): IQR outlier removal factor for RDS
        ema_alpha (float): EMA alpha value for RDS
        space (str): space for distance calculation ('feature' or 'logit')
        auto_correction (bool): whether to use automatic correction
        init_method (str): initialization method ('energy' or 'max_prob')
        target_methods (list): list of methods to evaluate
        layer_list (list): list of layers to use for feature extraction
        batch_size (int): batch size
        num_batches (int): number of batches to evaluate
        save_dir (str): directory to save results
        
    Returns:
        dict: evaluation results
    """
    # Create assets directory for storing cache files
    assets_dir = f'{args.data_dir}/assets'
    if not os.path.exists(assets_dir):
        os.makedirs(assets_dir)
        print(f"Created directory: {assets_dir} for caching OOD detector parameters")
    # Set default layer list
    if layer_list is None:
        if 'cifar' in dataset:
            layer_list = ['block3']
        elif 'tiny_imagenet' in dataset:
            layer_list = ['layer4']
        else:
            raise ValueError(f"Unknown dataset: {dataset}")
        
    if dataset == 'cifar100':
        if args.arch == 'Hendrycks2020AugMix_WRN':
            penultimate_layer = 'block3'
        elif args.arch == 'vit-tiny':
            penultimate_layer = 'vit.encoder.layer.11'
        elif args.arch == 'swin-tiny':
            penultimate_layer = 'swin.encoder.layers.3.blocks.1'
        else:
            raise ValueError(f"Unknown architecture: {args.arch}")
        num_classes = 100

    elif dataset == 'tiny_imagenet':
        penultimate_layer = 'layer4'
        num_classes = 200
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    
    if len(layer_list) > 1:
        mahalanobis_layer_list = layer_list[:-1]
    else:
        mahalanobis_layer_list = layer_list
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Setup for batch evaluation
    n_ind = x_ind_all.size(0)
    n_ood = x_ood_all.size(0)
    
    # Calculate number of samples per batch
    batch_ind = min(math.ceil(n_ind / num_batches), batch_size)
    batch_ood = min(math.ceil(n_ood / num_batches), batch_size)
    
    # Initialize variables for storing OOD scores
    results = {
        method: {
            'batch_aurocs': [],
            'batch_fprs': [],
            'all_scores': []
        }
        for method in target_methods
    }
    y_ind_train = None
    # Common labels
    all_ood_labels = []
    
    # Track classification performance
    accuracies = []
    all_predictions = []
    all_true_labels = []
    
    # Arrays for tracking OSCR
    oscr_values = []
    
    # Set model to evaluation mode
    model.eval()
    # =====================================
    # 1. Initialize each method (performed only once for the entire dataset)
    # =====================================
    print("Initializing OOD detection methods...")
    # Initialize RDS
    if 'RDS' in target_methods:
        online_rds_layer_dict = {}
        for layer in layer_list:
            online_rds = OnlineRDS(
                layer_name=layer,
                confidence_threshold=rds_confidence_threshold,
                iqr_factor=rds_iqr_factor,
                ema_alpha=rds_ema_alpha,
                space='feature',
                auto_correction=auto_correction,
                init_methods=init_methods,
                temperature=temperature,
                flip_weight=flip_weight
            )
            online_rds_layer_dict[layer] = online_rds
    
    # use train features and logits
    if 'KNN' in target_methods or 'Mahalanobis_single' in target_methods or 'Mahalanobis_ensemble' in target_methods or 'ViM' in target_methods:

        if dataset == 'cifar100':
            trainset = torchvision.datasets.CIFAR100(root=args.data_dir, train=True, download=True, transform=transform_none)
            train_loader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset), shuffle=False, pin_memory=True, num_workers=2)
            x_ind_train, y_ind_train = next(iter(train_loader))
            
        elif dataset == 'tiny_imagenet':
            x_ind_train, y_ind_train = load_tiny_imagenet_train(data_dir=args.data_dir)
        
        if args.arch in ['vit-tiny','swin-tiny']:
            # Process interpolation in smaller batches to avoid OOM
            batch_size_interp = 50  # Adjust this based on your GPU memory
            num_batches_interp = math.ceil(x_ind_train.size(0) / batch_size_interp)
            x_ind_train_resized = []
            
            for i in tqdm(range(num_batches_interp), desc="Resizing training images"):
                start_idx = i * batch_size_interp
                end_idx = min((i + 1) * batch_size_interp, x_ind_train.size(0))
                batch = x_ind_train[start_idx:end_idx]
                resized_batch = torch.nn.functional.interpolate(batch, size=(224, 224), mode='bilinear')
                x_ind_train_resized.append(resized_batch)
                if i % 10 == 0:
                    torch.cuda.empty_cache()
            x_ind_train = torch.cat(x_ind_train_resized, dim=0)

        exists_precomputed_features = False
        # check if precomputed features and logits exist
        if dataset in ['cifar100', 'tiny_imagenet']:
            if os.path.exists(f'{assets_dir}/train_features_{dataset}_{args.arch}.pkl') and os.path.exists(f'{assets_dir}/train_logits_{dataset}_{args.arch}.pkl') and os.path.exists(f'{assets_dir}/train_labels_{dataset}.pkl'):
                exists_precomputed_features = True
        else:
            raise ValueError(f"Unknown dataset: {dataset}")
        if exists_precomputed_features:
            print(f"Loading precomputed features and logits...")
            train_features_by_layer = torch.load(f'{assets_dir}/train_features_{args.dataset}_{args.arch}.pkl')
            train_logits = torch.load(f'{assets_dir}/train_logits_{args.dataset}_{args.arch}.pkl')
            y_ind_train = torch.load(f'{assets_dir}/train_labels_{args.dataset}.pkl')
        else:
            print(f"Extracting features from training samples for initialization...")
            train_features_by_layer = {}
            for layer in tqdm(layer_list, desc="Extracting features"):
                layer_train_features, train_logits = extract_features_by_layer(
                            model, x_ind_train, layer, batch_size=500, dataset=dataset
                        )        
                train_features_by_layer[layer] = layer_train_features
            
            torch.save(train_features_by_layer, f'{assets_dir}/train_features_{dataset}_{args.arch}.pkl')
            torch.save(train_logits, f'{assets_dir}/train_logits_{dataset}_{args.arch}.pkl')
            torch.save(y_ind_train, f'{assets_dir}/train_labels_{dataset}.pkl')
        
    # Check if precomputed ViM parameters exist
    if 'ViM' in target_methods:
        vim_cache_path = f'{assets_dir}/vim_params_{dataset}_{args.arch}.pkl'
        if os.path.exists(vim_cache_path):
            print(f"Loading precomputed ViM parameters from {vim_cache_path}...")
            vim_params = torch.load(vim_cache_path)
        else:
            print("Initializing ViM...")
            vim_params = initialize_vim(model, dataset, train_features_by_layer[penultimate_layer], train_logits)
            # Save the initialized parameters
            torch.save(vim_params, vim_cache_path)
            print(f"ViM parameters saved to {vim_cache_path}")
    
    # For KNN we can't save the Faiss index directly with pickle, so we always initialize it
    # but we can skip feature computation if we already have the features
    if 'KNN' in target_methods:
        print("Initializing KNN...")
        knn_params = initialize_knn(train_features_by_layer[penultimate_layer])
    
    # Initialize single-layer Mahalanobis
    # Check if precomputed single-layer Mahalanobis parameters exist
    if 'Mahalanobis_single' in target_methods:
        mahalanobis_cache_path = f'{assets_dir}/mahalanobis_params_{dataset}_{args.arch}.pkl'
        if os.path.exists(mahalanobis_cache_path):
            print(f"Loading precomputed single-layer Mahalanobis parameters from {mahalanobis_cache_path}...")
            mahalanobis_params = torch.load(mahalanobis_cache_path)
        else:
            print("Initializing single-layer Mahalanobis method...")
            mahalanobis_params = initialize_mahalanobis(
                model, dataset, train_features_by_layer[penultimate_layer], y_ind_train, num_classes
            )
            # Save the initialized parameters
            torch.save(mahalanobis_params, mahalanobis_cache_path)
            print(f"Single-layer Mahalanobis parameters saved to {mahalanobis_cache_path}")
        # Initialize multi-layer Mahalanobis ensemble (without weight learning)
    
    if 'Mahalanobis_ensemble' in target_methods:
        mahalanobis_ensemble_cache_path = f'{assets_dir}/mahalanobis_ensemble_params_{dataset}_{args.arch}.pkl'
        if os.path.exists(mahalanobis_ensemble_cache_path):
            print(f"Loading precomputed multi-layer Mahalanobis ensemble parameters from {mahalanobis_ensemble_cache_path}...")
            mahalanobis_ensemble_params = torch.load(mahalanobis_ensemble_cache_path)
        else:
            print("Initializing multi-layer Mahalanobis ensemble method...")
            mahalanobis_ensemble_params = initialize_mahalanobis_ensemble(
                model, dataset, train_features_by_layer, y_ind_train, num_classes,
                learn_weights=False  # weights will be learned separately
            )
            # Save the initialized parameters
            torch.save(mahalanobis_ensemble_params, mahalanobis_ensemble_cache_path)
            print(f"Multi-layer Mahalanobis ensemble parameters saved to {mahalanobis_ensemble_cache_path}")
        # Generate perturbation-based synthetic OOD features and learn weights
        print("Generating perturbed features for Mahalanobis weight learning...")
        
        # Sample training data (to save memory)
        if dataset == 'cifar100':
            val_sample_size = 1000
        elif dataset == 'tiny_imagenet':
            val_sample_size = 2000
        else:
            val_sample_size = 1000

        
        # Sample training data
        indices = np.random.choice(len(y_ind_train), val_sample_size, replace=False)
        val_features_by_layer = {}
        mahal_layer_list = layer_list[:-1]
        for layer in mahal_layer_list:
            val_features_by_layer[layer] = train_features_by_layer[layer][indices]
        # Generate perturbed features using FGSM for all datasets
        # Use the proper FGSM perturbation instead of simple Gaussian noise
        # First, get a subset of inputs for perturbation
        # Check if precomputed layer weights exist
        layer_weights_cache_path = f'{assets_dir}/mahalanobis_layer_weights_{dataset}_{args.arch}.pkl'
        if os.path.exists(layer_weights_cache_path):
            print(f"Loading precomputed Mahalanobis layer weights from {layer_weights_cache_path}...")
            layer_weights = torch.load(layer_weights_cache_path)
        else:
            # Now use the real FGSM implementation from baseline_ood
            from metrics.baseline_ood import generate_perturbed_features
            val_inputs = x_ind_train[indices].cuda()
            # Call the FGSM perturbation function with correct parameters
            print("Generating perturbed features using FGSM...")
            perturbed_features_by_layer = generate_perturbed_features(
                model=model, 
                inputs=val_inputs, 
                mahalanobis_params=mahalanobis_ensemble_params, 
                layer_names=mahal_layer_list,
                magnitude=0.001,  # Standard perturbation magnitude
                dataset=dataset  # Pass dataset for proper normalization
            )
            # Learn layer weights
            print("Learning Mahalanobis layer weights...")
            layer_weights = learn_mahalanobis_layer_weights(
                val_features_by_layer,  # ID features (all layers)
                perturbed_features_by_layer,  # Synthetic OOD features
                mahalanobis_ensemble_params,
                C=1.0,  # Regularization strength
            )
            # Save the learned weights
            torch.save(layer_weights, layer_weights_cache_path)
            print(f"Mahalanobis layer weights saved to {layer_weights_cache_path}")
        
        # Add learned weights to parameters
        print(f"mahalanobis layer weights: {layer_weights}")
        mahalanobis_ensemble_params['layer_weights'] = layer_weights

    print("Methods initialization completed.")
    # =====================================
    # 2. Batch evaluation
    # =====================================
    for batch_idx in range(num_batches):
        print(f"Processing batch {batch_idx+1}/{num_batches}...")
        
        # Select ID/OOD data batch
        start_ind = batch_idx * batch_ind
        end_ind = min((batch_idx + 1) * batch_ind, n_ind)
        
        start_ood = batch_idx * batch_ood
        end_ood = min((batch_idx + 1) * batch_ood, n_ood)
        
        x_ind_batch = x_ind_all[start_ind:end_ind]
        x_ood_batch = x_ood_all[start_ood:end_ood]
        y_ind_batch = y_ind_all[start_ind:end_ind]

        # Create mixed batch
        x_combined = torch.cat([x_ind_batch, x_ood_batch], dim=0)
        true_ood_labels = np.concatenate([np.ones(end_ind - start_ind), np.zeros(end_ood - start_ood)])
        
        features_by_layer = {}
        
        for layer in layer_list:
            features_layer, logits = extract_features_by_layer(
                model, x_combined, layer, batch_size=x_combined.size(0), dataset=dataset
            )
            features_by_layer[layer] = features_layer
        features = features_by_layer[penultimate_layer]
        
        # 1. Calculate Energy score ##################################################################################################################
        #if 'Energy' in target_methods:
        energy_scores = compute_energy_score(logits)
        
        # 2. Calculate MSP score ##################################################################################################################
        if 'MSP' in target_methods:
            softmax_probs = torch.softmax(logits, dim=1)
            msp_scores = torch.max(softmax_probs, dim=1)[0].cpu().numpy()
        
        # 3. Calculate Entropy score
        if 'Entropy' in target_methods:
            log_softmax = torch.log_softmax(logits, dim=1)
            entropy = -torch.sum(softmax_probs * log_softmax, dim=1).cpu().numpy()
            entropy_scores = -entropy  # Invert sign (higher value means ID)
        
        if 'Max_logit' in target_methods:
            max_logit_scores = torch.max(logits, dim=1)[0].cpu().numpy()
        
        ## 4. Calculate GradNorm score - using initialized parameters ##################################################################################################################
        if 'GradNorm' in target_methods:
            gradnorm_scores = compute_gradnorm_score(model, dataset, features)#, batch_size=batch_size)
            if torch.is_tensor(gradnorm_scores):
                gradnorm_scores = gradnorm_scores.numpy()
            
        ## 5. Calculate ViM score - using initialized parameters ##################################################################################################################
        if 'ViM' in target_methods:
            vim_scores = compute_vim_score(features, logits, vim_params)
        
        ## 6. Calculate KNN score - using initialized parameters ##################################################################################################################
        if 'KNN' in target_methods:
            k = 50 if 'cifar' in dataset else 200  # Use a larger k value for tiny_imagenet
            knn_scores = compute_knn_score(features, knn_params, k)
            
        # 7. Calculate Mahalanobis (single layer) score
        if 'Mahalanobis_single' in target_methods:
            mahalanobis_single_scores = compute_mahalanobis_score(
                features, 
                mahalanobis_params['class_means'],
                mahalanobis_params['precision_matrix'],
                mahalanobis_params['num_classes']
            )
        if 'Mahalanobis_ensemble' in target_methods:
            # 8. Calculate Mahalanobis Ensemble (multi-layer) score - using learned weights
            # Check if learned weights exist

            features_by_layer_for_mahalanobis = {}
            for layer in mahalanobis_layer_list:
                if layer in features_by_layer:
                    features_by_layer_for_mahalanobis[layer] = features_by_layer[layer]
            if 'layer_weights' in mahalanobis_ensemble_params:
                print(f"Using learned Mahalanobis layer weights")
                
                mahalanobis_ensemble_scores = compute_mahalanobis_ensemble_score(
                    features_by_layer_for_mahalanobis, 
                    mahalanobis_ensemble_params,
                    layer_weights=mahalanobis_ensemble_params['layer_weights']  # Use learned weights
                )
            else:
                print(f"Using uniform Mahalanobis layer weights")
                mahalanobis_ensemble_scores = compute_mahalanobis_ensemble_score(
                    features_by_layer_for_mahalanobis, mahalanobis_ensemble_params
                )
        
        ## 9. ODIN ##################################################################################################################
        if 'ODIN' in target_methods:
            odin_scores = compute_odin_score(model, x_combined, dataset=dataset, temperature=1000)

        # 10. Calculate RDS score ##################################################################################################################
        rds_scores = np.zeros(x_combined.size(0))
        
        layer_scores_list = []
        if 'RDS' in target_methods:
            for layer in layer_list:
                _, layer_scores, _, _ = online_rds_layer_dict[layer](
                    features_by_layer[layer], logits=logits, update=True, scores_flipDet=msp_scores
                )
                layer_scores_list.append(layer_scores)
            
            layer_scores_array = np.array(layer_scores_list)  # shape: [num_layers, batch_size]

            rds_scores = np.mean(layer_scores_array, axis=0)  # shape: [batch_size]
        
        # Evaluate performance of each method
        for method_name in target_methods:
            
            if method_name == 'RDS':
                scores = rds_scores
            elif method_name == 'Energy':
                scores = energy_scores
            elif method_name == 'MSP':
                scores = msp_scores
            elif method_name == 'Entropy':
                scores = entropy_scores
            elif method_name == 'Max_logit':
                scores = max_logit_scores
            elif method_name == 'GradNorm':
                scores = gradnorm_scores
            elif method_name == 'ViM':
                scores = vim_scores
            elif method_name == 'KNN':
                scores = knn_scores
            elif method_name == 'Mahalanobis_single':
                scores = mahalanobis_single_scores
            elif method_name == 'Mahalanobis_ensemble':
                scores = mahalanobis_ensemble_scores
            elif method_name == 'ODIN':
                scores = odin_scores
            else:
                raise ValueError(f"Unknown method: {method_name}")

            # Calculate AUROC
            batch_auroc = metrics.roc_auc_score(true_ood_labels, scores)
            
            # Calculate FPR@95TPR
            fpr, tpr, _ = metrics.roc_curve(true_ood_labels, scores)
            idx = np.argmin(np.abs(tpr - 0.95))
            fpr_at_95_tpr = fpr[idx] if idx < len(fpr) else 1.0
            
            # Save results
            results[method_name]['batch_aurocs'].append(batch_auroc)
            results[method_name]['batch_fprs'].append(fpr_at_95_tpr)
            results[method_name]['all_scores'].append(scores)
        
        # Save labels
        all_ood_labels.append(true_ood_labels)
        
        # Evaluate classification performance (only if y_ind_all is provided)
        if y_ind_all is not None:
            # Extract predictions only for ID samples
            id_outputs = logits[:end_ind - start_ind]
            id_predictions = torch.argmax(id_outputs, dim=1).cpu().numpy()
            id_true_labels = y_ind_all[start_ind:end_ind].cpu().numpy()
            
            # Calculate accuracy
            batch_accuracy = (id_predictions == id_true_labels).mean()
            accuracies.append(batch_accuracy)
            
            all_predictions.append(id_predictions)
            all_true_labels.append(id_true_labels)
            
            # Calculate OSCR
            id_energy_scores = energy_scores[:end_ind - start_ind]
            ood_energy_scores = energy_scores[end_ind - start_ind:]
            
            batch_oscr = get_oscr(id_energy_scores, ood_energy_scores, id_predictions, id_true_labels)
            oscr_values.append(batch_oscr)
            
            # Print batch results
            print(f"Batch {batch_idx+1}/{num_batches}:")
            for method_name in results.keys():
                print(f"  {method_name}: AUROC={results[method_name]['batch_aurocs'][-1]:.4f}, "
                      f"FPR@95TPR={results[method_name]['batch_fprs'][-1]:.4f}")
            print(f"  Accuracy={batch_accuracy:.4f}, OSCR={batch_oscr:.4f}")
        else:
            # Print batch results
            print(f"Batch {batch_idx+1}/{num_batches}:")
            for method_name in results.keys():
                print(f"  {method_name}: AUROC={results[method_name]['batch_aurocs'][-1]:.4f}, "
                      f"FPR@95TPR={results[method_name]['batch_fprs'][-1]:.4f}")
        
    # =====================================
    # 3. Calculate and visualize final results
    # =====================================
    final_results = {}
    
    # Calculate final performance for each method
    for method_name in results.keys():
        all_scores = np.concatenate(results[method_name]['all_scores'])
        all_labels = np.concatenate(all_ood_labels)
        
        final_auroc = metrics.roc_auc_score(all_labels, all_scores)
        fpr, tpr, _ = metrics.roc_curve(all_labels, all_scores)
        idx = np.argmin(np.abs(tpr - 0.95))
        final_fpr = fpr[idx] if idx < len(fpr) else 1.0
        
        avg_auroc = np.mean(results[method_name]['batch_aurocs'])
        avg_fpr = np.mean(results[method_name]['batch_fprs'])
        
        final_results[method_name] = {
            'batch_aurocs': results[method_name]['batch_aurocs'],
            'batch_fprs': results[method_name]['batch_fprs'],
            'final_auroc': final_auroc,
            'final_fpr': final_fpr,
            'avg_auroc': avg_auroc,
            'avg_fpr': avg_fpr
        }
    
    # Classification performance results (only if y_ind_all is provided)
    if y_ind_all is not None:
        all_preds_concat = np.concatenate(all_predictions)
        all_true_concat = np.concatenate(all_true_labels)
        
        final_accuracy = (all_preds_concat == all_true_concat).mean()
        final_oscr = np.mean(oscr_values)  # Simply use average OSCR
        
        final_results['classification'] = {
            'batch_accuracies': accuracies,
            'final_accuracy': final_accuracy,
            'avg_accuracy': np.mean(accuracies)
        }
        
        final_results['oscr'] = {
            'batch_values': oscr_values,
            'final_oscr': final_oscr,
            'avg_oscr': np.mean(oscr_values)
        }
    
    # Save results summary
    with open(os.path.join(save_dir, 'summary_results.txt'), 'w') as f:
        f.write(f"OOD Detection Methods Summary Results\n")
        f.write(f"Dataset: {dataset}, Corruption: {os.path.basename(save_dir)}\n\n")
        
        for method_name in sorted(results.keys()):
            f.write(f"{method_name}:\n")
            f.write(f"  Final AUROC: {final_results[method_name]['final_auroc']:.4f}\n")
            f.write(f"  Final FPR@95TPR: {final_results[method_name]['final_fpr']:.4f}\n")
            f.write(f"  Average batch AUROC: {final_results[method_name]['avg_auroc']:.4f}\n")
            f.write(f"  Average batch FPR@95TPR: {final_results[method_name]['avg_fpr']:.4f}\n\n")
        
        if y_ind_all is not None:
            f.write(f"Classification Performance:\n")
            f.write(f"  Final Accuracy: {final_results['classification']['final_accuracy']:.4f}\n")
            f.write(f"  Average batch Accuracy: {final_results['classification']['avg_accuracy']:.4f}\n\n")
            
            f.write(f"OSCR Performance:\n")
            f.write(f"  Final OSCR: {final_results['oscr']['final_oscr']:.4f}\n")
            f.write(f"  Average batch OSCR: {final_results['oscr']['avg_oscr']:.4f}\n")
    
    # Structure final results to match evaluate_adapted_model format
    final_total_results = {
        'ood_detection': {
            method_name: {
                'batch_aurocs': results[method_name]['batch_aurocs'],
                'batch_fprs': results[method_name]['batch_fprs'],
                'all_scores': results[method_name]['all_scores'],
                'final_auroc': final_results[method_name]['final_auroc'],
                'final_fpr': final_results[method_name]['final_fpr'],
                'avg_auroc': final_results[method_name]['avg_auroc'],
                'avg_fpr': final_results[method_name]['avg_fpr']
            }
            for method_name in results.keys()
        }
    }
    
    # Classification performance results (only if y_ind_all is provided)
    if y_ind_all is not None:
        final_total_results['classification'] = final_results['classification']
        final_total_results['oscr'] = final_results['oscr']
    
    print(f"\nEvaluation completed. Results saved to {save_dir}")
    
    return final_total_results
