import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import random
import numpy as np
from typing import Dict
from sklearn.model_selection import ShuffleSplit
import glob
import os
from tqdm import tqdm
from copy import deepcopy

from models.classifier import SetTransformer, LinearClassifier, MLPClassifier
from models.feature_selector_models import MLPFeatureSelector, SetTransformerFeatureSelector
from merlin_arthur_framework.feature_selectors import SFWFeatureSelector, ModelFeatureSelector
from config.config_dataclass import TrainerConfig, DatasetConfig, ModelConfig, BooleanConfig, FeatureSelectorConfig
from utils.metrics import get_accuracy, compute_confusion_matrix, compute_feature_distribution, compute_precision_and_entropy, compute_average_occurrence, FEATURE_INTERPRETATIONS_NCB_SEED_0

class BaseTrainer:
    def __init__(
        self, 
        trainer_config: TrainerConfig,
        dataset_config: DatasetConfig,
        model_config: ModelConfig, 
        bool_config: BooleanConfig,
        feature_selector_config: FeatureSelectorConfig,
        logger=None
    ):
        """Initialize trainer with complete configuration setup
        
        Args:
            trainer_config: Configuration for training parameters
            dataset_config: Configuration for dataset parameters
            model_config: Configuration for model parameters
            bool_config: Configuration for boolean parameters
            logger: Logger for logging training metrics
        """
        # Setup device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Setup all configurations
        self._setup_trainer_config(trainer_config)
        self._setup_dataset_config(dataset_config)
        self._setup_model_config(model_config)
        self._setup_bool_config(bool_config)
        self._setup_feature_selector_config(feature_selector_config)
        
        # Initialize empty attributes
        self.model = None
        self.optimizer = None
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None
        self.num_classes = None
        self.input_dim = None
        self.num_slots = None
        self.num_blocks = None

        # set seed
        self._setup_seed(self.seed)

        # Store logger instance
        self.logger = logger

    def _setup_trainer_config(self, config: TrainerConfig) -> None:
        """Setup trainer configuration parameters"""
        self.epochs = config.epochs
        self.approach = config.approach
        self.seed = config.seed
        self.use_wandb = config.wandb
        self.res_dir = config.res_dir
        self.learning_rate = config.lr
        self.weight_decay = config.weight_decay
        self.trainer_config = config
        self.early_stopping = config.early_stopping
        self.patience = config.patience

    def _setup_dataset_config(self, config: DatasetConfig) -> None:
        """Setup dataset configuration parameters"""
        self.data_dir = config.data_dir
        self.batch_size = config.batch_size
        self.enc_type = config.enc_type
        self.num_workers = config.num_workers
        self.unconf_split = config.unconf_split
        self.dataset_config = config
        self.partial_conf_ratio = config.partial_conf_ratio
        self.partial_conf_dir = config.partial_conf_dir
        
    def _setup_model_config(self, config: ModelConfig) -> None:
        """Setup model configuration parameters"""
        self.model_name = config.model
        self.pretrained_model = config.pretrained_model
        self.pretrained_path = config.pretrained_path
        self.n_heads = config.n_heads
        self.set_transf_hidden = config.set_transf_hidden
        self.model_config = config
        self.hidden_dim = config.hidden_dim
        self.dropout = config.dropout

    def _setup_bool_config(self, config: BooleanConfig) -> None:
        """Setup boolean configuration parameters"""
        # Empty as BooleanConfig is currently empty
        self.boolean_config = config
        self.save_model = config.save_model
        self.save_confusion_matrix = config.save_confusion_matrix

    def _setup_feature_selector_config(self, config: FeatureSelectorConfig) -> None:
        """Setup feature selector configuration parameters"""
        self.feature_selector_config = config
        self.segmentation_method = config.segmentation_method
        self.mask_size = config.mask_size
        
        # If lr_fs is set, use it for both lr_merlin and lr_morgana
        if config.lr_fs is not None:
            print(f"Using lr_fs for both Merlin and Morgana")
            self.lr_merlin = config.lr_fs
            self.lr_morgana = config.lr_fs
        else:
            self.lr_merlin = config.lr_merlin
            self.lr_morgana = config.lr_morgana
        if config.weight_decay_fs is not None:
            print(f"Using weight_decay_fs for both Merlin and Morgana")
            self.weight_decay_merlin = config.weight_decay_fs
            self.weight_decay_morgana = config.weight_decay_fs
        else:
            self.weight_decay_merlin = config.weight_decay_merlin
            self.weight_decay_morgana = config.weight_decay_morgana
        self.gamma = config.gamma
        self.l1_penalty_coefficient = config.l1_penalty_coefficient
        self.sfw_max_iterations = config.sfw_max_iterations
        self.sfw_patience = config.sfw_patience
        self.fs_model = config.fs_model
        self.fs_hidden_dim = config.fs_hidden_dim
        self.fs_dropout = config.fs_dropout
        self.fs_n_heads = config.fs_n_heads
        self.feature_distribution = config.feature_distribution
        if config.feat_interp_ncb_s0:
            self.feature_interpretations = FEATURE_INTERPRETATIONS_NCB_SEED_0
        else:
            self.feature_interpretations = None
        self.compute_prec_and_ent = config.compute_prec_and_ent
        self.compute_avg_occ = config.compute_avg_occ

    def _setup_seed(self, seed: int) -> None:
        """Setup all seeds for full reproducibility
        Args:
            seed: Integer seed value
        """        
        # Python
        random.seed(seed)
        # Numpy
        np.random.seed(seed)
        # PyTorch
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # for multi-GPU
        # Deterministic operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        print(f"\nSeed set to {seed} for reproducibility!")

    def setup_data(self):
        print(f'Loading precomputed {self.enc_type} encodings from {self.data_dir}')
        self._load_encodings()
        self._print_dataset_info(self.train_loader, self.val_loader, self.test_loader)

    def _load_encodings(self):
        train_encs = glob.glob(f'{self.data_dir}/train*encs*.npy')[0]
        train_labels = glob.glob(f'{self.data_dir}/train_labels_*.npy')[0]

        val_encs = glob.glob(f'{self.data_dir}/val*encs*.npy')[0]
        val_labels = glob.glob(f'{self.data_dir}/val_labels_*.npy')[0]

        test_encs = glob.glob(f'{self.data_dir}/test*encs*.npy')[0]
        test_labels = glob.glob(f'{self.data_dir}/test_labels_*.npy')[0]

        # Load the files
        X_train = np.load(train_encs)
        y_train = np.load(train_labels)

        X_val = np.load(val_encs)
        y_val = np.load(val_labels)

        X_test = np.load(test_encs)
        y_test = np.load(test_labels)

        if self.enc_type in ['retrieval_corpus', 'one_hot_padded']:
            #flatten last two dimensions in this case
            self.num_blocks = X_train.shape[-2]
            X_train = X_train.reshape(*X_train.shape[:-2], -1)
            X_val = X_val.reshape(*X_val.shape[:-2], -1)
            X_test = X_test.reshape(*X_test.shape[:-2], -1)

        self.input_dim = X_train.shape[-1]
        self.num_slots = X_train.shape[1]
        self.num_classes = len(np.unique(y_train))

        if self.unconf_split:
            print('Splitting validation set from train set...')
            X_test, y_test = X_val, y_val
            
            # Calculate how many samples we need per class to match test set size
            test_size = len(y_test)
            samples_per_class = test_size // self.num_classes
            
            print(f"Creating validation set with {test_size} samples ({samples_per_class} per class)")
            
            # Create indices for validation and training sets
            val_indices = []
            train_indices = []
            
            # Process each class
            for class_idx in range(self.num_classes):
                # Get indices for this class in train set
                class_indices = np.where(y_train == class_idx)[0]
                
                if len(class_indices) < samples_per_class:
                    print(f"Warning: Only {len(class_indices)} samples available for class {class_idx}, using all of them")
                    val_samples = len(class_indices)
                else:
                    val_samples = samples_per_class
                
                # Randomly shuffle the indices
                np.random.shuffle(class_indices)
                
                # Split indices for val and remaining train
                val_indices.extend(class_indices[:val_samples])
                train_indices.extend(class_indices[val_samples:])
            
            # Create new datasets using the indices (more memory efficient)
            X_val_new = X_train[val_indices]
            y_val_new = y_train[val_indices]
            X_train_new = X_train[train_indices]
            y_train_new = y_train[train_indices]
            
            # Replace original arrays
            X_val = X_val_new
            y_val = y_val_new
            X_train = X_train_new
            y_train = y_train_new
            
            # Shuffle datasets
            train_shuffle_idx = np.random.permutation(len(X_train))
            val_shuffle_idx = np.random.permutation(len(X_val))
            X_train = X_train[train_shuffle_idx]
            y_train = y_train[train_shuffle_idx]
            X_val = X_val[val_shuffle_idx]
            y_val = y_val[val_shuffle_idx]

        elif not self.unconf_split and self.partial_conf_ratio > 0:
            print(f'Including additional samples equal to {self.partial_conf_ratio * 100}% of train and val set sizes from {self.partial_conf_dir}...')

            try:
                part_conf_encs = glob.glob(f'{self.partial_conf_dir}/train*encs*.npy')[0]
                part_conf_labels = glob.glob(f'{self.partial_conf_dir}/train_labels*.npy')[0]
            except IndexError:
                part_conf_encs = glob.glob(f'{self.partial_conf_dir}/test*encs*.npy')[0]
                part_conf_labels = glob.glob(f'{self.partial_conf_dir}/test_labels*.npy')[0]
            
            X_part_conf = np.load(part_conf_encs)
            y_part_conf = np.load(part_conf_labels)
            
            if self.enc_type in ['retrieval_corpus', 'one_hot_padded']:
                #flatten last two dimensions in this case
                X_part_conf = X_part_conf.reshape(*X_part_conf.shape[:-2], -1)
            
            # Calculate total samples needed
            train_samples_to_add = int(len(y_train) * self.partial_conf_ratio)
            val_samples_to_add = int(len(y_val) * self.partial_conf_ratio)
            total_samples_needed = train_samples_to_add + val_samples_to_add
            samples_per_class = total_samples_needed // self.num_classes
            
            print(f'Adding {train_samples_to_add} samples to train set ({train_samples_to_add//self.num_classes} per class)')
            print(f'Adding {val_samples_to_add} samples to val set ({val_samples_to_add//self.num_classes} per class)')
            
            # Initialize arrays for train and val additions
            train_added_X = []
            train_added_y = []
            val_added_X = []
            val_added_y = []
            
            # Process each class
            for class_idx in range(self.num_classes):
                # Get indices for this class in partial confidence set
                class_indices = np.where(y_part_conf == class_idx)[0]
                
                if len(class_indices) < samples_per_class:
                    print(f"Warning: Only {len(class_indices)} samples available for class {class_idx}, using all of them")
                    train_samples = len(class_indices) * train_samples_to_add // total_samples_needed
                    val_samples = len(class_indices) - train_samples
                else:
                    train_samples = train_samples_to_add // self.num_classes
                    val_samples = val_samples_to_add // self.num_classes
                
                # Randomly shuffle the indices
                np.random.shuffle(class_indices)
                
                # Split indices for train and val
                train_indices = class_indices[:train_samples]
                val_indices = class_indices[train_samples:train_samples + val_samples]
                
                # Add samples to respective sets
                train_added_X.append(X_part_conf[train_indices])
                train_added_y.append(y_part_conf[train_indices])
                val_added_X.append(X_part_conf[val_indices])
                val_added_y.append(y_part_conf[val_indices])
            
            # Combine added samples
            train_added_X = np.concatenate(train_added_X)
            train_added_y = np.concatenate(train_added_y)
            val_added_X = np.concatenate(val_added_X)
            val_added_y = np.concatenate(val_added_y)
            
            # Add samples to datasets and shuffle
            X_train = np.concatenate([X_train, train_added_X])
            y_train = np.concatenate([y_train, train_added_y])
            X_val = np.concatenate([X_val, val_added_X])
            y_val = np.concatenate([y_val, val_added_y])
            
            # Shuffle datasets
            train_shuffle_idx = np.random.permutation(len(X_train))
            val_shuffle_idx = np.random.permutation(len(X_val))
            X_train = X_train[train_shuffle_idx]
            y_train = y_train[train_shuffle_idx]
            X_val = X_val[val_shuffle_idx]
            y_val = y_val[val_shuffle_idx]
            
            print(f'New train set shape: {X_train.shape}')
            print(f'New val set shape: {X_val.shape}')

        dataset_train = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train))
        dataset_val = TensorDataset(torch.Tensor(X_val), torch.Tensor(y_val))
        dataset_test = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test))

        self.train_loader, self.val_loader, self.test_loader = self._create_data_loaders(dataset_train, dataset_val, dataset_test)

    def _create_data_loaders(self, dataset_train, dataset_val, dataset_test):
        """Helper method to create data loaders"""
        train_loader = DataLoader(
            dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            dataset_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        test_loader = DataLoader(
            dataset_test,   
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        return train_loader, val_loader, test_loader

    def _print_dataset_info(self, train_loader, val_loader, test_loader):
        print(f"Train dataset size: {len(train_loader.dataset)}")
        print(f"Validation dataset size: {len(val_loader.dataset)}")
        print(f"Test dataset size: {len(test_loader.dataset)}")

        # Print dimensions of one datapoint
        sample_data, _ = next(iter(train_loader))
        print(f"Input data dimension: {sample_data.shape[1:]}")

    def setup_model(self):
        """Initialize model, criterion and optimizer"""
        if self.model_name.lower() == 'linear':
            self.model = LinearClassifier(input_dim=self.num_slots*self.input_dim, num_classes=self.num_classes+1)
        elif self.model_name.lower() == 'settransformer':
            self.model = SetTransformer(dim_input=self.input_dim, dim_hidden=self.set_transf_hidden, 
                                        num_heads=self.n_heads, dim_output=self.num_classes+1, ln=True)
        elif self.model_name.lower() == 'mlp':
            self.model = MLPClassifier(input_dim=self.num_slots*self.input_dim, num_classes=self.num_classes+1)
        else:
            raise ValueError(f"Model {self.model_name} not supported for encoding classification")
            
        self.model = self.model.to(self.device)

        # Setup optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )

        # Setup criterion
        self.criterion = nn.CrossEntropyLoss()

        self._print_model_info()

        if self.approach == "sfw":
            # Initialize feature selector (Merlin)
            self.merlin = SFWFeatureSelector(
                mask_size=self.mask_size,
                mode="merlin",
                lr=self.lr_merlin,
                l1_penalty_coefficient=self.l1_penalty_coefficient,
                num_blocks=self.num_blocks,
                sfw_max_iterations=self.sfw_max_iterations,
                sfw_patience=self.sfw_patience,
                enc_type=self.enc_type
            ).to(self.device)

            # Initialize feature selector (Morgana)
            self.morgana = SFWFeatureSelector(
                mask_size=self.mask_size,
                mode="morgana",
                lr=self.lr_morgana,
                l1_penalty_coefficient=self.l1_penalty_coefficient,
                idk_class=self.num_classes,
                num_blocks=self.num_blocks,
                sfw_max_iterations=self.sfw_max_iterations,
                sfw_patience=self.sfw_patience,
                enc_type=self.enc_type
            ).to(self.device)

        elif self.approach == "learn_fs":
            # Create one feature selector and deep copy it for the second one
            if self.fs_model.lower() == "mlp":
                self.merlin_model = MLPFeatureSelector(
                    input_dim=self.num_slots*self.input_dim, 
                    num_slots=self.num_slots, 
                    num_blocks=self.num_blocks, 
                    hidden_dim=self.fs_hidden_dim, 
                    dropout=self.fs_dropout
                )
            elif self.fs_model.lower() == "settransformer":
                self.merlin_model = SetTransformerFeatureSelector(
                    input_dim=self.input_dim, 
                    num_slots=self.num_slots, 
                    num_blocks=self.num_blocks, 
                    dim_hidden=self.fs_hidden_dim, 
                    num_heads=self.fs_n_heads, 
                    ln=True,  # NOTE: LayerNorm is enabled by default, but could lead to degenarate solutions
                    dropout=self.fs_dropout
                )
            else:
                raise ValueError(f"Feature selector model {self.fs_model} not supported")
            
            # Create a deep copy for morgana
            self.morgana_model = deepcopy(self.merlin_model)
            
            self.merlin = ModelFeatureSelector(
                mask_size=self.mask_size,
                mode="merlin",
                idk_class=self.num_classes,
                num_blocks=self.num_blocks,
                enc_type=self.enc_type,
                model=self.merlin_model
            ).to(self.device)

            self.morgana = ModelFeatureSelector(
                mask_size=self.mask_size,
                mode="morgana",
                idk_class=self.num_classes,
                num_blocks=self.num_blocks,
                enc_type=self.enc_type,
                model=self.morgana_model
            ).to(self.device)

            self.merlin_optimizer = torch.optim.Adam(self.merlin.parameters(), lr=self.lr_merlin, weight_decay=self.weight_decay_merlin)
            self.morgana_optimizer = torch.optim.Adam(self.morgana.parameters(), lr=self.lr_morgana, weight_decay=self.weight_decay_morgana)

        if self.pretrained_model:
            try:
                print(f"\nLoading pretrained model from {self.pretrained_path}")
                checkpoint = torch.load(self.pretrained_path, map_location=self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                                
                # Print information about the loaded model
                print(f"Loaded checkpoint from epoch {checkpoint['epoch']+1}")
                if 'best_val_acc' in checkpoint:
                    print(f"Validation accuracy: {checkpoint['best_val_acc']:.2f}%")
                    
                # Load merlin and morgana state dicts if they exist
                if 'merlin_state_dict' in checkpoint:
                    self.merlin.load_state_dict(checkpoint['merlin_state_dict'])
                    print(f"Loaded merlin state dict from epoch {checkpoint['epoch']+1}")
                if 'morgana_state_dict' in checkpoint:
                    self.morgana.load_state_dict(checkpoint['morgana_state_dict'])
                    print(f"Loaded morgana state dict from epoch {checkpoint['epoch']+1}")
            except FileNotFoundError:
                print(f"\nWARNING: Pretrained model not found at {self.pretrained_path}")
                print("Starting with randomly initialized model.")
        
        if self.approach != 'regular':
            self._print_feature_selector_info()

    def _print_model_info(self):
        """Helper method to print model information"""
        print("\nModel setup:")
        print(f"Architecture: {self.model_name}")
        print(f"Pretrained: {self.pretrained_model}")
        num_params = sum(p.numel() for p in self.model.parameters()) # type: ignore
        print(f"Number of parameters: {num_params:,}")  # Formatted with commas
        print(f"Learning rate: {self.learning_rate}")
        print(f"Device: {self.device}")

        if self.model_name.lower() == 'settransformer':
            print(f"Hidden dimension: {self.set_transf_hidden}")
            print(f"Number of heads: {self.n_heads}")
        elif self.model_name.lower() == 'mlp':
            print(f"Hidden dimension: {self.hidden_dim}")
            print(f"Dropout rate: {self.dropout}")

    def _print_feature_selector_info(self):
        """Helper method to print feature selector information"""
        print("\nFeature Selector setup:")
        if self.approach == "sfw":
            print(f"Method: Stochastic Frank-Wolfe (SFW)")
            print(f"SFW max iterations: {self.sfw_max_iterations}")
            print(f"SFW patience: {self.sfw_patience}")

        elif self.approach == "learn_fs":
            print(f"Method: Learnable Feature Selector")
            print(f"Feature Selector model: {self.fs_model}")
            if self.fs_model.lower() == "mlp":
                print(f"Hidden dimension: {self.fs_hidden_dim}")
                print(f"Dropout: {self.fs_dropout}")

            elif self.fs_model.lower() == "settransformer":
                print(f"Hidden dimension: {self.fs_hidden_dim}")
                print(f"Number of heads: {self.fs_n_heads}")
                print(f"Dropout: {self.fs_dropout}")

            num_params = sum(p.numel() for p in self.merlin.parameters()) # type: ignore
            print(f"Number of parameters: {num_params:,}")  # Formatted with commas

        print(f"Mask size (sparsity): {self.mask_size} features")
        print(f"Gamma: {self.gamma}")
        print(f"Merlin learning rate: {self.lr_merlin}")
        print(f"Morgana learning rate: {self.lr_morgana}")
        print(f"L1 penalty coefficient: {self.l1_penalty_coefficient}")
        print(f"Device: {self.device}")
        
    def train(self):
        if self.approach == 'regular':
            return self._train_regular()
        elif self.approach == 'sfw':
            return self._train_sfw()
        elif self.approach == 'learn_fs':
            return self._train_with_learnable_fs()
        else:
            raise ValueError(f"Approach {self.approach} not supported")
        # NOTE: add unet and learnable feature selection approaches (also for train epoch and validate)
    
    def train_epoch(self):
        """Train model for one epoch using current approach
        
        Returns:
            dict: Dictionary containing training metrics
        """
        if self.approach == "regular":
            return self._train_epoch_regular()
        elif self.approach == "sfw":
            return self._train_epoch_sfw()
        elif self.approach == "learn_fs":
            return self._train_epoch_with_learnable_fs()
        else:
            raise ValueError(f"Approach {self.approach} not supported")
    
    def validate(self, loader: DataLoader):
        """Run validation based on specified approach
        
        Returns:
            dict: Dictionary containing validation metrics
        """
        if self.approach == "regular":
            return self._validate_regular(loader)
        elif self.approach == "sfw":
            return self._validate_sfw(loader)
        elif self.approach == "learn_fs":
            return self._validate_with_learnable_fs(loader)
        else:
            raise ValueError(f"Approach {self.approach} not supported")

    def _train_regular(self):
        """Regular training approach
        
        Returns:
            dict: Dictionary containing best validation metrics
        """
        print(f"\nStarting regular training for {self.epochs} epochs...")
        best_metrics = {
            'best_val_acc': 0,
            'best_epoch': -1,
            'val_loss': float('inf')
        }
        no_improvement = 0  # Counter for early stopping
        best_model_state = None  # Store the best model state

        # Determine save path for confusion matrix plots
        save_path = None
        if self.save_confusion_matrix:
            base_dir = f'{self.approach}_{self.model_name}_on_{self.enc_type}_seed{self.seed}'
            if self.logger is not None:
                base_dir = f'{base_dir}_{self.logger.run.name}'
            save_dir = os.path.join(self.res_dir, base_dir)
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, 'confusion_matrix')

        for epoch in range(self.epochs):
            print(f"\nEpoch {epoch+1}/{self.epochs}")
            
            # Train and validate
            train_metrics = self.train_epoch()
            val_metrics = self.validate(self.val_loader)
            
            # Log metrics to wandb if enabled
            if self.logger is not None:
                self.logger.log({
                    'epoch': epoch,
                    'train/loss': train_metrics['train_loss'],
                    'train/accuracy': train_metrics['train_acc'],
                    'val/loss': val_metrics['loss'],
                    'val/accuracy': val_metrics['acc'],
                    'learning_rate': self.learning_rate,
                }, step=epoch)

            # Early stopping based on validation loss improvement
            if val_metrics['loss'] < best_metrics['val_loss']:
                best_metrics.update({
                    'best_val_acc': val_metrics['acc'],
                    'best_epoch': epoch,
                    'val_loss': val_metrics['loss'],
                    'train_acc': train_metrics['train_acc'],
                    'train_loss': train_metrics['train_loss']
                })
                no_improvement = 0  # Reset counter on improvement
                
                # Deep copy of the model's state dictionary
                best_model_state = {
                    'epoch': epoch,
                    'model_state_dict': {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
                }

                if self.save_model:
                    checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        **best_metrics
                    }

                    # Base directory name
                    base_dir = f'{self.approach}_{self.model_name}_on_{self.enc_type}_seed{self.seed}'
                    
                    # If using wandb, append the run name
                    if self.logger is not None:
                        base_dir = f'{base_dir}_{self.logger.run.name}'

                    checkpoint_dir = os.path.join(self.res_dir, base_dir)
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
                    
                    self.save_checkpoint(checkpoint, checkpoint_path)
                    print(f"Updated best model checkpoint with validation accuracy: {best_metrics['best_val_acc']:.2f}%")
            else:
                no_improvement += 1
                print(f"No improvement for {no_improvement} epoch(s).")

            # Check early stopping criterion
            if self.early_stopping and no_improvement >= self.patience:
                print(f"Early stopping triggered after {no_improvement} epochs with no improvement.")
                break
                    
        print(f"\nTraining completed! Best validation accuracy: {best_metrics['best_val_acc']:.2f}% "
              f"at epoch {best_metrics['best_epoch']+1}\n")
        
        # Load the best model state
        self.model.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                    for k, v in best_model_state['model_state_dict'].items()})

        compute_confusion_matrix(
            loader=self.val_loader, 
            set_name="Validation",
            approach=self.approach,
            model=self.model,
            device=self.device,
            num_classes=self.num_classes,
            logger=self.logger,
            save_path=None if save_path is None else f"{save_path}_val"
        )
        test_metrics = self.validate(self.test_loader)
        print(f"Test accuracy: {test_metrics['acc']:.2f}%\n")
        compute_confusion_matrix(
            loader=self.test_loader, 
            set_name="Test",
            approach=self.approach,
            model=self.model,
            device=self.device,
            num_classes=self.num_classes,
            logger=self.logger,
            save_path=None if save_path is None else f"{save_path}_test"
        )

        if self.logger is not None:
            self.logger.log({
                'val/best_loss': best_metrics['val_loss'],
                'val/best_epoch': best_metrics['best_epoch'],
                'val/best_accuracy': best_metrics['best_val_acc'],
                'test/best_accuracy': test_metrics['acc']
            })

        return best_metrics

    def _train_epoch_regular(self):
        self.model.train()

         # Progress bar for training
        pbar = tqdm(self.train_loader, desc='Training')

        total_loss = 0
        correct = 0
        total = 0

        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()

            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # Backward pass
            loss.backward()
            self.optimizer.step()

            # Update metrics
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{total_loss/len(self.train_loader):.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
        
        # Compute epoch metrics
        metrics = {
            'train_loss': total_loss / len(self.train_loader),
            'train_acc': 100. * correct / total
        }
        
        return metrics

    def _validate_regular(self, loader: DataLoader):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        # Progress bar for validation
        pbar = tqdm(loader, desc='Validating' if loader == self.val_loader else 'Testing')
        
        with torch.no_grad():
            for inputs, targets in pbar:
                # Move data to device
                inputs = inputs.to(self.device)
                targets = targets.to(self.device).long()
                
                # Forward pass
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                
                # Update metrics
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f'{total_loss/len(loader):.4f}',
                    'acc': f'{100.*correct/total:.2f}%'
                })
        
        # Compute validation metrics
        metrics = {
            'loss': total_loss / len(loader),
            'acc': 100. * correct / total
        }
        
        return metrics

    def _train_sfw(self):
        """Merlin-Arthur training approach using SFW
        
        Returns:
            dict: Dictionary containing best validation metrics
        """
        print(f"\nStarting Merlin-Arthur training with SFW for {self.epochs} epochs...")
            
        best_metrics = {
            'best_combined_metric': 0,
            'best_epoch': -1,
            'val_loss': float('inf'),
            'val_completeness': 0,
            'val_soundness': 0,
            'train_completeness': 0,
            'train_soundness': 0,
            'train_loss': float('inf')
        }
        
        best_model_state = None  # Store the best model state

        # Determine save path for confusion matrix plots
        save_path = None
        if self.save_confusion_matrix:
            base_dir = f'{self.approach}_{self.model_name}_on_{self.enc_type}_seed{self.seed}'
            if self.logger is not None:
                base_dir = f'{base_dir}_{self.logger.run.name}'
            save_dir = os.path.join(self.res_dir, base_dir)
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, 'confusion_matrix')
        
        for epoch in range(self.epochs):
            print(f"\nEpoch {epoch+1}/{self.epochs}")
            
            # Train and validate
            train_metrics = self._train_epoch_sfw()
            val_metrics = self._validate_sfw(self.val_loader)
            
            # Update validation metrics to include completeness and soundness
            if 'val_completeness' not in val_metrics:
                # For backward compatibility, use val_acc as completeness if not already present
                val_metrics['val_completeness'] = val_metrics['acc']
                val_metrics['val_soundness'] = 0  # Default if not present
            
            # Log metrics to wandb if enabled
            if self.logger is not None:
                metrics_dict = {
                    'epoch': epoch,
                    'train/loss': train_metrics['train_loss'],
                    'train/completeness': train_metrics['train_completeness'],
                    'train/soundness': train_metrics['train_soundness'],
                    'val/loss': val_metrics['loss'],
                    'val/completeness': val_metrics['completeness'],
                    'val/soundness': val_metrics['soundness'],
                    'mask/sparsity': self.mask_size,
                }
                self.logger.log(metrics_dict, step=epoch)
            
            # Consolidated Epoch Summary
            print("\n" + "=" * 60)
            print("EPOCH SUMMARY")
            print("=" * 60)
            print(f"Epoch {epoch+1}/{self.epochs}")
            print("-" * 60)
            print("Training Statistics:")
            print(f"   Loss            : {train_metrics['train_loss']:.4f}")
            print(f"   Completeness    : {train_metrics['train_completeness']:.2f}%")
            print(f"   Soundness       : {train_metrics['train_soundness']:.2f}%")
            print("-" * 60)
            print("Validation Statistics:")
            print(f"   Loss            : {val_metrics['loss']:.4f}")
            print(f"   Completeness    : {val_metrics['completeness']:.2f}%")
            print(f"   Soundness       : {val_metrics['soundness']:.2f}%")
            print("-" * 60)

            # Update best metrics and save model
            if val_metrics['soundness'] > 90:
                # If soundness threshold is met, add a large bonus to ensure it's better than sub-90 models
                # but still maintains the comp+sound ordering among models above 90%
                combined_metric = val_metrics['completeness'] + val_metrics['soundness'] + 100
            else:
                combined_metric = val_metrics['completeness'] + val_metrics['soundness']

            if combined_metric > best_metrics['best_combined_metric']:
                best_metrics.update({
                    'best_combined_metric': combined_metric,
                    'best_epoch': epoch,
                    'val_loss': val_metrics['loss'],
                    'val_completeness': val_metrics['completeness'],
                    'val_soundness': val_metrics['soundness'],
                    'train_completeness': train_metrics['train_completeness'],
                    'train_soundness': train_metrics['train_soundness'],
                    'train_loss': train_metrics['train_loss']
                })
                
                # Deep copy of the model's state dictionary
                best_model_state = {
                    'epoch': epoch,
                    'model_state_dict': {k: v.cpu().clone() for k, v in self.model.state_dict().items()},
                    'merlin_state_dict': {k: v.cpu().clone() for k, v in self.merlin.state_dict().items()},
                    'morgana_state_dict': {k: v.cpu().clone() for k, v in self.morgana.state_dict().items()}
                }
                
                if self.save_model:
                    checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'merlin_state_dict': self.merlin.state_dict(),
                        'morgana_state_dict': self.morgana.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        **best_metrics
                    }

                    # Base directory name
                    base_dir = f'{self.approach}_{self.model_name}_on_{self.enc_type}_seed{self.seed}'
                    
                    # If using wandb, append the run name
                    if self.logger is not None:
                        base_dir = f'{base_dir}_{self.logger.run.name}'

                    checkpoint_dir = os.path.join(self.res_dir, base_dir)
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
                        
                    self.save_checkpoint(checkpoint, checkpoint_path)
                    print(f"Updated best model checkpoint with combined metric: {combined_metric:.2f}")

         # Print final summary with best metrics
        print(f"\nTraining completed!")
        # Print final metrics with best model
        print("\nFinal evaluation with best model:\n")
        print(f"Best combined metric at epoch {best_metrics['best_epoch']+1}")
        print(f"Best validation completeness: {best_metrics['val_completeness']:.2f}%")
        print(f"Best validation soundness: {best_metrics['val_soundness']:.2f}%")
        
        # Load the best model state
        if best_model_state:
            self.model.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                      for k, v in best_model_state['model_state_dict'].items()})
            self.merlin.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                       for k, v in best_model_state['merlin_state_dict'].items()})
            self.morgana.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                        for k, v in best_model_state['morgana_state_dict'].items()})


        compute_confusion_matrix(
            loader=self.val_loader, 
            set_name="Validation",
            approach=self.approach,
            model=self.model,
            merlin=self.merlin,
            morgana=self.morgana,
            device=self.device,
            num_classes=self.num_classes,
            logger=self.logger,
            save_path=None if save_path is None else f"{save_path}_val"
        )
        test_metrics = self._validate_sfw(self.test_loader)
        print(f"Best test completeness: {test_metrics['completeness']:.2f}%")
        print(f"Best test soundness: {test_metrics['soundness']:.2f}%")

        compute_confusion_matrix(
            loader=self.test_loader, 
            set_name="Test",
            approach=self.approach,
            model=self.model,
            merlin=self.merlin,
            morgana=self.morgana,
            device=self.device,
            num_classes=self.num_classes,
            logger=self.logger,
            save_path=None if save_path is None else f"{save_path}_test"
        )

        # Log final metrics to wandb if enabled
        if self.logger is not None:
            self.logger.log({
                'val/best_combined_metric': best_metrics['best_combined_metric'],
                'val/best_loss': best_metrics['val_loss'],
                'val/best_epoch': best_metrics['best_epoch'],
                'val/best_completeness': best_metrics['val_completeness'],
                'val/best_soundness': best_metrics['val_soundness'],
                'test/best_completeness': test_metrics['completeness'],
                'test/best_soundness': test_metrics['soundness']
            })
        
            if self.feature_distribution:
                compute_feature_distribution(
                    loader=self.val_loader, 
                    set_name="Validation",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    num_slots=self.num_slots,
                    num_blocks=self.num_blocks,
                    mask_size=self.mask_size,
                    enc_type=self.enc_type,
                    approach=self.approach,
                    logger=self.logger,
                    feature_interpretations=self.feature_interpretations
                )
                compute_feature_distribution(
                    loader=self.test_loader, 
                    set_name="Test",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    num_slots=self.num_slots,
                    num_blocks=self.num_blocks,
                    mask_size=self.mask_size,
                    enc_type=self.enc_type,
                    approach=self.approach,
                    logger=self.logger,
                    feature_interpretations=self.feature_interpretations
                )
        
        if self.compute_prec_and_ent:
            print("\nComputing additional metrics...")
            for target_class in range(self.num_classes):
                # Compute for validation set
                compute_precision_and_entropy(
                    loader=self.val_loader,
                    target_class=target_class,
                    tolerance=0.000001,
                    set_name="Validation",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    batch_size=self.batch_size,
                    num_workers=self.num_workers,
                    seed=self.seed,
                    approach=self.approach,
                    logger=self.logger
                )
                
                # Compute for test set
                compute_precision_and_entropy(
                    loader=self.test_loader,
                    target_class=target_class,
                    tolerance=0.000001,
                    set_name="Test",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    batch_size=self.batch_size,
                    num_workers=self.num_workers,
                    seed=self.seed,
                    approach=self.approach,
                    logger=self.logger
                )
        
        return best_metrics

    def _train_epoch_sfw(self):
        """Train model for one epoch using SFW Merlin-Arthur approach
        
        Returns:
            dict: Dictionary containing training metrics
        """
        self.model.train()
        total_loss = 0
        total_completeness = 0
        total_soundness = 0
        batch_count = 0

         # Progress bar for training
        pbar = tqdm(self.train_loader, desc='Training with Merlin-Arthur')
        
        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()
            
            # Step 1: Optimize masks using SFW
            continuous_mask_merlin = self.merlin(inputs, targets, self.model)
            continuous_mask_morgana = self.morgana(inputs, targets, self.model)
            
            # Step 2: Convert to binary masks using top-k selection
            binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)

            # Step 3: Apply mask and compute logits
            self.optimizer.zero_grad()
            masked_inputs_merlin = self.merlin.apply_mask(inputs, binary_mask_merlin)
            masked_inputs_morgana = self.morgana.apply_mask(inputs, binary_mask_morgana)
            logits_merlin = self.model(masked_inputs_merlin)
            logits_morgana = self.model(masked_inputs_morgana)

            # Step 4: Calculate loss
            merlin_loss = self.merlin.criterion(logits_merlin, targets)
            morgana_loss = self.morgana.criterion(logits_morgana, targets)
            loss = merlin_loss + self.gamma * morgana_loss     

            # Backward pass for classifier
            loss.backward()
            self.optimizer.step()
            
            # Update metrics
            total_loss += loss.item()

            # Calculate accuracies (completeness and soundness)
            batch_completeness = get_accuracy(logits_merlin, targets, mode="merlin", idk_class=self.num_classes)
            batch_soundness = get_accuracy(logits_morgana, targets, mode="morgana", idk_class=self.num_classes)
            
            # Accumulate for epoch average
            total_completeness += batch_completeness
            total_soundness += batch_soundness
            batch_count += 1   

            # Update progress bar with overlap info and accuracy
            pbar.set_postfix({
                'loss': f'{total_loss/batch_count:.4f}',
                'comp': f'{100.*batch_completeness:.2f}%',
                'sound': f'{100.*batch_soundness:.2f}%'
            })    

        avg_completeness = total_completeness / batch_count if batch_count > 0 else 0
        avg_soundness = total_soundness / batch_count if batch_count > 0 else 0

        # Compute epoch metrics
        metrics = {
            'train_loss': total_loss / batch_count,
            'train_completeness': 100. * avg_completeness,
            'train_soundness': 100. * avg_soundness
        }
        
        return metrics
    
    def _validate_sfw(self, loader: DataLoader):
        """Validate model using SFW Merlin-Arthur approach
        
        Returns:
            dict: Dictionary containing validation metrics
        """
        self.model.eval()
        total_loss = 0
        total_completeness = 0
        total_soundness = 0
        batch_count = 0
        
        # Progress bar for validation
        pbar = tqdm(loader, desc='Validating with Merlin-Arthur' if loader == self.val_loader else 'Testing with Merlin-Arthur')
        
        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()

            # Temporarily enable gradients for mask optimization
            with torch.enable_grad():
                # Optimize masks (in eval mode) - requires gradients
                continuous_mask_merlin = self.merlin(inputs, targets, self.model)
                continuous_mask_morgana = self.morgana(inputs, targets, self.model)
            
            # Disable gradients for the rest of the validation process
            with torch.no_grad():
                # Convert to binary masks using top-k selection
                binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
                binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)

                # Apply masks and get predictions
                masked_inputs_merlin = self.merlin.apply_mask(inputs, binary_mask_merlin)
                masked_inputs_morgana = self.morgana.apply_mask(inputs, binary_mask_morgana)
                
                logits_merlin = self.model(masked_inputs_merlin)
                logits_morgana = self.model(masked_inputs_morgana)
                
                # Calculate loss
                merlin_loss = self.merlin.criterion(logits_merlin, targets)
                morgana_loss = self.morgana.criterion(logits_morgana, targets)
                loss = merlin_loss + self.gamma * morgana_loss
                
                # Update metrics
                total_loss += loss.item()
                
                # Calculate accuracies (completeness and soundness)
                batch_completeness = get_accuracy(logits_merlin, targets, mode="merlin", idk_class=self.num_classes)
                batch_soundness = get_accuracy(logits_morgana, targets, mode="morgana", idk_class=self.num_classes)
                
                # Accumulate for epoch average
                total_completeness += batch_completeness
                total_soundness += batch_soundness
                batch_count += 1
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f'{total_loss/batch_count:.4f}',
                    'comp': f'{100.*batch_completeness:.2f}%',
                    'sound': f'{100.*batch_soundness:.2f}%'
                })
        
        # Calculate epoch-level metrics
        avg_completeness = total_completeness / batch_count if batch_count > 0 else 0
        avg_soundness = total_soundness / batch_count if batch_count > 0 else 0
                
        # Compute validation metrics
        metrics = {
            'loss': total_loss / batch_count,
            'acc': 100. * avg_completeness,  # Keep val_acc for backward compatibility
            'completeness': 100. * avg_completeness,
            'soundness': 100. * avg_soundness
        }
        
        return metrics
    
    def _train_with_learnable_fs(self):
        """Merlin-Arthur training approach with learnable feature selectors (basically the same as SFW function)
        
        Returns:
            dict: Dictionary containing best validation metrics
        """
        print(f"\nStarting Merlin-Arthur training with learnable feature selectors for {self.epochs} epochs...")
            
        best_metrics = {
            'best_combined_metric': 0,
            'best_epoch': -1,
            'val_loss': float('inf'),
            'val_completeness': 0,
            'val_soundness': 0,
            'train_completeness': 0,
            'train_soundness': 0,
            'train_loss': float('inf')
        }

        best_model_state = None  # Store the best model state

        # Determine save path for confusion matrix plots
        save_path = None
        if self.save_confusion_matrix:
            base_dir = f'{self.approach}_{self.model_name}_on_{self.enc_type}_seed{self.seed}'
            if self.logger is not None:
                base_dir = f'{base_dir}_{self.logger.run.name}'
            save_dir = os.path.join(self.res_dir, base_dir)
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, 'confusion_matrix')

        for epoch in range(self.epochs):
            print(f"\nEpoch {epoch+1}/{self.epochs}")
            
            # Train and validate
            train_metrics = self._train_epoch_with_learnable_fs()
            val_metrics = self._validate_with_learnable_fs(self.val_loader)
            
            # Update validation metrics to include completeness and soundness
            if 'val_completeness' not in val_metrics:
                # For backward compatibility, use val_acc as completeness if not already present
                val_metrics['val_completeness'] = val_metrics['acc']
                val_metrics['val_soundness'] = 0  # Default if not present
            
            # Log metrics to wandb if enabled
            if self.logger is not None:
                metrics_dict = {
                    'epoch': epoch,
                    'train/loss': train_metrics['train_loss'],
                    'train/completeness': train_metrics['train_completeness'],
                    'train/soundness': train_metrics['train_soundness'],
                    'val/loss': val_metrics['loss'],
                    'val/completeness': val_metrics['completeness'],
                    'val/soundness': val_metrics['soundness'],
                    'mask/sparsity': self.mask_size,
                }
                self.logger.log(metrics_dict, step=epoch)

            # Consolidated Epoch Summary
            print("\n" + "=" * 60)
            print("EPOCH SUMMARY")
            print("=" * 60)
            print(f"Epoch {epoch+1}/{self.epochs}")
            print("-" * 60)
            print("Training Statistics:")
            print(f"   Loss            : {train_metrics['train_loss']:.4f}")
            print(f"   Completeness    : {train_metrics['train_completeness']:.2f}%")
            print(f"   Soundness       : {train_metrics['train_soundness']:.2f}%")
            print("-" * 60)
            print("Validation Statistics:")
            print(f"   Loss            : {val_metrics['loss']:.4f}")
            print(f"   Completeness    : {val_metrics['completeness']:.2f}%")
            print(f"   Soundness       : {val_metrics['soundness']:.2f}%")
            print("-" * 60)

            # Update best metrics and save model
            if val_metrics['soundness'] > 90:
                # If soundness threshold is met, add a large bonus to ensure it's better than sub-90 models
                # but still maintains the comp+sound ordering among models above 90%
                combined_metric = val_metrics['completeness'] + val_metrics['soundness'] + 100
            else:
                combined_metric = val_metrics['completeness'] + val_metrics['soundness']
                
            if combined_metric > best_metrics['best_combined_metric']:
                best_metrics.update({
                    'best_combined_metric': combined_metric,
                    'best_epoch': epoch,
                    'val_loss': val_metrics['loss'],
                    'val_completeness': val_metrics['completeness'],
                    'val_soundness': val_metrics['soundness'],
                    'train_completeness': train_metrics['train_completeness'],
                    'train_soundness': train_metrics['train_soundness'],
                    'train_loss': train_metrics['train_loss']
                })
                
                # Deep copy of the model's state dictionary
                best_model_state = {
                    'epoch': epoch,
                    'model_state_dict': {k: v.cpu().clone() for k, v in self.model.state_dict().items()},
                    'merlin_state_dict': {k: v.cpu().clone() for k, v in self.merlin.state_dict().items()},
                    'morgana_state_dict': {k: v.cpu().clone() for k, v in self.morgana.state_dict().items()}
                }
                
                if self.save_model:
                    checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'merlin_state_dict': self.merlin.state_dict(),
                        'morgana_state_dict': self.morgana.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'merlin_optimizer_state_dict': self.merlin_optimizer.state_dict(),
                        'morgana_optimizer_state_dict': self.morgana_optimizer.state_dict(),
                        **best_metrics
                    }

                    # Base directory name
                    base_dir = f'{self.approach}_{self.fs_model}_classifier_{self.model_name}_on_{self.enc_type}_seed{self.seed}'
                    
                    # If using wandb, append the run name
                    if self.logger is not None:
                        base_dir = f'{base_dir}_{self.logger.run.name}'

                    checkpoint_dir = os.path.join(self.res_dir, base_dir)
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
                        
                    self.save_checkpoint(checkpoint, checkpoint_path)
                    print(f"Updated best model checkpoint with combined metric: {combined_metric:.2f}")

         # Print final summary with best metrics
        print(f"\nTraining completed!")
        # Print final metrics with best model
        print("\nFinal evaluation with best model:\n")
        print(f"Best combined metric at epoch {best_metrics['best_epoch']+1}")
        print(f"Best validation completeness: {best_metrics['val_completeness']:.2f}%")
        print(f"Best validation soundness: {best_metrics['val_soundness']:.2f}%")
        
        # Load the best model state
        if best_model_state:
            self.model.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                      for k, v in best_model_state['model_state_dict'].items()})
            self.merlin.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                       for k, v in best_model_state['merlin_state_dict'].items()})
            self.morgana.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                        for k, v in best_model_state['morgana_state_dict'].items()})


        compute_confusion_matrix(
            loader=self.val_loader, 
            set_name="Validation",
            approach=self.approach,
            model=self.model,
            merlin=self.merlin,
            morgana=self.morgana,
            device=self.device,
            num_classes=self.num_classes,
            logger=self.logger,
            save_path=None if save_path is None else f"{save_path}_val"
        )
        test_metrics = self._validate_with_learnable_fs(self.test_loader)
        print(f"Best test completeness: {test_metrics['completeness']:.2f}%")
        print(f"Best test soundness: {test_metrics['soundness']:.2f}%")
        compute_confusion_matrix(
            loader=self.test_loader, 
            set_name="Test",
            approach=self.approach,
            model=self.model,
            merlin=self.merlin,
            morgana=self.morgana,
            device=self.device,
            num_classes=self.num_classes,
            logger=self.logger,
            save_path=None if save_path is None else f"{save_path}_test"
        )

        # Log final metrics to wandb if enabled
        if self.logger is not None:
            self.logger.log({
                'val/best_combined_metric': best_metrics['best_combined_metric'],
                'val/best_loss': best_metrics['val_loss'],
                'val/best_epoch': best_metrics['best_epoch'],
                'val/best_completeness': best_metrics['val_completeness'],
                'val/best_soundness': best_metrics['val_soundness'],
                'test/best_completeness': test_metrics['completeness'],
                'test/best_soundness': test_metrics['soundness']
            })
        
            if self.feature_distribution:
                compute_feature_distribution(
                    loader=self.val_loader, 
                    set_name="Validation",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    num_slots=self.num_slots,
                    num_blocks=self.num_blocks,
                    mask_size=self.mask_size,
                    enc_type=self.enc_type,
                    approach=self.approach,
                    logger=self.logger,
                    feature_interpretations=self.feature_interpretations
                )
                compute_feature_distribution(
                    loader=self.test_loader, 
                    set_name="Test",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    num_slots=self.num_slots,
                    num_blocks=self.num_blocks,
                    mask_size=self.mask_size,
                    enc_type=self.enc_type,
                    approach=self.approach,
                    logger=self.logger,
                    feature_interpretations=self.feature_interpretations
                )
        
        if self.compute_prec_and_ent:
            print("\nComputing additional metrics...")
            for target_class in range(self.num_classes):
                # Compute for validation set
                compute_precision_and_entropy(
                    loader=self.val_loader,
                    target_class=target_class,
                    tolerance=0.000001,
                    set_name="Validation",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    batch_size=self.batch_size,
                    num_workers=self.num_workers,
                    seed=self.seed,
                    approach=self.approach,
                    logger=self.logger
                )
                
                # Compute for test set
                compute_precision_and_entropy(
                    loader=self.test_loader,
                    target_class=target_class,
                    tolerance=0.000001,
                    set_name="Test",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    batch_size=self.batch_size,
                    num_workers=self.num_workers,
                    seed=self.seed,
                    approach=self.approach,
                    logger=self.logger
                )
        
        if self.compute_avg_occ:
            print("Compute average occurrence")
            compute_average_occurrence(
                    loader=self.val_loader, 
                    set_name="Validation",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    num_slots=self.num_slots,
                    num_blocks=self.num_blocks,
                    mask_size=self.mask_size,
                    enc_type=self.enc_type,
                    approach=self.approach,
                    logger=self.logger
                )
            compute_average_occurrence(
                    loader=self.test_loader, 
                    set_name="Test",
                    model=self.model,
                    merlin=self.merlin,
                    morgana=self.morgana,
                    device=self.device,
                    num_classes=self.num_classes,
                    num_slots=self.num_slots,
                    num_blocks=self.num_blocks,
                    mask_size=self.mask_size,
                    enc_type=self.enc_type,
                    approach=self.approach,
                    logger=self.logger
                )

        return best_metrics

    def _train_epoch_with_learnable_fs(self):
        """Train model for one epoch using Merlin-Arthur approach with learnable feature selectors
        
        Returns:
            dict: Dictionary containing training metrics
        """
        self.model.train()
        self.merlin.train()
        self.morgana.train()

        total_loss = 0
        total_completeness = 0
        total_soundness = 0
        batch_count = 0

         # Progress bar for training
        pbar = tqdm(self.train_loader, desc='Training with Merlin-Arthur')
        
        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()
            
            # Step 1: Optimize masks using learnable feature selectors
            continuous_mask_merlin = self._optimize_learnable_fs(inputs, targets, self.merlin, self.merlin_optimizer)
            continuous_mask_morgana = self._optimize_learnable_fs(inputs, targets, self.morgana, self.morgana_optimizer)
            
            # Step 2: Convert to binary masks using top-k selection
            binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)

            # Step 3: Apply mask and compute logits

            masked_inputs_merlin = self.merlin.apply_mask(inputs, binary_mask_merlin)
            masked_inputs_morgana = self.morgana.apply_mask(inputs, binary_mask_morgana)

            self.model.eval()  # NOTE: Need to be in eval mode to prevent batchnorm from updating

            logits_merlin = self.model(masked_inputs_merlin)
            logits_morgana = self.model(masked_inputs_morgana)

            # Step 4: Calculate losses
            # Merlin and Morgana losses
            merlin_loss = self.merlin.criterion(logits_merlin, targets)
            morgana_loss = self.morgana.criterion(logits_morgana, targets)
            
            # Combined loss for the classifier
            loss = merlin_loss + self.gamma * morgana_loss

            loss.backward()

            # Update optimizer
            self.optimizer.step()

            self.optimizer.zero_grad()
            self.merlin_optimizer.zero_grad()
            self.morgana_optimizer.zero_grad()
            
            # Update metrics
            total_loss += loss.item()

            # Calculate accuracies (completeness and soundness)
            batch_completeness = get_accuracy(logits_merlin, targets, mode="merlin", idk_class=self.num_classes)
            batch_soundness = get_accuracy(logits_morgana, targets, mode="morgana", idk_class=self.num_classes)
            
            # Accumulate for epoch average
            total_completeness += batch_completeness
            total_soundness += batch_soundness
            batch_count += 1   

            # Update progress bar with overlap info and accuracy
            pbar.set_postfix({
                'loss': f'{total_loss/batch_count:.4f}',
                'comp': f'{100.*batch_completeness:.2f}%',
                'sound': f'{100.*batch_soundness:.2f}%'
            })    

        avg_completeness = total_completeness / batch_count if batch_count > 0 else 0
        avg_soundness = total_soundness / batch_count if batch_count > 0 else 0

        # Compute epoch metrics
        metrics = {
            'train_loss': total_loss / batch_count,
            'train_completeness': 100. * avg_completeness,
            'train_soundness': 100. * avg_soundness
        }
        
        return metrics


    def _optimize_learnable_fs(self, inputs, targets, feature_selector, optimizer):
        """
        Single optimization step for the feature selector model (Merlin/Morgana)
        """
        self.model.eval()

        continuous_mask = feature_selector(inputs)
        l1_penalty = self.l1_penalty_coefficient * torch.mean(torch.abs(continuous_mask))
    
        masked_inputs = feature_selector.apply_mask(inputs, continuous_mask)
        logits = self.model(masked_inputs)

        if feature_selector.mode == "merlin":
            loss = feature_selector.criterion(logits, targets) + l1_penalty
        elif feature_selector.mode == "morgana":
            loss = -feature_selector.criterion(logits, targets) + l1_penalty

        loss.backward()
        optimizer.step()
        self.optimizer.zero_grad() # Arthur optimizer
        optimizer.zero_grad() # Feature selector optimizer

        return continuous_mask
    

    def _validate_with_learnable_fs(self, loader: DataLoader):
        """Validate model using Merlin-Arthur approach with learnable feature selectors
        
        Returns:
            dict: Dictionary containing validation metrics
        """
        self.model.eval()
        self.merlin.eval()
        self.morgana.eval()

        total_loss = 0
        total_completeness = 0
        total_soundness = 0
        batch_count = 0
        
        # Progress bar for validation
        pbar = tqdm(loader, desc='Validating with Merlin-Arthur' if loader == self.val_loader else 'Testing with Merlin-Arthur')
        
        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()

            # Disable gradients for the validation process
            with torch.no_grad():
                continuous_mask_merlin = self.merlin(inputs)
                continuous_mask_morgana = self.morgana(inputs)
                
                # Convert to binary masks using top-k selection
                binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
                binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)
            
                # Apply masks and get predictions
                masked_inputs_merlin = self.merlin.apply_mask(inputs, binary_mask_merlin)
                masked_inputs_morgana = self.morgana.apply_mask(inputs, binary_mask_morgana)
                
                logits_merlin = self.model(masked_inputs_merlin)
                logits_morgana = self.model(masked_inputs_morgana)
                
                # Calculate loss
                merlin_loss = self.merlin.criterion(logits_merlin, targets)
                morgana_loss = self.morgana.criterion(logits_morgana, targets)
                loss = merlin_loss + self.gamma * morgana_loss
                
                # Update metrics
                total_loss += loss.item()
                
                # Calculate accuracies (completeness and soundness)
                batch_completeness = get_accuracy(logits_merlin, targets, mode="merlin", idk_class=self.num_classes)
                batch_soundness = get_accuracy(logits_morgana, targets, mode="morgana", idk_class=self.num_classes)
                
                # Accumulate for epoch average
                total_completeness += batch_completeness
                total_soundness += batch_soundness
                batch_count += 1
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f'{total_loss/batch_count:.4f}',
                    'comp': f'{100.*batch_completeness:.2f}%',
                    'sound': f'{100.*batch_soundness:.2f}%'
                })
        
        # Calculate epoch-level metrics
        avg_completeness = total_completeness / batch_count if batch_count > 0 else 0
        avg_soundness = total_soundness / batch_count if batch_count > 0 else 0
                
        # Compute validation metrics
        metrics = {
            'loss': total_loss / batch_count,
            'acc': 100. * avg_completeness,  # Keep val_acc for backward compatibility
            'completeness': 100. * avg_completeness,
            'soundness': 100. * avg_soundness
        }
        
        return metrics
    
    def save_checkpoint(self, state: Dict, filename: str):
        """Save model checkpoint"""
        torch.save(state, filename)