# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    set_seed, AutoTokenizer, AutoModel, AutoConfig,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import random
import time
import os
import re
#import math
import matplotlib.pyplot as plt


# ==============================================================================
# 1. CONFIGURATION & EXPERIMENT SETUP
# ==============================================================================

GLOBAL_SEED = 7
MODEL_NAME = 'bert-base-uncased'
MAX_SEQ_LENGTH = 256
NUM_EPOCHS = 30
BATCH_SIZE = 64
LEARNING_RATE = 2e-5
K_FOLDS = 3
SAMPLE_FRACTION = 0.01
# Run experiments
TASKS = ['sst2'] # imdb, ths = Twitter Hate Speech, dt = Disaster Tweets
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MASKS_DIR = "./Masks"

# Sparse attention training parameters
MASK_FIX_EPOCH_RATIO = 0.5 #=================

# --- BENCHMARK CONFIGS ---
"""
CFCP_BENCHMARK_CONFIGS: List of sparse attention configurations to benchmark.

Each configuration is a dictionary with:
  - 'proj': Projection type ('CFCP' or 'BigBird')
  - 'l': Sparsity level parameter for CFCP (controls Hoyer index)
  - 'window_size': Local attention window size for BigBird

To run experiments:
1. Uncomment the desired configurations below
2. Multiple configs will run sequentially on all tasks in TASKS
3. Results are saved to 'CFCP_benchmark_results_{task}.csv'
"""
CFCP_BENCHMARK_CONFIGS = [
    # === CFCP === #
    {"proj": "CFCP", "l": 1},
    # {"proj": "CFCP", "l": 2},
    # {"proj": "CFCP", "l": 4},
    # {"proj": "CFCP", "l": 8},
    # {"proj": "CFCP", "l": 12},
    # {"proj": "CFCP", "l": 16},
    # {"proj": "CFCP", "l": 24},


    # === BigBird === #
    {"proj": "BigBird", "window_size": 1},
    # {"proj": "BigBird", "window_size": 2},
    # {"proj": "BigBird", "window_size": 4},
    # {"proj": "BigBird", "window_size": 8},
    # {"proj": "BigBird", "window_size": 12},
    # {"proj": "BigBird", "window_size": 16},
    # {"proj": "BigBird", "window_size": 24},


]

# ==============================================================================
# 2. PROJECTION & SPARSE UTILITY FUNCTIONS
# ==============================================================================

def get_config_name(config):
    """Generate config name from proj and parameters (l or window_size)."""
    proj = config.get('proj', 'Unknown')
    if 'l' in config:
        return f"{proj} {config['l']}"
    elif 'window_size' in config:
        return f"{proj} {config['window_size']}"
    else:
        return proj

def sparsity(M, tol=1.0e-3):
    """Return the L0 sparsity for the input tensor M."""
    if type(M) is not torch.Tensor:
        M = torch.as_tensor(M)
    M1 = torch.where(torch.abs(M.float()) < tol, torch.zeros_like(M), M)
    nb_nonzero = len(M1.nonzero())
    return 1.0 - nb_nonzero / M1.numel()

def reshapeDiag(m):
    """Reshapes matrix M into a circulant-like matrix C for diagonal-based projections."""
    shape = m.shape
    S1 = shape[-2] + shape[-1]
    S2 = S1 + 1
    
    mainsize = shape[:-2]   
    
    # Create padded matrix C
    C = torch.zeros(*mainsize, S2, S1, device=m.device)
    C[..., :shape[-2], :shape[-1]] = m
    C = C.reshape((*mainsize, S1, S2))
    
    # Return the part corresponding to the original dimensions
    return C[..., :shape[-2], :]

def sparsityHoyer(M):
    """Compute Hoyer sparsity measure based on max-norm of columns"""
    if type(M) is not torch.Tensor:
        M = torch.as_tensor(M)
    
    # V = M.abs().max(0)[0] # Max absolute value per column (L-infinity norm of each column)
    V = torch.linalg.norm(M, ord=2, dim=0)

    l1 = V.sum()
    l22 = V.dot(V)
    
    if l1 == 0:
        return 1.0
    
    # squared Hoyer index / L1-L2 ratio
    return l1 * l1 / l22 / V.numel()

def CFCP_with_loops(y: torch.Tensor,
                                              l: float,
                                              eps: float = 1e-12,
                                              max_iter: int = 50):
    """
    One-shot CFCP projection for a 1D tensor.
    Returns (x, loops) where loops is the number of active-set stabilizing while-loops.
    """
    assert y.dim() == 1, "Expected 1D tensor"
    n = y.numel()
    sgn = torch.sign(y)
    x = y.abs().clone()

    nu_prev = int((x > 0).sum().item())
    nu = nu_prev + 1
    loops = 0
    alpha = torch.tensor(0.0, device=y.device, dtype=torch.get_default_dtype())
    while nu != nu_prev and loops < max_iter:
        loops += 1
        nu_prev = nu
        mask = (x > 0)
        nu = int(mask.sum().item())
        if nu == 0:
            return torch.zeros_like(y), loops
        x_active = x[mask]
        l1_active = x_active.sum()
        l2_active = torch.sqrt((x_active * x_active).sum().clamp_min(eps))
        Hx = (l1_active / l2_active).pow(2).clamp_min(1.0).clamp_max(float(n))
        num = l * (nu - Hx)
        den = Hx * (nu - l + eps)
        frac = (num / den).clamp_min(0.0)
        root = torch.sqrt(frac)
        alpha = (l1_active / max(nu, 1)) * (1.0 - root)
        x = torch.where(x >= alpha, x, torch.zeros_like(x))

    mask = (x > 0)
    nu = int(mask.sum().item())
    if nu == 0:
        return torch.zeros_like(y), loops
    x_active = x[mask]
    l1_active = x_active.sum()
    denom = (1.0 - (alpha * nu) / (l1_active + eps))
    lam = 1.0 / denom if abs(float(denom)) > eps else torch.tensor(1.0, device=y.device, dtype=torch.get_default_dtype())
    d = torch.zeros_like(x)
    d_val = (l1_active / max(nu, 1))
    d[mask] = d_val
    x = lam * x + (1.0 - lam) * d

    x = x * sgn
    xy = (x * y).sum()
    xx = (x * x).sum().clamp_min(eps)
    scale = (xy / xx) if xx > 0 else torch.tensor(0.0, device=y.device, dtype=torch.get_default_dtype())
    x = x * scale
    return x, loops

def proj_Hoyer(w0, level, device="cpu"):
    """Wrapper for Hoyer projection, handles flattening and reshaping."""
    w = torch.as_tensor(w0, dtype=torch.get_default_dtype(), device=device)
    init_shape = w.size()

    if w.dim() > 1:
        w = w.reshape(-1)

    x, _loops = CFCP_with_loops(w, float(level))

    Q = x.reshape(init_shape).clone().detach()
    if not torch.is_tensor(w0):
        Q = Q.data.numpy()
    return Q

def proj_l2ball(w0, l, device="cpu", eps=1e-12):
    w = torch.as_tensor(w0, dtype=torch.get_default_dtype(), device=device)
    l2 = torch.linalg.norm(w, ord=2)
    if l2 <= l:
        return w
    return w * (l / (l2 + eps))


def bilevel_proj_HoyerL2ball(w2, level, device="cpu", eps=1e-12):
    device = w2.device if isinstance(w2, torch.Tensor) else torch.device(device)
    w = torch.as_tensor(w2, dtype=torch.get_default_dtype(), device=device)

    if w.dim() == 1:
        return proj_Hoyer(w, level, device=device)

    # Column L2 norms
    col_norms = torch.linalg.norm(w, ord=2, dim=0)

    # Hoyer-project the vector of norms
    PW = proj_Hoyer(col_norms, level, device=device)
    PW = torch.where(torch.isnan(PW), torch.zeros_like(PW), PW)

    # Vectorized per-column L2-ball projection
    # scale_i = min(1, PW_i / ||col_i||_2)
    scales = torch.minimum(torch.ones_like(col_norms), PW / (col_norms + eps))

    Res = w * scales.unsqueeze(0)

    Q = Res.clone().detach().requires_grad_(True)
    return Q


def horizontal_bilevel_proj_HoyerL2ball(w, level, device="cpu"):
    """
    Row-wise (horizontal) bilevel Hoyer projection.

    This function applies the bilevel_proj_HoyerL2ball to the rows of
    the input matrix by transposing the last two dimensions, calling the
    existing column-wise projection, then transposing back.

    Supports 1D vectors (falls back to proj_Hoyer) and 2D matrices or
    batched matrices with shape (..., n_rows, n_cols).
    """
    device = w.device if isinstance(w, torch.Tensor) else torch.device(device)
    w = torch.as_tensor(w, dtype=torch.get_default_dtype(), device=device)

    if w.dim() == 1:
        # For vectors, use the standard proj_Hoyer behaviour
        Q = proj_Hoyer(w, level, device=device)
        return Q

    # Transpose last two dims to convert rows -> columns
    w_t = w.transpose(-2, -1)

    # Apply the existing column-wise bilevel projection on the transposed tensor
    p_t = bilevel_proj_HoyerL2ball(w_t, level, device=device)

    # Transpose back to original orientation
    Q = p_t.transpose(-2, -1)
    return Q

def diag_bilevel_proj_HoyerL2ball(w, l):
    """
    Performs bilevel Hoyer projection after reshaping the matrix W into 
    a circulant-like matrix C, thereby enforcing diagonal sparsity.
    """
    shape = w.shape
    
    S1 = shape[-2] + shape[-1]
    S2 = S1 + 1
    
    mainsize = shape[:-2]    
    
    # 1. Reshape to circulant-like matrix C
    C = torch.zeros(*mainsize, S2, S1, device=w.device, dtype=torch.get_default_dtype())
    C[..., :shape[-2], :shape[-1]] = w
    C = C.reshape((*mainsize, S1, S2))
    
    # 2. Apply projection on the reshaped matrix
    p = bilevel_proj_HoyerL2ball(C, l, device=w.device)
    
    # 3. Un-reshape back to original matrix structure
    p = p.reshape((*mainsize, S2, S1))
    
    # Return the part corresponding to the original dimensions
    return p[..., :shape[-2], :shape[-1]]


# ==============================================================================
# 3. MODEL COMPONENTS & SPARSE ATTENTION MODULES
# ==============================================================================

class WeightedAttentionMask(nn.Module):
    """
    Custom attention module with per-head learnable masks and projection.
    """
    def __init__(self, embed_dim, num_heads, config, orig_self=None):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # create projection layers; if an original HF self-attention module is provided,
        # copy its weights to preserve pretrained behaviour for the dense phase.
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        if orig_self is not None:
            # HuggingFace BertSelfAttention uses attributes named `query`, `key`, `value`.
            try:
                if hasattr(orig_self, 'query'):
                    self.q_proj.weight.data.copy_(orig_self.query.weight.data)
                    if orig_self.query.bias is not None:
                        self.q_proj.bias.data.copy_(orig_self.query.bias.data)
                if hasattr(orig_self, 'key'):
                    self.k_proj.weight.data.copy_(orig_self.key.weight.data)
                    if orig_self.key.bias is not None:
                        self.k_proj.bias.data.copy_(orig_self.key.bias.data)
                if hasattr(orig_self, 'value'):
                    self.v_proj.weight.data.copy_(orig_self.value.weight.data)
                    if orig_self.value.bias is not None:
                        self.v_proj.bias.data.copy_(orig_self.value.bias.data)
            except Exception:
                # If copying fails for any reason, continue with random init.
                pass
        
        seq_len = config['seq_len']
        # Initialize masks with small random values so projections produce sparsity
        # (projecting a uniform matrix of ones returns ones, no sparsity)
        self.weighted_masks = nn.Parameter(torch.randn((num_heads, seq_len, seq_len)).abs() * 0.1 + 0.1)
        
        # Store initial pretrained mask values for reset
        self.initial_pretrained_masks = self.weighted_masks.data.clone().detach()
        
        self.fixed_masks = [None] * num_heads # Stores the projected (initial) sparse values after fixing
        self.fixed_binary_masks = [None] * num_heads # Binary (0/1) sparsity pattern; 1 = keep/trainable, 0 = pruned
        self.config = config
        
        
    def measureSparsity(self, eps=1e-6):
        """Compute average L0 and Hoyer sparsity across all heads."""
        l0_sparsities = []
        hoyer_sparsities = []
        
        for head_idx in range(self.num_heads):
            # Use fixed mask if available, otherwise use weighted mask
            obs_mask = self.fixed_masks[head_idx]
            if obs_mask is None:
                obs_mask = self.weighted_masks[head_idx]
            
            l0_spar = sparsity(obs_mask)
            
            # Hoyer sparsity is measured on the reshaped (diagonal) space if 'diag' is active
            if "diag" in self.config["proj"]:
                obs_mask_hoyer = reshapeDiag(obs_mask)
            else:
                obs_mask_hoyer = obs_mask

            hoyer_spar = sparsityHoyer(obs_mask_hoyer)
            
            # Ensure we append plain Python floats (move tensors to CPU first if needed)
            try:
                l0_val = float(l0_spar)
            except Exception:
                # Fallback: if it's a tensor on CUDA, move to CPU then convert
                l0_val = float(l0_spar.cpu().item()) if hasattr(l0_spar, 'cpu') else float(l0_spar)

            try:
                if isinstance(hoyer_spar, float):
                    hoyer_val = hoyer_spar
                elif hasattr(hoyer_spar, 'item'):
                    # move to cpu if needed then get scalar
                    hoyer_val = float(hoyer_spar.cpu().item()) if hasattr(hoyer_spar, 'cpu') else float(hoyer_spar.item())
                else:
                    hoyer_val = float(hoyer_spar)
            except Exception:
                # Last-resort conversion
                hoyer_val = float(hoyer_spar)

            l0_sparsities.append(l0_val)
            hoyer_sparsities.append(hoyer_val)
        
        return np.mean(l0_sparsities), np.mean(hoyer_sparsities)

    def forward(self, query, key, value, key_padding_mask=None):
        batch_size = query.size(0)

        # Normalize HF extended mask to (b, seq_len)
        if key_padding_mask is not None:
            if key_padding_mask.dim() == 4:
                # HF extended mask: shape (b, 1, 1, seq_len) or (b, 1, seq_len, seq_len)
                # Values are 0 for real tokens, large negative (e.g., -10000) for padding
                if key_padding_mask.size(1) == 1:
                    if key_padding_mask.size(2) == 1:
                        # (b, 1, 1, seq_len) -> (b, seq_len)
                        kpm = key_padding_mask.squeeze(1).squeeze(1)
                    else:
                        # (b, 1, seq_len, seq_len) -> take first row (b, 1, seq_len) -> (b, seq_len)
                        kpm = key_padding_mask[:, 0, 0, :]
                    # Convert: 0 -> 1 (real token), large negative -> 0 (padding)
                    key_padding_mask = (kpm == 0).long()
                else:
                    raise ValueError(f"Unexpected 4D attention mask shape: {key_padding_mask.shape}")
            elif key_padding_mask.dim() == 2:
                # Already (b, seq_len), assume 1=token, 0=pad
                pass
            else:
                raise ValueError(f"Unexpected attention mask shape: {key_padding_mask.shape}")

        Q = self.q_proj(query)
        K = self.k_proj(key)
        V = self.v_proj(value)

        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(self.head_dim, dtype=torch.float32, device=Q.device)
        )

        if key_padding_mask is not None:
            # key_padding_mask: (batch, seq_len) 1=token, 0=pad
            pad_mask = (key_padding_mask == 0)  # True where padding
            pad_mask = pad_mask.unsqueeze(1).unsqueeze(2)          # (b,1,1,seq_k)
            pad_mask = pad_mask.expand(batch_size, self.num_heads, scores.size(-2), scores.size(-1))
            scores = scores.masked_fill(pad_mask, float('-inf'))

        is_fixed = self.config.get('fixed', False)
        use_weighted = self.config.get('use_weighted_masks', True)

        for head_idx in range(self.num_heads):
            head_scores = scores[:, head_idx, :, :]
            if not is_fixed:
                if use_weighted:
                    head_scores = head_scores + self.weighted_masks[head_idx].unsqueeze(0)
            else:
                if self.fixed_masks[head_idx] is None:
                    self._compute_fixed_mask_for_head(head_idx)
                fixed_binary = self.fixed_binary_masks[head_idx]
                head_scores = head_scores + (self.weighted_masks[head_idx] * fixed_binary).unsqueeze(0)
                head_scores = head_scores.masked_fill(fixed_binary == 0, float('-inf'))
            scores[:, head_idx, :, :] = head_scores

        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = torch.where(
            attention_weights.isnan(),
            torch.full_like(attention_weights, 1.0 / attention_weights.shape[-1]),
            attention_weights
        )

        output = torch.matmul(attention_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        return output, attention_weights
    
    def _compute_fixed_mask_for_head(self, head_idx):
        """Compute fixed mask for a specific head using projection."""
        with torch.no_grad():
            proj_name = self.config['proj']
            l = self.config.get('l', None)
            window_size = self.config.get('window_size', None)
            head_mask = self.weighted_masks[head_idx]
            tol = 1e-6 # Threshold to binary mask for hard masking
            device = head_mask.device

            if proj_name == 'BigBird':
                # BigBird-style sparse attention mask
                seq_len = head_mask.shape[0]
                combined = torch.zeros_like(head_mask)
                
                # 1. Global attention: first row and first column
                combined[0, :] = 1.0
                combined[:, 0] = 1.0
                
                # 2. Local attention: diagonal band based on window_size
                if window_size is not None:
                    for i in range(seq_len):
                        start = max(0, i - window_size)
                        end = min(seq_len, i + window_size + 1)
                        combined[i, start:end] = 1.0
                
                # Use head_mask values where combined is 1, otherwise 0
                combined = combined * head_mask

            elif proj_name == 'CFCP':
                # Diagonal bilevel Hoyer projection
                diag_mask = diag_bilevel_proj_HoyerL2ball(head_mask, l)
                combined = diag_mask

            else:
                raise ValueError(f"Unknown proj_name: {proj_name}")

            # Ensure combined is on device
            combined = combined.to(device)
            
            # Preserve first row and first column
            combined[0, :] = head_mask[0, :]
            combined[:, 0] = head_mask[:, 0]
            # Binary pattern (1 = keep, 0 = prune)
            binary_mask = (combined.abs() > tol).float()
            
            pretrained_head_mask = self.initial_pretrained_masks[head_idx].to(device)
            self.weighted_masks.data[head_idx] = pretrained_head_mask * binary_mask
            
            self.fixed_masks[head_idx] = combined.detach()
            self.fixed_binary_masks[head_idx] = binary_mask.detach()
    
    def weightsPenality(self):
        """Compute L1 penalty on the weighted masks to enforce sparsity."""
        penalty = 0.0
        for head_idx in range(self.num_heads):
            mask = self.weighted_masks[head_idx]
            penalty += torch.sum(torch.abs(mask))
        return penalty

class SparseBERTSelfAttention(nn.Module):
    """Wrapper to integrate WeightedAttentionMask into the BERT layer structure."""
    def __init__(self, config, sparse_config, max_seq_len=MAX_SEQ_LENGTH, orig_self=None):
        super().__init__()
        sparse_config['seq_len'] = max_seq_len
        self.max_seq_len = max_seq_len
        self.max_seq_length = max_seq_len
        self.attention = WeightedAttentionMask(
            embed_dim=config.hidden_size,
            num_heads=config.num_attention_heads,
            config=sparse_config,
            orig_self=orig_self
        )

        # Store initial pretrained QKV projection weights for potential reset
        self.initial_q_weight = self.attention.q_proj.weight.data.clone().detach()
        self.initial_k_weight = self.attention.k_proj.weight.data.clone().detach()
        self.initial_v_weight = self.attention.v_proj.weight.data.clone().detach()
        if self.attention.q_proj.bias is not None:
            self.initial_q_bias = self.attention.q_proj.bias.data.clone().detach()
            self.initial_k_bias = self.attention.k_proj.bias.data.clone().detach()
            self.initial_v_bias = self.attention.v_proj.bias.data.clone().detach()
        
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        *args,
        **kwargs
    ):
        """Match HF BertSelfAttention signature; ignore cross-attention fields."""

        batch_size, seq_len, _ = hidden_states.shape

        # Prepare attention_mask to match the sequence we feed
        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_len), device=hidden_states.device, dtype=torch.long)
        
        # Extract head_mask and output_attentions from kwargs if present
        head_mask = kwargs.get('head_mask', None)
        output_attentions = kwargs.get('output_attentions', False)
        
        if seq_len == self.max_seq_length:
            attention_output, attention_weights = self.attention(
                hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask
            )
        else:
            # Handle variable sequence length
            if seq_len < self.max_seq_length:
                pad_size = self.max_seq_length - seq_len
                padded_hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size, 0, 0))
                padded_attention_mask = F.pad(attention_mask, (0, pad_size), value=0)

                padded_output, attention_weights = self.attention(
                    padded_hidden_states, padded_hidden_states, padded_hidden_states, key_padding_mask=padded_attention_mask
                )

                attention_output = padded_output[:, :seq_len, :]
            else:
                # Truncate
                truncated_hidden_states = hidden_states[:, :self.max_seq_length, :]
                truncated_attention_mask = attention_mask[:, :self.max_seq_length]
                attention_output, attention_weights = self.attention(
                    truncated_hidden_states, truncated_hidden_states, truncated_hidden_states, key_padding_mask=truncated_attention_mask
                )
        
        # Apply head_mask if provided (shape: [num_heads] or [batch, num_heads, 1, 1])
        if head_mask is not None and attention_weights is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.view(1, -1, 1, 1)
            attention_weights = attention_weights * head_mask

        if output_attentions:
            return attention_output, attention_weights
        else:
            return attention_output, None
    
    def measureSparsity(self):
        l0, hoyer = self.attention.measureSparsity()
        return l0, hoyer
    
    def weightsPenality(self):
        return self.attention.weightsPenality()

    def reset_to_pretrained(self):
        """Reset Q, K, V projection weights to their initial pretrained state."""
        with torch.no_grad():
            self.attention.q_proj.weight.data.copy_(self.initial_q_weight)
            self.attention.k_proj.weight.data.copy_(self.initial_k_weight)
            self.attention.v_proj.weight.data.copy_(self.initial_v_weight)
            if self.attention.q_proj.bias is not None:
                self.attention.q_proj.bias.data.copy_(self.initial_q_bias)
                self.attention.k_proj.bias.data.copy_(self.initial_k_bias)
                self.attention.v_proj.bias.data.copy_(self.initial_v_bias)

class SparseBERTModel(nn.Module):
    """
    BERT model with attention layers replaced by custom sparse attention.
    """
    def __init__(self, model_name=MODEL_NAME, sparse_config=None, num_labels=2, max_seq_len=MAX_SEQ_LENGTH):
        super().__init__()
        self.config = AutoConfig.from_pretrained(model_name)
        self.config.num_labels = num_labels
        self.max_seq_len = max_seq_len
        
        self.bert = AutoModel.from_pretrained(model_name)
        
        self.sparse_attentions = []
        if sparse_config is not None:
            self.sparse_config = sparse_config
            
            for layer in self.bert.encoder.layer:
                orig_self = getattr(layer.attention, 'self', None)
                sparse_attention = SparseBERTSelfAttention(
                    self.config, sparse_config, self.max_seq_len, orig_self=orig_self
                )
                layer.attention.self = sparse_attention
                self.sparse_attentions.append(sparse_attention)
        else:
            self.sparse_config = None
        
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.config.hidden_size, num_labels)

        # Store initial pretrained weights for classifier
        self.initial_classifier_weight = self.classifier.weight.data.clone().detach()
        if self.classifier.bias is not None:
            self.initial_classifier_bias = self.classifier.bias.data.clone().detach()
        
    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        """Forward accepts token_type_ids (segment ids) for sentence-pair tasks."""
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits
    
    def measureSparsity(self):
        if self.sparse_config is None or not self.sparse_attentions:
            return 0.0, 0.0
        
        total_l0, total_hoyer = 0.0, 0.0
        
        for attn in self.sparse_attentions:
            l0, hoyer = attn.measureSparsity()
            total_l0 += l0
            total_hoyer += hoyer
        
        num_layers = len(self.sparse_attentions)
        return total_l0 / num_layers, total_hoyer / num_layers
    
    def weightsPenality(self):
        if self.sparse_config is None:
            return 0.0
        
        total_penalty = sum(attn.weightsPenality() for attn in self.sparse_attentions)
        return total_penalty

    def reset_to_pretrained(self):
        """Reset all trainable components to their initial pretrained state."""
        print("Resetting model to pretrained weights...")
        with torch.no_grad():
            # Reset sparse attention layers (Q, K, V projections)
            for attn in self.sparse_attentions:
                attn.reset_to_pretrained()
            
            # Reset classifier
            self.classifier.weight.data.copy_(self.initial_classifier_weight)
            if self.classifier.bias is not None:
                self.classifier.bias.data.copy_(self.initial_classifier_bias)
        
        print("Model reset to pretrained weights complete.")

# ==============================================================================
# 4. DATASET & UTILITY FUNCTIONS
# ==============================================================================

def set_all_seeds(seed=GLOBAL_SEED):
    """Set all random seeds for reproducible results."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    set_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    # print(f"All random seeds set to {seed}")

def reset_seed(seed, fold_num):
    """Reset seeds before each fold."""
    fold_seed = seed + 3 * fold_num
    set_all_seeds(fold_seed)
    # print(f"Seeds reset for fold {fold_num + 1} with seed {fold_seed}")

def slugify(text):
    """Converts text to a file-safe string."""
    return re.sub(r'\W+', '_', text).strip('_')

class SubsetDataset(Dataset):
    """Custom subset dataset to avoid indexing issues."""
    def __init__(self, original_dataset, indices):
        self.original_dataset = original_dataset
        self.indices = [int(i) for i in indices]
        
    def __len__(self):
        return len(self.indices)
        
    def __getitem__(self, idx):
        return self.original_dataset[self.indices[idx]]

class GLUEDataset(Dataset):
    """Loads and tokenizes a GLUE task dataset"""
    def __init__(self, task_name='sst2', split='train', max_length=MAX_SEQ_LENGTH, 
                 sample_fraction=1.0, seed=GLOBAL_SEED):
        self.task_name = task_name.lower()
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        if self.task_name == 'sst2':
            # SST2 requires single sentence encoding
            dataset = load_dataset('glue', 'sst2')[split]
            self.text_col = 'sentence' # Single sentence column
            self.label_col = 'label'
            self.num_labels = 2
        elif self.task_name == 'qnli':
            # QNLI requires sentence-pair encoding: (question, sentence)
            dataset = load_dataset('glue', 'qnli')[split]
            # tokenizers expect (text1, text2)
            self.text_col = ('question', 'sentence')
            self.label_col = 'label'
            self.num_labels = 2
        elif self.task_name == 'mrpc':
            # MRPC: paraphrase detection, sentence pair (sentence1, sentence2)
            dataset = load_dataset('glue', 'mrpc')[split]
            self.text_col = ('sentence1', 'sentence2')
            self.label_col = 'label'
            self.num_labels = 2
        elif self.task_name == 'qqp':
            # QQP: Quora Question Pairs (sentence pair)
            dataset = load_dataset('glue', 'qqp')[split]
            self.text_col = ('question1', 'question2')
            self.label_col = 'label'
            self.num_labels = 2
        elif self.task_name == 'mnli':
            # MNLI: Multi-genre Natural Language Inference (3-class, premise/hypothesis)
            dataset = load_dataset('glue', 'mnli')[split]
            self.text_col = ('premise', 'hypothesis')
            self.label_col = 'label'
            self.num_labels = 3
        elif self.task_name == 'imdb':
            # IMDB sentiment dataset from local CSV
            csv_path = os.path.join(os.path.dirname(__file__), 'IMDB Dataset.csv')
            df = pd.read_csv(csv_path)
            df.columns = [c.strip().lower() for c in df.columns]
            if 'review' not in df.columns or 'sentiment' not in df.columns:
                raise ValueError("IMDB CSV must have 'review' and 'sentiment' columns")

            # Map sentiment strings to integer labels
            df['label'] = df['sentiment'].str.lower().map({'positive': 1, 'negative': 0})
            if df['label'].isna().any():
                raise ValueError("IMDB sentiment labels must be 'positive' or 'negative'")

            # Light cleanup of HTML breaks to keep tokenizer happier
            df['review'] = df['review'].astype(str).str.replace('<br />', ' ', regex=False)

            if sample_fraction < 1.0:
                df = df.sample(frac=sample_fraction, random_state=seed).reset_index(drop=True)
                print(f"Sampled {sample_fraction*100}% of IMDB dataset ({len(df)} samples)")

            self.text_col = 'review'
            self.label_col = 'label'
            self.num_labels = 2
            dataset = df[[self.text_col, self.label_col]].to_dict(orient='records')
        elif self.task_name == 'ths':
            # Twitter Hate Speech dataset from Hugging Face
            dataset = load_dataset('tdavidson/hate_speech_offensive', split=split)
            
            # Map the dataset: 'tweet' column for text, 'label' column already exists
            # The tdavidson/hate_speech_offensive dataset has labels: 0=hate speech, 1=offensive language, 2=neither
            # We map: 0 or 1 -> hate speech (label=1), 2 -> not hate speech (label=0)
            def map_label(example):
                example['class'] = 1 if example['class'] in [0, 1] else 0
                return example
            
            dataset = dataset.map(map_label)
            
            if sample_fraction < 1.0:
                sample_size = int(len(dataset) * sample_fraction)
                # Create a shuffled subset
                indices = list(range(len(dataset)))
                np.random.seed(seed)
                np.random.shuffle(indices)
                dataset = dataset.select(indices[:sample_size])
                print(f"Sampled {sample_fraction*100}% of THS dataset ({len(dataset)} samples)")
            
            self.text_col = 'tweet'
            self.label_col = 'class'
            self.num_labels = 2
            
            # Convert to list of dicts for consistency
            dataset = [{'tweet': item['tweet'], 'class': item['class']} for item in dataset]
            print(f"Label distribution: Class 0 (hate): {sum(1 for d in dataset if d['class'] == 1)}, Class 1 (not hate): {sum(1 for d in dataset if d['class'] == 0)}")
        elif self.task_name in ['dt']:
            # Disaster Tweets dataset from local CSV: expects columns 'text' and 'target'
            csv_path = os.path.join(os.path.dirname(__file__), 'disaster_tweets.csv')
            df = pd.read_csv(csv_path)
            df.columns = [c.strip().lower() for c in df.columns]
            if 'text' not in df.columns or 'target' not in df.columns:
                raise ValueError("disaster_tweets.csv must have 'text' and 'target' columns")

            df = df[['text', 'target']].dropna()
            df['text'] = df['text'].astype(str)
            # Ensure labels are ints in {0,1}
            try:
                df['target'] = df['target'].astype(int)
            except Exception:
                # Fallback: coerce non-numeric to 0
                df['target'] = pd.to_numeric(df['target'], errors='coerce').fillna(0).astype(int)
            df['target'] = df['target'].clip(lower=0, upper=1)

            if sample_fraction < 1.0:
                df = df.sample(frac=sample_fraction, random_state=seed).reset_index(drop=True)
                print(f"Sampled {sample_fraction*100}% of Disaster Tweets dataset ({len(df)} samples)")

            self.text_col = 'text'
            self.label_col = 'target'
            self.num_labels = 2
            dataset = df[[self.text_col, self.label_col]].to_dict(orient='records')
            # Simple label distribution
            try:
                pos = int((df['target'] == 1).sum())
                neg = int((df['target'] == 0).sum())
                print(f"Label distribution: target=1: {pos}, target=0: {neg}")
            except Exception:
                pass
        else:
             raise ValueError(f"Task {task_name} not supported by this focused pipeline.")
        
        if sample_fraction < 1.0 and self.task_name not in ['imdb', 'ths', 'dt']:
            sample_size = int(len(dataset) * sample_fraction)
            dataset = dataset.shuffle(seed=seed).select(range(sample_size))
            print(f"Sampled {sample_fraction*100}% of {task_name} dataset ({len(dataset)} samples)")
        
        self.data = dataset
        print(f"Loaded {task_name} {split}: {len(dataset)} samples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if isinstance(idx, (np.integer, int)):
            idx = int(idx)
        else:
            idx = idx.item()
            
        item = self.data[idx]
        
        if isinstance(self.text_col, str):
            # SST2: single sentence
            text = str(item[self.text_col])
            encoding = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
        else:
            
            text1, text2 = str(item[self.text_col[0]]), str(item[self.text_col[1]])
            encoding = self.tokenizer(
                text1, text2,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
        
        label = int(item[self.label_col])

        result = {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

        # Include token_type_ids when available
        if 'token_type_ids' in encoding:
            result['token_type_ids'] = encoding['token_type_ids'].squeeze()

        return result

def compute_metrics(eval_pred):
    """Calculates evaluation metrics (accuracy, F1, precision, recall).

    Uses binary F1 for binary classification tasks to align with GLUE MRPC/XNLI style,
    otherwise weighted for multi-class.
    """
    predictions, labels = eval_pred

    if predictions.ndim == 2:
        predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, predictions)
    unique_labels = np.unique(labels)
    if len(unique_labels) == 2:
        f1 = f1_score(labels, predictions, average='binary', zero_division=0)
        precision = precision_score(labels, predictions, average='binary', zero_division=0)
        recall = recall_score(labels, predictions, average='binary', zero_division=0)
    else:
        f1 = f1_score(labels, predictions, average='weighted', zero_division=0)
        precision = precision_score(labels, predictions, average='weighted', zero_division=0)
        recall = recall_score(labels, predictions, average='weighted', zero_division=0)

    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }


def save_attention_masks(model, l, epoch, base_dir=MASKS_DIR, task_name=None, fold_num=None, proj='na'):
    """Saves attention masks for each sparse attention layer and each head as PNG files.

    Files are written to: {base_dir}/eta_{l}/mask_layer{L}_head{H}_epoch{E}.png
    If fixed masks exist (in WeightedAttentionMask.fixed_masks), they are also saved
    with suffix `_fixed`.
    Masks are saved as binary images (black and white).
    """
    try:
        eta_str = str(l)
    except Exception:
        eta_str = 'na'

    base_dir = base_dir + f"_{task_name}_{proj}/" if task_name else base_dir
    out_dir = os.path.join(base_dir, f"eta_{eta_str}")
    os.makedirs(out_dir, exist_ok=True)

    # If model doesn't have sparse_attentions or it's empty, nothing to save
    if not hasattr(model, 'sparse_attentions') or not model.sparse_attentions:
        return

    tol = 1e-20  # Threshold for binary conversion

    for layer_idx, attn_wrapper in enumerate(model.sparse_attentions):
        if layer_idx == 0:
            # attn_wrapper is SparseBERTSelfAttention; its attention attr is WeightedAttentionMask
            try:
                wam = attn_wrapper.attention
            except Exception:
                continue

            # Save the learnable weighted masks as binary
            try:
                masks = wam.weighted_masks.detach().cpu().numpy()
            except Exception:
                continue

            num_heads = masks.shape[0]
            for head_idx in range(num_heads):
                mask = masks[head_idx]
                # Convert to binary: 1 where abs(mask) > tol, 0 otherwise
                # Create RGB image using custom colors: background=#440154, foreground=#f6e023
                binary_mask_bool = (np.abs(mask) > tol)
                bg_color = np.array([0x3B, 0x4C, 0xC0], dtype=np.uint8)   # (68, 1, 84) rgba(37, 150, 190)
                # fg_color = np.array([0x25, 0x00, 0x00], dtype=np.uint8)   # (246, 224, 35) rgb(164, 0, 0)
                fg_color = np.array([0xF7, 0xB5, 0x98], dtype=np.uint8)   # (246, 224, 35) rgb(164, 0, 0)


                binary_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
                binary_mask[~binary_mask_bool] = bg_color
                binary_mask[binary_mask_bool] = fg_color
                fname = os.path.join(out_dir, f"mask_head{head_idx}_epoch{epoch+1}.png")
                try:
                    plt.figure(figsize=(4, 4))
                    plt.imshow(binary_mask, cmap='gray', vmin=0, vmax=255, interpolation='nearest')
                    plt.title(f"{task_name} Mask, l={eta_str}")
                    # plt.axis('off')
                    plt.savefig(fname, bbox_inches='tight', pad_inches=0.1)
                    plt.close()
                except Exception:
                    try:
                        plt.close('all')
                    except Exception:
                        pass

def train_and_evaluate_bert(config, task_name, k_folds, num_epochs, batch_size, learning_rate, sample_fraction, seed):
    """
    Main training and evaluation loop using K-Fold Cross-Validation.
    """
    set_all_seeds(seed)

    dataset = GLUEDataset(task_name=task_name, split='train', sample_fraction=sample_fraction, seed=seed)
    
    print(f"Using device: {DEVICE}")
    
    if task_name in ['mrpc', 'ths', 'dt', 'qqp', 'mnli']:
        # Use stratified splits to preserve label ratio across folds (reduces variance)
        kf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=seed)
    else:
        kf = KFold(n_splits=k_folds, shuffle=True, random_state=seed)
    
    # Extract labels for stratification
    labels = [dataset.data[i][dataset.label_col] for i in range(len(dataset))]
    
    fold_results = []
    fold_sparsities_l0 = []
    fold_sparsities_hoyer = []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(range(len(dataset)), labels)):
        print(f"\nFold {fold+1}/{k_folds}")
        reset_seed(seed, fold)
        start_time = time.time()
        
        train_subset = SubsetDataset(dataset, train_idx)
        val_subset = SubsetDataset(dataset, val_idx)
        
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
        
        sparse_config_copy = config.copy()
        sparse_config_copy['seq_len'] = MAX_SEQ_LENGTH
        sparse_config_copy['use_weighted_masks'] = False  # Disable weighted masks for dense configs

        model = SparseBERTModel(
            sparse_config=sparse_config_copy,
            num_labels=dataset.num_labels,
            max_seq_len=MAX_SEQ_LENGTH
        ).to(DEVICE)
        
        
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        total_steps = len(train_loader) * num_epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0.1 * total_steps,
            num_training_steps=total_steps
        )
        
        loss_fn = nn.CrossEntropyLoss()
        
        # Training loop
        for epoch in range(num_epochs):
            model.train()
            epoch_loss = 0.0
            num_batches = 0
            
            # Check if we should fix masks (Projection is applied)
            if epoch == int(num_epochs * MASK_FIX_EPOCH_RATIO):
                if sparse_config_copy:
                    sparse_config_copy['fixed'] = True
                    sparse_config_copy['use_weighted_masks'] = True
                    print(f"Fixed sparse masks (Projection) at epoch {epoch+1}")
                    # Force computation of fixed masks for all heads/layers so they can be saved
                    try:
                        with torch.no_grad():
                            for attn in model.sparse_attentions:
                                wam = attn.attention
                                for head_idx in range(wam.num_heads):
                                    # compute and store fixed mask for this head
                                    wam._compute_fixed_mask_for_head(head_idx)
                    except Exception as e:
                        # If eager computation fails, we will compute fixed masks lazily during forward/inference
                        print(f"Warning: eager fixed-mask computation failed: {e}")
                    

                    proj_type = sparse_config_copy.get('proj', '')
                    if 'BigBird' not in proj_type:
                        model.reset_to_pretrained()
                        print("Model reset to pretrained weights after mask fixing")
                    else:
                        print(f"Skipping weight reset for {proj_type} - preserving learned weights")
                    
                    # Recreate optimizer and scheduler with reset weights
                    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
                    remaining_steps = len(train_loader) * (num_epochs - epoch)
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=0.1 * remaining_steps,
                        num_training_steps=remaining_steps
                    )
                    print(f"Optimizer and scheduler reset for remaining {num_epochs - epoch} epochs")
                    
                    # Save masks immediately after fixing
                    try:
                        save_attention_masks(model, sparse_config_copy.get('l', sparse_config_copy.get('window_size', 'na')),
                                            epoch, base_dir=MASKS_DIR, task_name=task_name, fold_num=fold+1, proj= sparse_config_copy.get('proj', 'na'))
                    except Exception:
                        pass


            
            for batch in train_loader:
                optimizer.zero_grad()
                
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                labels = batch['labels'].to(DEVICE)
                # Include token_type_ids when available
                token_type_ids = batch['token_type_ids'].to(DEVICE) if 'token_type_ids' in batch else None

                logits = model(input_ids, attention_mask, token_type_ids=token_type_ids)
                loss_without_penalty = loss_fn(logits, labels)
                
                # Apply sparsity penalty only after masks have been fixed (projection)
                penalty = 0.0
                # Only apply mask L1 penalty after sparse phase starts
                if sparse_config_copy and sparse_config_copy.get('use_weighted_masks', False):
                    penalty = model.weightsPenality()
                    lam_val = 0.01
                    penalty *= lam_val

                total_loss = loss_without_penalty + penalty
                
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                
                epoch_loss += loss_without_penalty.item()
                num_batches += 1
            
            avg_loss = epoch_loss / num_batches
            
            # Simple validation check
            model.eval()
            val_predictions = []
            val_labels = []
            with torch.no_grad():
                for batch in val_loader:
                    inp = batch['input_ids'].to(DEVICE)
                    attn = batch['attention_mask'].to(DEVICE)
                    ttype = batch['token_type_ids'].to(DEVICE) if 'token_type_ids' in batch else None
                    logits = model(inp, attn, token_type_ids=ttype)
                    predictions = torch.argmax(logits, dim=-1)
                    val_predictions.extend(predictions.cpu().numpy())
                    val_labels.extend(batch['labels'].cpu().numpy())
            
            val_acc = accuracy_score(val_labels, val_predictions)
            try:
                from sklearn.metrics import confusion_matrix
                cm = confusion_matrix(val_labels, val_predictions)
                pos_pred_count = int((np.array(val_predictions) == 1).sum())
                pos_true_count = int((np.array(val_labels) == 1).sum())
                # print(f"\nVal size={len(val_labels)}, pos_true={pos_true_count}, pos_pred={pos_pred_count}, ConfMat=\n{cm}")
            except Exception:
                pass
            metrics_epoch = compute_metrics((np.array(val_predictions), np.array(val_labels)))
            f1_epoch = metrics_epoch.get('f1', 0.0)


            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {f1_epoch:.4f}")

        
        # Final Evaluation for the fold
        model.eval()
        val_predictions = []
        val_labels = []
        

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                labels = batch['labels'].to(DEVICE)

                # token_type_ids for sentence-pair inputs
                token_type_ids = batch['token_type_ids'].to(DEVICE) if 'token_type_ids' in batch else None

                logits = model(input_ids, attention_mask, token_type_ids=token_type_ids)
                predictions = torch.argmax(logits, dim=-1)

                val_predictions.extend(predictions.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
                
        metrics = compute_metrics((np.array(val_predictions), np.array(val_labels)))
        l0_spar, hoyer_spar = model.measureSparsity()
        
        fold_results.append(metrics)
        fold_sparsities_l0.append(l0_spar)
        fold_sparsities_hoyer.append(hoyer_spar)
        
        fold_time = time.time() - start_time
        print(f"Fold {fold+1} Finished (Time: {fold_time:.2f}s): Acc={metrics['accuracy']:.4f}, L0 Sparsity={l0_spar:.4f}, Hoyer={hoyer_spar:.6f}")
    
    # Aggregate results
    avg_metrics = {key: np.mean([result[key] for result in fold_results]) for key in fold_results[0].keys()}
    avg_l0_sparsity = np.mean(fold_sparsities_l0)
    avg_hoyer_sparsity = np.mean(fold_sparsities_hoyer)
    
    return avg_metrics, avg_l0_sparsity, avg_hoyer_sparsity

    
# ==============================================================================
# 5. MAIN EXECUTION PIPELINE
# ==============================================================================

def run_CFCP_benchmark(seed=GLOBAL_SEED):
    """
    Executes the benchmark across all defined CFCP configurations and tasks.
    """
    set_all_seeds(seed)
    
    results = []
    print(f"Starting Benchmark on {TASKS}...")
    print(f"Model: {MODEL_NAME}")
    print ("_"*60)
    print(f"Total Configurations: {len(CFCP_BENCHMARK_CONFIGS)}")
    print(f"Sequence Length: {MAX_SEQ_LENGTH}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Epochs: {NUM_EPOCHS}")
    print(f"K-Folds: {K_FOLDS}")
    print(f"Learning Rate: {LEARNING_RATE}\n")
    # print ("_"*60)
    
    for config in CFCP_BENCHMARK_CONFIGS:
        for task in TASKS:
            config_name = get_config_name(config)
            print(f"\n{'='*60}")
            if 'cfcp' in config['proj'].lower():
                print(f"Experiment: {config_name} on {task} (l={config.get('l', 'na')})")
            else:
                print(f"Experiment: {config_name} on {task} (window_size={config.get('window_size', 'na')})")
            print(f"{'-'*60}")
            
            try:
                metrics, l0_sparsity, hoyer_sparsity = train_and_evaluate_bert(
                    config=config,
                    task_name=task,
                    k_folds=K_FOLDS,
                    num_epochs=NUM_EPOCHS,
                    batch_size=BATCH_SIZE,
                    learning_rate=LEARNING_RATE,
                    sample_fraction=SAMPLE_FRACTION,
                    seed=seed
                )
                # print(f"Metrics: {metrics}")
                results.append({
                    'Model': config_name,
                    'Task': task,
                    'l': config.get('l', config.get('window_size', 'na')),
                    'Accuracy': metrics['accuracy'],
                    'F1': metrics['f1'],
                    'L0_Sparsity': l0_sparsity,
                    'Hoyer_Sparsity': hoyer_sparsity
                })
                
            except Exception as e:
                print(f"Error running {config_name} on {task}: {e}")
                import traceback
                traceback.print_exc()
                results.append({
                    'Model': config_name, 'Task': task, 'l': config.get('l', config.get('window_size', 'na')),
                    'Accuracy': 0.0, 'F1': 0.0, 'L0_Sparsity': 0.0, 'Hoyer_Sparsity': 0.0
                })
    
    results_df = pd.DataFrame(results)
    output_filename = f'CFCP_benchmark_results_{TASKS[0].lower()}.csv'
    results_df.to_csv(output_filename, index=False)
    print(f"\n{'='*80}")
    print(f"Benchmark finished. Results saved to {output_filename}")
    
    print("\nCFCP Benchmark Results Summary:")
    print(results_df.round(4).to_string())
    print(f"{'='*80}")
    
    return results_df

# ==============================================================================
# 6. EXECUTION
# ==============================================================================

if __name__ == "__main__":
    run_CFCP_benchmark(seed=GLOBAL_SEED)
