# forward_forward/training/trainer.py (Enhanced with Multi-Mode Class Grouping)

import os
import torch
import numpy as np
import torchvision.transforms as T
import wandb
from omegaconf import OmegaConf
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ExponentialLR, PolynomialLR, OneCycleLR
from torch.utils.data import DataLoader
from typing import List, Optional, Dict, Tuple, Any, Union
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image

from forward_forward.evaluation.ff_evaluation import evaluate
from forward_forward.models.model_factory import SkipConnection
from forward_forward.data.transforms import ApplyTransform
from forward_forward.models.layers.class_grouping import ClassGroupingMode
from torchvision.transforms import RandAugment


# Constants
SCHEDULER_REGISTRY = {
    "StepLR": StepLR,
    "CosineAnnealingLR": CosineAnnealingLR,
    "ExponentialLR": ExponentialLR,
    "PolynomialLR": PolynomialLR,
    "OneCycleLR": OneCycleLR,
}

DEFAULT_AUGMENTATION_CONFIG = {
    "batch_size": 128,
    "num_workers": 2,
    "initial_aug_pad": 4,
    "initial_cutout": 8,
    "initial_num_ops": 3,
    "initial_magnitude": 12,
    # "accuracy_threshold": 0.01,
}

DEFAULT_TRAINING_CONFIG = {
    "conv_lr": 0.001,
    "conv_weight_decay": 0.0001,
    "scale_lr": 0.001,
    "scale_weight_decay": 0.0001,
    "betas": (0.9, 0.999),
}


class Cutout:
    """Cutout data augmentation."""
    
    def __init__(self, size: int = 16):
        self.size = size

    def __call__(self, img: Image.Image) -> Image.Image:
        h, w = img.size[1], img.size[0]
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.size // 2, 0, h)
        y2 = np.clip(y + self.size // 2, 0, h)
        x1 = np.clip(x - self.size // 2, 0, w)
        x2 = np.clip(x + self.size // 2, 0, w)

        img_np = np.array(img)
        img_np[y1:y2, x1:x2] = 0
        return Image.fromarray(img_np)


class AugmentationManager:
    """Manages adaptive augmentation during training."""
    
    def __init__(self, cfg: Any, train_subset: Any):
        self.cfg = cfg
        self.train_subset = train_subset
        self.aug_pad = DEFAULT_AUGMENTATION_CONFIG["initial_aug_pad"]
        self.cutout = DEFAULT_AUGMENTATION_CONFIG["initial_cutout"]
        self.num_ops = DEFAULT_AUGMENTATION_CONFIG["initial_num_ops"]
        self.magnitude = DEFAULT_AUGMENTATION_CONFIG["initial_magnitude"]
        # self.accuracy_threshold = DEFAULT_AUGMENTATION_CONFIG["accuracy_threshold"]
    
    def create_train_loader(self) -> DataLoader:
        """Create training dataloader with current augmentation settings."""
        consecutive_cfg = self.cfg.get("consecutive_training", {})
        global_batch_size = consecutive_cfg.get("batch_size", DEFAULT_AUGMENTATION_CONFIG["batch_size"])
        
        if self.cfg.data.dataset.lower() in ["mnist", "fashionmnist"]:
            # Simpler augmentation for MNIST
            train_transform = T.Compose([
                T.RandomRotation(10),
                T.RandomAffine(0, translate=(0.1, 0.1)),
                T.ToTensor(),
            ])
        else:
            train_transform = T.Compose([
                T.RandomHorizontalFlip(p=0.5),
                
                RandAugment(
                    num_ops=self.num_ops,
                    magnitude=self.magnitude
                ),
                
                # T.RandomGrayscale(p=0.05),
                
                T.ToTensor(),
            ])
        
        # # 4 Aug
        # train_transform = T.Compose([
        #     T.RandomCrop(
        #         32,
        #         padding=self.aug_pad,
        #         padding_mode="reflect"
        #     ),
        #     T.ColorJitter(
        #         brightness=0.3,
        #         contrast=0.3,
        #         saturation=0.3,
        #         hue=0.2
        #     ),
        #     T.RandomHorizontalFlip(p=0.5),
        #     RandAugment(
        #         num_ops=self.num_ops,
        #         magnitude=self.magnitude
        #     ),
        #     Cutout(size=self.cutout),
        #     T.RandomGrayscale(p=0.1),
        #     T.ToTensor(),
        # ])
        
        train_dataset = ApplyTransform(self.train_subset, transform=train_transform)
        
        num_workers = self.cfg.get("num_workers", DEFAULT_AUGMENTATION_CONFIG["num_workers"])
        train_loader = torch.utils.data.DataLoader(
            train_dataset, 
            batch_size=global_batch_size,
            shuffle=True, 
            num_workers=num_workers
        )
        
        print(f"🔄 Updated training loader - batch_size: {global_batch_size}, "
            #   f"augmentation: pad={self.aug_pad}, cutout={self.cutout}, "
              f"ops={self.num_ops}, mag={self.magnitude}")
        
        return train_loader


class WandBManager:
    """Manages Weights & Biases integration."""
    
    @staticmethod
    def setup(cfg: Any, device: torch.device, block_name: str) -> None:
        """Initialize wandb session."""
        credentials = cfg.get("wandb_credentials", {})
        if "api_key" in credentials:
            os.environ["WANDB_API_KEY"] = credentials["api_key"]
            if "entity" in credentials:
                os.environ["WANDB_ENTITY"] = credentials["entity"]
        
        run_name = cfg.get("run_name", None)
        # name = f"{run_name}_{block_name}"
        
        wandb.init(
            project=cfg.get("project", "forward_forward"),
            name=run_name,
            group=cfg.training.get("group_name", "experiment"),
            config=OmegaConf.to_container(cfg, resolve=True),
        )
        
        wandb.config.update({"device": device.type})
        if device.type == "cuda":
            wandb.config.update({
                "gpu_name": torch.cuda.get_device_name(device),
                "gpu_memory_gb": round(
                    torch.cuda.get_device_properties(device).total_memory / (1024 ** 3), 1
                )
            })
        else:
            wandb.alert(
                title="Training on CPU", 
                text="CUDA was not available. Training is running on CPU."
            )


class OptimizerManager:
    """Manages optimizers and schedulers for different training blocks."""
    
    def __init__(self, optimizer_class: torch.optim.Optimizer = torch.optim.AdamW):
        self.optimizer_class = optimizer_class
        self.optimizers = {}
        self.schedulers = {}
    
    def create_ff_optimizers(
        self, 
        model: nn.Module, 
        hyperparams: Dict[str, Dict[str, Any]]
    ) -> Tuple[Dict[str, torch.optim.Optimizer], Dict[str, Any]]:
        """Create optimizers for Forward-Forward training with class grouping awareness."""
        optimizers = {}
        schedulers = {}
        
        for block_name in model.trainable_names:
            block = model.layers[block_name]
            
            # Skip if block has no trainable parameters
            if not any(p.requires_grad for p in block.parameters()):
                print(f"→ Skipping {block_name} (frozen)")
                continue
            
            # Log class grouping information for this block
            grouping_mode = None
            if hasattr(block.layer, 'class_grouping_manager') and block.layer.class_grouping_manager is not None:
                grouping_mode = block.layer.class_grouping_manager.mode.value
                num_effective_classes = block.layer.class_grouping_manager.get_num_classes()
                group_info = block.layer.class_grouping_manager.get_group_info()
                print(f"→ {block_name} using class grouping ({grouping_mode}): "
                      f"{num_effective_classes} effective classes, groups: {group_info}")
            
            # Create optimizer for this block
            optimizers[block_name] = self.optimizer_class(
                [
                    {
                        'params': block.layer.conv.parameters(),
                        'lr': hyperparams[block_name]["conv_lr"],
                        'weight_decay': hyperparams[block_name]["conv_weight_decay"],
                    },
                    {
                        'params': block.layer.label_encoder.parameters(),
                        'lr': hyperparams[block_name]["scale_lr"],
                        'weight_decay': hyperparams[block_name]["scale_weight_decay"],
                    },
                    {'params': [block.layer.feature_threshold]}
                ],
                betas=DEFAULT_TRAINING_CONFIG["betas"]
            )
            
            # Setup scheduler if specified
            scheduler = self._create_scheduler(optimizers[block_name], hyperparams[block_name])
            if scheduler:
                schedulers[block_name] = scheduler
        
        return optimizers, schedulers
    
    def create_bp_optimizer(
        self, 
        model: nn.Module, 
        cfg: Any
    ) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
        """Create optimizer for backpropagation training."""
        train_config = cfg.training
        learning_rate = getattr(train_config, "conv_lr", DEFAULT_TRAINING_CONFIG["conv_lr"])
        weight_decay = getattr(train_config, "conv_weight_decay", DEFAULT_TRAINING_CONFIG["conv_weight_decay"])
        
        optimizer = self.optimizer_class(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=DEFAULT_TRAINING_CONFIG["betas"]
        )
        
        scheduler = None
        if hasattr(train_config, 'scheduler') and train_config.scheduler:
            scheduler = self._create_scheduler(optimizer, train_config.scheduler)
        
        return optimizer, scheduler
    
    def _create_scheduler(self, optimizer: torch.optim.Optimizer, config: Dict[str, Any]) -> Optional[Any]:
        """Create learning rate scheduler from configuration."""
        if not isinstance(config, dict):
            return None
            
        sched_conf = config.get("scheduler", config)  # Handle both nested and direct config
        if not sched_conf:
            return None
            
        sched_type = sched_conf.get("type", None)
        if sched_type in SCHEDULER_REGISTRY:
            SchedulerClass = SCHEDULER_REGISTRY[sched_type]
            scheduler_params = {k: v for k, v in sched_conf.items() if k != "type"}
            return SchedulerClass(optimizer, **scheduler_params)
        
        return None


class ConfigurationManager:
    """Manages configuration extraction and validation."""
    
    @staticmethod
    def extract_hyperparameters(cfg: Any) -> Dict[str, Dict[str, Any]]:
        """Extract learning rates and decay parameters with defaults."""
        result = {}
        training_defaults = cfg.training

        for block in cfg.model.architecture:
            name = getattr(block, "name", None)
            if name is None:
                continue

            block_info = {
                "conv_lr": getattr(block, "conv_lr", training_defaults.conv_lr),
                "conv_weight_decay": getattr(block, "conv_weight_decay", training_defaults.conv_weight_decay),
                "scale_lr": getattr(block, "scale_lr", training_defaults.scale_lr),
                "scale_weight_decay": getattr(block, "scale_weight_decay", training_defaults.scale_weight_decay),
            }

            # Scheduler info if present
            scheduler = getattr(block, "scheduler", None)
            if scheduler:
                scheduler_type = scheduler.get("type", None)
                if scheduler_type == "PolynomialLR":
                    block_info["scheduler"] = {
                        "type": scheduler_type,
                        "total_iters": scheduler.get("total_iters", None),
                        "power": scheduler.get("power", None),
                    }
                elif scheduler_type == "get_cosine_schedule_with_warmup":
                    block_info["scheduler"] = {
                        "type": scheduler_type,
                        "num_training_steps": scheduler.get("num_training_steps", None),
                        "num_warmup_steps": scheduler.get("num_warmup_steps", None),
                    }

            result[name] = block_info

        return result


class BaseTrainer:
    """Base trainer class with common functionality."""
    
    def __init__(
        self,
        model: torch.nn.Module,
        train_dataset: Any,
        cfg: Any,
        input_shape: Tuple[int, ...],
        optimizer_class: torch.optim.Optimizer = torch.optim.AdamW,  # Optimizer class is defined here
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        wandb_enabled: bool = False,
        run_name: str = "forward_forward",
        dataset: str = "dataset",
        output_dir: str = "experiments",
        num_classes: int = 10,
        **kwargs  # For backward compatibility
    ):
        self.model = model.to(device)
        self.train_dataset = train_dataset
        self.cfg = cfg if not isinstance(cfg, tuple) else cfg[0]
        self.device = torch.device(device)
        self.wandb_enabled = wandb_enabled
        self.run_name = run_name
        self.dataset = dataset
        self.num_classes = num_classes
        self.input_shape = input_shape
        
        # Setup output directory
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
        print(f"✅ Trainer output directory: {self.output_dir}")
        
        # Initialize managers
        self.optimizer_manager = OptimizerManager(optimizer_class)
        self.augmentation_manager = AugmentationManager(self.cfg, train_dataset)
        
        # Create optimized inputs directory
        self.optimized_inputs_dir = os.path.join(output_dir, "optimized_inputs")
        os.makedirs(self.optimized_inputs_dir, exist_ok=True)
        
        # For backward compatibility - handle legacy parameters
        self.block_loss_types = kwargs.get('block_loss_types', {})
        self.block_margins = kwargs.get('block_margins', {})

    def _save_checkpoint(
        self, 
        epoch: int, 
        accuracy: float, 
        block_name: str = "", 
        additional_data: Optional[Dict[str, Any]] = None
    ) -> str:
        """Save model checkpoint."""
        suffix = f"_{block_name}" if block_name else ""
        ckpt_path = os.path.join(
            self.output_dir, 
            f"{self.run_name}{suffix}_{self.dataset}_{accuracy:.4f}.pth"
        )
        
        checkpoint_data = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'accuracy': accuracy,
            'config': self.cfg
        }
        
        if additional_data:
            checkpoint_data.update(additional_data)
        
        torch.save(checkpoint_data, ckpt_path)
        print(f"Model saved at {ckpt_path}")
        
        # Save in WandB if enabled
        if self.wandb_enabled:
            wandb.save(ckpt_path)
        
        return ckpt_path

    def _evaluate_epoch(self, dataloader: DataLoader, criterion: nn.Module) -> Tuple[float, float]:
        """Evaluate the model on a given dataloader for one epoch."""
        self.model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                
                outputs = self.model(x)
                loss = criterion(outputs, y)
                
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()
        
        avg_loss = total_loss / len(dataloader)
        accuracy = correct / total
        
        return avg_loss, accuracy


class Trainer(BaseTrainer):
    """Trainer for layer-wise greedy Forward-Forward training with enhanced class grouping support."""

    def train(
        self, 
        val_dataloader: Optional[DataLoader] = None, 
        test_dataloader: Optional[DataLoader] = None, 
        total_epochs: int = 200
    ) -> None:
        """Train all layers consecutively for each minibatch."""
        hyperparams = ConfigurationManager.extract_hyperparameters(self.cfg)
        
        # Initialize optimizers and schedulers
        optimizers, schedulers = self.optimizer_manager.create_ff_optimizers(
            self.model, hyperparams
        )
        
        # Setup wandb
        if wandb.run is not None:
            wandb.finish()
        if self.cfg.training.get("wandb", False):
            WandBManager.setup(self.cfg, self.device, "consecutive_training")
        if self.wandb_enabled:
            wandb.define_metric("epoch")
        
        print(f"\nStarting consecutive minibatch training for {total_epochs} epochs")
        print(f"Training blocks: {list(optimizers.keys())}")
        
        # Log class grouping information for each trainable block
        self._log_class_grouping_info(optimizers.keys())
        
        # Set all blocks to training mode
        for block_name in optimizers.keys():
            self.model.layers[block_name].train()
        
        train_loader = self.augmentation_manager.create_train_loader()
        
        for epoch in range(total_epochs):
            
            epoch_losses = {block_name: 0.0 for block_name in optimizers.keys()}
            total_correct_predictions = {block_name: 0.0 for block_name in optimizers.keys()}
            num_batches = 0
            
            # Training loop
            for batch_idx, (x, y) in enumerate(train_loader):
                x, y = x.to(self.device), y.to(self.device)
                num_batches += 1

                # Forward pass and training for each block consecutively
                epoch_losses, total_correct_predictions = self._train_batch_consecutive(
                    x, y, optimizers, epoch_losses, total_correct_predictions
                )

            # End of epoch processing
            self._process_end_of_epoch(
                epoch, total_epochs, epoch_losses, num_batches,
                optimizers, schedulers, total_correct_predictions, val_dataloader
            )
            # TESTING: Evaluate on test set at the end of each epoch because we are training with train + val set
            # self._process_end_of_epoch(
            #     epoch, total_epochs, epoch_losses, num_batches,
            #     optimizers, schedulers, total_correct_predictions, test_dataloader
            # )

        # Final test evaluation
        self._final_evaluation(test_dataloader, optimizers, total_epochs)

    def _log_class_grouping_info(self, trainable_block_names: List[str]) -> None:
        """Log class grouping information for all trainable blocks."""
        print("\n📊 Class Grouping Configuration:")
        for block_name in trainable_block_names:
            block = self.model.layers[block_name]
            if hasattr(block.layer, 'class_grouping_manager') and block.layer.class_grouping_manager is not None:
                manager = block.layer.class_grouping_manager
                mode = manager.mode.value
                num_effective = manager.get_num_classes()
                groups = manager.get_group_info()
                
                print(f"  {block_name}:")
                print(f"    Mode: {mode}")
                print(f"    Effective classes: {num_effective}")
                print(f"    Groups: {groups}")
                
                if manager.is_group_aware_negative_mode():
                    # Show example of negative sampling constraints
                    example_class = 0  # Just use class 0 as example
                    valid_negatives = manager.get_valid_negative_classes(example_class)
                    same_group = manager.get_same_group_classes(example_class)
                    print(f"    Example - Class {example_class} same group: {same_group}")
                    print(f"    Example - Class {example_class} valid negatives: {valid_negatives}")
            else:
                print(f"  {block_name}: No class grouping")
        print()

    def _train_batch_consecutive(
        self, 
        x: torch.Tensor, 
        y: torch.Tensor, 
        optimizers: Dict[str, torch.optim.Optimizer],
        epoch_losses: Dict[str, float],
        total_correct_predictions: Dict[str, float],
    ) -> Tuple[Dict[str, float], Dict[str, float]]:
        """Train a single batch through all blocks consecutively with class grouping awareness."""
        outputs = {"input": x}
        curr_x = x.clone()
        for layer_name in self.model.layers:
            layer = self.model.layers[layer_name]
            
            if layer_name in optimizers:
                # Trainable block - compute loss and update
                optimizer = optimizers[layer_name]
                
                # Forward pass
                if isinstance(layer, SkipConnection):
                    skip_input = outputs[layer.skip_from]
                    feats = layer(curr_x, skip_input)
                else:
                    feats = layer(curr_x)
                
                # Compute loss and update (automatically handles class grouping)
                loss, correct_predictions = self._compute_ff_loss(layer, feats, y, "bce")
                total_correct_predictions[layer_name] += correct_predictions
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_losses[layer_name] += loss.item()
                
                # Update current input for next layer
                curr_x = feats.detach()
                outputs[layer_name] = curr_x
                
            else:
                # Non-trainable layer - just forward pass
                with torch.no_grad():
                    if isinstance(layer, SkipConnection):
                        skip_input = outputs[layer.skip_from]
                        curr_x = layer(curr_x, skip_input)
                    else:
                        curr_x = layer(curr_x)
                    outputs[layer_name] = curr_x
        
        return epoch_losses, total_correct_predictions

    def custom_margin_ranking_loss(self, x1, x2, y, margin):
        # x1, x2: logits (batch,)
        # y: targets (+1 or -1) (batch,)
        # margin: tensor of margins (batch,)
        diff = x1 - x2
        losses = torch.clamp(margin - y * diff, min=0)
        return losses.mean()

    def ova_hinge_loss_with_margins(self, class_scores, y, margins):
        B, C = class_scores.shape
        
        # Create OvA targets: +1 for true class, -1 otherwise
        targets = -torch.ones_like(class_scores)
        targets[torch.arange(B), y] = 1
        
        # Expand margins to (1, C) for broadcasting
        margins_expanded = margins.unsqueeze(0)  # shape (1, C)
        
        # Compute hinge loss element-wise
        losses = torch.clamp(margins_expanded - targets * class_scores, min=0)
        
        # Average over batch and classes
        return losses.mean()



    def pairwise_margin_loss(self, class_scores, y, margin_matrix):
        B, C = class_scores.shape

        # Gather positive logits: (B, 1)
        pos_logits = class_scores[torch.arange(B), y].unsqueeze(1)

        # Repeat pos_logits across classes: (B, C)
        pos_logits_exp = pos_logits.expand(-1, C)

        # Margins for each sample: (B, C)
        margins = margin_matrix[y]  # gather rows of M by y: shape (B, C)

        # Compute hinge loss per negative class
        losses = torch.clamp(margins - (pos_logits_exp - class_scores), min=0)

        # Ignore positive class margin loss (where j == y_i)
        losses[torch.arange(B), y] = 0

        return losses.sum() / B  # or mean over valid pairs

    def update_margin_matrix_from_confusion(self, confusion, alpha=1.0, beta=0.1, ema_margin=None, gamma=0.1):
        # confusion: (C, C) counts matrix
        # Normalize rows
        row_sums = confusion.sum(dim=1, keepdim=True).clamp(min=1e-6)
        conf_norm = confusion / row_sums

        # Compute margin matrix
        margin_matrix = alpha * conf_norm + beta

        # Zero diagonal
        margin_matrix.fill_diagonal_(0)

        # EMA update
        if ema_margin is not None:
            margin_matrix = (1 - gamma) * ema_margin + gamma * margin_matrix

        return margin_matrix

    def _compute_ff_loss(
        self, 
        layer: nn.Module, 
        feats: torch.Tensor, 
        y: torch.Tensor, 
        loss_type: str,
    ) -> torch.Tensor:
        """Compute Forward-Forward loss for a layer.
        Args:
            layer: The layer to compute loss for.
            feats: Input features (batch_size, ...).
            y: Ground truth labels (batch_size,).
            loss_type: Loss type ('bce', 'mrl', 'scaled_mrl' or 'infoNCE').
        Returns:
            Loss tensor.
        """
        goodness = layer.layer.image_encoder(feats)  # (batch_size, out_channels)
        
        if loss_type == 'margin':
            class_scores = layer.layer.gpredict(goodness)
            # self.margin_matrix[layer_name] = self.update_margin_matrix_from_confusion(self.confusion_matrix[layer_name], ema_margin=self.margin_matrix[layer_name])
            # loss = self.pairwise_margin_loss(class_scores, y, self.margin_matrix[layer_name])
            return loss, (torch.argmax(class_scores, dim=1) == y).float().mean().item()
        
        elif loss_type == 'bce':
            # Original BCE loss (as in your code)
            pos_logits = layer.layer.embedding_alignment(goodness, y, "positive")
            neg_logits, correct_predictions = layer.layer.embedding_alignment(goodness, y, "negative")  # correct_predictions for efficient calculation of the training accuracy
            total_loss = F.binary_cross_entropy_with_logits(
                torch.cat([pos_logits, neg_logits]),
                torch.cat([torch.ones_like(pos_logits), torch.zeros_like(neg_logits)]),
            )
            return total_loss, correct_predictions

        elif loss_type == 'mrl':
            # Original BCE loss (as in your code)
            pos_logits = layer.layer.embedding_alignment(goodness, y, "positive")
            neg_logits, correct_predictions = layer.layer.embedding_alignment(goodness, y, "negative")  # correct_predictions for efficient calculation of the training accuracy

            total_loss = F.margin_ranking_loss(pos_logits, neg_logits, target=torch.ones_like(pos_logits), margin=0.5)
            return total_loss, correct_predictions
        
        elif loss_type == 'scaled_mrl':
            pos_logits = layer.layer.embedding_alignment2(goodness, y, "positive")
            neg_logits, correct_predictions, class_scores, hard_label = layer.layer.embedding_alignment2(goodness, y, "negative")  # correct_predictions for efficient calculation of the training accuracy
            
            B = class_scores.shape[0]

            # Get score of sampled hard negative class
            wrong_class = torch.argmax(hard_label, dim=1)
            neg_score = class_scores[torch.arange(B), wrong_class]  # (B,)
            true_score = class_scores[torch.arange(B), y]           # (B,)
            difficulty = (neg_score - true_score) / true_score

            # Adaptive margin from difficulty
            base_margin = 0.5
            scale = 1.0
            adaptive_margin = base_margin + scale * difficulty.clamp(min=-2, max=2).sigmoid()  # (B,)

            # Manually compute per-sample margin ranking loss
            loss = torch.clamp(-(pos_logits - neg_logits) + adaptive_margin, min=0).mean()

            return loss, correct_predictions


        elif loss_type == 'infoNCE':
            neg_logits, correct_predictions = layer.layer.embedding_alignment(goodness, y, "negative")
            return layer.layer.info_nce_loss(goodness, y), correct_predictions

        else:
            raise ValueError(f"Unsupported loss_type: {loss_type}")

    def _process_end_of_epoch(
        self,
        epoch: int,
        total_epochs: int,
        epoch_losses: Dict[str, float],
        num_batches: int,
        optimizers: Dict[str, torch.optim.Optimizer],
        schedulers: Dict[str, Any],
        total_correct_predictions: Dict[str, float],
        val_dataloader: Optional[DataLoader]
    ) -> None:
        """Process end of epoch logging and evaluation."""
        avg_losses = {name: loss / num_batches for name, loss in epoch_losses.items()}

        print(f"Epoch {epoch + 1}/{total_epochs}")
        for block_name, avg_loss in avg_losses.items():
            # Add grouping mode info to logging
            grouping_info = ""
            block = self.model.layers[block_name]
            if hasattr(block.layer, 'class_grouping_manager') and block.layer.class_grouping_manager is not None:
                mode = block.layer.class_grouping_manager.mode.value
                grouping_info = f" [{mode}]"
            
            print(f"  {block_name}{grouping_info}: loss={avg_loss:.6f}")

        # Update schedulers
        for block_name, scheduler in schedulers.items():
            scheduler.step()

        # Validation evaluation (automatically handles class grouping)
        # TO CHECK: Disabled for checking test accuracy when training with train dataset plus validation dataset
        _ = evaluate(
            model=self.model,
            dataloader=val_dataloader,
            device=self.device,
            split="val",
            step=epoch,
            num_epochs=total_epochs,
            wandb_enabled=self.wandb_enabled,
            original_num_classes=self.num_classes,
        )
        
        train_accuracy = {name: accuracy / num_batches for name, accuracy in total_correct_predictions.items()}
        
        # Log epoch-level metrics
        if self.wandb_enabled:
            log_dict = {"epoch": epoch}
            for block_name, avg_loss in avg_losses.items():
                log_dict[f"{block_name}/train/epoch_loss"] = avg_loss
                if block_name in optimizers:
                    log_dict[f"{block_name}/train/learning_rate"] = optimizers[block_name].param_groups[0]["lr"]
                if block_name in train_accuracy:
                    log_dict[f"{block_name}/train/accuracy"] = train_accuracy[block_name]
                
                # Log grouping mode info
                block = self.model.layers[block_name]
                if hasattr(block.layer, 'class_grouping_manager') and block.layer.class_grouping_manager is not None:
                    mode = block.layer.class_grouping_manager.mode.value
                    log_dict[f"{block_name}/train/grouping_mode"] = mode
                    
            wandb.log(log_dict, step=epoch)

    def _final_evaluation(
        self, 
        test_dataloader: Optional[DataLoader], 
        optimizers: Dict[str, torch.optim.Optimizer],
        total_epochs: int
    ) -> None:
        """Perform final test evaluation and save model."""
        if test_dataloader is None:
            return
            
        result = evaluate(
            model=self.model,
            dataloader=test_dataloader,
            device=self.device,
            split="test",
            step=total_epochs - 1,
            num_epochs=total_epochs,
            wandb_enabled=True,
            original_num_classes=self.num_classes,
        )
        
        if result:
            best_block = max(result.items(), key=lambda x: x[1]["accuracy"])
            best_block_name, best_block_metrics = best_block
            best_accuracy = best_block_metrics["accuracy"]
            
            # Log grouping mode for best block
            grouping_mode = best_block_metrics.get("grouping_mode", "none")
            print(f"[test] {best_block_name} ({grouping_mode}) | Test accuracy: {best_accuracy:.4f}")
            
            # Save final checkpoint
            self._save_checkpoint(
                epoch=total_epochs - 1,
                accuracy=best_accuracy,
                block_name=f"consecutive_{best_block_name}",
                additional_data={"grouping_modes": {
                    name: metrics.get("grouping_mode", "none") 
                    for name, metrics in result.items()
                }}
            )


