# Unified PyTorch training, evaluation, and result saving
import os
import gc
import json
import time
import numpy as np
import pandas as pd
from collections import deque

import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch.amp import autocast, GradScaler
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import balanced_accuracy_score, f1_score, log_loss, roc_auc_score

# defaults for any missing hyperparams
DEFAULT_HP = {
    "learning_rate": 1e-4,
    "epochs": 100,
    "batch_size": 128,
    "patience": 10,
    "window_size": 5,
    "use_attention": True,
    "augmentation": None,
    "aug_magnitude": 0.0,
    "use_amp": False,
    "max_grad_norm": None,
    "use_weight_decay": True,
    "weight_decay": 1e-5,
    "use_swa": False,
    "use_batch_norm": False,
    "use_dropout": False,
    "dropout_rate": 0.0,
    "use_skip": False,
    "skip_type": None
}


class Experiment:
    def __init__(self, loader, model, monitor, mode, hyperparams, device, model_key, task_id, tuning_applied, seed, ablation_output_dir=None):
        self.loader = loader
        self.model = model
        self.monitor = monitor
        self.mode = mode
        self.hp = hyperparams
        self.device = torch.device(device)
        self.model_key = model_key
        self.task_id = task_id
        self.tuning_applied = tuning_applied
        self.seed = seed  # Store the seed
        self.tuning_suffix = "_tuned" if tuning_applied else "_default"
        self.ablation_output_dir = ablation_output_dir
        
        if hasattr(self.model, 'to'):
            self.model.to(self.device)
            
        if ablation_output_dir:
            self.results_dir = os.path.join(ablation_output_dir, "results", 
                                          f"ws{hyperparams.get('window_size', 'NA')}_"
                                          f"att{hyperparams.get('use_attention', 'NA')}", 
                                          model_key, str(task_id))
        else:
            # Default behavior
            self.results_dir = os.path.join("results", model_key, str(task_id))


    def run(self):
        if hasattr(self.model, 'to') and self.device is not None:
            self.model.to(self.device)
        if self.mode == "cumulative":
            results = self._run_cumulative()
        else:
            results = self._run_segmental()
        self._save_results(results)

    def _run_cumulative(self):
        """Enhanced cumulative training with better progress reporting"""
        X_test, y_test = self.loader.get_test()
        test_loader = DataLoader(
            TensorDataset(torch.from_numpy(X_test).float(),
                        torch.from_numpy(y_test).long()),
            batch_size=self.hp["batch_size"]
        )

        results = []
        
        # Print experiment details
        print(f"\n=== Running {self.model_key.upper()} in {self.mode} mode ===")
        print(f"Task ID: {self.task_id}, Number of segments: {self.loader.num_segments}")
        print(f"Using hyperparameters: {self.hp}")
        
        for seg in range(self.loader.num_segments):
            print(f"\n--- Processing Segment {seg+1}/{self.loader.num_segments} ---")
            
            data = self.loader.get_segment(seg, cumulative=True)
            train_loader, val_loader = self._make_loaders(data)
            
            print(f"Train set size: {len(data['X_train'])}, "
                f"Val set size: {len(data['X_val'])}, "
                f"Test set size: {len(data['X_test'])}")
            
            energy_task_id = f"{self.tuning_suffix[1:]}_seed{self.seed}_seg{seg}"

            start = time.time()

            # TRAINING ENERGY TRACKING
            training_start = time.time()
            self.monitor.start_monitoring(task_id=f"{energy_task_id}_training")
            stats = self._train_and_eval(train_loader, val_loader, test_loader, is_segmental=False, precomputed=None)
            df_training = self.monitor.stop_monitoring(save=True)
            training_duration = time.time() - training_start
            training_summary = self.monitor.get_summary(df_training)
            training_energy = training_summary.get("Total_energy_joules", np.nan)

            # INFERENCE ENERGY TRACKING  
            inference_start = time.time()
            self.monitor.start_monitoring(task_id=f"{energy_task_id}_inference")
            test_metrics = self._evaluate_test_only(test_loader, is_segmental=False, precomputed=None)
            df_inference = self.monitor.stop_monitoring(save=True)
            inference_duration = time.time() - inference_start
            inference_summary = self.monitor.get_summary(df_inference)
            inference_energy = inference_summary.get("Total_energy_joules", np.nan)

            duration = time.time() - start

            result = {
                "segment": seg,
                "instances": len(data['X_train']),
                **test_metrics,
                "training_energy_joules": training_energy,
                "inference_energy_joules": inference_energy, 
                "total_energy_joules": training_energy + inference_energy,
                "training_time_seconds": training_duration,
                "inference_time_seconds": inference_duration,
                "total_time_seconds": duration
            }
            results.append(result)
            
            # More detailed segment reporting
            print(f"Segment {seg+1} completed in {duration:.2f}s")
            print(f"Test Accuracy: {stats['test']['balanced_accuracy']:.4f}")
            print(f"Test F1 Weighted: {stats['test']['f1_weighted']:.4f}")
            print(f"Test Log Loss: {stats['test'].get('log_loss', np.nan):.4f}")
            print(f"Test AUC-ROC: {stats['test'].get('auc_roc', np.nan):.4f}")
            energy = training_energy + inference_energy
            print(f"Energy Consumption: {energy:.2f} J")
        
        
        # Calculate NetScore-T metrics
        netscore_balanced, netscore_logloss = [], []
        for r in results[:-1]:  # Exclude final summary entry
            # Balanced accuracy version
            ns_bal = r['balanced_accuracy'] / np.log10(r['total_energy_joules'] + 1)
            netscore_balanced.append(ns_bal)
            
            # Log-loss version
            p = 1 / (r['log_loss'] + 1e-5)
            ns_log = p / np.log10(r['total_energy_joules'] + 1)
            netscore_logloss.append(ns_log)
        
        avg_netscore_bal = np.mean(netscore_balanced)
        avg_netscore_log = np.mean(netscore_logloss)
        
        # Print final summary
        total_time = sum(r["total_time_seconds"] for r in results)
        total_energy_training = sum(r["training_energy_joules"] for r in results)
        total_energy_inference = sum(r["inference_energy_joules"] for r in results)
        total_energy = sum(r["total_energy_joules"] for r in results)
        final_stats = results[-1]

        summary_entry = {
            "Total_time_seconds": total_time,
            "Total_energy_training_joules": total_energy_training,
            "Total_energy_inference_joules": total_energy_inference,
            "Total_energy_joules": total_energy,
            "Final_balanced_acc": final_stats["balanced_accuracy"],
            "Final_log_loss": final_stats["log_loss"],
            "NetScore-T_balanced": avg_netscore_bal,
            "NetScore-T_logloss": avg_netscore_log
        }
        results.append(summary_entry)

        print(f"\n=== Summary for {self.model_key.upper()} ===")
        print(f"Total time: {total_time:.2f}s")
        print(f"Total energy training: {total_energy_training:.2f}J")
        print(f"Total energy inference: {total_energy_inference:.2f}J")
        print(f"Total energy: {total_energy:.2f}J")
        print(f"Final accuracy: {final_stats['balanced_accuracy']:.4f}")
        print(f"Final log loss: {final_stats['log_loss']:.4f}")
        print(f"Final AUC-ROC: {final_stats['auc_roc']:.4f}")
        print(f"Final F1: {final_stats['f1_weighted']:.4f}")
        print(f"NetScore-T (Balanced): {avg_netscore_bal:.3f}")
        print(f"NetScore-T (Log-Loss): {avg_netscore_log:.3f}")
        
        return results

    def _run_segmental(self):
        """Enhanced segmental training with better progress reporting"""
        X_test, y_test = self.loader.get_test()
        test_loader = DataLoader(
            TensorDataset(torch.from_numpy(X_test).float(),
                        torch.from_numpy(y_test).long()),
            batch_size=self.hp["batch_size"]
        )

        buffer = deque(maxlen=self.hp["window_size"])
        results = []
        
        # Print experiment details
        print(f"\n=== Running {self.model_key.upper()} in {self.mode} mode ===")
        print(f"Task ID: {self.task_id}, Number of segments: {self.loader.num_segments}")
        print(f"Using hyperparameters: {self.hp}")
        
        for seg in range(self.loader.num_segments):
            print(f"\n--- Processing Segment {seg+1}/{self.loader.num_segments} ---")
            
            data = self.loader.get_segment(seg, cumulative=False)
            train_loader, val_loader = self._make_loaders(data, include_index=True)

            print(f"Train set size: {len(data['X_train'])}, "
                f"Val set size: {len(data['X_val'])}, "
                f"Test set size: {len(data['X_test'])}")

            precomputed = []
            if buffer and self.hp["use_attention"]:
                precomputed = self._precompute(buffer, train_loader)

            energy_task_id = f"{self.tuning_suffix[1:]}_seed{self.seed}_seg{seg}"

            start = time.time()

            # TRAINING ENERGY TRACKING
            training_start = time.time()
            self.monitor.start_monitoring(task_id=f"{energy_task_id}_training")
            stats = self._train_and_eval(train_loader, val_loader, test_loader, is_segmental=True, precomputed=None)
            df_training = self.monitor.stop_monitoring(save=True)
            training_duration = time.time() - training_start
            training_summary = self.monitor.get_summary(df_training)
            training_energy = training_summary.get("Total_energy_joules", np.nan)

            # INFERENCE ENERGY TRACKING  
            inference_start = time.time()
            self.monitor.start_monitoring(task_id=f"{energy_task_id}_inference")
            test_metrics = self._evaluate_test_only(test_loader, is_segmental=True, precomputed=None)
            df_inference = self.monitor.stop_monitoring(save=True)
            inference_duration = time.time() - inference_start
            inference_summary = self.monitor.get_summary(df_inference)
            inference_energy = inference_summary.get("Total_energy_joules", np.nan)

            duration = time.time() - start

            result = {
                "segment": seg,
                "instances": len(data['X_train']),
                **test_metrics,
                "training_energy_joules": training_energy,
                "inference_energy_joules": inference_energy, 
                "total_energy_joules": training_energy + inference_energy,
                "training_time_seconds": training_duration,
                "inference_time_seconds": inference_duration,
                "total_time_seconds": duration
            }
            results.append(result)
            
            # More detailed segment reporting
            print(f"Segment {seg+1} completed in {duration:.2f}s")
            print(f"Test Accuracy: {stats['test']['balanced_accuracy']:.4f}")
            print(f"Test F1 Weighted: {stats['test']['f1_weighted']:.4f}")
            print(f"Test Log Loss: {stats['test'].get('log_loss', np.nan):.4f}")
            print(f"Test AUC-ROC: {stats['test'].get('auc_roc', np.nan):.4f}")
            energy = training_energy + inference_energy
            print(f"Energy Consumption: {energy:.2f} J")

            buffer.append(self.model)
        
        
        # Calculate NetScore-T metrics
        netscore_balanced, netscore_logloss = [], []
        for r in results[:-1]:  # Exclude final summary entry
            # Balanced accuracy version
            ns_bal = r['balanced_accuracy'] / np.log10(r['total_energy_joules'] + 1)
            netscore_balanced.append(ns_bal)
            
            # Log-loss version
            p = 1 / (r['log_loss'] + 1e-7)
            ns_log = p / np.log10(r['total_energy_joules'] + 1)
            netscore_logloss.append(ns_log)
        
        avg_netscore_bal = np.mean(netscore_balanced)
        avg_netscore_log = np.mean(netscore_logloss)
        
        # Print final summary
        total_time = sum(r["total_time_seconds"] for r in results)
        total_energy_training = sum(r["training_energy_joules"] for r in results)
        total_energy_inference = sum(r["inference_energy_joules"] for r in results)
        total_energy = sum(r["total_energy_joules"] for r in results)
        final_stats = results[-1]

        summary_entry = {
            "Total_time_seconds": total_time,
            "Total_energy_training_joules": total_energy_training,
            "Total_energy_inference_joules": total_energy_inference,
            "Total_energy_joules": total_energy,
            "Final_balanced_acc": final_stats["balanced_accuracy"],
            "Final_log_loss": final_stats["log_loss"],
            "NetScore-T_balanced": avg_netscore_bal,
            "NetScore-T_logloss": avg_netscore_log
        }
        results.append(summary_entry)

        print(f"\n=== Summary for {self.model_key.upper()} ===")
        print(f"Total time: {total_time:.2f}s")
        print(f"Total energy training: {total_energy_training:.2f}J")
        print(f"Total energy inference: {total_energy_inference:.2f}J")
        print(f"Total energy: {total_energy:.2f}J")
        print(f"Final accuracy: {final_stats['balanced_accuracy']:.4f}")
        print(f"Final log loss: {final_stats['log_loss']:.4f}")
        print(f"Final AUC-ROC: {final_stats['auc_roc']:.4f}")
        print(f"Final F1: {final_stats['f1_weighted']:.4f}")
        print(f"NetScore-T (Balanced): {avg_netscore_bal:.3f}")
        print(f"NetScore-T (Log-Loss): {avg_netscore_log:.3f}")
        
        return results

    def _make_loaders(self, data, include_index=False):
        X_train, y_train = data["X_train"], data["y_train"]
        X_val, y_val = data["X_val"], data["y_val"]

        if include_index:
            idxs = torch.arange(len(X_train))
            train_ds = TensorDataset(torch.from_numpy(X_train).float(),
                                     torch.from_numpy(y_train).long(),
                                     idxs)
        else:
            train_ds = TensorDataset(torch.from_numpy(X_train).float(),
                                     torch.from_numpy(y_train).long())

        val_ds = TensorDataset(torch.from_numpy(X_val).float(),
                               torch.from_numpy(y_val).long())

        train_loader = DataLoader(train_ds,
                                  batch_size=self.hp["batch_size"],
                                  shuffle=True)
        val_loader = DataLoader(val_ds,
                                batch_size=self.hp["batch_size"])
        return train_loader, val_loader

    def _precompute(self, buffer, train_loader):
        feats_list = []
        with torch.no_grad():
            for m in buffer:
                m.eval()
                feats = []
                for batch in train_loader:
                    xb = batch[0].to(self.device)
                    _, f = m(xb, None)
                    feats.append(f.cpu())
                
                # Concatenate features for this model
                stacked_feats = torch.cat(feats, dim=0)
                
                # Apply normalization like in the new code
                normalized_feats = stacked_feats / (torch.norm(stacked_feats, dim=1, keepdim=True) + 1e-8)
                feats_list.append(normalized_feats)
                
        return feats_list


    def _train_and_eval(self, train_loader, val_loader, test_loader,
                        is_segmental, precomputed):
        """Improved training and evaluation with feature normalization and better optimizer settings"""

        if self.model_key in ["tabm", "tabr", "modernnca", "realmlp", "tabpfnv2", "tabpfn", "xgboost", "lightgbm", "catboost",
                 "linearmodel", "knn", "svm", "decisiontree", "randomforest", "saint", "resnet", "fttransformer", "danet", "stg", "tabnet", "node", "vime"]:
            # Convert DataLoaders to numpy arrays
            X_train, y_train = self._dataloader_to_numpy(train_loader)
            X_val, y_val = self._dataloader_to_numpy(val_loader)
            X_test, y_test = self._dataloader_to_numpy(test_loader)
            
            if self.model_key == "saint":
                num_features = X_train.shape[1]
                if num_features > 100:
                    print(f"Skipping SAINT: {num_features} features (limit: 100)")
                    return {'test': {
                        'accuracy': 0.0,
                        'balanced_accuracy': 0.0,
                        'f1_weighted': 0.0,
                        'log_loss': float('inf'),
                        'auc_roc': 0.5
                    }}
                
            if self.model_key == "tabpfn":
                num_classes = len(np.unique(y_train))
                num_features = X_train.shape[1]
                
                if num_classes > 10 or num_features > 500:  # TabPFN's limits
                    print(f"Skipping TabPFN: {num_classes} classes and {num_features} features (limits: 10 classes, 500 features)")
                    # Return dummy metrics with all expected fields
                    return {'test': {
                        'accuracy': 0.0,
                        'balanced_accuracy': 0.0,
                        'f1_weighted': 0.0,
                        'log_loss': float('inf'),
                        'auc_roc': 0.5
                    }}
            
            # Fit and predict
            self._print_gpu_memory()
            self.model.fit(X_train, y_train, X_val, y_val)
            self._print_gpu_memory()
            y_pred = self.model.predict(X_test)
            y_proba = self.model.predict_proba(X_test)
            
            # Handle tuple case for y_pred and y_proba
            if isinstance(y_pred, tuple):
                actual_predictions = y_pred[0]
                if y_proba is None and len(y_pred) > 1:
                    potential_proba = y_pred[1]
                    if isinstance(potential_proba, (np.ndarray, pd.Series)):
                        y_proba = potential_proba
                y_pred = actual_predictions

            # Debugging shapes
            print(f"Shape of y_test: {y_test.shape}")
            print(f"Shape of y_pred: {y_pred.shape}")
            print(f"Shape of y_proba: {y_proba.shape if y_proba is not None else 'None'}")

            # Calculate metrics
            test_metrics = self._compute_metrics(y_test, y_pred, y_proba)
            # if hasattr(self, 'model'):
            #     if hasattr(self.model, 'model') and hasattr(self.model.model, 'to'):
            #         self.model.model.cpu()  # Move model to CPU
            #     del self.model
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            return {'test': test_metrics}
        else:
            # Get dataset class information before training
            train_classes = self._get_dataset_classes(train_loader)
            
            # optimizer & scheduler - now with better hyperparameters
            Optim = AdamW if is_segmental else Adam
            optim = Optim(self.model.parameters(),
                lr=self.hp["learning_rate"],
                eps=1e-8,  # Add this parameter
                weight_decay=self.hp["weight_decay"] if self.hp["use_weight_decay"] else 1e-5)  # Default small weight decay
            
            scheduler = ReduceLROnPlateau(optim, mode="max", factor=0.5, patience=5)
            criterion = torch.nn.CrossEntropyLoss()
            scaler = GradScaler() if self.hp["use_amp"] else None

            best_val = 0.0
            counter = 0
            best_state = None

            for epoch in range(self.hp["epochs"]):
                # Improved train epoch with feature normalization
                train_metrics = self._train_epoch(train_loader, optim, criterion,
                                                scaler, is_segmental, precomputed)
                val_metrics = self._evaluate_loader(val_loader, criterion,
                                                is_segmental, precomputed, train_classes)
                scheduler.step(val_metrics["balanced_accuracy"])

                # Print training progress for better visibility
                # print(f"Epoch {epoch+1}/{self.hp['epochs']}: "
                #     f"Train Acc = {train_metrics['balanced_accuracy']:.4f}, "
                #     f"Val Acc = {val_metrics['balanced_accuracy']:.4f}, "
                #     f"Val F1 = {val_metrics['f1_weighted']:.4f}, "
                #     f"Loss = {train_metrics.get('loss', 0.0):.4f}")

                if val_metrics["balanced_accuracy"] > best_val:
                    best_val = val_metrics["balanced_accuracy"]
                    best_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
                    counter = 0
                else:
                    counter += 1
                    if counter >= self.hp["patience"]:
                        print(f"Early stopping at epoch {epoch+1}")
                        break

            if best_state:
                self.model.load_state_dict(best_state)

            # Pass train_classes to ensure test evaluation matches training classes
            test_metrics = self._evaluate_loader(test_loader, criterion,
                                            is_segmental, precomputed, train_classes)
            return {"train": train_metrics, "val": val_metrics, "test": test_metrics}
        
    def _evaluate_test_only(self, test_loader, is_segmental, precomputed):
        """Only evaluate on test set for inference energy measurement"""
        if self.model_key in ["tabm", "tabr", "modernnca", "realmlp", "tabpfnv2", "tabpfn", "xgboost", "lightgbm", "catboost",
                            "linearmodel", "knn", "svm", "decisiontree", "randomforest", 
                            "saint", "resnet", "fttransformer", "danet", "stg", "tabnet", "node", "vime"]:
            X_test, y_test = self._dataloader_to_numpy(test_loader)
            y_pred = self.model.predict(X_test)
            y_proba = self.model.predict_proba(X_test)
            
            if isinstance(y_pred, tuple):
                actual_predictions = y_pred[0]
                if y_proba is None and len(y_pred) > 1:
                    y_proba = y_pred[1]
                y_pred = actual_predictions
                
            return self._compute_metrics(y_test, y_pred, y_proba)
        else:
            criterion = torch.nn.CrossEntropyLoss()
            return self._evaluate_loader(test_loader, criterion, is_segmental, precomputed, None)
            
    def _dataloader_to_numpy(self, loader):
        """Convert PyTorch DataLoader to numpy arrays"""
        X, y = [], []
        for batch in loader:
            X.append(batch[0].cpu().numpy())
            y.append(batch[1].cpu().numpy())
        return np.concatenate(X), np.concatenate(y)
    
    def _compute_metrics(self, y_true, y_pred, y_proba):
        """Handle binary/multiclass metrics differently"""
        metrics = {
            'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
            'f1_weighted': f1_score(y_true, y_pred, average='weighted')
        }
        
        # Handle log loss
        try:
            metrics['log_loss'] = log_loss(y_true, y_proba)
        except:
            metrics['log_loss'] = np.nan
        
        # Handle AUC-ROC based on number of classes
        n_classes = len(np.unique(y_true))
        try:
            if n_classes == 2:
                metrics['auc_roc'] = roc_auc_score(y_true, y_proba[:, 1])
            else:
                metrics['auc_roc'] = roc_auc_score(y_true, y_proba, 
                                                multi_class='ovr',
                                                average='weighted')
        except:
            metrics['auc_roc'] = np.nan
        
        return metrics

    def _get_dataset_classes(self, loader):
        """Extract all unique classes from a DataLoader."""
        all_classes = set()
        for batch in loader:
            y = batch[1].numpy()
            all_classes.update(np.unique(y))
        return sorted(list(all_classes))

    def _train_epoch(self, loader, optim, criterion, scaler,
                     is_segmental, precomputed):
        self.model.train()
        running_loss = 0.0
        all_preds, all_targs = [], []

        for batch in loader:
            xb = batch[0].to(self.device)
            yb = batch[1].to(self.device)
            idxs = batch[2] if is_segmental else None

            optim.zero_grad()
            ctx = None
            if is_segmental and precomputed and self.hp["use_attention"]:
                ctx = [p[idxs].to(self.device) for p in precomputed]

            if scaler:
                with autocast(device_type=self.device.type):
                    out, _ = (self.model(xb, ctx) if is_segmental else (self.model(xb), None))
                    loss = criterion(out, yb)
                scaler.scale(loss).backward()
                if self.hp["max_grad_norm"]:
                    scaler.unscale_(optim)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hp["max_grad_norm"])
                scaler.step(optim)
                scaler.update()
            else:
                out, _ = (self.model(xb, ctx) if is_segmental else (self.model(xb), None))
                loss = criterion(out, yb)
                loss.backward()
                if self.hp["max_grad_norm"]:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hp["max_grad_norm"])
                optim.step()
            
            running_loss += loss.item()
            preds = out.argmax(dim=1).cpu().numpy()
            all_preds.append(preds)
            all_targs.append(yb.cpu().numpy())

        y_pred = np.concatenate(all_preds)
        y_true = np.concatenate(all_targs)
        
        try:
            # Simple try-except without any fancy handling
            bal_acc = balanced_accuracy_score(y_true, y_pred)
        except Exception:
            bal_acc = np.nan
            
        try:
            f1_w = f1_score(y_true, y_pred, average="weighted", zero_division=0)
        except Exception:
            f1_w = np.nan
            
        return {
            "loss": running_loss / len(loader),
            "balanced_accuracy": bal_acc,
            "f1_weighted": f1_w
        }

    def _evaluate_loader(self, loader, criterion, is_segmental, precomputed, train_classes=None):
        self.model.eval()
        running_loss = 0.0
        all_logits = []
        all_targs = []

        with torch.no_grad():
            for batch in loader:
                xb = batch[0].to(self.device)
                yb = batch[1].to(self.device)
                # Context isn't needed for evaluation
                ctx = None

                # Pass ctx=None during evaluation
                out, _ = (self.model(xb, ctx) if is_segmental else (self.model(xb), None))
                loss = criterion(out, yb)
                running_loss += loss.item()
                
                # Store logits rather than softmax outputs
                all_logits.append(out.cpu().numpy())
                all_targs.append(yb.cpu().numpy())

        # Convert lists to arrays
        logits = np.vstack(all_logits)
        y_true = np.concatenate(all_targs)
        
        # Get unique classes in test data
        test_classes = np.unique(y_true)
        
        # Normalize logits to probabilities correctly - apply softmax along axis 1
        probs = self._softmax(logits)
        
        # Get predictions from probabilities
        y_pred = np.argmax(probs, axis=1)
        
        # Initialize metrics
        metrics = {
            "loss": running_loss / len(loader),
            "balanced_accuracy": np.nan,
            "f1_weighted": np.nan,
            "log_loss": np.nan,
            "auc_roc": np.nan
        }

        try:
            # Balanced accuracy - simple calculation
            metrics["balanced_accuracy"] = balanced_accuracy_score(y_true, y_pred) 
        except Exception:
            pass
            
        try:
            # F1 score with zero_division handling
            metrics["f1_weighted"] = f1_score(y_true, y_pred, average="weighted", zero_division=0)
        except Exception:
            pass
            
        # For log_loss and AUC, we need special handling
        if len(test_classes) > 1:
            try:
                # For log_loss, normalize probabilities if necessary
                # Check if probabilities sum to 1
                row_sums = np.sum(probs, axis=1)
                if not np.allclose(row_sums, 1.0, rtol=1e-5, atol=1e-8):
                    # Renormalize if they don't
                    probs = probs / row_sums[:, np.newaxis]
                
                metrics["log_loss"] = log_loss(y_true, probs)
            except Exception:
                pass
                
            try:
                # Handle AUC calculation
                if probs.shape[1] == 2:
                    metrics["auc_roc"] = roc_auc_score(y_true, probs[:, 1])
                else:
                    metrics["auc_roc"] = roc_auc_score(y_true, probs, multi_class="ovr")
            except Exception:
                pass

        return metrics
    
    def _softmax(self, x):
        """Compute softmax values for each row of matrix x."""
        # Subtract max for numerical stability
        exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
        return exp_x / np.sum(exp_x, axis=1, keepdims=True)
    @staticmethod
    def _print_gpu_memory():
        if torch.cuda.is_available():
            print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
            print(f"GPU memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")

    def _save_results(self, results):
        """Modified save_results to respect ablation directory"""
        # Determine device type for folder path
        device_type_folder = self.device.type  # This will be 'cuda' or 'cpu'
        
        # Check if we're in ablation mode
        if self.ablation_output_dir:
            # Use ablation directory structure
            out_dir = os.path.join(self.ablation_output_dir, "metrics", 
                                 device_type_folder, self.model_key, str(self.task_id))
        else:
            # Default behavior - original directory
            out_dir = os.path.join("results", "metrics", device_type_folder, 
                                 self.model_key, str(self.task_id))
        
        os.makedirs(out_dir, exist_ok=True)
        # Include seed in filename
        fname = f"{self.mode}{self.tuning_suffix}_seed{self.seed}.json"
        
        # Add seed to results metadata
        results_with_seed = {
            "metadata": {
                "seed": self.seed,
                "model": self.model_key,
                "mode": self.mode,
                "tuning": "tuned" if self.tuning_applied else "default",
                # "parameters": self.hp,
                # Add window size and attention info for ablation tracking
                "window_size": self.hp.get("window_size") if hasattr(self, 'hp') else None,
                "use_attention": self.hp.get("use_attention") if hasattr(self, 'hp') else None
            },
            "segments": results
        }
        
        path = os.path.join(out_dir, fname)
        with open(path, "w") as f:
            json.dump(results_with_seed, f, indent=2)
        print(f"Saved results to {path}")