import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from typing import List, Tuple, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import time
from copy import deepcopy

from .base_classifier import BaseMILClassifier, PredictionResult
from .classifier_factory import register_classifier


@register_classifier('homil')
class HOMILClassifier(BaseMILClassifier):
    
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.cov_extraction_method = config.get('cov_extraction_method', 'weighted_rows')
        self.fusion_method = config.get('fusion_method', 'attention')

        self.newton_schulz_iters = config.get('newton_schulz_iters', 5)
        self.newton_schulz_eps = config.get('newton_schulz_eps', 1e-8) 

        self.conv_kernel_size = config.get('conv_kernel_size', 3)  
        self.conv_num_kernels = config.get('conv_num_kernels', 4) 

        self.enable_gradient_checkpointing = config.get('enable_gradient_checkpointing', False)

        self.hidden_dim = config.get('hidden_dim', 256)
        self.dropout_rate = config.get('dropout_rate', 0.3)

        self.learning_rate = float(config.get('learning_rate', 0.001))
        self.weight_decay = float(config.get('weight_decay', 1e-4))
        self.batch_size = config.get('batch_size', 8)
        self.epochs = config.get('epochs', 100)

        self.enable_visualization = config.get('enable_visualization', False)
        self.viz_save_dir = config.get('viz_save_dir', 'homil_visualizations')
        self.viz_update_freq = config.get('viz_update_freq', 10)  # Update every 10 epochs

        if self.enable_visualization:
            self.training_history = {
                'train_loss': [], 'train_acc': [], 'train_auc': [], 'train_f1': [],
                'val_loss': [], 'val_acc': [], 'val_auc': [], 'val_f1': [],
                'fusion_weights_mean': [], 'fusion_weights_std': [],
                'first_order_weight': [], 'second_order_weight': [],
                'first_order_norm': [], 'second_order_norm': [],
                'first_order_mean': [], 'second_order_mean': [],
                'first_order_std': [], 'second_order_std': [],
                'norm_ratio': [] 
            }
            Path(self.viz_save_dir).mkdir(parents=True, exist_ok=True)


        
    def build_model(self, feature_dim: int, n_classes: int) -> nn.Module:
        
        class SimpleCovarianceNet(nn.Module):
            def __init__(self, feature_dim, hidden_dim, n_classes, dropout_rate,
                        cov_extraction_method, fusion_method, newton_schulz_iters, newton_schulz_eps,
                        conv_kernel_size, conv_num_kernels):
                super().__init__()

                self.feature_dim = feature_dim
                self.hidden_dim = hidden_dim
                self.cov_extraction_method = cov_extraction_method
                self.fusion_method = fusion_method
                self.newton_schulz_iters = newton_schulz_iters
                self.newton_schulz_eps = newton_schulz_eps
                self.conv_kernel_size = conv_kernel_size
                self.conv_num_kernels = conv_num_kernels
                
                self.feature_processor = nn.Sequential(
                    nn.Linear(feature_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate)
                )
                
                attention_dim = hidden_dim // 2
                self.patch_attention = nn.Sequential(
                    nn.Linear(hidden_dim, attention_dim),
                    nn.Tanh(),
                    nn.Linear(attention_dim, 1)
                )

                if cov_extraction_method == 'conv_pooling':
                    self.cov_conv1d = nn.Conv1d(
                        in_channels=1,
                        out_channels=conv_num_kernels,
                        kernel_size=conv_kernel_size,
                        padding=0,
                        bias=False
                    )
                    if cov_extraction_method == 'weighted_rows':
                        cov_feature_dim = hidden_dim
                    elif cov_extraction_method == 'upper_triangle':
                        cov_feature_dim = hidden_dim * (hidden_dim + 1) // 2
                    elif cov_extraction_method == 'sqrt_upper_triangle':
                        cov_feature_dim = hidden_dim * (hidden_dim + 1) // 2
                    elif cov_extraction_method == 'conv_pooling':
                        cov_feature_dim = hidden_dim 
                    else:
                        raise ValueError(f"Unknown cov_extraction_method: {cov_extraction_method}")

                    self.cov_processor = nn.Sequential(
                        nn.Linear(cov_feature_dim, hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(dropout_rate),
                        nn.Linear(hidden_dim, hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(dropout_rate)
                    )

                    if fusion_method == 'attention':
                        self.fusion_attention = nn.Sequential(
                            nn.Linear(hidden_dim, attention_dim),
                            nn.Tanh(),
                            nn.Linear(attention_dim, 1)
                        )
                        classifier_input_dim = hidden_dim
                    elif fusion_method == 'concatenation':
                        classifier_input_dim = hidden_dim * 2
                    else:
                        raise ValueError(f"Unknown fusion_method: {fusion_method}")

                else:
                    classifier_input_dim = hidden_dim

                self.classifier = nn.Sequential(
                    nn.Linear(classifier_input_dim, hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(hidden_dim // 2, n_classes)
                )
                
            def forward(self, patches):
                # patches: [batch_size, n_patches, feature_dim]
                processed_patches = self.feature_processor(patches)  # [batch_size, n_patches, hidden_dim]
                patch_attn_weights = self.patch_attention(processed_patches)  # [batch_size, n_patches, 1]
                patch_attn_weights = torch.softmax(patch_attn_weights, dim=1) 

                wsi_mean_repr = torch.sum(patch_attn_weights * processed_patches, dim=1)  # [batch_size, hidden_dim]

                cov_matrix = self._compute_covariance_matrix(
                    processed_patches, patch_attn_weights, wsi_mean_repr
                )
                cov_features = self._extract_covariance_features(cov_matrix)
                wsi_representation = self._fuse_representations(wsi_mean_repr, cov_features)

                output = self.classifier(wsi_representation)

                return output

            def _compute_covariance_matrix(self, processed_patches, patch_attn_weights, wsi_mean_repr):
                # μ = (Σ wᵢ * xᵢ) / (Σ wᵢ)
                # Cov = (Σ wᵢ * (xᵢ - μ) * (xᵢ - μ)ᵀ) / (Σ wᵢ)
                weighted_mean = wsi_mean_repr.unsqueeze(1)  # [batch_size, 1, hidden_dim]

                centered_features = processed_patches - weighted_mean  # [batch_size, n_patches, hidden_dim]

                # Cov = (X-μ)ᵀ * diag(w) * (X-μ) = Σ wᵢ * (xᵢ-μ) * (xᵢ-μ)ᵀ

                sqrt_weights = torch.sqrt(patch_attn_weights)  # [batch_size, n_patches, 1]
                weighted_centered = sqrt_weights * centered_features  # [batch_size, n_patches, hidden_dim]

                # (√wᵢ*(xᵢ-μ))ᵀ * (√wᵢ*(xᵢ-μ)) = Σ wᵢ*(xᵢ-μ)*(xᵢ-μ)ᵀ
                cov_matrix = torch.bmm(
                    weighted_centered.transpose(1, 2),  # [batch_size, hidden_dim, n_patches]
                    weighted_centered                   # [batch_size, n_patches, hidden_dim]
                )  # [batch_size, hidden_dim, hidden_dim]
                batch_size, hidden_dim, _ = cov_matrix.shape
                eps = 1e-6
                eye = torch.eye(hidden_dim, device=cov_matrix.device).unsqueeze(0).expand(batch_size, -1, -1)
                cov_matrix = cov_matrix + eps * eye

                return cov_matrix

            def _compute_covariance_matrix_fast(self, processed_patches, patch_attn_weights, wsi_mean_repr=None):
                if wsi_mean_repr is None:
                    wsi_mean_repr = torch.sum(patch_attn_weights * processed_patches, dim=1)

                weighted_mean = wsi_mean_repr.unsqueeze(1)
                centered_features = processed_patches - weighted_mean

                sqrt_weights = torch.sqrt(patch_attn_weights)
                weighted_centered = sqrt_weights * centered_features

                cov_matrix = torch.bmm(
                    weighted_centered.transpose(1, 2),
                    weighted_centered
                )

                batch_size, hidden_dim, _ = cov_matrix.shape
                eps = 1e-6
                eye = torch.eye(hidden_dim, device=cov_matrix.device).unsqueeze(0).expand(batch_size, -1, -1)
                cov_matrix = cov_matrix + eps * eye

                return cov_matrix



            def _extract_covariance_features(self, cov_matrix):


                if self.cov_extraction_method == 'weighted_rows':
                    variances = torch.diagonal(cov_matrix, dim1=-2, dim2=-1)  # [batch_size, hidden_dim]
                    weights = torch.softmax(variances, dim=-1)  # [batch_size, hidden_dim]
                    weighted_row_means = torch.sum(
                        cov_matrix * weights.unsqueeze(-1), dim=-1
                    )  # [batch_size, hidden_dim]
                    return weighted_row_means

                elif self.cov_extraction_method == 'upper_triangle':
                    batch_size, dim, _ = cov_matrix.shape
                    triu_indices = torch.triu_indices(dim, dim, device=cov_matrix.device)
                    upper_triangle = cov_matrix[:, triu_indices[0], triu_indices[1]]  # [batch_size, dim*(dim+1)/2]

                    return upper_triangle

                elif self.cov_extraction_method == 'sqrt_upper_triangle':
                    sqrt_matrix = self._newton_schulz_sqrt(
                        cov_matrix,
                        num_iters=self.newton_schulz_iters,
                        eps=self.newton_schulz_eps
                    )

                    batch_size, dim, _ = sqrt_matrix.shape
                    triu_indices = torch.triu_indices(dim, dim, device=sqrt_matrix.device)
                    upper_triangle = sqrt_matrix[:, triu_indices[0], triu_indices[1]]  # [batch_size, dim*(dim+1)/2]

                    return upper_triangle

                elif self.cov_extraction_method == 'conv_pooling':
                    batch_size, dim, _ = cov_matrix.shape
                    rows = cov_matrix.view(batch_size * dim, 1, dim)  # [batch_size * dim, 1, dim]
                    conv_output = self.cov_conv1d(rows)  # [batch_size * dim, num_kernels, conv_length]
                    # conv_length = dim - kernel_size + 1
                    conv_length = dim - self.conv_kernel_size + 1

                    conv_output = conv_output.view(batch_size, dim, self.conv_num_kernels, conv_length)

                    spatial_pooled = torch.max(conv_output, dim=3)[0]  # [batch_size, dim, num_kernels]
                    row_features = torch.max(spatial_pooled, dim=2)[0]  # [batch_size, dim]

                    return row_features

                else:
                    raise ValueError(f"Unknown cov_extraction_method: {self.cov_extraction_method}")

            def _fuse_representations(self, mean_repr, cov_repr):
                if self.fusion_method == 'attention':
                    if mean_repr.shape[1] != cov_repr.shape[1]:
                        cov_repr = self.cov_processor(cov_repr)

                    dual_repr = torch.stack([mean_repr, cov_repr], dim=1)  # [batch_size, 2, hidden_dim]
                    fusion_weights = self.fusion_attention(dual_repr)  # [batch_size, 2, 1]
                    fusion_weights = torch.softmax(fusion_weights, dim=1)  # [batch_size, 2, 1]

                    if self.training:
                        first_norm = torch.norm(mean_repr, dim=1).mean().item()
                        second_norm = torch.norm(cov_repr, dim=1).mean().item()
                        first_mean = mean_repr.mean().item()
                        second_mean = cov_repr.mean().item()
                        first_std = mean_repr.std().item()
                        second_std = cov_repr.std().item()
                        norm_ratio = second_norm / first_norm if first_norm > 0 else 0

                        self.current_vector_stats = {
                            'first_order_norm': first_norm,
                            'second_order_norm': second_norm,
                            'first_order_mean': first_mean,
                            'second_order_mean': second_mean,
                            'first_order_std': first_std,
                            'second_order_std': second_std,
                            'norm_ratio': norm_ratio
                        }



                    fused_repr = torch.sum(fusion_weights * dual_repr, dim=1)  # [batch_size, hidden_dim]
                    return fused_repr

                elif self.fusion_method == 'concatenation':
                    fused_repr = torch.cat([mean_repr, cov_repr], dim=1)
                    return fused_repr

                else:
                    raise ValueError(f"Unknown fusion_method: {self.fusion_method}")

            def _newton_schulz_sqrt(self, A, num_iters=5, eps=1e-8):

                batch_size, dim, _ = A.shape
                device = A.device

                if isinstance(eps, torch.Tensor):
                    eps = eps.item()

                I = torch.eye(dim, device=device).unsqueeze(0).expand(batch_size, -1, -1)
                A_reg = A + float(eps) * I

                Y = A_reg.clone()
                Z = I.clone()

                for i in range(num_iters):
                    # Y_{k+1} = 0.5 * Y_k * (3*I - Z_k * Y_k)
                    # Z_{k+1} = 0.5 * (3*I - Z_k * Y_k) * Z_k

                    ZY = torch.bmm(Z, Y)
                    Y_new = 0.5 * torch.bmm(Y, 3 * I - ZY)
                    Z_new = 0.5 * torch.bmm(3 * I - ZY, Z)

                    Y = Y_new
                    Z = Z_new

                return Y

        return SimpleCovarianceNet(
            feature_dim=feature_dim,
            hidden_dim=self.hidden_dim,
            n_classes=n_classes,
            dropout_rate=self.dropout_rate,
            cov_extraction_method=self.cov_extraction_method,
            fusion_method=self.fusion_method,
            newton_schulz_iters=self.newton_schulz_iters,
            newton_schulz_eps=self.newton_schulz_eps,
            conv_kernel_size=self.conv_kernel_size,
            conv_num_kernels=self.conv_num_kernels
        )
    
    def prepare_data(self, bags: List[Tuple[np.ndarray, Any]],
                    labels: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
        max_patches = max(len(bag_features) for bag_features, _ in bags)

        batch_data = []
        for bag_features, _ in bags:
            if len(bag_features) < max_patches:
                padding = np.zeros((max_patches - len(bag_features), bag_features.shape[1]))
                padded_features = np.vstack([bag_features, padding])
            else:
                padded_features = bag_features
            
            batch_data.append(padded_features)
        
        data_tensor = torch.FloatTensor(np.array(batch_data))
        labels_tensor = torch.LongTensor(labels)

        return data_tensor, labels_tensor

    def _extract_fusion_weights(self, data_tensor: torch.Tensor) -> np.ndarray:
        """Extract fusion weights for visualization"""

        self.model.eval()
        device = next(self.model.parameters()).device
        data_tensor = data_tensor.to(device)

        all_weights = []

        with torch.no_grad():
            for i in range(data_tensor.shape[0]):
                single_bag = data_tensor[i:i+1]

                processed_patches = self.model.feature_processor(single_bag)
                patch_attn_weights = self.model.patch_attention(processed_patches)
                patch_attn_weights = torch.softmax(patch_attn_weights, dim=1)

                wsi_mean_repr = torch.sum(patch_attn_weights * processed_patches, dim=1)

                cov_matrix = self.model._compute_covariance_matrix_fast(
                    processed_patches, patch_attn_weights, wsi_mean_repr
                )

                variances = torch.diagonal(cov_matrix, dim1=-2, dim2=-1)
                weights = torch.softmax(variances, dim=-1)
                weighted_row_means = torch.sum(cov_matrix * weights.unsqueeze(-1), dim=-1)
                wsi_cov_repr = self.model.cov_processor(weighted_row_means)

                dual_repr = torch.stack([wsi_mean_repr, wsi_cov_repr], dim=1)
                fusion_weights = self.model.fusion_attention(dual_repr)
                fusion_weights = torch.softmax(fusion_weights, dim=1)

                weights_np = fusion_weights.squeeze().cpu().numpy()
                all_weights.append(weights_np)

        return np.array(all_weights)

    def _calculate_auc_f1(self, labels: List[int], probs: np.ndarray) -> Tuple[float, float]:
        """Calculate AUC and F1 scores"""
        from sklearn.metrics import roc_auc_score, f1_score

        try:
            if len(set(labels)) > 1:
                if probs.shape[1] == 2:
                    auc = roc_auc_score(labels, probs[:, 1])
                else:
                    auc = roc_auc_score(labels, probs, multi_class='ovr')
            else:
                auc = 0.0

            predictions = np.argmax(probs, axis=1)
            f1 = f1_score(labels, predictions, average='weighted')

        except Exception as e:
            print(f"⚠️  Metric calculation failed: {e}")
            auc, f1 = 0.0, 0.0

        return auc, f1





    def _save_final_analysis(self):
        """Save final weight analysis"""
        if not self.enable_visualization:
            return

        try:
            # Set font to avoid Chinese character display issues
            plt.rcParams['font.family'] = ['DejaVu Sans', 'Liberation Sans']
            plt.rcParams['axes.unicode_minus'] = False

            fig, axes = plt.subplots(2, 2, figsize=(15, 12))
            fig.suptitle('Simple Covariance Final Weight Analysis', fontsize=16, fontweight='bold')

            epochs_range = range(1, len(self.training_history['first_order_weight']) + 1)

            # Weight evolution trend
            axes[0, 0].plot(epochs_range, self.training_history['first_order_weight'], '#2F5597',
                           label='1st Order Weight', linewidth=2, marker='o', markersize=3)
            axes[0, 0].plot(epochs_range, self.training_history['second_order_weight'], '#C55A11',
                           label='2nd Order Weight', linewidth=2, marker='s', markersize=3)
            axes[0, 0].set_xlabel('Epoch', fontsize=12)
            axes[0, 0].set_ylabel('Weight Value', fontsize=12)
            axes[0, 0].set_title('Weight Evolution During Training', fontsize=14, fontweight='bold')
            axes[0, 0].legend(fontsize=10)
            axes[0, 0].tick_params(axis='both', which='major', labelsize=10)
            axes[0, 0].grid(True, alpha=0.3)
            axes[0, 0].set_ylim(0, 1)

            # Final weight distribution
            final_weights = self.training_history['fusion_weights_mean'][-1]
            axes[0, 1].pie(final_weights, labels=['1st Order (Mean)', '2nd Order (Covariance)'],
                          autopct='%1.1f%%', startangle=90, colors=['#2F5597', '#C55A11'], textprops={'fontsize': 10})
            axes[0, 1].set_title(f'Final Weight Distribution\n1st Order: {final_weights[0]:.3f}, 2nd Order: {final_weights[1]:.3f}',
                               fontsize=14, fontweight='bold')

            # Weight stability analysis
            first_order_weights = np.array(self.training_history['first_order_weight'])
            second_order_weights = np.array(self.training_history['second_order_weight'])

            axes[1, 0].hist(first_order_weights, bins=20, alpha=0.7, label='1st Order Weight', color='#2F5597', edgecolor='black')
            axes[1, 0].hist(second_order_weights, bins=20, alpha=0.7, label='2nd Order Weight', color='#C55A11', edgecolor='black')
            axes[1, 0].set_xlabel('Weight Value', fontsize=12)
            axes[1, 0].set_ylabel('Frequency', fontsize=12)
            axes[1, 0].set_title('Weight Distribution Histogram', fontsize=14, fontweight='bold')
            axes[1, 0].legend(fontsize=10)
            axes[1, 0].tick_params(axis='both', which='major', labelsize=10)
            axes[1, 0].grid(True, alpha=0.3)

            # Weight statistics
            axes[1, 1].axis('off')
            stats_text = f"""
Weight Statistics Analysis:

1st Order Weight:
  Final Value: {final_weights[0]:.4f}
  Mean: {np.mean(first_order_weights):.4f}
  Std Dev: {np.std(first_order_weights):.4f}
  Min: {np.min(first_order_weights):.4f}
  Max: {np.max(first_order_weights):.4f}

2nd Order Weight:
  Final Value: {final_weights[1]:.4f}
  Mean: {np.mean(second_order_weights):.4f}
  Std Dev: {np.std(second_order_weights):.4f}
  Min: {np.min(second_order_weights):.4f}
  Max: {np.max(second_order_weights):.4f}

Weight Correlation: {np.corrcoef(first_order_weights, second_order_weights)[0, 1]:.4f}

Training Epochs: {len(first_order_weights)}
"""
            axes[1, 1].text(0.1, 0.9, stats_text, transform=axes[1, 1].transAxes,
                            fontsize=10, verticalalignment='top', fontfamily='monospace',
                            bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.5))

            plt.tight_layout()

            # Save final analysis
            final_path = Path(self.viz_save_dir) / 'final_weight_analysis.png'
            plt.savefig(final_path, dpi=300, bbox_inches='tight')
            plt.close()

            print(f"Final weight analysis saved: {final_path}")

        except Exception as e:
            print(f"Final analysis save failed: {e}")

    def _extract_fusion_weights_duplicate(self, data_tensor: torch.Tensor) -> np.ndarray:

        self.model.eval()
        device = next(self.model.parameters()).device
        data_tensor = data_tensor.to(device)

        all_weights = []

        with torch.no_grad():
            for i in range(data_tensor.shape[0]):
                single_bag = data_tensor[i:i+1]  # [1, n_patches, feature_dim]

                processed_patches = self.model.feature_processor(single_bag)
                patch_attn_weights = self.model.patch_attention(processed_patches)
                patch_attn_weights = torch.softmax(patch_attn_weights, dim=1)
                wsi_mean_repr = torch.sum(patch_attn_weights * processed_patches, dim=1)
                cov_matrix = self.model._compute_covariance_matrix(processed_patches, patch_attn_weights, wsi_mean_repr)
                cov_features = self.model._extract_covariance_features(cov_matrix)
                if wsi_mean_repr.shape[1] != cov_features.shape[1]:
                    wsi_cov_repr = self.model.cov_processor(cov_features)
                else:
                    wsi_cov_repr = cov_features
                dual_repr = torch.stack([wsi_mean_repr, wsi_cov_repr], dim=1)
                fusion_weights = self.model.fusion_attention(dual_repr)
                fusion_weights = torch.softmax(fusion_weights, dim=1)

                weights_np = fusion_weights.squeeze().cpu().numpy()
                all_weights.append(weights_np)

        return np.array(all_weights)  # [batch_size, 2]

    def _calculate_auc_f1(self, labels: List[int], probs: np.ndarray) -> Tuple[float, float]:
        from sklearn.metrics import roc_auc_score, f1_score

        try:
            if len(set(labels)) > 1: 
                if probs.shape[1] == 2: 
                    auc = roc_auc_score(labels, probs[:, 1])
                else:
                    auc = roc_auc_score(labels, probs, multi_class='ovr')
            else:
                auc = 0.0

            predictions = np.argmax(probs, axis=1)
            f1 = f1_score(labels, predictions, average='weighted')

        except Exception as e:
            print(f"Metric calculation failed: {e}")
            auc, f1 = 0.0, 0.0

        return auc, f1

    def _update_visualization(self, epoch: int, train_metrics: dict, val_metrics: dict,
                            train_data: torch.Tensor, val_data: torch.Tensor = None):
        """Update visualization data"""
        if not self.enable_visualization:
            return

        # Record training metrics
        self.training_history['train_loss'].append(train_metrics.get('loss', 0))
        self.training_history['train_acc'].append(train_metrics.get('acc', 0))
        self.training_history['train_auc'].append(train_metrics.get('auc', 0))
        self.training_history['train_f1'].append(train_metrics.get('f1', 0))

        # Record validation metrics
        self.training_history['val_loss'].append(val_metrics.get('loss', 0))
        self.training_history['val_acc'].append(val_metrics.get('acc', 0))
        self.training_history['val_auc'].append(val_metrics.get('auc', 0))
        self.training_history['val_f1'].append(val_metrics.get('f1', 0))

        # Extract fusion weights
        train_weights = self._extract_fusion_weights(train_data)
        mean_weights = np.mean(train_weights, axis=0)
        std_weights = np.std(train_weights, axis=0)

        self.training_history['first_order_weight'].append(mean_weights[0])
        self.training_history['second_order_weight'].append(mean_weights[1])
        self.training_history['fusion_weights_mean'].append(mean_weights)
        self.training_history['fusion_weights_std'].append(std_weights)

        if hasattr(self.model, 'current_vector_stats'):
            stats = self.model.current_vector_stats
            self.training_history['first_order_norm'].append(stats['first_order_norm'])
            self.training_history['second_order_norm'].append(stats['second_order_norm'])
            self.training_history['first_order_mean'].append(stats['first_order_mean'])
            self.training_history['second_order_mean'].append(stats['second_order_mean'])
            self.training_history['first_order_std'].append(stats['first_order_std'])
            self.training_history['second_order_std'].append(stats['second_order_std'])
            self.training_history['norm_ratio'].append(stats['norm_ratio'])
        else:
            self.training_history['first_order_norm'].append(0.0)
            self.training_history['second_order_norm'].append(0.0)
            self.training_history['first_order_mean'].append(0.0)
            self.training_history['second_order_mean'].append(0.0)
            self.training_history['first_order_std'].append(0.0)
            self.training_history['second_order_std'].append(0.0)
            self.training_history['norm_ratio'].append(0.0)

            print(f"Epoch {epoch+1}: No vector statistics available, using default values 0.0")

        # Generate visualization every N epochs
        if epoch % self.viz_update_freq == 0:
            self._generate_realtime_plots(epoch)

    def _generate_realtime_plots(self, epoch: int):
        """Generate real-time visualization plots"""
        try:
            # Set font to avoid Chinese character display issues
            plt.rcParams['font.family'] = ['DejaVu Sans', 'Liberation Sans']
            plt.rcParams['axes.unicode_minus'] = False

            fig, axes = plt.subplots(3, 3, figsize=(20, 15))
            fig.suptitle(f'HOMIL Training Monitor - Epoch {epoch+1}', fontsize=16, fontweight='bold')

            epochs_range = range(1, len(self.training_history['train_loss']) + 1)

            # 1. Loss curves
            axes[0, 0].plot(epochs_range, self.training_history['train_loss'], '#006666', label='Train Loss', linewidth=2)
            axes[0, 0].plot(epochs_range, self.training_history['val_loss'], '#920000', label='Val Loss', linewidth=2)
            axes[0, 0].set_xlabel('Epoch', fontsize=11)
            axes[0, 0].set_ylabel('Loss', fontsize=11)
            axes[0, 0].set_title('Loss Curves', fontsize=12, fontweight='bold')
            axes[0, 0].legend(fontsize=9)
            axes[0, 0].tick_params(axis='both', which='major', labelsize=9)
            axes[0, 0].grid(True, alpha=0.3)

            # 2. Accuracy curves
            axes[0, 1].plot(epochs_range, self.training_history['train_acc'], '#006666', label='Train Acc', linewidth=2)
            axes[0, 1].plot(epochs_range, self.training_history['val_acc'], '#920000', label='Val Acc', linewidth=2)
            axes[0, 1].set_xlabel('Epoch', fontsize=11)
            axes[0, 1].set_ylabel('Accuracy', fontsize=11)
            axes[0, 1].set_title('Accuracy Curves', fontsize=12, fontweight='bold')
            axes[0, 1].legend(fontsize=9)
            axes[0, 1].tick_params(axis='both', which='major', labelsize=9)
            axes[0, 1].grid(True, alpha=0.3)

            # 3. AUC curves
            axes[0, 2].plot(epochs_range, self.training_history['train_auc'], '#006666', label='Train AUC', linewidth=2)
            axes[0, 2].plot(epochs_range, self.training_history['val_auc'], '#920000', label='Val AUC', linewidth=2)
            axes[0, 2].set_xlabel('Epoch', fontsize=11)
            axes[0, 2].set_ylabel('AUC', fontsize=11)
            axes[0, 2].set_title('AUC Curves', fontsize=12, fontweight='bold')
            axes[0, 2].legend(fontsize=9)
            axes[0, 2].tick_params(axis='both', which='major', labelsize=9)
            axes[0, 2].grid(True, alpha=0.3)

            # 4. Fusion weights evolution
            axes[1, 0].plot(epochs_range, self.training_history['first_order_weight'], '#2F5597',
                           label=r'$v^{(1)}: \alpha$', linewidth=2, markersize=4)
            axes[1, 0].plot(epochs_range, self.training_history['second_order_weight'], '#C55A11',
                           label=r'$v^{(2)}: 1-\alpha$', linewidth=2, markersize=4)
            axes[1, 0].set_xlabel('Epoch', fontsize=11)
            axes[1, 0].set_ylabel('Weight Value', fontsize=11)
            axes[1, 0].set_title('Fusion Weights', fontsize=12, fontweight='bold')
            axes[1, 0].legend(fontsize=9)
            axes[1, 0].tick_params(axis='both', which='major', labelsize=9)
            axes[1, 0].grid(True, alpha=0.3)
            axes[1, 0].set_ylim(0, 1)

            # 5. Current weight distribution pie chart
            current_weights = self.training_history['fusion_weights_mean'][-1]
            axes[1, 1].pie(current_weights, labels=['1st Order (Mean)', '2nd Order (Covariance)'],
                          autopct='%1.1f%%', startangle=90, colors=['#2F5597', '#C55A11'], textprops={'fontsize': 9})
            axes[1, 1].set_title(f'Current Weight Distribution\n1st Order: {current_weights[0]:.3f}, 2nd Order: {current_weights[1]:.3f}',
                               fontsize=12, fontweight='bold')

            # 6. F1 score curves
            axes[1, 2].plot(epochs_range, self.training_history['train_f1'], '#006666', label='Train F1', linewidth=2)
            axes[1, 2].plot(epochs_range, self.training_history['val_f1'], '#920000', label='Val F1', linewidth=2)
            axes[1, 2].set_xlabel('Epoch', fontsize=11)
            axes[1, 2].set_ylabel('F1 Score', fontsize=11)
            axes[1, 2].set_title('F1 Score Curves', fontsize=12, fontweight='bold')
            axes[1, 2].legend(fontsize=9)
            axes[1, 2].tick_params(axis='both', which='major', labelsize=9)
            axes[1, 2].grid(True, alpha=0.3)

            # 7. Vector norms comparison
            if len(self.training_history['first_order_norm']) > 0:
                axes[2, 0].plot(epochs_range, self.training_history['first_order_norm'], '#2F5597',
                               label='1st Order Norm', linewidth=2, marker='o', markersize=3)
                axes[2, 0].plot(epochs_range, self.training_history['second_order_norm'], '#C55A11',
                               label='2nd Order Norm', linewidth=2, marker='s', markersize=3)
                axes[2, 0].set_xlabel('Epoch', fontsize=11)
                axes[2, 0].set_ylabel('Vector Norm', fontsize=11)
                axes[2, 0].set_title('Vector Norms Evolution', fontsize=12, fontweight='bold')
                axes[2, 0].legend(fontsize=9)
                axes[2, 0].tick_params(axis='both', which='major', labelsize=9)
                axes[2, 0].grid(True, alpha=0.3)

            # 8. Norm ratio evolution
            if len(self.training_history['norm_ratio']) > 0:
                axes[2, 1].plot(epochs_range, self.training_history['norm_ratio'], '#006666',
                               label='2nd/1st Norm Ratio', linewidth=2, marker='d', markersize=3)
                axes[2, 1].set_xlabel('Epoch', fontsize=11)
                axes[2, 1].set_ylabel('Norm Ratio', fontsize=11)
                axes[2, 1].set_title('Norm Ratio (2nd Order / 1st Order)', fontsize=12, fontweight='bold')
                axes[2, 1].legend(fontsize=9)
                axes[2, 1].tick_params(axis='both', which='major', labelsize=9)
                axes[2, 1].grid(True, alpha=0.3)
                axes[2, 1].axhline(y=1.0, color='#920000', linestyle='--', alpha=0.5, label='Equal Norm')

            # 9. Vector means comparison
            if len(self.training_history['first_order_mean']) > 0:
                axes[2, 2].plot(epochs_range, self.training_history['first_order_mean'], '#2F5597',
                               label='1st Order Mean', linewidth=2, marker='o', markersize=3)
                axes[2, 2].plot(epochs_range, self.training_history['second_order_mean'], '#C55A11',
                               label='2nd Order Mean', linewidth=2, marker='s', markersize=3)
                axes[2, 2].set_xlabel('Epoch', fontsize=11)
                axes[2, 2].set_ylabel('Vector Mean', fontsize=11)
                axes[2, 2].set_title('Vector Means Evolution', fontsize=12, fontweight='bold')
                axes[2, 2].legend(fontsize=9)
                axes[2, 2].tick_params(axis='both', which='major', labelsize=9)
                axes[2, 2].grid(True, alpha=0.3)

            plt.tight_layout()

            # Save plot
            save_path = Path(self.viz_save_dir) / f'training_progress_epoch_{epoch+1:03d}.png'
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()

            # Save current plot
            current_path = Path(self.viz_save_dir) / 'current_training_progress.png'
            save_path_copy = Path(self.viz_save_dir) / f'training_progress_epoch_{epoch+1:03d}.png'
            if save_path_copy.exists():
                import shutil
                shutil.copy2(save_path_copy, current_path)

            print(f"Visualization updated: {save_path}")

        except Exception as e:
            print(f"Visualization update failed: {e}")

    def _generate_combined_plots(self, epoch: int, epochs_range):
        try:
            fig1, axes1 = plt.subplots(1, 2, figsize=(12, 5))
            fig1.suptitle(f'Loss & Fusion Analysis - Epoch {epoch+1}', fontsize=18, fontweight='bold')

            # a. Loss curves
            axes1[0].plot(epochs_range, self.training_history['train_loss'], '#006666', label='Train Loss', linewidth=2)
            axes1[0].plot(epochs_range, self.training_history['val_loss'], '#920000', label='Val Loss', linewidth=2)
            axes1[0].set_xlabel('Epoch', fontsize=14)
            axes1[0].set_ylabel('Loss', fontsize=14)
            axes1[0].set_title('Loss Curves', fontsize=16, fontweight='bold')
            axes1[0].legend(fontsize=12)
            axes1[0].tick_params(axis='both', which='major', labelsize=12)
            axes1[0].grid(True, alpha=0.3)
            axes1[0].text(0.02, 0.98, 'a', transform=axes1[0].transAxes, fontsize=20, fontweight='bold',
                         verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

            # b. Fusion weights evolution
            axes1[1].plot(epochs_range, self.training_history['first_order_weight'], '#2F5597',
                         label=r'$\alpha^{(1)} (v^{(1)})$', linewidth=2)
            axes1[1].plot(epochs_range, self.training_history['second_order_weight'], '#C55A11',
                         label=r'$\alpha^{(2)} (v^{(2)})$', linewidth=2)
            axes1[1].set_xlabel('Epoch', fontsize=14)
            axes1[1].set_ylabel('Weight Value', fontsize=14)
            axes1[1].set_title('Fusion Weights', fontsize=16, fontweight='bold')
            axes1[1].legend(fontsize=12)
            axes1[1].tick_params(axis='both', which='major', labelsize=12)
            axes1[1].grid(True, alpha=0.3)
            axes1[1].set_ylim(0, 1)
            axes1[1].text(0.02, 0.98, 'b', transform=axes1[1].transAxes, fontsize=20, fontweight='bold',
                         verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

            plt.tight_layout()

            save_path1 = Path(self.viz_save_dir) / f'loss_fusion_epoch_{epoch+1:03d}.png'
            plt.savefig(save_path1, dpi=150, bbox_inches='tight')

            current_path1 = Path(self.viz_save_dir) / 'current_loss_fusion.png'
            plt.savefig(current_path1, dpi=150, bbox_inches='tight')
            plt.close()

            fig2, axes2 = plt.subplots(2, 2, figsize=(12, 10))
            fig2.suptitle(f'Comprehensive Training Analysis - Epoch {epoch+1}', fontsize=18, fontweight='bold')

            # a. Loss curves
            axes2[0, 0].plot(epochs_range, self.training_history['train_loss'], '#006666', label='Train Loss', linewidth=2)
            axes2[0, 0].plot(epochs_range, self.training_history['val_loss'], '#920000', label='Val Loss', linewidth=2)
            axes2[0, 0].set_xlabel('Epoch', fontsize=12)
            axes2[0, 0].set_ylabel('Loss', fontsize=12)
            axes2[0, 0].set_title('Loss Curves', fontsize=14, fontweight='bold')
            axes2[0, 0].legend(fontsize=10)
            axes2[0, 0].tick_params(axis='both', which='major', labelsize=10)
            axes2[0, 0].grid(True, alpha=0.3)
            axes2[0, 0].text(0.02, 0.98, 'a', transform=axes2[0, 0].transAxes, fontsize=16, fontweight='bold',
                            verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

            # b. Accuracy curves
            axes2[0, 1].plot(epochs_range, self.training_history['train_acc'], '#006666', label='Train Acc', linewidth=2)
            axes2[0, 1].plot(epochs_range, self.training_history['val_acc'], '#920000', label='Val Acc', linewidth=2)
            axes2[0, 1].set_xlabel('Epoch', fontsize=12)
            axes2[0, 1].set_ylabel('Accuracy', fontsize=12)
            axes2[0, 1].set_title('Accuracy Curves', fontsize=14, fontweight='bold')
            axes2[0, 1].legend(fontsize=10)
            axes2[0, 1].tick_params(axis='both', which='major', labelsize=10)
            axes2[0, 1].grid(True, alpha=0.3)
            axes2[0, 1].text(0.02, 0.98, 'b', transform=axes2[0, 1].transAxes, fontsize=16, fontweight='bold',
                            verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

            # c. Fusion weights evolution
            axes2[1, 0].plot(epochs_range, self.training_history['first_order_weight'], '#2F5597',
                            label='1st Order (Mean)', linewidth=2, marker='o', markersize=3)
            axes2[1, 0].plot(epochs_range, self.training_history['second_order_weight'], '#C55A11',
                            label='2nd Order (Cov)', linewidth=2, marker='s', markersize=3)
            axes2[1, 0].set_xlabel('Epoch', fontsize=12)
            axes2[1, 0].set_ylabel('Weight Value', fontsize=12)
            axes2[1, 0].set_title('Fusion Weights Evolution', fontsize=14, fontweight='bold')
            axes2[1, 0].legend(fontsize=10)
            axes2[1, 0].tick_params(axis='both', which='major', labelsize=10)
            axes2[1, 0].grid(True, alpha=0.3)
            axes2[1, 0].set_ylim(0, 1)
            axes2[1, 0].text(0.02, 0.98, 'c', transform=axes2[1, 0].transAxes, fontsize=16, fontweight='bold',
                            verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

            # d. Current weight distribution (pie chart)
            current_weights = self.training_history['fusion_weights_mean'][-1]
            axes2[1, 1].pie(current_weights, labels=['1st Order (Mean)', '2nd Order (Covariance)'],
                           autopct='%1.1f%%', startangle=90, colors=['#2F5597', '#C55A11'], textprops={'fontsize': 10})
            axes2[1, 1].set_title(f'Current Weight Distribution\n1st Order: {current_weights[0]:.3f}, 2nd Order: {current_weights[1]:.3f}',
                                 fontsize=14, fontweight='bold')
            axes2[1, 1].text(0.02, 0.98, 'd', transform=axes2[1, 1].transAxes, fontsize=16, fontweight='bold',
                            verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

            plt.tight_layout()

            save_path2 = Path(self.viz_save_dir) / f'comprehensive_analysis_epoch_{epoch+1:03d}.png'
            plt.savefig(save_path2, dpi=150, bbox_inches='tight')

            current_path2 = Path(self.viz_save_dir) / 'current_comprehensive_analysis.png'
            plt.savefig(current_path2, dpi=150, bbox_inches='tight')
            plt.close()

            print(f"Combined plots saved: {save_path1}, {save_path2}")

        except Exception as e:
            print(f"Combined plots generation failed: {e}")

    def _save_final_analysis_duplicate(self):
        """Save final weight analysis"""
        if not self.enable_visualization:
            return

        try:
            # Set font to avoid Chinese character display issues
            plt.rcParams['font.family'] = ['DejaVu Sans', 'Liberation Sans']
            plt.rcParams['axes.unicode_minus'] = False

            fig, axes = plt.subplots(2, 2, figsize=(15, 12))
            fig.suptitle('Simple Covariance Final Weight Analysis', fontsize=16, fontweight='bold')

            epochs_range = range(1, len(self.training_history['first_order_weight']) + 1)

            # Weight evolution trend
            axes[0, 0].plot(epochs_range, self.training_history['first_order_weight'], '#2F5597',
                           label='1st Order Weight', linewidth=2, marker='o', markersize=3)
            axes[0, 0].plot(epochs_range, self.training_history['second_order_weight'], '#C55A11',
                           label='2nd Order Weight', linewidth=2, marker='s', markersize=3)
            axes[0, 0].set_xlabel('Epoch')
            axes[0, 0].set_ylabel('Weight Value')
            axes[0, 0].set_title('Weight Evolution During Training')
            axes[0, 0].legend()
            axes[0, 0].grid(True, alpha=0.3)
            axes[0, 0].set_ylim(0, 1)

            # Final weight distribution
            final_weights = self.training_history['fusion_weights_mean'][-1]
            axes[0, 1].pie(final_weights, labels=['1st Order (Mean)', '2nd Order (Covariance)'],
                          autopct='%1.1f%%', startangle=90, colors=['#2F5597', '#C55A11'])
            axes[0, 1].set_title(f'Final Weight Distribution\n1st Order: {final_weights[0]:.3f}, 2nd Order: {final_weights[1]:.3f}')

            # Weight stability analysis
            first_order_weights = np.array(self.training_history['first_order_weight'])
            second_order_weights = np.array(self.training_history['second_order_weight'])

            axes[1, 0].hist(first_order_weights, bins=20, alpha=0.7, label='1st Order Weight', color='#2F5597', edgecolor='black')
            axes[1, 0].hist(second_order_weights, bins=20, alpha=0.7, label='2nd Order Weight', color='#C55A11', edgecolor='black')
            axes[1, 0].set_xlabel('Weight Value')
            axes[1, 0].set_ylabel('Frequency')
            axes[1, 0].set_title('Weight Distribution Histogram')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)

            # Weight statistics
            axes[1, 1].axis('off')
            stats_text = f"""
Weight Statistics Analysis:

1st Order Weight:
  Final Value: {final_weights[0]:.4f}
  Mean: {np.mean(first_order_weights):.4f}
  Std Dev: {np.std(first_order_weights):.4f}
  Min: {np.min(first_order_weights):.4f}
  Max: {np.max(first_order_weights):.4f}

2nd Order Weight:
  Final Value: {final_weights[1]:.4f}
  Mean: {np.mean(second_order_weights):.4f}
  Std Dev: {np.std(second_order_weights):.4f}
  Min: {np.min(second_order_weights):.4f}
  Max: {np.max(second_order_weights):.4f}

Weight Correlation: {np.corrcoef(first_order_weights, second_order_weights)[0, 1]:.4f}

Training Epochs: {len(first_order_weights)}
"""
            axes[1, 1].text(0.1, 0.9, stats_text, transform=axes[1, 1].transAxes,
                            fontsize=11, verticalalignment='top', fontfamily='monospace',
                            bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.5))

            plt.tight_layout()

            # Save final analysis
            final_path = Path(self.viz_save_dir) / 'final_weight_analysis.png'
            plt.savefig(final_path, dpi=300, bbox_inches='tight')
            plt.close()

            print(f"Final weight analysis saved: {final_path}")

        except Exception as e:
            print(f"Final analysis save failed: {e}")

    def train_epoch(self, train_data: torch.Tensor, train_labels: torch.Tensor,
                   val_data: torch.Tensor, val_labels: torch.Tensor,
                   epoch: int) -> Tuple[float, float, float, float]:
        if not hasattr(self, 'optimizer'):
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
                weight_decay=self.weight_decay
            )
            self.criterion = nn.CrossEntropyLoss()
        self.model.train()
        dataset = torch.utils.data.TensorDataset(train_data, train_labels)
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=self.batch_size, shuffle=True
        )

        train_losses = []
        train_correct = 0
        total_samples = 0

        for batch_data, batch_labels in dataloader:
            batch_data = self._to_device(batch_data)
            batch_labels = self._to_device(batch_labels)

            self.optimizer.zero_grad()

            outputs = self.model(batch_data)
            loss = self.criterion(outputs, batch_labels)

            loss.backward()
            self.optimizer.step()

            train_losses.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            train_correct += (predicted == batch_labels).sum().item()
            total_samples += batch_labels.size(0)

        train_loss = np.mean(train_losses)
        train_acc = train_correct / total_samples

        self.model.eval()
        with torch.no_grad():
            val_data = self._to_device(val_data)
            val_labels = self._to_device(val_labels)

            val_outputs = self.model(val_data)
            val_loss = self.criterion(val_outputs, val_labels).item()
            _, val_predicted = torch.max(val_outputs.data, 1)
            val_correct = (val_predicted == val_labels).sum().item()
            val_acc = val_correct / len(val_labels)

        if self.enable_visualization:
            self.model.eval()
            with torch.no_grad():
                train_outputs = self.model(self._to_device(train_data))
                train_probs = torch.softmax(train_outputs, dim=1).cpu().numpy()
                train_auc, train_f1 = self._calculate_auc_f1(train_labels.cpu().numpy(), train_probs)

                val_outputs = self.model(self._to_device(val_data))
                val_probs = torch.softmax(val_outputs, dim=1).cpu().numpy()
                val_auc, val_f1 = self._calculate_auc_f1(val_labels.cpu().numpy(), val_probs)

            train_metrics = {'loss': train_loss, 'acc': train_acc, 'auc': train_auc, 'f1': train_f1}
            val_metrics = {'loss': val_loss, 'acc': val_acc, 'auc': val_auc, 'f1': val_f1}
            self._update_visualization(epoch, train_metrics, val_metrics, train_data, val_data)

            if epoch == self.epochs - 1:
                self._generate_realtime_plots(epoch)
                self._save_final_analysis()

                epochs_range = range(1, len(self.training_history['train_loss']) + 1)
                self._generate_combined_plots(epoch, epochs_range)

        return train_loss, train_acc, val_loss, val_acc



    def predict_bags(self, bags: List[Tuple[np.ndarray, Any]]) -> PredictionResult:
        self.model.eval()

        data_tensor, _ = self.prepare_data(bags, [0] * len(bags)) 

        with torch.no_grad():
            data_tensor = self._to_device(data_tensor)
            outputs = self.model(data_tensor)
            probabilities = torch.softmax(outputs, dim=1).cpu().numpy()
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
            confidence = np.max(probabilities, axis=1)

        return PredictionResult(
            predictions=predictions,
            probabilities=probabilities,
            confidence=confidence,
            bag_names=[]  
        )

    def _get_wsi_feature_for_bag(self, bag_features: np.ndarray) -> np.ndarray:
        data_tensor = torch.FloatTensor(bag_features).unsqueeze(0)

        self.model.eval()
        with torch.no_grad():
            data_tensor = self._to_device(data_tensor)

            processed_patches = self.model.feature_processor(data_tensor)  # [1, n_patches, hidden_dim]

            patch_attn_weights = self.model.patch_attention(processed_patches)  # [1, n_patches, 1]
            patch_attn_weights = torch.softmax(patch_attn_weights, dim=1)

            wsi_mean_repr = torch.sum(patch_attn_weights * processed_patches, dim=1)  # [1, hidden_dim]

            cov_matrix = self.model._compute_covariance_matrix_fast(
                processed_patches, patch_attn_weights, wsi_mean_repr
            )

            variances = torch.diagonal(cov_matrix, dim1=-2, dim2=-1)  # [1, hidden_dim]

            weights = torch.softmax(variances, dim=-1)  # [1, hidden_dim]

            weighted_row_means = torch.sum(
                cov_matrix * weights.unsqueeze(-1),  # [1, hidden_dim, hidden_dim]
                dim=-1
            )  # [1, hidden_dim]

            wsi_cov_repr = self.model.cov_processor(weighted_row_means)  # [1, hidden_dim]

            dual_repr = torch.stack([wsi_mean_repr, wsi_cov_repr], dim=1)  # [1, 2, hidden_dim]
            fusion_weights = self.model.fusion_attention(dual_repr)  # [1, 2, 1]
            fusion_weights = torch.softmax(fusion_weights, dim=1)

            wsi_representation = torch.sum(fusion_weights * dual_repr, dim=1)  # [1, hidden_dim]

            return wsi_representation.squeeze(0).cpu().numpy()
