import os
import torch
import torch.nn as nn
import random
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm   
from copy import deepcopy

from data_ncb.CLEVR_Hans_image_dataset import CLEVRHansDataset
from utils.metrics import get_accuracy
from pixel_space.classifier_pixel import ResNet18
from pixel_space.config_dataclass_pixel import TrainerConfig, DatasetConfig, ModelConfig, BooleanConfig, FeatureSelectorConfig
from pixel_space.feature_selectors_pixel import SFWPixelFeatureSelector, ModelPixelFeatureSelector, SimpleNet

class BasePixelSpaceTrainer:
    def __init__(
        self, 
        trainer_config: TrainerConfig,
        dataset_config: DatasetConfig,
        model_config: ModelConfig, 
        bool_config: BooleanConfig,
        feature_selector_config: FeatureSelectorConfig,
        logger=None
    ):
        """Initialize trainer with complete configuration setup
        
        Args:
            trainer_config: Configuration for training parameters
            dataset_config: Configuration for dataset parameters
            model_config: Configuration for model parameters
            bool_config: Configuration for boolean parameters
            logger: Logger for logging training metrics
        """
        # Setup device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Setup all configurations
        self._setup_trainer_config(trainer_config)
        self._setup_dataset_config(dataset_config)
        self._setup_model_config(model_config)
        self._setup_bool_config(bool_config)
        self._setup_feature_selector_config(feature_selector_config)
        
        # Initialize empty attributes
        self.model = None
        self.optimizer = None
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None

        # set seed
        self._setup_seed(self.seed)

        # Store logger instance
        self.logger = logger

    def _setup_trainer_config(self, config: TrainerConfig) -> None:
        """Setup trainer configuration parameters"""
        self.epochs = config.epochs
        self.approach = config.approach
        self.learning_rate = config.lr
        self.weight_decay = config.weight_decay
        self.seed = config.seed
        self.res_dir = config.res_dir
        self.use_wandb = config.wandb
        self.early_stopping = config.early_stopping
        self.patience = config.patience
        self.trainer_config = config

    def _setup_dataset_config(self, config: DatasetConfig) -> None:
        """Setup dataset configuration parameters"""
        self.data_dir = config.data_dir
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.dataset_config = config
        self.unconf_split = config.unconf_split
        self.num_classes = config.num_classes

    def _setup_model_config(self, config: ModelConfig) -> None:
        """Setup model configuration parameters"""
        self.model_name = config.model
        self.pretrained_model = config.pretrained_model
        self.pretrained_path = config.pretrained_path
        self.model_config = config

    def _setup_bool_config(self, config: BooleanConfig) -> None:
        """Setup boolean configuration parameters"""
        self.save_model = config.save_model
        self.boolean_config = config

    def _setup_feature_selector_config(self, config: FeatureSelectorConfig) -> None:
        """Setup feature selector configuration parameters"""
        self.feature_selector_config = config
        # If lr_fs is set, use it for both lr_merlin and lr_morgana
        if config.lr_fs is not None:
            print(f"Using lr_fs for both Merlin and Morgana")
            self.lr_merlin = config.lr_fs
            self.lr_morgana = config.lr_fs
        else:
            self.lr_merlin = config.lr_merlin
            self.lr_morgana = config.lr_morgana
        if config.weight_decay_fs is not None:
            print(f"Using weight_decay_fs for both Merlin and Morgana")
            self.weight_decay_merlin = config.weight_decay_fs
            self.weight_decay_morgana = config.weight_decay_fs
        else:
            self.weight_decay_merlin = config.weight_decay_merlin
            self.weight_decay_morgana = config.weight_decay_morgana
        self.gamma = config.gamma
        self.l1_penalty_coefficient = config.l1_penalty_coefficient
        self.l2_penalty_coefficient = config.l2_penalty_coefficient
        self.tv_penalty_coefficient = config.tv_penalty_coefficient
        self.mask_size = config.mask_size
        self.sfw_max_iterations = config.sfw_max_iterations
        self.sfw_patience = config.sfw_patience
        self.unet_steps = config.unet_steps

    def _setup_seed(self, seed: int) -> None:
        """Setup all seeds for full reproducibility
        Args:
            seed: Integer seed value
        """        
        # Python
        random.seed(seed)
        # Numpy
        np.random.seed(seed)
        # PyTorch
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # for multi-GPU
        # Deterministic operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        print(f"\nSeed set to {seed} for reproducibility!")

    def setup_data(self):
        print(f'Loading images from {self.data_dir}')
        self._load_images()
        self._print_dataset_info(self.train_loader, self.val_loader, self.test_loader)

    def _load_images(self):
        """Load and process image datasets"""

        conf_version = self.data_dir.split(os.path.sep)[-2]

        # Create datasets with wrapper
        dataset_train = CLEVRHansDatasetWrapper(
            self.data_dir, "train", lexi=True, conf_vers=conf_version
        )
        dataset_val = CLEVRHansDatasetWrapper(
            self.data_dir, "val", lexi=True, conf_vers=conf_version
        )
        dataset_test = CLEVRHansDatasetWrapper(
            self.data_dir, "test", lexi=True, conf_vers=conf_version
        )

        if self.unconf_split:
            print('Splitting validation set from train set...')
            # Use val set as test set
            dataset_test = dataset_val
            
            # Calculate how many samples we need per class to match test set size
            test_size = len(dataset_test)
            samples_per_class = test_size // self.num_classes
            
            print(f"Creating validation set with {test_size} samples ({samples_per_class} per class)")
            
            # Create class-wise indices more efficiently
            class_indices = [[] for _ in range(self.num_classes)]
            
            # Use a DataLoader to efficiently iterate through the dataset
            temp_loader = DataLoader(
                dataset_train, 
                batch_size=1000,  # Large batch size for efficiency
                shuffle=False,
                num_workers=self.num_workers
            )
            
            idx = 0
            for batch_imgs, batch_labels in temp_loader:
                for label in batch_labels:
                    class_indices[label.item()].append(idx)
                    idx += 1
            
            # Create indices for validation and training sets
            val_indices = []
            train_indices = []
            
            # Process each class
            for class_idx in range(self.num_classes):
                indices = class_indices[class_idx]
                
                if len(indices) < samples_per_class:
                    print(f"Warning: Only {len(indices)} samples available for class {class_idx}, using all of them")
                    val_samples = len(indices)
                else:
                    val_samples = samples_per_class
                
                # Randomly shuffle the indices
                np.random.shuffle(indices)
                
                # Split indices for val and remaining train
                val_indices.extend(indices[:val_samples])
                train_indices.extend(indices[val_samples:])
            
            # Create subset datasets
            dataset_val = Subset(dataset_train, val_indices)
            dataset_train = Subset(dataset_train, train_indices)
            
            print(f"Split train set into {len(dataset_train)} train and {len(dataset_val)} validation samples")

        # Create DataLoaders
        self.train_loader = DataLoader(
            dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        self.val_loader = DataLoader(
            dataset_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        self.test_loader = DataLoader(
            dataset_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def _print_dataset_info(self, train_loader, val_loader, test_loader):
        print(f"Train dataset size: {len(train_loader.dataset)}")
        print(f"Validation dataset size: {len(val_loader.dataset)}")
        print(f"Test dataset size: {len(test_loader.dataset)}")

        # Print dimensions of one datapoint
        sample_data, _ = next(iter(train_loader))
        print(f"Input data dimension: {sample_data.shape[1:]}")

    def setup_model(self):
        if self.model_name.lower() == 'resnet18':
            self.model = ResNet18(dim_output=self.num_classes+1)
        else:
            raise ValueError(f"Model {self.model_name} not supported")
        
        # Move model to device
        self.model = self.model.to(self.device)
        
        # Setup optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )

        # Setup criterion
        self.criterion = nn.CrossEntropyLoss()

        self._print_model_info()

        if self.approach == "sfw":
            # Initialize feature selector (Merlin)
            self.merlin = SFWPixelFeatureSelector(
                mask_size=self.mask_size,
                mode="merlin",
                lr=self.lr_merlin,
                l1_penalty_coefficient=self.l1_penalty_coefficient,
                sfw_max_iterations=self.sfw_max_iterations,
                sfw_patience=self.sfw_patience
            ).to(self.device)

            # Initialize feature selector (Morgana)
            self.morgana = SFWPixelFeatureSelector(
                mask_size=self.mask_size,
                mode="morgana",
                lr=self.lr_morgana,
                l1_penalty_coefficient=self.l1_penalty_coefficient,
                idk_class=self.num_classes,
                sfw_max_iterations=self.sfw_max_iterations,
                sfw_patience=self.sfw_patience
            ).to(self.device)

        elif self.approach == "unet":
            self.merlin_model = SimpleNet(
                n_channels = 3, 
                bilinear = True, 
                apply_sigmoid = True
                )
            
            # Create a deep copy for morgana
            self.morgana_model = deepcopy(self.merlin_model)
            
            self.merlin = ModelPixelFeatureSelector(
                mask_size=self.mask_size,
                mode="merlin",
                idk_class=self.num_classes,
                model=self.merlin_model
            ).to(self.device)

            self.morgana = ModelPixelFeatureSelector(
                mask_size=self.mask_size,
                mode="morgana",
                idk_class=self.num_classes,
                model=self.morgana_model
            ).to(self.device)

            self.merlin_optimizer = torch.optim.Adam(self.merlin.parameters(), lr=self.lr_merlin, weight_decay=self.weight_decay_merlin)
            self.morgana_optimizer = torch.optim.Adam(self.morgana.parameters(), lr=self.lr_morgana, weight_decay=self.weight_decay_morgana)

        if self.pretrained_model:
            try:
                print(f"\nLoading pretrained model from {self.pretrained_path}")
                checkpoint = torch.load(self.pretrained_path, map_location=self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                                
                # Print information about the loaded model
                print(f"Loaded checkpoint from epoch {checkpoint['epoch']+1}")
                if 'best_val_acc' in checkpoint:
                    print(f"Validation accuracy: {checkpoint['best_val_acc']:.2f}%")
                    
                # Load merlin and morgana state dicts if they exist
                if 'merlin_state_dict' in checkpoint:
                    self.merlin.load_state_dict(checkpoint['merlin_state_dict'])
                    print(f"Loaded merlin state dict from epoch {checkpoint['epoch']+1}")
                if 'morgana_state_dict' in checkpoint:
                    self.morgana.load_state_dict(checkpoint['morgana_state_dict'])
                    print(f"Loaded morgana state dict from epoch {checkpoint['epoch']+1}")
            except FileNotFoundError:
                print(f"\nWARNING: Pretrained model not found at {self.pretrained_path}")
                print("Starting with randomly initialized model.")

        if self.approach != 'regular':
            self._print_feature_selector_info()

    def _print_model_info(self):
        print("\nModel setup:")
        print(f"Architecture: {self.model_name}")
        print(f"Number of parameters: {sum(p.numel() for p in self.model.parameters())}")
        print(f"Learning rate: {self.learning_rate}")
        print(f"Device: {self.device}")

    def _print_feature_selector_info(self):
        """Helper method to print feature selector information"""
        print("\nFeature Selector setup:")
        if self.approach == "sfw":
            print(f"Method: Stochastic Frank-Wolfe (SFW)")
            print(f"SFW max iterations: {self.sfw_max_iterations}")
            print(f"SFW patience: {self.sfw_patience}")
        elif self.approach == "unet":
            print(f"Method: U-Net")
            num_params = sum(p.numel() for p in self.merlin.parameters()) # type: ignore
            print(f"Number of parameters: {num_params:,}")  # Formatted with commas
        print(f"Mask size (sparsity): {self.mask_size} features")
        print(f"Gamma: {self.gamma}")
        print(f"Merlin learning rate: {self.lr_merlin}")
        print(f"Morgana learning rate: {self.lr_morgana}")
        print(f"L1 penalty coefficient: {self.l1_penalty_coefficient}")
        print(f"Device: {self.device}")

    def train(self):
        if self.approach == 'regular':
            return self._train_regular()
        elif self.approach == 'sfw':    
            return self._train_sfw()
        elif self.approach == 'unet':
            return self._train_unet()
        else:
            raise ValueError(f"Approach {self.approach} not supported")
        
    def train_epoch(self):
        """Train model for one epoch using current approach
        
        Returns:
            dict: Dictionary containing training metrics
        """
        if self.approach == "regular":
            return self._train_epoch_regular()
        elif self.approach == "sfw":
            return self._train_epoch_sfw()
        elif self.approach == "unet":
            return self._train_epoch_unet()
        else:
            raise ValueError(f"Approach {self.approach} not supported")
    
    def validate(self, loader: DataLoader):
        """Run validation based on specified approach
        
        Returns:
            dict: Dictionary containing validation metrics
        """
        if self.approach == "regular":
            return self._validate_regular(loader)
        elif self.approach == "sfw":
            return self._validate_sfw(loader)
        elif self.approach == "unet":
            return self._validate_unet(loader)
        else:
            raise ValueError(f"Approach {self.approach} not supported")
        
    def _train_regular(self):
        """Regular training approach
        
        Returns:
            dict: Dictionary containing best validation metrics
        """
        print(f"\nStarting regular training for {self.epochs} epochs...")
        best_metrics = {
            'best_val_acc': 0,
            'best_epoch': -1,
            'val_loss': float('inf')
        }
        no_improvement = 0  # Counter for early stopping
        best_model_state = None  # Store the best model state

        for epoch in range(self.epochs):
            print(f"\nEpoch {epoch+1}/{self.epochs}")
            
            # Train and validate
            train_metrics = self.train_epoch()
            val_metrics = self.validate(self.val_loader)
            
            # Log metrics to wandb if enabled
            if self.logger is not None:
                self.logger.log({
                    'epoch': epoch,
                    'train/loss': train_metrics['train_loss'],
                    'train/accuracy': train_metrics['train_acc'],
                    'val/loss': val_metrics['loss'],
                    'val/accuracy': val_metrics['acc'],
                    'learning_rate': self.learning_rate,
                }, step=epoch)

            # Early stopping based on validation loss improvement
            if val_metrics['loss'] < best_metrics['val_loss']:
                best_metrics.update({
                    'best_val_acc': val_metrics['acc'],
                    'best_epoch': epoch,
                    'val_loss': val_metrics['loss'],
                    'train_acc': train_metrics['train_acc'],
                    'train_loss': train_metrics['train_loss']
                })
                no_improvement = 0  # Reset counter on improvement
                
                # Deep copy of the model's state dictionary
                best_model_state = {
                    'epoch': epoch,
                    'model_state_dict': {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
                }

                if self.save_model:
                    checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        **best_metrics
                    }

                    # Base directory name
                    base_dir = f'{self.approach}_{self.model_name}_seed{self.seed}'
                    
                    # If using wandb, append the run name
                    if self.logger is not None:
                        base_dir = f'{base_dir}_{self.logger.run.name}'

                    checkpoint_dir = os.path.join(self.res_dir, base_dir)
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
                    
                    torch.save(checkpoint, checkpoint_path)
                    print(f"Updated best model checkpoint with validation accuracy: {best_metrics['best_val_acc']:.2f}%")
            else:
                no_improvement += 1
                print(f"No improvement for {no_improvement} epoch(s).")

            # Check early stopping criterion
            if self.early_stopping and no_improvement >= self.patience:
                print(f"Early stopping triggered after {no_improvement} epochs with no improvement.")
                break
                    
        print(f"\nTraining completed! Best validation accuracy: {best_metrics['best_val_acc']:.2f}% "
              f"at epoch {best_metrics['best_epoch']+1}\n")
        
        # Load the best model state
        self.model.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                    for k, v in best_model_state['model_state_dict'].items()})

        test_metrics = self.validate(self.test_loader)
        print(f"Test accuracy: {test_metrics['acc']:.2f}%\n")

        if self.logger is not None:
            self.logger.log({
                'val/best_loss': best_metrics['val_loss'],
                'val/best_epoch': best_metrics['best_epoch'],
                'val/best_accuracy': best_metrics['best_val_acc'],
                'test/best_accuracy': test_metrics['acc']
            })

        return best_metrics

    def _train_epoch_regular(self):
        self.model.train()

         # Progress bar for training
        pbar = tqdm(self.train_loader, desc='Training')

        total_loss = 0
        correct = 0
        total = 0

        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()

            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # Backward pass
            loss.backward()
            self.optimizer.step()

            # Update metrics
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{total_loss/len(self.train_loader):.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
        
        # Compute epoch metrics
        metrics = {
            'train_loss': total_loss / len(self.train_loader),
            'train_acc': 100. * correct / total
        }
        
        return metrics

    def _validate_regular(self, loader: DataLoader):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        # Progress bar for validation
        pbar = tqdm(loader, desc='Validating' if loader == self.val_loader else 'Testing')
        
        with torch.no_grad():
            for inputs, targets in pbar:
                # Move data to device
                inputs = inputs.to(self.device)
                targets = targets.to(self.device).long()
                
                # Forward pass
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                
                # Update metrics
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f'{total_loss/len(loader):.4f}',
                    'acc': f'{100.*correct/total:.2f}%'
                })
        
        # Compute validation metrics
        metrics = {
            'loss': total_loss / len(loader),
            'acc': 100. * correct / total
        }
        
        return metrics
    
    def _train_sfw(self):
        """Merlin-Arthur training approach using SFW
        
        Returns:
            dict: Dictionary containing best validation metrics
        """
        print(f"\nStarting Merlin-Arthur training with SFW for {self.epochs} epochs...")
            
        best_metrics = {
            'best_combined_metric': 0,
            'best_epoch': -1,
            'val_loss': float('inf'),
            'val_completeness': 0,
            'val_soundness': 0,
            'train_completeness': 0,
            'train_soundness': 0,
            'train_loss': float('inf')
        }
        
        best_model_state = None  # Store the best model state
        
        for epoch in range(self.epochs):
            print(f"\nEpoch {epoch+1}/{self.epochs}")
            
            # Train and validate
            train_metrics = self._train_epoch_sfw()
            val_metrics = self._validate_sfw(self.val_loader)
            
            # Update validation metrics to include completeness and soundness
            if 'val_completeness' not in val_metrics:
                # For backward compatibility, use val_acc as completeness if not already present
                val_metrics['val_completeness'] = val_metrics['acc']
                val_metrics['val_soundness'] = 0  # Default if not present
            
            # Log metrics to wandb if enabled
            if self.logger is not None:
                metrics_dict = {
                    'epoch': epoch,
                    'train/loss': train_metrics['train_loss'],
                    'train/completeness': train_metrics['train_completeness'],
                    'train/soundness': train_metrics['train_soundness'],
                    'val/loss': val_metrics['loss'],
                    'val/completeness': val_metrics['completeness'],
                    'val/soundness': val_metrics['soundness'],
                    'mask/sparsity': self.mask_size,
                }
                self.logger.log(metrics_dict, step=epoch)
            
            # Consolidated Epoch Summary
            print("\n" + "=" * 60)
            print("EPOCH SUMMARY")
            print("=" * 60)
            print(f"Epoch {epoch+1}/{self.epochs}")
            print("-" * 60)
            print("Training Statistics:")
            print(f"   Loss            : {train_metrics['train_loss']:.4f}")
            print(f"   Completeness    : {train_metrics['train_completeness']:.2f}%")
            print(f"   Soundness       : {train_metrics['train_soundness']:.2f}%")
            print("-" * 60)
            print("Validation Statistics:")
            print(f"   Loss            : {val_metrics['loss']:.4f}")
            print(f"   Completeness    : {val_metrics['completeness']:.2f}%")
            print(f"   Soundness       : {val_metrics['soundness']:.2f}%")
            print("-" * 60)

            # Update best metrics and save model
            if val_metrics['soundness'] > 90:
                # If soundness threshold is met, add a large bonus to ensure it's better than sub-90 models
                # but still maintains the comp+sound ordering among models above 90%
                combined_metric = val_metrics['completeness'] + val_metrics['soundness'] + 100
            else:
                combined_metric = val_metrics['completeness'] + val_metrics['soundness']

            if combined_metric > best_metrics['best_combined_metric']:
                best_metrics.update({
                    'best_combined_metric': combined_metric,
                    'best_epoch': epoch,
                    'val_loss': val_metrics['loss'],
                    'val_completeness': val_metrics['completeness'],
                    'val_soundness': val_metrics['soundness'],
                    'train_completeness': train_metrics['train_completeness'],
                    'train_soundness': train_metrics['train_soundness'],
                    'train_loss': train_metrics['train_loss']
                })
                
                # Deep copy of the model's state dictionary
                best_model_state = {
                    'epoch': epoch,
                    'model_state_dict': {k: v.cpu().clone() for k, v in self.model.state_dict().items()},
                    'merlin_state_dict': {k: v.cpu().clone() for k, v in self.merlin.state_dict().items()},
                    'morgana_state_dict': {k: v.cpu().clone() for k, v in self.morgana.state_dict().items()}
                }
                
                if self.save_model:
                    checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'merlin_state_dict': self.merlin.state_dict(),
                        'morgana_state_dict': self.morgana.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        **best_metrics
                    }
                    # Base directory name
                    base_dir = f'{self.approach}_{self.model_name}_seed{self.seed}'
                    
                    # If using wandb, append the run name
                    if self.logger is not None:
                        base_dir = f'{base_dir}_{self.logger.run.name}'
                    
                    checkpoint_dir = os.path.join(self.res_dir, base_dir)
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
                        
                    torch.save(checkpoint, checkpoint_path)
                    print(f"Updated best model checkpoint with combined metric: {combined_metric:.2f}")

        # Print final summary with best metrics
        print(f"\nTraining completed!")
        # Print final metrics with best model
        print("\nFinal evaluation with best model:\n")
        print(f"Best combined metric at epoch {best_metrics['best_epoch']+1}")
        print(f"Best validation completeness: {best_metrics['val_completeness']:.2f}%")
        print(f"Best validation soundness: {best_metrics['val_soundness']:.2f}%")
        
        # Load the best model state
        if best_model_state:
            self.model.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                      for k, v in best_model_state['model_state_dict'].items()})
            self.merlin.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                       for k, v in best_model_state['merlin_state_dict'].items()})
            self.morgana.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                        for k, v in best_model_state['morgana_state_dict'].items()})
            
        test_metrics = self._validate_sfw(self.test_loader)
        print(f"Best test completeness: {test_metrics['completeness']:.2f}%")
        print(f"Best test soundness: {test_metrics['soundness']:.2f}%")

        # Log final metrics to wandb if enabled
        if self.logger is not None:
            self.logger.log({
                'val/best_combined_metric': best_metrics['best_combined_metric'],
                'val/best_loss': best_metrics['val_loss'],
                'val/best_epoch': best_metrics['best_epoch'],
                'val/best_completeness': best_metrics['val_completeness'],
                'val/best_soundness': best_metrics['val_soundness'],
                'test/best_completeness': test_metrics['completeness'],
                'test/best_soundness': test_metrics['soundness']
            })

        return best_metrics

    def _train_epoch_sfw(self):
        """Train model for one epoch using SFW Merlin-Arthur approach
        
        Returns:
            dict: Dictionary containing training metrics
        """
        self.model.train()
        total_loss = 0
        total_completeness = 0
        total_soundness = 0
        batch_count = 0

         # Progress bar for training
        pbar = tqdm(self.train_loader, desc='Training with Merlin-Arthur')
        
        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()
            
            # Step 1: Optimize masks using SFW
            continuous_mask_merlin = self.merlin(inputs, targets, self.model)
            continuous_mask_morgana = self.morgana(inputs, targets, self.model)
            
            # Step 2: Convert to binary masks using top-k selection
            binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)

            # Step 3: Apply mask and compute logits
            self.optimizer.zero_grad()
            masked_inputs_merlin = self.merlin.apply_mask(inputs, binary_mask_merlin)
            masked_inputs_morgana = self.morgana.apply_mask(inputs, binary_mask_morgana)
            logits_merlin = self.model(masked_inputs_merlin)
            logits_morgana = self.model(masked_inputs_morgana)

            # Step 4: Calculate loss
            merlin_loss = self.merlin.criterion(logits_merlin, targets)
            morgana_loss = self.morgana.criterion(logits_morgana, targets)
            loss = merlin_loss + self.gamma * morgana_loss     

            # Backward pass for classifier
            loss.backward()
            self.optimizer.step()
            
            # Update metrics
            total_loss += loss.item()

            # Calculate accuracies (completeness and soundness)
            batch_completeness = get_accuracy(logits_merlin, targets, mode="merlin", idk_class=self.num_classes)
            batch_soundness = get_accuracy(logits_morgana, targets, mode="morgana", idk_class=self.num_classes)
            
            # Accumulate for epoch average
            total_completeness += batch_completeness
            total_soundness += batch_soundness
            batch_count += 1   

            # Update progress bar with overlap info and accuracy
            pbar.set_postfix({
                'loss': f'{total_loss/batch_count:.4f}',
                'comp': f'{100.*batch_completeness:.2f}%',
                'sound': f'{100.*batch_soundness:.2f}%'
            })    

        avg_completeness = total_completeness / batch_count if batch_count > 0 else 0
        avg_soundness = total_soundness / batch_count if batch_count > 0 else 0

        # Compute epoch metrics
        metrics = {
            'train_loss': total_loss / batch_count,
            'train_completeness': 100. * avg_completeness,
            'train_soundness': 100. * avg_soundness
        }
        
        return metrics

    def _validate_sfw(self, loader: DataLoader):
        """Validate model using SFW Merlin-Arthur approach
        
        Returns:
            dict: Dictionary containing validation metrics
        """
        self.model.eval()
        total_loss = 0
        total_completeness = 0
        total_soundness = 0
        batch_count = 0
        
        # Progress bar for validation
        pbar = tqdm(loader, desc='Validating with Merlin-Arthur' if loader == self.val_loader else 'Testing with Merlin-Arthur')
        
        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()

            # Temporarily enable gradients for mask optimization
            with torch.enable_grad():
                # Optimize masks (in eval mode) - requires gradients
                continuous_mask_merlin = self.merlin(inputs, targets, self.model)
                continuous_mask_morgana = self.morgana(inputs, targets, self.model)
            
            # Disable gradients for the rest of the validation process
            with torch.no_grad():
                # Convert to binary masks using top-k selection
                binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
                binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)

                # Apply masks and get predictions
                masked_inputs_merlin = self.merlin.apply_mask(inputs, binary_mask_merlin)
                masked_inputs_morgana = self.morgana.apply_mask(inputs, binary_mask_morgana)
                
                logits_merlin = self.model(masked_inputs_merlin)
                logits_morgana = self.model(masked_inputs_morgana)
                
                # Calculate loss
                merlin_loss = self.merlin.criterion(logits_merlin, targets)
                morgana_loss = self.morgana.criterion(logits_morgana, targets)
                loss = merlin_loss + self.gamma * morgana_loss
                
                # Update metrics
                total_loss += loss.item()
                
                # Calculate accuracies (completeness and soundness)
                batch_completeness = get_accuracy(logits_merlin, targets, mode="merlin", idk_class=self.num_classes)
                batch_soundness = get_accuracy(logits_morgana, targets, mode="morgana", idk_class=self.num_classes)
                
                # Accumulate for epoch average
                total_completeness += batch_completeness
                total_soundness += batch_soundness
                batch_count += 1
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f'{total_loss/batch_count:.4f}',
                    'comp': f'{100.*batch_completeness:.2f}%',
                    'sound': f'{100.*batch_soundness:.2f}%'
                })
        
        # Calculate epoch-level metrics
        avg_completeness = total_completeness / batch_count if batch_count > 0 else 0
        avg_soundness = total_soundness / batch_count if batch_count > 0 else 0
                
        # Compute validation metrics
        metrics = {
            'loss': total_loss / batch_count,
            'acc': 100. * avg_completeness,  # Keep val_acc for backward compatibility
            'completeness': 100. * avg_completeness,
            'soundness': 100. * avg_soundness
        }
        
        return metrics
    
    def _train_unet(self):
        """Train model using U-Net approach
        
        Returns:
            dict: Dictionary containing training metrics
        """
        print(f"\nStarting Merlin-Arthur training with U-Net feature selectors for {self.epochs} epochs...")
            
        best_metrics = {
            'best_combined_metric': 0,
            'best_epoch': -1,
            'val_loss': float('inf'),
            'val_completeness': 0,
            'val_soundness': 0,
            'train_completeness': 0,
            'train_soundness': 0,
            'train_loss': float('inf')
        }

        best_model_state = None  # Store the best model state

        for epoch in range(self.epochs):
            print(f"\nEpoch {epoch+1}/{self.epochs}")
            
            # Train and validate
            train_metrics = self._train_epoch_unet()
            val_metrics = self._validate_unet(self.val_loader)
            
            # Update validation metrics to include completeness and soundness
            if 'val_completeness' not in val_metrics:
                # For backward compatibility, use val_acc as completeness if not already present
                val_metrics['val_completeness'] = val_metrics['acc']
                val_metrics['val_soundness'] = 0  # Default if not present
            
            # Log metrics to wandb if enabled
            if self.logger is not None:
                metrics_dict = {
                    'epoch': epoch,
                    'train/loss': train_metrics['train_loss'],
                    'train/completeness': train_metrics['train_completeness'],
                    'train/soundness': train_metrics['train_soundness'],
                    'val/loss': val_metrics['loss'],
                    'val/completeness': val_metrics['completeness'],
                    'val/soundness': val_metrics['soundness'],
                    'mask/sparsity': self.mask_size,
                }
                self.logger.log(metrics_dict, step=epoch)

            # Consolidated Epoch Summary
            print("\n" + "=" * 60)
            print("EPOCH SUMMARY")
            print("=" * 60)
            print(f"Epoch {epoch+1}/{self.epochs}")
            print("-" * 60)
            print("Training Statistics:")
            print(f"   Loss            : {train_metrics['train_loss']:.4f}")
            print(f"   Completeness    : {train_metrics['train_completeness']:.2f}%")
            print(f"   Soundness       : {train_metrics['train_soundness']:.2f}%")
            print("-" * 60)
            print("Validation Statistics:")
            print(f"   Loss            : {val_metrics['loss']:.4f}")
            print(f"   Completeness    : {val_metrics['completeness']:.2f}%")
            print(f"   Soundness       : {val_metrics['soundness']:.2f}%")
            print("-" * 60)

            # Update best metrics and save model
            if val_metrics['soundness'] > 90:
                # If soundness threshold is met, add a large bonus to ensure it's better than sub-90 models
                # but still maintains the comp+sound ordering among models above 90%
                combined_metric = val_metrics['completeness'] + val_metrics['soundness'] + 100
            else:
                combined_metric = val_metrics['completeness'] + val_metrics['soundness']
                
            if combined_metric > best_metrics['best_combined_metric']:
                best_metrics.update({
                    'best_combined_metric': combined_metric,
                    'best_epoch': epoch,
                    'val_loss': val_metrics['loss'],
                    'val_completeness': val_metrics['completeness'],
                    'val_soundness': val_metrics['soundness'],
                    'train_completeness': train_metrics['train_completeness'],
                    'train_soundness': train_metrics['train_soundness'],
                    'train_loss': train_metrics['train_loss']
                })
                
                # Deep copy of the model's state dictionary
                best_model_state = {
                    'epoch': epoch,
                    'model_state_dict': {k: v.cpu().clone() for k, v in self.model.state_dict().items()},
                    'merlin_state_dict': {k: v.cpu().clone() for k, v in self.merlin.state_dict().items()},
                    'morgana_state_dict': {k: v.cpu().clone() for k, v in self.morgana.state_dict().items()}
                }
                
                if self.save_model:
                    checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'merlin_state_dict': self.merlin.state_dict(),
                        'morgana_state_dict': self.morgana.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'merlin_optimizer_state_dict': self.merlin_optimizer.state_dict(),
                        'morgana_optimizer_state_dict': self.morgana_optimizer.state_dict(),
                        **best_metrics
                    }

                    # Base directory name
                    base_dir = f'{self.approach}_classifier_{self.model_name}_seed{self.seed}'
                    
                    # If using wandb, append the run name
                    if self.logger is not None:
                        base_dir = f'{base_dir}_{self.logger.run.name}'

                    checkpoint_dir = os.path.join(self.res_dir, base_dir)
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
                        
                    torch.save(checkpoint, checkpoint_path)
                    print(f"Updated best model checkpoint with combined metric: {combined_metric:.2f}")

         # Print final summary with best metrics
        print(f"\nTraining completed!")
        # Print final metrics with best model
        print("\nFinal evaluation with best model:\n")
        print(f"Best combined metric at epoch {best_metrics['best_epoch']+1}")
        print(f"Best validation completeness: {best_metrics['val_completeness']:.2f}%")
        print(f"Best validation soundness: {best_metrics['val_soundness']:.2f}%")
        
        # Load the best model state
        if best_model_state:
            self.model.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                      for k, v in best_model_state['model_state_dict'].items()})
            self.merlin.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                       for k, v in best_model_state['merlin_state_dict'].items()})
            self.morgana.load_state_dict({k: v.cuda() if torch.cuda.is_available() else v 
                                        for k, v in best_model_state['morgana_state_dict'].items()})

        test_metrics = self._validate_unet(self.test_loader)
        print(f"Best test completeness: {test_metrics['completeness']:.2f}%")
        print(f"Best test soundness: {test_metrics['soundness']:.2f}%")

        # Log final metrics to wandb if enabled
        if self.logger is not None:
            self.logger.log({
                'val/best_combined_metric': best_metrics['best_combined_metric'],
                'val/best_loss': best_metrics['val_loss'],
                'val/best_epoch': best_metrics['best_epoch'],
                'val/best_completeness': best_metrics['val_completeness'],
                'val/best_soundness': best_metrics['val_soundness'],
                'test/best_completeness': test_metrics['completeness'],
                'test/best_soundness': test_metrics['soundness']
            })
        
        return best_metrics

    def _train_epoch_unet(self):
        """Train model for one epoch using U-Net approach
        
        Returns:
            dict: Dictionary containing training metrics
        """
        self.model.train()
        self.merlin.train()
        self.morgana.train()

        total_loss = 0
        total_completeness = 0
        total_soundness = 0
        batch_count = 0

         # Progress bar for training
        pbar = tqdm(self.train_loader, desc='Training with Merlin-Arthur')
        
        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()
            
            # Step 1: Optimize masks using learnable feature selectors
            continuous_mask_merlin = self._optimize_unet(inputs, targets, self.merlin, self.merlin_optimizer, steps=self.unet_steps)
            continuous_mask_morgana = self._optimize_unet(inputs, targets, self.morgana, self.morgana_optimizer, steps=self.unet_steps)
            
            # Step 2: Convert to binary masks using top-k selection
            binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)

            # Step 3: Apply mask and compute logits

            masked_inputs_merlin = self.merlin.apply_mask(inputs, binary_mask_merlin)
            masked_inputs_morgana = self.morgana.apply_mask(inputs, binary_mask_morgana)

            self.model.eval()  # NOTE: Need to be in eval mode to prevent batchnorm from updating

            logits_merlin = self.model(masked_inputs_merlin)
            logits_morgana = self.model(masked_inputs_morgana)

            # Step 4: Calculate losses
            # Merlin and Morgana losses
            merlin_loss = self.merlin.criterion(logits_merlin, targets)
            morgana_loss = self.morgana.criterion(logits_morgana, targets)
            
            # Combined loss for the classifier
            loss = merlin_loss + self.gamma * morgana_loss

            loss.backward()

            # Update optimizer
            self.optimizer.step()

            self.optimizer.zero_grad()
            self.merlin_optimizer.zero_grad()
            self.morgana_optimizer.zero_grad()
            
            # Update metrics
            total_loss += loss.item()

            # Calculate accuracies (completeness and soundness)
            batch_completeness = get_accuracy(logits_merlin, targets, mode="merlin", idk_class=self.num_classes)
            batch_soundness = get_accuracy(logits_morgana, targets, mode="morgana", idk_class=self.num_classes)
            
            # Accumulate for epoch average
            total_completeness += batch_completeness
            total_soundness += batch_soundness
            batch_count += 1   

            # Update progress bar with overlap info and accuracy
            pbar.set_postfix({
                'loss': f'{total_loss/batch_count:.4f}',
                'comp': f'{100.*batch_completeness:.2f}%',
                'sound': f'{100.*batch_soundness:.2f}%'
            })    

        avg_completeness = total_completeness / batch_count if batch_count > 0 else 0
        avg_soundness = total_soundness / batch_count if batch_count > 0 else 0

        # Compute epoch metrics
        metrics = {
            'train_loss': total_loss / batch_count,
            'train_completeness': 100. * avg_completeness,
            'train_soundness': 100. * avg_soundness
        }
        
        return metrics

    def _optimize_unet(self, inputs, targets, feature_selector, optimizer, steps=1):
        """
        Single optimization step for the U-Net (Merlin/Morgana)
        """
        self.model.eval()

        for _ in range(steps):
            continuous_mask = feature_selector(inputs)
            continuous_mask = feature_selector.normalize_l1(continuous_mask, self.mask_size)

            l1_penalty = self.l1_penalty_coefficient * torch.mean(torch.abs(continuous_mask))
            #l2_penalty = self.l2_penalty_coefficient * torch.mean(torch.square(continuous_mask))

            #tv_norm = torch.sum(torch.abs(continuous_mask[:, :, :, :-1] - continuous_mask[:, :, :, 1:]) ** 2) + torch.sum(
            #            torch.abs(continuous_mask[:, :, :-1, :] - continuous_mask[:, :, 1:, :]) ** 2) 
            #tv_norm = tv_norm / (continuous_mask.shape[0])
            #tv_penalty = self.tv_penalty_coefficient * tv_norm
 
            
            masked_inputs = feature_selector.apply_mask(inputs, continuous_mask)
            logits = self.model(masked_inputs)

            if feature_selector.mode == "merlin":
                loss = feature_selector.criterion(logits, targets) + l1_penalty #+ l2_penalty + tv_penalty
            elif feature_selector.mode == "morgana":
                loss = -feature_selector.criterion(logits, targets) + l1_penalty #+ l2_penalty + tv_penalty

        loss.backward()
        optimizer.step()
        self.optimizer.zero_grad() # Arthur optimizer
        optimizer.zero_grad() # Feature selector optimizer

        return continuous_mask

    def _validate_unet(self, loader: DataLoader):
        """Validate model using U-Net approach
        
        Returns:
            dict: Dictionary containing validation metrics
        """
        self.model.eval()
        self.merlin.eval()
        self.morgana.eval()

        total_loss = 0
        total_completeness = 0
        total_soundness = 0
        batch_count = 0
        
        # Progress bar for validation
        pbar = tqdm(loader, desc='Validating with Merlin-Arthur' if loader == self.val_loader else 'Testing with Merlin-Arthur')
        
        for inputs, targets in pbar:
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device).long()

            # Disable gradients for the validation process
            with torch.no_grad():
                continuous_mask_merlin = self.merlin(inputs)
                continuous_mask_morgana = self.morgana(inputs)

                continuous_mask_merlin = self.merlin.normalize_l1(continuous_mask_merlin, self.mask_size)
                continuous_mask_morgana = self.morgana.normalize_l1(continuous_mask_morgana, self.mask_size)
                
                # Convert to binary masks using top-k selection
                binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
                binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)
            
                # Apply masks and get predictions
                masked_inputs_merlin = self.merlin.apply_mask(inputs, binary_mask_merlin)
                masked_inputs_morgana = self.morgana.apply_mask(inputs, binary_mask_morgana)
                
                logits_merlin = self.model(masked_inputs_merlin)
                logits_morgana = self.model(masked_inputs_morgana)
                
                # Calculate loss
                merlin_loss = self.merlin.criterion(logits_merlin, targets)
                morgana_loss = self.morgana.criterion(logits_morgana, targets)
                loss = merlin_loss + self.gamma * morgana_loss
                
                # Update metrics
                total_loss += loss.item()
                
                # Calculate accuracies (completeness and soundness)
                batch_completeness = get_accuracy(logits_merlin, targets, mode="merlin", idk_class=self.num_classes)
                batch_soundness = get_accuracy(logits_morgana, targets, mode="morgana", idk_class=self.num_classes)
                
                # Accumulate for epoch average
                total_completeness += batch_completeness
                total_soundness += batch_soundness
                batch_count += 1
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f'{total_loss/batch_count:.4f}',
                    'comp': f'{100.*batch_completeness:.2f}%',
                    'sound': f'{100.*batch_soundness:.2f}%'
                })
        
        # Calculate epoch-level metrics
        avg_completeness = total_completeness / batch_count if batch_count > 0 else 0
        avg_soundness = total_soundness / batch_count if batch_count > 0 else 0
                
        # Compute validation metrics
        metrics = {
            'loss': total_loss / batch_count,
            'acc': 100. * avg_completeness,  # Keep val_acc for backward compatibility
            'completeness': 100. * avg_completeness,
            'soundness': 100. * avg_soundness
        }
        
        return metrics
    
class CLEVRHansDatasetWrapper(CLEVRHansDataset):
    """Wrapper for CLEVRHansDataset that only returns images and labels"""
    def __getitem__(self, item):
        image, _, label, _ = super().__getitem__(item)
        return image, label
    
    