from omegaconf import DictConfig, ListConfig
from tqdm import tqdm
import torch.nn.functional as F

import clip.clip as clip
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast

from .utils import get_class_ids_per_task, get_class_names, batch, merge_we_router, wise_we, moving_avg, l2_loss, \
    virtual_vocab, distillation
import copy

from .cc import conceptual_captions

from . import utils
import os
import random

from .dynamic_dataset import DynamicDataset


class DFAModule(nn.Module):
    """Three-expert fusion with hierarchical routing.

    - Expert1 (E1): generalization expert (AFFA-like), trained first per task.
    - Expert2a/2b (E2a/E2b): task-specialized experts, trained in Stage B.
    - Router_top: mixes E1 vs fused(E2a,E2b).
    - Router_e2: mixes E2a vs E2b (moe-w style per-sample split on image features).
    """
    def __init__(self, embed_dim: int, r1: int = 16, r2: int = 8, router_hidden: int = 0, beta: float = 1.0,
                 e2_router_mode: str = 'moe_w', e2_top_k: int = 1, single_router_3way: bool = False,
                 adapter_dropout: float = 0.1, adapter_scalar: float = 0.1, num_task_experts: int = 2):
        super().__init__()
        self.embed_dim = int(embed_dim)
        self.beta = float(beta)
        self.e2_router_mode = str(e2_router_mode).lower()
        self.e2_top_k = max(1, int(e2_top_k))
        self.single_router_3way = bool(single_router_3way)
        self.adapter_dropout = float(adapter_dropout)
        self.adapter_scalar = float(adapter_scalar)
        self.num_task_experts = max(1, int(num_task_experts))
        # When True, Expert1 acts as identity (e1 = x). Used when Stage A is skipped to match original CLIP.
        self.e1_identity: bool = False
        # Expert1
        self.e1 = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, r1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=self.adapter_dropout),
            nn.Linear(r1, self.embed_dim),
        )
        # Task-specific experts list (variable count)
        self.e2_list = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(self.embed_dim),
                nn.Linear(self.embed_dim, r2),
                nn.ReLU(inplace=True),
                nn.Dropout(p=self.adapter_dropout),
                nn.Linear(r2, self.embed_dim),
            ) for _ in range(self.num_task_experts)
        ])
        # Ensure untrained experts are identity residuals: e(x) = x + 0
        # Zero-init the output projection so initial residual = 0
        with torch.no_grad():
            if isinstance(self.e1[-1], nn.Linear):
                nn.init.zeros_(self.e1[-1].weight)
                nn.init.zeros_(self.e1[-1].bias)
            for e2 in self.e2_list:
                if isinstance(e2[-1], nn.Linear):
                    nn.init.zeros_(e2[-1].weight)
                    nn.init.zeros_(e2[-1].bias)
        # Routers
        if self.single_router_3way:
            # Single 3-way router over [E1, E2a, E2b]
            if router_hidden and int(router_hidden) > 0:
                self.router3 = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.embed_dim, int(router_hidden)),
                    nn.ReLU(inplace=True),
                    nn.Linear(int(router_hidden), 3),
                )
            else:
                self.router3 = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.embed_dim, 3),
                )
        else:
            # Top router: E1 vs (E2a/E2b)
            if router_hidden and int(router_hidden) > 0:
                self.router = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.embed_dim, int(router_hidden)),
                    nn.ReLU(inplace=True),
                    nn.Linear(int(router_hidden), 2),
                )
            else:
                self.router = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.embed_dim, 2),
                )
        # Specific router: among E2 experts (disabled when single_router_3way)
        if not self.single_router_3way:
            if self.e2_router_mode == 'moe_w':
                # moe-w style: learnable linear gates with optional noise, top-k selection
                self.w_gate_e2 = nn.Linear(self.embed_dim, self.num_task_experts, bias=False)
                self.w_noise_e2 = nn.Linear(self.embed_dim, self.num_task_experts, bias=False)
            else:
                if router_hidden and int(router_hidden) > 0:
                    self.router_e2 = nn.Sequential(
                        nn.LayerNorm(self.embed_dim),
                        nn.Linear(self.embed_dim, int(router_hidden)),
                        nn.ReLU(inplace=True),
                        nn.Linear(int(router_hidden), self.num_task_experts),
                    )
                else:
                    self.router_e2 = nn.Sequential(
                        nn.LayerNorm(self.embed_dim),
                        nn.Linear(self.embed_dim, self.num_task_experts),
                    )

    def _noisy_topk_gating_e2(self, x: torch.Tensor, train: bool = True, noise_epsilon: float = 1e-2) -> torch.Tensor:
        """moe-w style Noisy Top-k gating over variable number of E2 experts.
        Returns gates of shape [B, E] with zeros outside top-k.
        """
        clean_logits = self.w_gate_e2(x)
        if train:
            raw_noise_stddev = self.w_noise_e2(x)
            noise_std = F.softplus(raw_noise_stddev) + float(noise_epsilon)
            logits = clean_logits + torch.randn_like(clean_logits) * noise_std
        else:
            logits = clean_logits
        E = logits.size(-1)
        k = min(int(self.e2_top_k), int(E))
        top_logits, top_idx = logits.topk(k, dim=1)
        top_w = torch.softmax(top_logits, dim=1)
        # Align dtypes to avoid scatter() dtype mismatch under AMP
        if top_w.dtype != logits.dtype:
            top_w = top_w.to(dtype=logits.dtype)
        gates = torch.zeros_like(logits, dtype=logits.dtype)
        gates.scatter_(1, top_idx, top_w)
        return gates

    def forward(self, x: torch.Tensor, route: bool = True, weights=None) -> torch.Tensor:
        # residual adapters
        if getattr(self, 'e1_identity', False):
            e1 = x
        else:
            e1 = x + self.adapter_scalar * self.e1(x)
        e2_outs = [x + self.adapter_scalar * m(x) for m in self.e2_list]

        if route:
            if getattr(self, 'single_router_3way', False):
                w3 = torch.softmax(self.router3(x.float()), dim=-1)  # [B,3]
            else:
                w_top = torch.softmax(self.router(x.float()), dim=-1)  # [B,2]
                if self.e2_router_mode == 'moe_w':
                    w_spec = self._noisy_topk_gating_e2(x.float(), train=self.training)
                else:
                    w_spec = torch.softmax(self.router_e2(x.float()), dim=-1)
        else:
            # manual weights
            if isinstance(weights, dict):
                if getattr(self, 'single_router_3way', False):
                    w3 = weights.get('w3', None)
                else:
                    w_top = weights.get('top', None)
                    w_spec = weights.get('spec', None)
            else:
                if getattr(self, 'single_router_3way', False):
                    w3 = weights
                else:
                    w_top = weights
                    w_spec = None
            if getattr(self, 'single_router_3way', False):
                if w3 is None:
                    w3 = torch.tensor([1/3, 1/3, 1/3], device=x.device, dtype=x.dtype).view(1, 3).expand(x.size(0), 3)
            else:
                if w_top is None:
                    w_top = torch.tensor([0.5, 0.5], device=x.device, dtype=x.dtype).view(1, 2).expand(x.size(0), 2)
                if w_spec is None:
                    E = max(1, self.num_task_experts)
                    w_spec = torch.full((1, E), 1.0 / float(E), device=x.device, dtype=x.dtype).expand(x.size(0), E)

        if getattr(self, 'single_router_3way', False):
            # 3-way fusion over [E1, E2a, E2b]; use first two task experts when available
            e2a = e2_outs[0] if len(e2_outs) > 0 else (0.0 * e1)
            e2b = e2_outs[1] if len(e2_outs) > 1 else (0.0 * e1)
            fused = w3[:, 0:1] * e1 + w3[:, 1:2] * e2a + w3[:, 2:3] * e2b
            return self.beta * fused
        else:
            # Weighted fusion over all task experts
            e_stack = torch.stack(e2_outs, dim=2)  # [B, D, E]
            e_spec = (e_stack * w_spec.unsqueeze(-2)).sum(dim=2)
            fused = w_top[:, 0:1] * e1 + w_top[:, 1:2] * e_spec
            return self.beta * fused


class DFABlock(nn.Module):
    """Per-block task expert adapters (E2a/E2b) with moe-w style router (no top E1 branch).
    Inserted after a visual transformer block to add residual: y = y + adapter_scalar * fused(y).
    """
    def __init__(self, embed_dim: int, r2: int = 64, adapter_dropout: float = 0.1, adapter_scalar: float = 0.1, e2_top_k: int = 2, num_task_experts: int = 2):
        super().__init__()
        self.embed_dim = int(embed_dim)
        self.adapter_scalar = float(adapter_scalar)
        self.e2_top_k = max(1, int(e2_top_k))
        self.num_task_experts = max(1, int(num_task_experts))
        # Experts
        self.e2_list = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(self.embed_dim),
                nn.Linear(self.embed_dim, r2),
                nn.ReLU(inplace=True),
                nn.Dropout(p=float(adapter_dropout)),
                nn.Linear(r2, self.embed_dim),
            ) for _ in range(self.num_task_experts)
        ])
        with torch.no_grad():
            for e2 in self.e2_list:
                if isinstance(e2[-1], nn.Linear):
                    nn.init.zeros_(e2[-1].weight)
                    nn.init.zeros_(e2[-1].bias)
        # moe-w style gates
        self.w_gate_e2 = nn.Linear(self.embed_dim, self.num_task_experts, bias=False)
        self.w_noise_e2 = nn.Linear(self.embed_dim, self.num_task_experts, bias=False)

    def _noisy_topk(self, x: torch.Tensor, train: bool = True, noise_epsilon: float = 1e-2) -> torch.Tensor:
        clean = self.w_gate_e2(x)  # shape [..., E]
        if train:
            raw = self.w_noise_e2(x)
            noise_std = F.softplus(raw) + float(noise_epsilon)
            logits = clean + torch.randn_like(clean) * noise_std
        else:
            logits = clean
        E = logits.size(-1)
        k = min(int(self.e2_top_k), int(E))
        top_logits, top_idx = logits.topk(k, dim=-1)
        top_w = torch.softmax(top_logits, dim=-1)
        if top_w.dtype != logits.dtype:
            top_w = top_w.to(dtype=logits.dtype)
        gates = torch.zeros_like(logits, dtype=logits.dtype)
        gates.scatter_(-1, top_idx, top_w)
        return gates

    def forward(self, y: torch.Tensor) -> torch.Tensor:
        # Per-sample routing between E2a/E2b
        w_spec = self._noisy_topk(y.float(), train=self.training)  # [...,2]
        # Stack expert outputs as [..., D, E] and weight by w_spec[..., E]
        e2_outs = [m(y) for m in self.e2_list]
        e_stack = torch.stack(e2_outs, dim=-1)  # [..., D, E]
        r = (e_stack * w_spec.unsqueeze(-2)).sum(dim=-1)
        return y + self.adapter_scalar * r


class DFA3Block(nn.Module):
    """Per-block three-expert adapters with hierarchical routing (top + E2-spec).
    Experts: E1 (general), E2a/E2b (specialized). Routers: router_top (E1 vs E2),
    and moe-w style E2 router (E2a vs E2b). All computed from the CLS token and
    broadcast to tokens. Output is residual added by caller.
    """
    def __init__(self, embed_dim: int, r1: int = 16, r2: int = 64, adapter_dropout: float = 0.1,
                 adapter_scalar: float = 0.1, e2_top_k: int = 2, router_hidden: int = 0,
                 single_router_3way: bool = False, num_task_experts: int = 2):
        super().__init__()
        self.embed_dim = int(embed_dim)
        self.adapter_scalar = float(adapter_scalar)
        self.e2_top_k = max(1, int(e2_top_k))
        self.single_router_3way = bool(single_router_3way)
        self.num_task_experts = max(1, int(num_task_experts))
        # E1 general adapter
        self.e1 = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, int(r1)),
            nn.ReLU(inplace=True),
            nn.Dropout(p=float(adapter_dropout)),
            nn.Linear(int(r1), self.embed_dim),
        )
        # E2 adapters (variable count)
        self.e2_list = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(self.embed_dim),
                nn.Linear(self.embed_dim, int(r2)),
                nn.ReLU(inplace=True),
                nn.Dropout(p=float(adapter_dropout)),
                nn.Linear(int(r2), self.embed_dim),
            ) for _ in range(self.num_task_experts)
        ])
        with torch.no_grad():
            # zero-init last proj to start as identity residual
            if isinstance(self.e1[-1], nn.Linear):
                nn.init.zeros_(self.e1[-1].weight)
                nn.init.zeros_(self.e1[-1].bias)
            for mod in self.e2_list:
                if isinstance(mod[-1], nn.Linear):
                    nn.init.zeros_(mod[-1].weight)
                    nn.init.zeros_(mod[-1].bias)
        # Routers
        if self.single_router_3way:
            if int(router_hidden) > 0:
                self.router3 = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.embed_dim, int(router_hidden)),
                    nn.ReLU(inplace=True),
                    nn.Linear(int(router_hidden), 3),
                )
            else:
                self.router3 = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.embed_dim, 3),
                )
            # still define attributes for compatibility
            self.router_top = nn.Identity()
            self.w_gate_e2 = nn.Linear(self.embed_dim, self.num_task_experts, bias=False)
            self.w_noise_e2 = nn.Linear(self.embed_dim, self.num_task_experts, bias=False)
        else:
            if int(router_hidden) > 0:
                self.router_top = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.embed_dim, int(router_hidden)),
                    nn.ReLU(inplace=True),
                    nn.Linear(int(router_hidden), 2),
                )
            else:
                self.router_top = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.embed_dim, 2),
                )
            self.w_gate_e2 = nn.Linear(self.embed_dim, self.num_task_experts, bias=False)
            self.w_noise_e2 = nn.Linear(self.embed_dim, self.num_task_experts, bias=False)

    def _noisy_topk_cls(self, cls_x: torch.Tensor, train: bool = True, noise_epsilon: float = 1e-2) -> torch.Tensor:
        clean = self.w_gate_e2(cls_x)
        if train:
            raw = self.w_noise_e2(cls_x)
            noise_std = F.softplus(raw) + float(noise_epsilon)
            logits = clean + torch.randn_like(clean) * noise_std
        else:
            logits = clean
        E = logits.size(-1)
        k = min(int(self.e2_top_k), int(E))
        top_logits, top_idx = logits.topk(k, dim=-1)
        top_w = torch.softmax(top_logits, dim=-1)
        if top_w.dtype != logits.dtype:
            top_w = top_w.to(dtype=logits.dtype)
        gates = torch.zeros_like(logits, dtype=logits.dtype)
        gates.scatter_(-1, top_idx, top_w)
        return gates  # [N,E]

    def forward(self, pre_ffn: torch.Tensor) -> torch.Tensor:
        # pre_ffn: [L, N, D]
        # Stage-A behavior: only E1 residual, no routing
        if getattr(self, 'e1_only', False):
            return self.adapter_scalar * self.e1(pre_ffn)
        # Normal behavior
        cls_feat = pre_ffn.permute(1, 0, 2)[:, 0, :]
        e1_res = self.adapter_scalar * self.e1(pre_ffn)
        e2_res_list = [self.adapter_scalar * m(pre_ffn) for m in self.e2_list]
        if self.single_router_3way:
            w3 = torch.softmax(self.router3(cls_feat.float()), dim=-1)  # [N,3]
            w3_b = w3.unsqueeze(0).unsqueeze(-1)  # [1,N,3,1]
            # For single_router_3way, support only first two E2 experts to keep 3-way
            e2a_res = e2_res_list[0] if len(e2_res_list) > 0 else 0.0 * e1_res
            e2b_res = e2_res_list[1] if len(e2_res_list) > 1 else 0.0 * e1_res
            res_stack = torch.stack([e1_res, e2a_res, e2b_res], dim=2)  # [L,N,3,D]
            res = (res_stack * w3_b).sum(dim=2)
            return res
        else:
            # Optional debug override for top router weights
            if hasattr(self, 'debug_w_top') and self.debug_w_top is not None:
                wt = self.debug_w_top
                if wt.dim() == 1:
                    # [2] -> [N,2]
                    w_top = wt.view(1, 2).expand(cls_feat.size(0), 2)
                else:
                    # assume [N,2]
                    w_top = wt
                # ensure dtype/device match
                w_top = w_top.to(device=cls_feat.device, dtype=cls_feat.dtype)
            else:
                w_top = torch.softmax(self.router_top(cls_feat.float()), dim=-1)  # [N,2]
            w_spec = self._noisy_topk_cls(cls_feat.float(), train=self.training)  # [N,E]
            # Store router weights for statistics collection
            self._last_w_top = w_top.detach()  # [N, 2]
            self._last_w_spec = w_spec.detach()  # [N, E]
            w_top_b = w_top.unsqueeze(0).unsqueeze(-1)
            w_spec_b = w_spec.unsqueeze(0).unsqueeze(-1)
            e2_stack = torch.stack(e2_res_list, dim=2)
            e2_res = (e2_stack * w_spec_b).sum(dim=2)
            res_stack = torch.stack([e1_res, e2_res], dim=2)
            res = (res_stack * w_top_b).sum(dim=2)
            return res


class VisualBlockWithDFA(nn.Module):
    def __init__(self, base_block: nn.Module, dfa_block: nn.Module):
        super().__init__()
        self.base = base_block
        self.dfa = dfa_block

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.base(x)
        y = self.dfa(y)
        return y


class ResidualAttentionBlockWithDFA(nn.Module):
    """Wrap a CLIP ResidualAttentionBlock and inject DFABlock output at FFN segment
    (aligning with moe-w): x = x + Attn(ln1(x)); x = x + MLP(ln2(x)) + DFA(x_pre_ffn).

    Gating follows moe-w: compute gates from CLS token (first sequence token),
    then broadcast to all tokens in the sequence.
    """
    def __init__(self, base_block: nn.Module, dfa_block: nn.Module):
        super().__init__()
        self.base = base_block
        self.dfa = dfa_block

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is shape [L, N, D]
        # Attention residual
        x = x + self.base.attention(self.base.ln_1(x))
        # Pre-FFN representation
        pre_ffn = x
        # Compute adapter residual depending on block type
        if hasattr(self.dfa, 'router_top') and hasattr(self.dfa, 'e1'):
            # DFA3Block path: hierarchical routing E1/E2 inside block
            adapter_out = self.dfa(pre_ffn)
            # FFN residual (original block path)
            x = x + self.base.mlp(self.base.ln_2(x)) + adapter_out
            return x
        else:
            # Legacy DFABlock path: E2 only with CLS gating
            cls_feat = pre_ffn.permute(1, 0, 2)[:, 0, :]
            w_spec = self.dfa._noisy_topk(cls_feat.float(), train=self.training)  # [...,E]
            # Stack expert outputs as [..., D, E] and weight by w_spec[..., E]
            e2_outs = [self.dfa.e2_list[i](pre_ffn) for i in range(len(self.dfa.e2_list))]
            e_stack = torch.stack(e2_outs, dim=-1)
            r = (e_stack * w_spec.unsqueeze(-2)).sum(dim=-1)
            adapter_out = self.dfa.adapter_scalar * r
            # FFN residual (original block path)
            x = x + self.base.mlp(self.base.ln_2(x)) + adapter_out
            return x


class TextBlockWithDFA(nn.Module):
    def __init__(self, base_block: nn.Module, dfa_block: nn.Module):
        super().__init__()
        self.base = base_block
        self.dfa = dfa_block

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.base(x)
        y = self.dfa(y)
        return y


class ClassIncremental(nn.Module):
    def __init__(self, cfg, device, jit=False):
        super().__init__()
        self.prompt_template = cfg.prompt_template
        self.device = device
        self.classes_names = None
        self.model, self.transforms, _ = clip.load(cfg.model_name, device=device, jit=jit)
        self.ref_model = None
        # If no class order provided, defer construction until classes_names are available
        if getattr(cfg, 'class_order', None) is not None:
            self.class_ids_per_task = list(get_class_ids_per_task(cfg))
        else:
            self.class_ids_per_task = None
        self.current_class_names = []
        # Cumulative absolute class ids in the same order as current_class_names/text_tokens
        self.seen_class_ids = []
        self.text_tokens = None
        self.dynamic_dataset = DynamicDataset(cfg)
        # DFA module (two experts + router)
        # Detect embedding dimension for image features
        visual_embed_dim = getattr(getattr(self.model, 'visual', None), 'output_dim', None)
        if visual_embed_dim is None:
            try:
                dummy = torch.zeros(1, 3, getattr(cfg, 'input_resolution', 224), getattr(cfg, 'input_resolution', 224), device=device)
                with torch.no_grad():
                    visual_embed_dim = self.model.encode_image(dummy).shape[-1]
            except Exception:
                visual_embed_dim = getattr(self.model, 'embed_dim', 1024)
        r1 = int(getattr(cfg, 'dfa_r1', 16))
        r2 = int(getattr(cfg, 'dfa_r2', 8))
        router_hidden = int(getattr(cfg, 'dfa_router_hidden', 0))
        beta = float(getattr(cfg, 'dfa_beta', 1.0))
        self.dfa = DFAModule(
            visual_embed_dim,
            r1=r1, r2=r2, router_hidden=router_hidden, beta=beta,
            e2_router_mode=str(getattr(cfg, 'e2_router_mode', 'moe_w')).lower(),
            e2_top_k=int(getattr(cfg, 'e2_top_k', 2)),
            single_router_3way=bool(getattr(cfg, 'single_router_3way', False)),
            adapter_dropout=float(getattr(cfg, 'adapter_dropout', 0.1)),
            adapter_scalar=float(getattr(cfg, 'adapter_scalar', 0.1)),
            num_task_experts=int(getattr(cfg, 'num_task_experts', 2)),
        ).to(self.device)
        # Gate mode: 'ood_confidence' (default) or 'router' or 'debug'
        self.gate_mode = str(getattr(cfg, 'gate_mode', 'ood_confidence')).lower()
        # Debug mode: manually specify weights for seen/unseen
        self.debug_mode = (self.gate_mode == 'debug')
        # If True, in debug mode force all mass to task experts (w_top=[0,1]) to mimic moe-w behavior
        self.debug_task_only = bool(getattr(cfg, 'debug_task_only', False))
        if self.debug_mode:
            # Seen weights: [w1, w2]
            _seen_w = getattr(cfg, 'debug_seen_weights', [0.3, 0.7])
            if isinstance(_seen_w, ListConfig):
                _seen_w = list(_seen_w)
            self.debug_seen_weights = [float(_seen_w[0]), float(_seen_w[1])] if isinstance(_seen_w, (list, tuple)) and len(_seen_w) >= 2 else [0.3, 0.7]
            s_sum = float(self.debug_seen_weights[0] + self.debug_seen_weights[1])
            if s_sum == 0:
                self.debug_seen_weights = [0.5, 0.5]
            else:
                self.debug_seen_weights = [self.debug_seen_weights[0] / s_sum, self.debug_seen_weights[1] / s_sum]

        # Injection mode: 'head' (default) or 'block' (inject into transformer like moe-w)
        self.dfa_inject_mode = str(getattr(cfg, 'dfa_inject_mode', 'head')).lower()
        self.dfa_blocks = []
        if self.dfa_inject_mode == 'block':
            self._inject_dfa_into_visual_blocks(cfg)
            # Optionally inject into the last K text transformer blocks (A1)
            self.text_dfa_blocks = []
            self._inject_dfa_into_text_blocks(cfg)

            # Zero-shot weights: either single pair [w1,w2] or Nx2 matrix (e.g., 25x2 for MTIL)
            _zs_w = getattr(cfg, 'debug_zs_weights', [0.7, 0.3])
            if isinstance(_zs_w, ListConfig):
                _zs_w = list(_zs_w)
            self.debug_zs_weights_matrix = None
            # Detect matrix: first element is a (list-like) pair
            is_matrix = (
                isinstance(_zs_w, (list, tuple)) and len(_zs_w) > 0 and (
                    isinstance(_zs_w[0], (list, tuple)) or isinstance(_zs_w[0], ListConfig)
                )
            )
            if is_matrix:
                try:
                    rows = []
                    for row in _zs_w:
                        if isinstance(row, ListConfig):
                            row = list(row)
                        if not isinstance(row, (list, tuple)) or len(row) < 2:
                            continue
                        rs = float(row[0])
                        re = float(row[1])
                        s = rs + re
                        if s == 0:
                            rows.append([0.5, 0.5])
                        else:
                            rows.append([rs / s, re / s])
                    if rows:
                        self.debug_zs_weights_matrix = torch.tensor(rows, dtype=torch.float32, device=self.device)
                except Exception:
                    self.debug_zs_weights_matrix = None
            # Fallback to single pair
            if self.debug_zs_weights_matrix is None:
                if isinstance(_zs_w, (list, tuple)) and len(_zs_w) >= 2:
                    zs_pair = [float(_zs_w[0]), float(_zs_w[1])]
                else:
                    zs_pair = [0.7, 0.3]
                z_sum = float(zs_pair[0] + zs_pair[1])
                if z_sum == 0:
                    self.debug_zs_weights = [0.5, 0.5]
                else:
                    self.debug_zs_weights = [zs_pair[0] / z_sum, zs_pair[1] / z_sum]
        # Lightweight replay buffer: abs_class_id -> list[tensor(C,H,W)]
        self.replay_buffer = {}
        self.replay_per_class = int(getattr(cfg, 'replay_per_class', 5))
        # Momentum queue for Stage A contrastive learning
        self.queue_size = int(getattr(cfg, 'moco_queue_size', 4096))
        self.queue_momentum = float(getattr(cfg, 'moco_momentum', 0.999))
        self.register_buffer('queue_img', torch.randn(visual_embed_dim, self.queue_size))
        self.register_buffer('queue_txt', torch.randn(visual_embed_dim, self.queue_size))
        self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
        self.queue_img = nn.functional.normalize(self.queue_img, dim=0)
        self.queue_txt = nn.functional.normalize(self.queue_txt, dim=0)

    def forward(self, image, taskid):
        with torch.no_grad():
            logits_per_image, _ = self.model(image, self.text_tokens, 0, is_train=False)
            probs = logits_per_image.softmax(dim=-1)
        return probs

    def adaptation(self, task_id, cfg, train_dataset, train_classes_names):
        # Derive per-task class ids directly from the scenario slice to avoid cfg-dependent mismatches
        task_slice = train_dataset[task_id:task_id + 1]
        tmp_loader = DataLoader(task_slice, batch_size=256, shuffle=False, num_workers=2)
        uniq = set()
        for _inputs, _targets, _tids in tmp_loader:
            uniq.update(_targets.tolist())
        real_ids = sorted(int(x) for x in uniq)
        # Remember start index before appending for external mapping
        start_idx = len(self.current_class_names)
        # Initialize or update class_ids_per_task with discovered ids
        if self.class_ids_per_task is None:
            self.class_ids_per_task = []
        while len(self.class_ids_per_task) < task_id:
            self.class_ids_per_task.append([])
        if len(self.class_ids_per_task) == task_id:
            self.class_ids_per_task.append(real_ids)
        else:
            self.class_ids_per_task[task_id] = real_ids

        # Update cumulative class names and absolute ids in the same order
        self.current_class_names += get_class_names(self.classes_names, real_ids)
        self.seen_class_ids += list(real_ids)
        self.text_tokens = clip.tokenize(
            [self.prompt_template.format(c) for c in self.current_class_names]
        ).to(self.device)

        # Expose mapping info for the last adapted task
        self.last_task_real_ids = list(real_ids)
        self.last_task_start_index = int(start_idx)

        if cfg.method != "zeroshot":
            self.train(task_id, cfg, train_dataset, train_classes_names)

    def train(self, task_id, cfg, train_dataset, train_classes_names):
        # Parameter count output before Task 0 training starts
        if int(task_id) == 0:
            total_params = sum(p.numel() for p in self.parameters())
            trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
            print(f"[Params] Total: {total_params:,} | Trainable: {trainable_params:,}")
        # 1) Data
        train_loader = DataLoader(
            train_dataset[task_id:task_id + 1],
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=min(8, os.cpu_count() or 8),
            pin_memory=True,
        )

        # Stage-specific training lengths
        epochs_global = int(getattr(cfg, 'epochs', 1))
        epochs_a = int(getattr(cfg, 'epochs_a', epochs_global))
        epochs_b = int(getattr(cfg, 'epochs_b', epochs_global))
        total_iterations_a = max(1, epochs_a * len(train_loader))
        total_iterations_b = max(1, epochs_b * len(train_loader))
        use_amp = bool(getattr(cfg, 'use_amp', True) and torch.cuda.is_available())
        # Initialize scaler here for Stage B (needed even if Stage A is skipped)
        scaler = GradScaler(enabled=use_amp)
        # Map abs labels to local 0..C-1
        task_class_ids = [int(c) for c in self.class_ids_per_task[task_id]]
        local_C = len(task_class_ids)
        max_cid = max(task_class_ids)
        map_table_cpu = torch.full((max_cid + 1,), -1, dtype=torch.long)
        for i, cid in enumerate(task_class_ids):
            map_table_cpu[cid] = i

        # Text tokens
        # Task-specific texts (positives)
        classnames = get_class_names(self.classes_names, self.class_ids_per_task[task_id])
        texts_task = clip.tokenize([self.prompt_template.format(c) for c in classnames]).to(self.device)
        # Alias for Stage B code paths that reference `texts`
        texts = texts_task
        # Broad pool for negatives: all seen classes + generic templates
        all_seen_names = []
        for tid in range(task_id + 1):
            all_seen_names.extend(get_class_names(self.classes_names, self.class_ids_per_task[tid]))
        generic_templates = getattr(cfg, 'e1_generic_templates', [
            "a photo of an object",
            "a photo of something",
            "an image of a thing",
            "a picture",
        ])
        pool_names = list(dict.fromkeys(all_seen_names)) + list(generic_templates)
        texts_pool = clip.tokenize([self.prompt_template.format(c) for c in pool_names]).to(self.device)

        # Freeze CLIP backbone
        for p in self.model.parameters():
            p.requires_grad = False

        # -------- Stage A: train Expert1 only (AFFA-like, InfoNCE with MoCo queue) --------
        skip_stage_a = bool(getattr(cfg, 'skip_stage_a', False))
        unified_stage = bool(getattr(cfg, 'unified_stage', False))
        if unified_stage:
            # Unified single-stage training: train E1, E2a, E2b together; routers as in Stage B
            # Optim setup similar to Stage B (block mode): include E1 params as adapters
            self.model.eval()
            weight_decay = float(getattr(cfg, 'weight_decay', 0.0))
            lr_e2 = float(getattr(cfg, 'lr_e2', getattr(cfg, 'lr', 1e-3)))
            lr_e2_text = float(getattr(cfg, 'text_lr_e2', max(lr_e2 * 0.1, 1e-6)))
            lr_e2_router = float(getattr(cfg, 'lr_e2_router', lr_e2))
            lr_top_router = float(getattr(cfg, 'lr_top_router', 5.0e-6))
            # Collect params
            img_adapt_params, txt_adapt_params = [], []
            e2_router_params, top_router_params = [], []
            # Freeze head routers; train per-block routers
            if hasattr(self.dfa, 'router'):
                for p in self.dfa.router.parameters():
                    p.requires_grad = False
            if hasattr(self.dfa, 'router_e2'):
                for p in self.dfa.router_e2.parameters():
                    p.requires_grad = False
            # Image blocks
            if hasattr(self, 'dfa_blocks') and self.dfa_blocks:
                for m in self.dfa_blocks:
                    # adapters: train E1 + all E2 together
                    if hasattr(m, 'e1'):
                        for p in m.e1.parameters(): p.requires_grad = True
                        img_adapt_params += list(m.e1.parameters())
                    for e2 in getattr(m, 'e2_list', []):
                        for p in e2.parameters(): p.requires_grad = True
                        img_adapt_params += list(e2.parameters())
                    # routers
                    if getattr(m, 'single_router_3way', False) and hasattr(m, 'router3'):
                        for p in m.w_gate_e2.parameters(): p.requires_grad = False
                        for p in m.w_noise_e2.parameters(): p.requires_grad = False
                        for p in m.router3.parameters(): p.requires_grad = True
                        top_router_params += list(m.router3.parameters())
                    else:
                        for p in m.w_gate_e2.parameters(): p.requires_grad = True
                        for p in m.w_noise_e2.parameters(): p.requires_grad = True
                        e2_router_params += list(m.w_gate_e2.parameters()) + list(m.w_noise_e2.parameters())
                        if hasattr(m, 'router_top'):
                            for p in m.router_top.parameters(): p.requires_grad = True
                            top_router_params += list(m.router_top.parameters())
            # Text blocks (optional)
            if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                for m in self.text_dfa_blocks:
                    if hasattr(m, 'e1'):
                        for p in m.e1.parameters(): p.requires_grad = True
                        txt_adapt_params += list(m.e1.parameters())
                    for e2 in getattr(m, 'e2_list', []):
                        for p in e2.parameters(): p.requires_grad = True
                        txt_adapt_params += list(e2.parameters())
                    if getattr(m, 'single_router_3way', False) and hasattr(m, 'router3'):
                        for p in m.w_gate_e2.parameters(): p.requires_grad = False
                        for p in m.w_noise_e2.parameters(): p.requires_grad = False
                        for p in m.router3.parameters(): p.requires_grad = True
                        top_router_params += list(m.router3.parameters())
                    else:
                        for p in m.w_gate_e2.parameters(): p.requires_grad = True
                        for p in m.w_noise_e2.parameters(): p.requires_grad = True
                        e2_router_params += list(m.w_gate_e2.parameters()) + list(m.w_noise_e2.parameters())
                        if hasattr(m, 'router_top'):
                            for p in m.router_top.parameters(): p.requires_grad = True
                            top_router_params += list(m.router_top.parameters())
            # Build optimizer
            param_groups = []
            base_lrs = []
            if img_adapt_params:
                param_groups.append({"params": img_adapt_params, "lr": lr_e2, "weight_decay": weight_decay})
                base_lrs.append(lr_e2)
            if txt_adapt_params:
                param_groups.append({"params": txt_adapt_params, "lr": lr_e2_text, "weight_decay": weight_decay})
                base_lrs.append(lr_e2_text)
            if e2_router_params:
                param_groups.append({"params": e2_router_params, "lr": lr_e2_router, "weight_decay": weight_decay})
                base_lrs.append(lr_e2_router)
            if top_router_params:
                param_groups.append({"params": top_router_params, "lr": lr_top_router, "weight_decay": weight_decay})
                base_lrs.append(lr_top_router)
            if not param_groups:
                raise RuntimeError("No parameters collected for Unified-stage optimization")
            opt_u = torch.optim.AdamW(param_groups)
            total_iterations_u = max(1, epochs_b * len(train_loader))
            sched_u = utils.cosine_lr(opt_u, base_lrs if len(base_lrs) > 1 else base_lrs[0], 30, total_iterations_u)
            # KD teacher (optional, though lambda_kd is usually 0 now)
            lambda_kd = float(getattr(cfg, 'lambda_kd', 0.0))
            tau_kd = float(getattr(cfg, 'tau_kd', 2.0))
            teacher = None
            if lambda_kd > 0.0:
                try:
                    teacher = copy.deepcopy(self.dfa).eval().to(self.device)
                    for p in teacher.parameters(): p.requires_grad = False
                except Exception:
                    teacher = None
            # MoCo buffering
            use_moco = bool(getattr(cfg, 'use_moco_queue', True))
            mq_img_buf, mq_txt_buf = [], []

            self.dfa.train()
            it = 0
            for epoch in range(epochs_b):
                for inputs, targets_abs, _ in tqdm(train_loader, desc=f"Task {task_id} Unified"):
                    sched_u(it); it += 1
                    inputs = inputs.to(self.device, non_blocking=True)
                    targets_abs = targets_abs.to(self.device, non_blocking=True)
                    map_table = map_table_cpu.to(targets_abs.device)
                    targets = map_table[targets_abs]
                    if (targets < 0).any():
                        raise RuntimeError("Label mapping failed in Unified stage")
                    # texts and scale
                    txt = self.model.encode_text(texts_task)
                    txt = txt / txt.norm(dim=-1, keepdim=True)
                    scale = self.model.logit_scale.exp()
                    # forward (block path routes inside encode_image)
                    opt_u.zero_grad(set_to_none=True)
                    with autocast(enabled=use_amp):
                        img = self.model.encode_image(inputs)
                        fused = img / img.norm(dim=-1, keepdim=True)
                        logits = scale * fused @ txt.t()
                        local_C = int(txt.size(0))
                        loss = F.cross_entropy(logits[:, :local_C], targets, label_smoothing=float(getattr(cfg, 'ls', 0.0)))
                        # L2 reg on E2a/E2b
                        lambda_reg = float(getattr(cfg, 'lambda_reg', 1e-4))
                        if lambda_reg > 0.0:
                            reg = 0.0
                            for m in getattr(self, 'dfa_blocks', []):
                                for e2 in getattr(m, 'e2_list', []):
                                    for p in e2.parameters(): reg = reg + p.pow(2).sum()
                            if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                                for m in self.text_dfa_blocks:
                                    for e2 in getattr(m, 'e2_list', []):
                                        for p in e2.parameters(): reg = reg + p.pow(2).sum()
                            loss = loss + lambda_reg * reg
                        # Optional KD (if enabled)
                        if lambda_kd > 0.0 and teacher is not None:
                            with torch.no_grad():
                                t_feat = teacher(img, route=True)
                                t_feat = t_feat / t_feat.norm(dim=-1, keepdim=True)
                                t_logits = scale * t_feat @ txt.t()
                            kd = F.kl_div(
                                F.log_softmax(logits / tau_kd, dim=-1),
                                F.softmax(t_logits / tau_kd, dim=-1),
                                reduction='batchmean'
                            ) * (tau_kd ** 2)
                            loss = loss + lambda_kd * kd
                        # Queue-based distribution regularizers (already implemented use fused/txt)
                        # Image/Text feature KL and global logits KL are added below in existing code paths
                    scaler.scale(loss).backward()
                    scaler.step(opt_u)
                    scaler.update()
                    # Buffer MoCo features (per-sample class text)
                    if use_moco and self.queue_size > 0:
                        try:
                            with torch.no_grad():
                                txt_task = self.model.encode_text(texts_task)
                                txt_task = txt_task / txt_task.norm(dim=-1, keepdim=True)
                                txt_pos_mean = txt_task[targets]
                            mq_img_buf.append(fused.detach().cpu())
                            mq_txt_buf.append(txt_pos_mean.detach().cpu())
                        except Exception:
                            pass
            # Post-enqueue buffered features once per task
            pre_q = int(getattr(self, 'queue_count', 0)) if hasattr(self, 'queue_size') else 0
            if use_moco and self.queue_size > 0 and len(mq_img_buf) > 0:
                try:
                    img_cat = torch.cat(mq_img_buf, dim=0)
                    txt_cat = torch.cat(mq_txt_buf, dim=0)
                    mode = str(getattr(cfg, 'moco_post_enqueue_mode', 'subsample')).lower()
                    if mode not in ('fifo', 'quota', 'subsample'):
                        mode = 'subsample'
                    # fullness-aware quota
                    if mode == 'subsample':
                        if img_cat.size(0) > self.queue_size:
                            perm = torch.randperm(img_cat.size(0))[: self.queue_size]
                            img_cat = img_cat[perm]
                            txt_cat = txt_cat[perm]
                    elif mode == 'quota':
                        quota = int(getattr(cfg, 'moco_post_enqueue_quota', max(0, self.queue_size // 4)))
                        quota = max(0, quota)
                        cur = int(getattr(self, 'queue_count', 0))
                        cap = int(getattr(self, 'queue_size', 0))
                        want = int(img_cat.size(0))
                        n_add = want if (cur + want) <= cap else min(quota, want)
                        img_cat = img_cat[:n_add]
                        txt_cat = txt_cat[:n_add]
                    # From task 1 onward, cap enqueue size to slow queue refresh
                    if int(task_id) >= 1:
                        task_quota = int(getattr(cfg, 'moco_task_enqueue_quota', 32))
                        task_quota = max(0, min(task_quota, int(self.queue_size)))
                        if task_quota > 0:
                            img_cat = img_cat[:task_quota]
                            txt_cat = txt_cat[:task_quota]
                    # enqueue
                    if img_cat.numel() > 0:
                        img_cat = img_cat.to(self.device, non_blocking=True)
                        txt_cat = txt_cat.to(self.device, non_blocking=True)
                        chunk = min(1024, self.queue_size)
                        for s in range(0, img_cat.size(0), chunk):
                            self._dequeue_and_enqueue(img_cat[s:s+chunk], txt_cat[s:s+chunk])
                except Exception:
                    pass
            post_q = int(getattr(self, 'queue_count', 0)) if hasattr(self, 'queue_size') else pre_q
            try:
                print(f"[MoCo][Task {task_id}] Queue before: {pre_q}/{self.queue_size} | after: {post_q}/{self.queue_size} | added: {max(0, post_q - pre_q)}")
            except Exception:
                pass
            return

        if skip_stage_a:
            # Skip Expert1 training and force Expert1 to be identity to match original CLIP features
            self.dfa.e1_identity = True
        else:
            # Ensure Expert1 uses its residual adapter when Stage A is enabled
            self.dfa.e1_identity = False
            # Head-mode vs block-mode Stage A parameterization
            if getattr(self, 'dfa_inject_mode', 'head') == 'block':
                # Train per-block E1 only; freeze routers and E2 in blocks; freeze head DFA
                for p in self.dfa.e1.parameters():
                    p.requires_grad = False
                for e2 in getattr(self.dfa, 'e2_list', []):
                    for p in e2.parameters(): p.requires_grad = False
                if hasattr(self.dfa, 'router'):
                    for p in self.dfa.router.parameters():
                        p.requires_grad = False
                if hasattr(self.dfa, 'router_e2'):
                    for p in self.dfa.router_e2.parameters():
                        p.requires_grad = False
                # Enable block E1 only behavior and grads
                for m in getattr(self, 'dfa_blocks', []):
                    m.e1_only = True
                    for p in m.e1.parameters():
                        p.requires_grad = True
                    # freeze routers and E2 during Stage A
                    for e2 in getattr(m, 'e2_list', []):
                        for p in e2.parameters(): p.requires_grad = False
                    for p in m.w_gate_e2.parameters(): p.requires_grad = False
                    for p in m.w_noise_e2.parameters(): p.requires_grad = False
                    # freeze both router_top and router3 if present
                    if hasattr(m, 'router_top'):
                        for p in m.router_top.parameters(): p.requires_grad = False
                    if hasattr(m, 'router3'):
                        for p in m.router3.parameters(): p.requires_grad = False
                # Do not touch text blocks here (image-side general expert only)
                if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                    for m in self.text_dfa_blocks:
                        m.e1_only = True
                        for p in m.e1.parameters():
                            p.requires_grad = True
                        for e2 in getattr(m, 'e2_list', []):
                            for p in e2.parameters(): p.requires_grad = False
                        for p in m.w_gate_e2.parameters(): p.requires_grad = False
                        for p in m.w_noise_e2.parameters(): p.requires_grad = False
                        if hasattr(m, 'router_top'):
                            for p in m.router_top.parameters(): p.requires_grad = False
                        if hasattr(m, 'router3'):
                            for p in m.router3.parameters(): p.requires_grad = False
            else:
                for p in self.dfa.e1.parameters():
                    p.requires_grad = True
                for e2 in getattr(self.dfa, 'e2_list', []):
                    for p in e2.parameters(): p.requires_grad = False
                # Freeze top router if present (not in single_router_3way)
                if hasattr(self.dfa, 'router'):
                    for p in self.dfa.router.parameters():
                        p.requires_grad = False
                # In moe_w mode, router_e2 does not exist; guard accordingly
                if hasattr(self.dfa, 'router_e2'):
                    for p in self.dfa.router_e2.parameters():
                        p.requires_grad = False

            # Print Stage A trainable vs total params
            total_params = sum(p.numel() for p in self.parameters())
            trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
            print(f"[Params][Stage A] Total: {total_params:,} | Trainable: {trainable_params:,}")

            lr_a = float(getattr(cfg, 'lr_e1', getattr(cfg, 'lr', 1e-4)))
            if getattr(self, 'dfa_inject_mode', 'head') == 'block':
                # Optimize all block E1 parameters
                params_a = []
                for m in getattr(self, 'dfa_blocks', []):
                    params_a += list(m.e1.parameters())
                if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                    for m in self.text_dfa_blocks:
                        params_a += list(m.e1.parameters())
                opt_a = torch.optim.AdamW(params_a, lr=lr_a, weight_decay=float(getattr(cfg, 'weight_decay', 0.0)))
            else:
                opt_a = torch.optim.AdamW(self.dfa.e1.parameters(), lr=lr_a, weight_decay=float(getattr(cfg, 'weight_decay', 0.0)))
            sched_a = utils.cosine_lr(opt_a, lr_a, 30, total_iterations_a)
            # Epoch-phased LR for E1 (Stage A)
            lr_e1_ini = float(getattr(cfg, 'lr_e1_ini', lr_a))
            lr_e1_next = float(getattr(cfg, 'lr_e1_next', lr_a))
            lr_e1_final = float(getattr(cfg, 'lr_e1_final', lr_a))
            iters_per_epoch = len(train_loader)
            half_ep2 = max(1, iters_per_epoch // 2)
            tau_con = float(getattr(cfg, 'tau_con', 0.05))
            use_moco = bool(getattr(cfg, 'use_moco_queue', True))
            # Buffer current-task features; enqueue only after Stage A completes
            mq_img_buf: list = []
            mq_txt_buf: list = []

            self.model.eval()
            self.dfa.train()
            it = 0
            for epoch in range(epochs_a):
                for inputs, targets_abs, _ in tqdm(train_loader, desc=f"Task {task_id} A (E1)"):
                    sched_a(it)
                    if epoch == 0:
                        lr_curr = lr_e1_ini
                    elif epoch == 1:
                        pos_in_ep = it - iters_per_epoch
                        lr_curr = lr_e1_next if pos_in_ep < half_ep2 else lr_e1_final
                    else:
                        lr_curr = lr_e1_final
                    for g in opt_a.param_groups:
                        g['lr'] = lr_curr
                    it += 1
                    inputs = inputs.to(self.device, non_blocking=True)
                    targets_abs = targets_abs.to(self.device, non_blocking=True)
                    # map to local ids
                    map_table = map_table_cpu.to(targets_abs.device)
                    targets = map_table[targets_abs]
                    if (targets < 0).any():
                        raise RuntimeError("Label mapping failed in Stage A")

                    with torch.no_grad():
                        # In block mode, we will re-encode with grad after anchors are prepared
                        img0 = self.model.encode_image(inputs)
                        img0 = img0 / img0.norm(dim=-1, keepdim=True)
                        # Build multi-template texts for Option C (SupCon-like, multi-positive)
                        templates_cfg = getattr(cfg, 'e1_generic_templates', [
                            "a photo of a {}",
                            "an image of a {}",
                        ])
                        # Dedup and prepend the base prompt template
                        all_templates = []
                        base_t = str(getattr(cfg, 'prompt_template', self.prompt_template))
                        if base_t not in all_templates:
                            all_templates.append(base_t)
                        for t in templates_cfg:
                            if t not in all_templates:
                                all_templates.append(t)
                        Tmpl = len(all_templates)

                        # Present classes in this batch (local ids)
                        present_local = torch.unique(targets).tolist()
                        present_local.sort()
                        present_names = [classnames[i] for i in present_local]

                        # Build present-class multi-template texts as anchors (positives for corresponding class)
                        texts_present = [tmpl.format(c) for c in present_names for tmpl in all_templates]

                        # Negatives from other seen classes (exclude present classes)
                        other_seen = [n for n in dict.fromkeys(all_seen_names) if n not in set(present_names)]
                        texts_neg = [tmpl.format(c) for c in other_seen for tmpl in all_templates] if other_seen else None
                        scale = self.model.logit_scale.exp()

                    # Encode text anchors/negatives with grad enabled so text-side E1 (if injected) can learn
                    train_text_e1 = bool(getattr(self, 'text_dfa_blocks', None)) and bool(getattr(self, 'text_dfa_blocks', []))
                    if train_text_e1:
                        txt_present_anchors = self.model.encode_text(clip.tokenize(texts_present).to(self.device))
                        txt_present_anchors = txt_present_anchors / txt_present_anchors.norm(dim=-1, keepdim=True)
                        if texts_neg:
                            txt_neg = self.model.encode_text(clip.tokenize(texts_neg).to(self.device))
                            txt_neg = txt_neg / txt_neg.norm(dim=-1, keepdim=True)
                        else:
                            txt_neg = None
                    else:
                        with torch.no_grad():
                            txt_present_anchors = self.model.encode_text(clip.tokenize(texts_present).to(self.device))
                            txt_present_anchors = txt_present_anchors / txt_present_anchors.norm(dim=-1, keepdim=True)
                            if texts_neg:
                                txt_neg = self.model.encode_text(clip.tokenize(texts_neg).to(self.device))
                                txt_neg = txt_neg / txt_neg.norm(dim=-1, keepdim=True)
                            else:
                                txt_neg = None
                    # Expert1 branch residual
                    if getattr(self, 'dfa_inject_mode', 'head') == 'block':
                        # Re-encode with grad so block E1 receives gradients
                        img0_grad = self.model.encode_image(inputs)
                        feats = img0_grad
                    else:
                        feats = img0 + self.dfa.e1(img0)
                    feats = feats / feats.norm(dim=-1, keepdim=True)
                    # Map from local class id -> index among present classes
                    present_index_map = {int(lid): idx for idx, lid in enumerate(present_local)}

                    # Build per-sample class-mean text (for queue update)
                    B = feats.size(0)
                    Tmpl = max(1, Tmpl)
                    txt_pos_mean = []
                    for i in range(B):
                        cls_lid = int(targets[i].item())
                        pidx = present_index_map[cls_lid]
                        start = pidx * Tmpl
                        end = start + Tmpl
                        txt_block = txt_present_anchors[start:end]
                        txt_pos_mean.append(txt_block.mean(dim=0, keepdim=True))
                    txt_pos_mean = torch.cat(txt_pos_mean, dim=0)  # [B, D]

                    opt_a.zero_grad(set_to_none=True)
                    with autocast(enabled=use_amp):
                        # Build candidate texts: anchors (present classes, multi-template) + negatives + optional queue
                        if use_moco and self.queue_size > 0:
                            queue_txt_batch = self.queue_txt.clone().detach().t().to(feats.device)  # [K, D]
                            cand_txt = torch.cat(
                                [
                                    txt_present_anchors,
                                    txt_neg if txt_neg is not None else txt_present_anchors.new_zeros((0, txt_present_anchors.size(1))),
                                    queue_txt_batch,
                                ],
                                dim=0,
                            )
                            queue_img_batch = self.queue_img.clone().detach().t().to(feats.device)  # [K, D]
                            cand_img = torch.cat([feats, queue_img_batch], dim=0)  # [B+K, D]
                        else:
                            cand_txt = torch.cat(
                                [
                                    txt_present_anchors,
                                    txt_neg if txt_neg is not None else txt_present_anchors.new_zeros((0, txt_present_anchors.size(1))),
                                ],
                                dim=0,
                            )
                            cand_img = feats

                        # Image->Text (multi-positive): each image has T positives for its class
                        N_txt = cand_txt.size(0)
                        pos_mask_i2t = torch.zeros(B, N_txt, dtype=torch.bool, device=feats.device)
                        for i in range(B):
                            cls_lid = int(targets[i].item())
                            pidx = present_index_map[cls_lid]
                            start = pidx * Tmpl
                            end = start + Tmpl
                            pos_mask_i2t[i, start:end] = True

                        logits_i2t = (scale * feats @ cand_txt.t()) / max(1e-6, tau_con)
                        # numer = logsumexp over positives; denom = logsumexp over all
                        pos_logits = logits_i2t.masked_fill(~pos_mask_i2t, float('-inf'))
                        numer = torch.logsumexp(pos_logits, dim=1)
                        denom = torch.logsumexp(logits_i2t, dim=1)
                        loss_i2t = -(numer - denom).mean()

                        # Text->Image (multi-positive): each text anchor (class,template) matches all images of that class in batch
                        A = txt_present_anchors.size(0)  # n_present * Tmpl
                        N_img = cand_img.size(0)
                        pos_mask_t2i = torch.zeros(A, N_img, dtype=torch.bool, device=feats.device)
                        # Build mapping from class index in present_local for each anchor row
                        for a in range(A):
                            cls_local = present_local[a // Tmpl]
                            # positives are batch images of this class (indices 0..B-1 only)
                            match_idx = (targets == int(cls_local)).nonzero(as_tuple=False).squeeze(1)
                            if match_idx.numel() > 0:
                                pos_mask_t2i[a, match_idx] = True

                        logits_t2i = (scale * txt_present_anchors @ cand_img.t()) / max(1e-6, tau_con)
                        pos_logits_t2i = logits_t2i.masked_fill(~pos_mask_t2i, float('-inf'))
                        numer_t2i = torch.logsumexp(pos_logits_t2i, dim=1)
                        denom_t2i = torch.logsumexp(logits_t2i, dim=1)
                        # Only average over anchors that have at least one positive in batch
                        valid_rows = pos_mask_t2i.any(dim=1)
                        if valid_rows.any():
                            loss_t2i = -((numer_t2i[valid_rows] - denom_t2i[valid_rows]).mean())
                        else:
                            loss_t2i = torch.zeros((), device=feats.device, dtype=feats.dtype)

                        loss_a = (loss_i2t + loss_t2i) * 0.5
                    scaler.scale(loss_a).backward()
                    scaler.step(opt_a)
                    scaler.update()
                    # Collect for post-task enqueue only (do not update queue during training)
                    if use_moco and self.queue_size > 0:
                        try:
                            mq_img_buf.append(feats.detach().cpu())
                            mq_txt_buf.append(txt_pos_mean.detach().cpu())
                        except Exception:
                            pass
                # Disable block e1_only after Stage A
            if getattr(self, 'dfa_inject_mode', 'head') == 'block':
                for m in getattr(self, 'dfa_blocks', []):
                    m.e1_only = False
                if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                    for m in self.text_dfa_blocks:
                        m.e1_only = False
                # Note: defer MoCo queue enqueue to AFTER Stage B completes
                # Store a few exemplars into replay buffer
                try:
                    for i in range(min(inputs.size(0), 4)):
                        abs_y = int(targets_abs[i].item())
                        if abs_y not in self.replay_buffer:
                            self.replay_buffer[abs_y] = []
                        if len(self.replay_buffer[abs_y]) < self.replay_per_class:
                            self.replay_buffer[abs_y].append(inputs[i].detach().cpu())
                except Exception:
                    pass

        # -------- Stage B: train Expert2a/2b + routers (classification + optional KD) --------
        for p in self.dfa.e1.parameters():
            p.requires_grad = False
        for e2 in getattr(self.dfa, 'e2_list', []):
            for p in e2.parameters(): p.requires_grad = True
        if getattr(self, 'dfa_inject_mode', 'head') == 'block':
            # LRs
            lr_e2 = float(getattr(cfg, 'lr_e2', getattr(cfg, 'lr', 1e-3)))
            lr_e2_text = float(getattr(cfg, 'text_lr_e2', max(lr_e2 * 0.1, 1e-6)))
            lr_e2_router = float(getattr(cfg, 'lr_e2_router', lr_e2))
            lr_top_router = float(getattr(cfg, 'lr_top_router', 5.0e-6))
            weight_decay = float(getattr(cfg, 'weight_decay', 0.0))

            # Collect params: adapters vs E2 routers
            img_adapt_params, txt_adapt_params = [], []
            e2_router_params = []
            for m in self.dfa_blocks:
                # adapters (freeze E1 in Stage B; train all E2 only)
                if hasattr(m, 'e1'):
                    for p in m.e1.parameters(): p.requires_grad = False
                for e2 in getattr(m, 'e2_list', []):
                    for p in e2.parameters(): p.requires_grad = True
                    img_adapt_params += list(e2.parameters())
                # routers
                if getattr(m, 'single_router_3way', False) and hasattr(m, 'router3'):
                    # single 3-way router: train router3 only
                    for p in m.w_gate_e2.parameters(): p.requires_grad = False
                    for p in m.w_noise_e2.parameters(): p.requires_grad = False
                    for p in m.router3.parameters(): p.requires_grad = True
                else:
                    # two-router: train E2 gates; top router handled below
                    for p in m.w_gate_e2.parameters(): p.requires_grad = True
                    for p in m.w_noise_e2.parameters(): p.requires_grad = True
                    e2_router_params += list(m.w_gate_e2.parameters()) + list(m.w_noise_e2.parameters())
            if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                for m in self.text_dfa_blocks:
                    if hasattr(m, 'e1'):
                        for p in m.e1.parameters(): p.requires_grad = False
                    for e2 in getattr(m, 'e2_list', []):
                        for p in e2.parameters(): p.requires_grad = True
                        txt_adapt_params += list(e2.parameters())
                    if getattr(m, 'single_router_3way', False) and hasattr(m, 'router3'):
                        for p in m.w_gate_e2.parameters(): p.requires_grad = False
                        for p in m.w_noise_e2.parameters(): p.requires_grad = False
                        for p in m.router3.parameters(): p.requires_grad = True
                    else:
                        for p in m.w_gate_e2.parameters(): p.requires_grad = True
                        for p in m.w_noise_e2.parameters(): p.requires_grad = True
                        e2_router_params += list(m.w_gate_e2.parameters()) + list(m.w_noise_e2.parameters())

            # Top router requires_grad: use per-block router_top; freeze head router in block mode
            top_router_params = []
            if hasattr(self.dfa, 'router'):
                for p in self.dfa.router.parameters():
                    p.requires_grad = False
            if hasattr(self, 'dfa_blocks') and self.dfa_blocks:
                for m in self.dfa_blocks:
                    if getattr(m, 'single_router_3way', False) and hasattr(m, 'router3'):
                        for p in m.router3.parameters(): p.requires_grad = True
                        top_router_params += list(m.router3.parameters())
                    else:
                        for p in m.router_top.parameters(): p.requires_grad = True
                        top_router_params += list(m.router_top.parameters())
            if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                for m in self.text_dfa_blocks:
                    if getattr(m, 'single_router_3way', False) and hasattr(m, 'router3'):
                        for p in m.router3.parameters(): p.requires_grad = True
                        top_router_params += list(m.router3.parameters())
                    else:
                        for p in m.router_top.parameters(): p.requires_grad = True
                        top_router_params += list(m.router_top.parameters())
            # router_e2 in head module not used in block moe_w path
            if hasattr(self.dfa, 'router_e2'):
                for p in self.dfa.router_e2.parameters():
                    p.requires_grad = False

            # Build optimizer param groups
            param_groups = []
            base_lrs = []
            if img_adapt_params:
                param_groups.append({"params": img_adapt_params, "lr": lr_e2, "weight_decay": weight_decay})
                base_lrs.append(lr_e2)
            if txt_adapt_params:
                param_groups.append({"params": txt_adapt_params, "lr": lr_e2_text, "weight_decay": weight_decay})
                base_lrs.append(lr_e2_text)
            if e2_router_params:
                param_groups.append({"params": e2_router_params, "lr": lr_e2_router, "weight_decay": weight_decay})
                base_lrs.append(lr_e2_router)
            if top_router_params:
                param_groups.append({"params": top_router_params, "lr": lr_top_router, "weight_decay": weight_decay})
                base_lrs.append(lr_top_router)

            # Fallback if nothing collected (shouldn't happen)
            if not param_groups:
                raise RuntimeError("No parameters collected for Stage B optimization in block mode")
            opt_b = torch.optim.AdamW(param_groups)
            sched_b = utils.cosine_lr(opt_b, base_lrs if len(base_lrs) > 1 else base_lrs[0], 30, total_iterations_b)
        elif getattr(self.dfa, 'single_router_3way', False):
            # Enable only the 3-way router; freeze others if exist
            if hasattr(self.dfa, 'router3'):
                for p in self.dfa.router3.parameters():
                    p.requires_grad = True
            if hasattr(self.dfa, 'router'):
                for p in self.dfa.router.parameters():
                    p.requires_grad = False
            if hasattr(self.dfa, 'w_gate_e2'):
                for p in self.dfa.w_gate_e2.parameters():
                    p.requires_grad = False
            if hasattr(self.dfa, 'w_noise_e2'):
                for p in self.dfa.w_noise_e2.parameters():
                    p.requires_grad = False
            if hasattr(self.dfa, 'router_e2'):
                for p in self.dfa.router_e2.parameters():
                    p.requires_grad = False
            params_b = []
            for e2 in getattr(self.dfa, 'e2_list', []):
                params_b += list(e2.parameters())
            if hasattr(self.dfa, 'router3'):
                params_b += list(self.dfa.router3.parameters())
        else:
            # head mode: two-level routing path (top + E2) within DFAModule
            joint_routers = bool(getattr(cfg, 'joint_train_routers', False))
            lr_e2 = float(getattr(cfg, 'lr_e2', getattr(cfg, 'lr', 1e-3)))
            lr_e2_router = float(getattr(cfg, 'lr_e2_router', lr_e2))
            lr_top_router = float(getattr(cfg, 'lr_top_router', 5.0e-6))
            weight_decay = float(getattr(cfg, 'weight_decay', 0.0))

            # E2 adapters
            adapters = []
            for e2 in getattr(self.dfa, 'e2_list', []):
                adapters += list(e2.parameters())
            for p in adapters: p.requires_grad = True

            # E2 router (moe_w default uses w_gate_e2/w_noise_e2)
            e2_router = []
            if getattr(self.dfa, 'e2_router_mode', 'moe_w') == 'moe_w':
                for p in self.dfa.w_gate_e2.parameters(): p.requires_grad = True
                for p in self.dfa.w_noise_e2.parameters(): p.requires_grad = True
                e2_router = list(self.dfa.w_gate_e2.parameters()) + list(self.dfa.w_noise_e2.parameters())
                if hasattr(self.dfa, 'router_e2'):
                    for p in self.dfa.router_e2.parameters(): p.requires_grad = False
            else:
                for p in self.dfa.router_e2.parameters(): p.requires_grad = True
                e2_router = list(self.dfa.router_e2.parameters())

            # Top router
            top_router = []
            for p in self.dfa.router.parameters():
                p.requires_grad = (self.gate_mode == 'router') and (joint_routers or True)
            if self.gate_mode == 'router':
                top_router = list(self.dfa.router.parameters())

            # Build param groups
            param_groups = []
            base_lrs = []
            if adapters:
                param_groups.append({"params": adapters, "lr": lr_e2, "weight_decay": weight_decay})
                base_lrs.append(lr_e2)
            if e2_router:
                lr_e2r = lr_e2_router if joint_routers else lr_e2
                param_groups.append({"params": e2_router, "lr": lr_e2r, "weight_decay": weight_decay})
                base_lrs.append(lr_e2r)
            if joint_routers and top_router:
                param_groups.append({"params": top_router, "lr": lr_top_router, "weight_decay": weight_decay})
                base_lrs.append(lr_top_router)
            elif (not joint_routers) and self.gate_mode == 'router':
                # keep previous behavior: include top router in same group
                param_groups.append({"params": top_router, "lr": lr_e2, "weight_decay": weight_decay})
                base_lrs.append(lr_e2)
            opt_b = torch.optim.AdamW(param_groups)
            sched_b = utils.cosine_lr(opt_b, base_lrs if len(base_lrs) > 1 else base_lrs[0], 30, total_iterations_b)
        if getattr(self, 'dfa_inject_mode', 'head') != 'block':
            lr_e2 = float(getattr(cfg, 'lr_e2', getattr(cfg, 'lr', 1e-3)))
            opt_b = torch.optim.AdamW(params_b, lr=lr_e2, weight_decay=float(getattr(cfg, 'weight_decay', 0.0)))
            sched_b = utils.cosine_lr(opt_b, lr_e2, 30, total_iterations_b)

        # Print Stage B trainable vs total params (after requires_grad setup)
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"[Params][Stage B] Total: {total_params:,} | Trainable: {trainable_params:,}")

        # Optional KD teacher snapshot
        lambda_kd = float(getattr(cfg, 'lambda_kd', 0.0))
        tau_kd = float(getattr(cfg, 'tau_kd', 2.0))
        teacher = None
        if lambda_kd > 0.0:
            try:
                teacher = copy.deepcopy(self.dfa).eval().to(self.device)
                for p in teacher.parameters():
                    p.requires_grad = False
            except Exception:
                teacher = None

        self.dfa.train()
        it = 0
        for epoch in range(epochs_b):
            for inputs, targets_abs, _ in tqdm(train_loader, desc=f"Task {task_id} B (E2+Router)"):
                sched_b(it)
                it += 1
                inputs = inputs.to(self.device, non_blocking=True)
                targets_abs = targets_abs.to(self.device, non_blocking=True)
                map_table = map_table_cpu.to(targets_abs.device)
                targets = map_table[targets_abs]
                if (targets < 0).any():
                    raise RuntimeError("Label mapping failed in Stage B")

                if getattr(self, 'dfa_inject_mode', 'head') == 'block':
                    # In block mode, we need gradients for per-block adapters.
                    # If text DFABlocks are present, compute text with grad; otherwise keep it under no_grad.
                    if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                        # Optional debug override for text-side top router in block mode
                        if self.gate_mode == 'debug':
                            w_pair = getattr(self, 'debug_seen_weights', [0.7, 0.3])
                            w_top = torch.tensor(w_pair, device=self.device, dtype=torch.float32)
                            for m in self.text_dfa_blocks:
                                m.debug_w_top = w_top
                        txt = self.model.encode_text(texts_task)
                        txt = txt / txt.norm(dim=-1, keepdim=True)
                        with torch.no_grad():
                            scale = self.model.logit_scale.exp()
                        # clear text-side debug override after forward
                        if self.gate_mode == 'debug':
                            for m in self.text_dfa_blocks:
                                if hasattr(m, 'debug_w_top'):
                                    m.debug_w_top = None
                    else:
                        with torch.no_grad():
                            txt = self.model.encode_text(texts_task)
                            txt = txt / txt.norm(dim=-1, keepdim=True)
                            scale = self.model.logit_scale.exp()
                else:
                    # Head mode: precompute both image and text without grad
                    with torch.no_grad():
                        img = self.model.encode_image(inputs)
                        img = img / img.norm(dim=-1, keepdim=True)
                        txt = self.model.encode_text(texts_task)
                        txt = txt / txt.norm(dim=-1, keepdim=True)
                        scale = self.model.logit_scale.exp()

                opt_b.zero_grad(set_to_none=True)
                with autocast(enabled=use_amp):
                    if getattr(self, 'dfa_inject_mode', 'head') == 'block':
                        # Compute image with grad; DFA is injected into blocks; optionally override top router in debug mode
                        if self.gate_mode == 'debug':
                            if getattr(self, 'debug_task_only', False):
                                w_pair = [0.0, 1.0]
                            else:
                                w_pair = getattr(self, 'debug_seen_weights', [0.7, 0.3])
                            w_top = torch.tensor(w_pair, device=self.device, dtype=torch.float32)
                            for m in getattr(self, 'dfa_blocks', []):
                                m.debug_w_top = w_top
                        img = self.model.encode_image(inputs)
                        img = img / img.norm(dim=-1, keepdim=True)
                        # clear debug override after forward
                        if self.gate_mode == 'debug':
                            for m in getattr(self, 'dfa_blocks', []):
                                if hasattr(m, 'debug_w_top'):
                                    m.debug_w_top = None
                        fused = img
                    elif getattr(self.dfa, 'single_router_3way', False):
                        # Single 3-way router: ignore debug/ood gating and route end-to-end
                        fused = self.dfa(img, route=True)
                    elif self.gate_mode == 'debug':
                        # Top weights: force all to task experts if requested, else use configured pair
                        if getattr(self, 'debug_task_only', False):
                            w_top = torch.tensor([0.0, 1.0], device=img.device, dtype=img.dtype).view(1, 2).expand(img.size(0), 2)
                        else:
                            w_top_list = self.debug_seen_weights
                            w_top = torch.tensor(w_top_list, device=img.device, dtype=img.dtype).view(1, 2).expand(img.size(0), 2)
                        if getattr(self.dfa, 'e2_router_mode', 'moe_w') == 'moe_w':
                            w_spec = self.dfa._noisy_topk_gating_e2(img.float(), train=False)
                        else:
                            w_spec = torch.softmax(self.dfa.router_e2(img.float()), dim=-1)
                        fused = self.dfa(img, route=False, weights={ 'top': w_top, 'spec': w_spec })
                    elif self.gate_mode == 'ood_confidence':
                        # Top weights from seen-confidence; specific split learned by E2 router (moe-w or mlp)
                        with torch.no_grad():
                            seen_tokens = self.text_tokens if self.text_tokens is not None else texts_task
                            t_seen = self.model.encode_text(seen_tokens)
                            t_seen = t_seen / t_seen.norm(dim=-1, keepdim=True)
                            conf = torch.softmax(scale * (img @ t_seen.t()), dim=-1).max(dim=1).values  # [B]
                            w_top = torch.stack([1.0 - conf, conf], dim=1)
                        if getattr(self.dfa, 'e2_router_mode', 'moe_w') == 'moe_w':
                            w_spec = self.dfa._noisy_topk_gating_e2(img.float(), train=False)
                        else:
                            w_spec = torch.softmax(self.dfa.router_e2(img.float()), dim=-1)
                        fused = self.dfa(img, route=False, weights={ 'top': w_top, 'spec': w_spec })
                    else:
                        # Both routers trained end-to-end
                        fused = self.dfa(img, route=True)
                    fused = fused / fused.norm(dim=-1, keepdim=True)
                    logits = scale * fused @ txt.t()
                    loss_ce = F.cross_entropy(logits[:, :local_C], targets, label_smoothing=float(getattr(cfg, 'ls', 0.0)))
                    loss = loss_ce
                    # KD term (teacher on previous A stage) not use now
                    if lambda_kd > 0.0 and teacher is not None:
                        with torch.no_grad():
                            t_feat = teacher(img, route=True)
                            t_feat = t_feat / t_feat.norm(dim=-1, keepdim=True)
                            t_logits = scale * t_feat @ txt.t()
                        kd = F.kl_div(
                            F.log_softmax(logits / tau_kd, dim=-1),
                            F.softmax(t_logits / tau_kd, dim=-1),
                            reduction='batchmean'
                        ) * (tau_kd ** 2)
                        loss = loss + lambda_kd * kd
                    # Replay KD: sample small batch from buffer, KD only (no labels)，not use now
                    replay_bs = int(getattr(cfg, 'replay_batch', 0))
                    if replay_bs > 0 and self.replay_buffer:
                        try:
                            import random as _rnd
                            all_items = []
                            for cls_id, lst in self.replay_buffer.items():
                                for t in lst:
                                    all_items.append(t)
                            if all_items:
                                _rnd.shuffle(all_items)
                                rep_imgs = torch.stack(all_items[:min(replay_bs, len(all_items))], dim=0).to(self.device)
                                with torch.no_grad():
                                    rep_img0 = self.model.encode_image(rep_imgs)
                                    rep_img0 = rep_img0 / rep_img0.norm(dim=-1, keepdim=True)
                                    # use all seen text tokens up to now
                                    rep_txt_all = self.model.encode_text(self.text_tokens)
                                    rep_txt_all = rep_txt_all / rep_txt_all.norm(dim=-1, keepdim=True)
                                    rep_scale = self.model.logit_scale.exp()
                                    t_feat2 = teacher(rep_img0, route=True) if teacher is not None else rep_img0
                                    t_feat2 = t_feat2 / t_feat2.norm(dim=-1, keepdim=True)
                                    t_logits2 = rep_scale * t_feat2 @ rep_txt_all.t()
                                s_feat2 = self.dfa(rep_img0, route=True)
                                s_feat2 = s_feat2 / s_feat2.norm(dim=-1, keepdim=True)
                                s_logits2 = rep_scale * s_feat2 @ rep_txt_all.t()
                                kd_rep = F.kl_div(
                                    F.log_softmax(s_logits2 / tau_kd, dim=-1),
                                    F.softmax(t_logits2 / tau_kd, dim=-1),
                                    reduction='batchmean'
                                ) * (tau_kd ** 2)
                                loss = loss + lambda_kd * kd_rep
                        except Exception:
                            pass
                    # L2 reg on E2a/E2b
                    lambda_reg = float(getattr(cfg, 'lambda_reg', 1e-4))
                    if lambda_reg > 0.0:
                        reg = 0.0
                        if getattr(self, 'dfa_inject_mode', 'head') == 'block':
                            for m in self.dfa_blocks:
                                for e2 in getattr(m, 'e2_list', []):
                                    for p in e2.parameters(): reg = reg + p.pow(2).sum()
                        else:
                            for e2 in getattr(self.dfa, 'e2_list', []):
                                for p in e2.parameters(): reg = reg + p.pow(2).sum()
                        loss = loss + lambda_reg * reg
                    loss_con_term = torch.zeros((), device=fused.device, dtype=fused.dtype)
                    # -------- Stage B Contrastive Loss (InfoNCE with queue as negatives) --------
                    lambda_b_con = float(getattr(cfg, 'lambda_b_con', 0.0))
                    if lambda_b_con > 0.0 and local_C > 1 and hasattr(self, 'queue_img') and hasattr(self, 'queue_txt'):
                        tau_b_con = float(getattr(cfg, 'tau_b_con', getattr(cfg, 'tau_con', 0.05)))
                        qcnt_con = int(getattr(self, 'queue_count', 0))
                        B_con = fused.size(0)
                        # txt is [local_C, D], already normalized; fused is [B, D], already normalized
                        # Build candidate texts: current task texts + queue_txt (as negatives)
                        if qcnt_con > 0 and self.queue_size > 0:
                            k_valid_con = min(qcnt_con, int(self.queue_size))
                            queue_txt_neg = self.queue_txt[:, :k_valid_con].t().to(fused.device)  # [K, D]
                            queue_txt_neg = queue_txt_neg / (queue_txt_neg.norm(dim=-1, keepdim=True) + 1e-12)
                            cand_txt_b = torch.cat([txt, queue_txt_neg], dim=0)  # [local_C + K, D]
                            queue_img_neg = self.queue_img[:, :k_valid_con].t().to(fused.device)  # [K, D]
                            queue_img_neg = queue_img_neg / (queue_img_neg.norm(dim=-1, keepdim=True) + 1e-12)
                            cand_img_b = torch.cat([fused, queue_img_neg], dim=0)  # [B + K, D]
                        else:
                            cand_txt_b = txt
                            cand_img_b = fused
                        # Image->Text InfoNCE: positive is the text of sample's class (index = targets[i])
                        N_txt_b = cand_txt_b.size(0)
                        logits_i2t_b = (scale * fused @ cand_txt_b.t()) / max(1e-6, tau_b_con)  # [B, N_txt]
                        # Build positive mask: each image i has positive at index targets[i] (0..local_C-1)
                        pos_mask_i2t_b = torch.zeros(B_con, N_txt_b, dtype=torch.bool, device=fused.device)
                        for i in range(B_con):
                            pos_idx = int(targets[i].item())
                            if 0 <= pos_idx < local_C:
                                pos_mask_i2t_b[i, pos_idx] = True
                        pos_logits_i2t_b = logits_i2t_b.masked_fill(~pos_mask_i2t_b, float('-inf'))
                        numer_i2t_b = torch.logsumexp(pos_logits_i2t_b, dim=1)
                        denom_i2t_b = torch.logsumexp(logits_i2t_b, dim=1)
                        loss_i2t_b = -(numer_i2t_b - denom_i2t_b).mean()
                        # Text->Image InfoNCE: each task text's positives are all images of that class in batch
                        N_img_b = cand_img_b.size(0)
                        logits_t2i_b = (scale * txt @ cand_img_b.t()) / max(1e-6, tau_b_con)  # [local_C, N_img]
                        pos_mask_t2i_b = torch.zeros(local_C, N_img_b, dtype=torch.bool, device=fused.device)
                        for c in range(local_C):
                            match_idx = (targets == c).nonzero(as_tuple=False).squeeze(1)
                            if match_idx.numel() > 0:
                                pos_mask_t2i_b[c, match_idx] = True
                        pos_logits_t2i_b = logits_t2i_b.masked_fill(~pos_mask_t2i_b, float('-inf'))
                        numer_t2i_b = torch.logsumexp(pos_logits_t2i_b, dim=1)
                        denom_t2i_b = torch.logsumexp(logits_t2i_b, dim=1)
                        valid_rows_b = pos_mask_t2i_b.any(dim=1)
                        if valid_rows_b.any():
                            loss_t2i_b = -((numer_t2i_b[valid_rows_b] - denom_t2i_b[valid_rows_b]).mean())
                        else:
                            loss_t2i_b = torch.zeros((), device=fused.device, dtype=fused.dtype)
                        loss_b_con = (loss_i2t_b + loss_t2i_b) * 0.5
                        loss_con_term = lambda_b_con * loss_b_con
                        loss = loss + loss_con_term
                    # Queue-based feature distribution KL (old queue vs current batch), image side
                    lambda_qfd = float(getattr(cfg, 'lambda_queue_fd', 0.0))
                    if lambda_qfd > 0.0 and hasattr(self, 'queue_img') and getattr(self, 'queue_size', 0) > 0:
                        qcnt = int(getattr(self, 'queue_count', 0))
                        if qcnt > 0:
                            try:
                                samp_cap = int(max(1, getattr(cfg, 'queue_fd_sample', min(2048, qcnt))))
                                eps = float(getattr(cfg, 'queue_fd_eps', 1.0e-5))
                                k_valid = min(qcnt, int(self.queue_size))
                                old_all = self.queue_img[:, :k_valid].t().to(fused.device)  # [k_valid, D]
                                if old_all.size(0) > samp_cap:
                                    perm = torch.randperm(old_all.size(0), device=old_all.device)[:samp_cap]
                                    old = old_all.index_select(0, perm)
                                else:
                                    old = old_all
                                new = fused  # [B, D]
                                # Means/vars over batch dimension
                                mu_old = old.mean(dim=0)
                                var_old = old.var(dim=0, unbiased=False)
                                mu_new = new.mean(dim=0)
                                var_new = new.var(dim=0, unbiased=False)
                                var_old = torch.clamp(var_old, min=eps)
                                var_new = torch.clamp(var_new, min=eps)
                                # KL(N_old || N_new) diagonal
                                kl_terms = torch.log(var_new / var_old) + (var_old + (mu_old - mu_new).pow(2)) / var_new - 1.0
                                loss_qfd = 0.5 * kl_terms.sum()
                                loss = loss + lambda_qfd * loss_qfd
                            except Exception:
                                pass
                    # Optional: print Stage B loss components (CE, contrastive, total)
                    if bool(getattr(cfg, 'print_stage_b_losses', True)):
                        interval = int(getattr(cfg, 'stage_b_log_interval', 1))
                        if interval <= 1 or (it % interval == 0):
                            try:
                                print(
                                    f"[Stage B][Task {task_id}][It {it}] "
                                    f"loss_ce={loss_ce.item():.4f} "
                                    f"loss_con={loss_con_term.item():.4f} "
                                    f"loss_total={loss.item():.4f}"
                                )
                            except Exception:
                                pass
                    # Optional: text-side distribution KL vs queue_txt
                    if lambda_qfd > 0.0 and bool(getattr(cfg, 'queue_fd_on_text', False)) and hasattr(self, 'queue_txt') and getattr(self, 'queue_size', 0) > 0:
                        qcnt_t = int(getattr(self, 'queue_count', 0))
                        if qcnt_t > 0:
                            try:
                                samp_cap = int(max(1, getattr(cfg, 'queue_fd_sample', min(2048, qcnt_t))))
                                eps = float(getattr(cfg, 'queue_fd_eps', 1.0e-5))
                                k_valid = min(qcnt_t, int(self.queue_size))
                                old_all_t = self.queue_txt[:, :k_valid].t().to(self.device)  # [k_valid, D]
                                # Normalize to unit for stability
                                old_all_t = old_all_t / (old_all_t.norm(dim=-1, keepdim=True) + 1.0e-12)
                                if old_all_t.size(0) > samp_cap:
                                    perm = torch.randperm(old_all_t.size(0), device=old_all_t.device)[:samp_cap]
                                    old_t = old_all_t.index_select(0, perm)
                                else:
                                    old_t = old_all_t
                                # Current text features over seen tokens (or current task tokens as fallback)
                                seen_tokens = self.text_tokens if getattr(self, 'text_tokens', None) is not None else texts_task
                                new_t = self.model.encode_text(seen_tokens)
                                new_t = new_t / (new_t.norm(dim=-1, keepdim=True) + 1.0e-12)
                                mu_old_t = old_t.mean(dim=0)
                                var_old_t = old_t.var(dim=0, unbiased=False)
                                mu_new_t = new_t.mean(dim=0)
                                var_new_t = new_t.var(dim=0, unbiased=False)
                                var_old_t = torch.clamp(var_old_t, min=eps)
                                var_new_t = torch.clamp(var_new_t, min=eps)
                                kl_terms_t = torch.log(var_new_t / var_old_t) + (var_old_t + (mu_old_t - mu_new_t).pow(2)) / var_new_t - 1.0
                                loss_qfd_t = 0.5 * kl_terms_t.sum()
                                loss = loss + lambda_qfd * loss_qfd_t
                            except Exception:
                                pass
                scaler.scale(loss).backward()
                scaler.step(opt_b)
                scaler.update()

        # After Stage B completes, enqueue buffered MoCo features from Stage A once per task
        pre_q = int(getattr(self, 'queue_count', 0)) if hasattr(self, 'queue_size') else 0
        if (not skip_stage_a) and use_moco and self.queue_size > 0:
            if 'mq_img_buf' in locals() and len(mq_img_buf) > 0:
                try:
                    img_cat = torch.cat(mq_img_buf, dim=0)
                    txt_cat = torch.cat(mq_txt_buf, dim=0)
                    mode = str(getattr(cfg, 'moco_post_enqueue_mode', 'subsample')).lower()
                    if mode not in ('fifo', 'quota', 'subsample'):
                        mode = 'subsample'
                    is_full = bool(getattr(self, 'queue_count', 0) >= getattr(self, 'queue_size', 0))
                    if mode == 'subsample':
                        if img_cat.size(0) > self.queue_size:
                            perm = torch.randperm(img_cat.size(0))[: self.queue_size]
                            img_cat = img_cat[perm]
                            txt_cat = txt_cat[perm]
                    elif mode == 'quota':
                        quota = int(getattr(cfg, 'moco_post_enqueue_quota', max(0, self.queue_size // 4)))
                        quota = max(0, quota)
                        cur = int(getattr(self, 'queue_count', 0))
                        cap = int(getattr(self, 'queue_size', 0))
                        want = int(img_cat.size(0))
                        if cur + want > cap:
                            n_add = min(quota, want)
                        else:
                            n_add = want
                        img_cat = img_cat[:n_add]
                        txt_cat = txt_cat[:n_add]
                    # From task 1 onward, cap enqueue size to slow queue refresh
                    if int(task_id) >= 1:
                        task_quota = int(getattr(cfg, 'moco_task_enqueue_quota', 32))
                        task_quota = max(0, min(task_quota, int(self.queue_size)))
                        if task_quota > 0:
                            img_cat = img_cat[:task_quota]
                            txt_cat = txt_cat[:task_quota]
                    # fifo: keep order
                    if img_cat.numel() > 0:
                        img_cat = img_cat.to(self.device, non_blocking=True)
                        txt_cat = txt_cat.to(self.device, non_blocking=True)
                        # chunk must not exceed queue_size to avoid indexing overflow
                        chunk = min(1024, self.queue_size)
                        for s in range(0, img_cat.size(0), chunk):
                            self._dequeue_and_enqueue(img_cat[s:s+chunk], txt_cat[s:s+chunk])
                except Exception:
                    pass
        post_q = int(getattr(self, 'queue_count', 0)) if hasattr(self, 'queue_size') else pre_q
        try:
            print(f"[MoCo][Task {task_id}] Queue before: {pre_q}/{self.queue_size} | after: {post_q}/{self.queue_size} | added: {max(0, post_q - pre_q)}")
        except Exception:
            pass

    def compute_logits(self, inputs, text_tokens, is_zeroshot: bool = False):
        """Returns image-text similarity logits over provided text tokens."""
        m_prev = self.model.training
        a_prev = self.dfa.training if getattr(self, 'dfa', None) is not None else False
        b_prev = [m.training for m in getattr(self, 'dfa_blocks', [])]
        t_prev = [m.training for m in getattr(self, 'text_dfa_blocks', [])] if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks else []
        self.model.eval()
        if getattr(self, 'dfa', None) is not None:
            self.dfa.eval()
        for m in getattr(self, 'dfa_blocks', []):
            m.eval()
        if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
            for m in self.text_dfa_blocks:
                m.eval()
        with torch.no_grad():
            img = self.model.encode_image(inputs)
            img = img / img.norm(dim=-1, keepdim=True)
            # Determine fusion according to injection mode and gate mode
            if getattr(self, 'dfa_inject_mode', 'head') == 'block':
                # Per-block adapters and routers apply inside encode_image.
                # In debug mode, override per-block top router weights before forward, then clear.
                if self.gate_mode == 'debug':
                    if getattr(self, 'debug_task_only', False):
                        w_pair = [0.0, 1.0]
                    else:
                        if is_zeroshot and getattr(self, 'debug_zs_weights_matrix', None) is not None:
                            idx = int(getattr(self, 'zs_dataset_index_for_debug', -1))
                            if 0 <= idx < self.debug_zs_weights_matrix.size(0):
                                row = self.debug_zs_weights_matrix[idx].to(device=img.device, dtype=img.dtype)
                                w_pair = [float(row[0].item()), float(row[1].item())]
                            else:
                                w_pair = getattr(self, 'debug_zs_weights', [0.7, 0.3])
                        else:
                            w_pair = getattr(self, 'debug_seen_weights', [0.7, 0.3])
                    w_top_vec = torch.tensor(w_pair, device=img.device, dtype=torch.float32)
                    for m in getattr(self, 'dfa_blocks', []):
                        m.debug_w_top = w_top_vec
                    img = self.model.encode_image(inputs)
                    img = img / img.norm(dim=-1, keepdim=True)
                    for m in getattr(self, 'dfa_blocks', []):
                        if hasattr(m, 'debug_w_top'):
                            m.debug_w_top = None
                else:
                    img = self.model.encode_image(inputs)
                    img = img / img.norm(dim=-1, keepdim=True)
                fused = img
            elif getattr(self, 'dfa', None) is not None and getattr(self.dfa, 'single_router_3way', False):
                fused = self.dfa(img, route=True)
            elif self.gate_mode == 'debug':
                # Top weights: force all to task experts if requested, else use configured pairs
                if getattr(self, 'debug_task_only', False):
                    w_top = torch.tensor([0.0, 1.0], device=img.device, dtype=img.dtype).view(1, 2).expand(img.size(0), 2)
                else:
                    if is_zeroshot and getattr(self, 'debug_zs_weights_matrix', None) is not None:
                        idx = int(getattr(self, 'zs_dataset_index_for_debug', -1))
                        if 0 <= idx < self.debug_zs_weights_matrix.size(0):
                            row = self.debug_zs_weights_matrix[idx].to(device=img.device, dtype=img.dtype)
                            w_top = row.view(1, 2).expand(img.size(0), 2)
                        else:
                            w_list = getattr(self, 'debug_zs_weights', [0.7, 0.3])
                            w_top = torch.tensor(w_list, device=img.device, dtype=img.dtype).view(1, 2).expand(img.size(0), 2)
                    else:
                        w_list = self.debug_seen_weights
                        w_top = torch.tensor(w_list, device=img.device, dtype=img.dtype).view(1, 2).expand(img.size(0), 2)
                if getattr(self.dfa, 'e2_router_mode', 'moe_w') == 'moe_w':
                    w_spec = self.dfa._noisy_topk_gating_e2(img.float(), train=False)
                else:
                    w_spec = torch.softmax(self.dfa.router_e2(img.float()), dim=-1)
                fused = self.dfa(img, route=False, weights={ 'top': w_top, 'spec': w_spec })
            elif self.gate_mode == 'ood_confidence':
                seen_tokens = self.text_tokens if self.text_tokens is not None else text_tokens
                t_seen = self.model.encode_text(seen_tokens)
                t_seen = t_seen / t_seen.norm(dim=-1, keepdim=True)
                scale = self.model.logit_scale.exp()
                conf = torch.softmax(scale * (img @ t_seen.t()), dim=-1).max(dim=1).values
                w_top = torch.stack([1.0 - conf, conf], dim=1)
                if getattr(self.dfa, 'e2_router_mode', 'moe_w') == 'moe_w':
                    w_spec = self.dfa._noisy_topk_gating_e2(img.float(), train=False)
                else:
                    w_spec = torch.softmax(self.dfa.router_e2(img.float()), dim=-1)
                fused = self.dfa(img, route=False, weights={ 'top': w_top, 'spec': w_spec })
            else:
                fused = self.dfa(img, route=True)
            fused = fused / fused.norm(dim=-1, keepdim=True)
            # Text-side debug override for block mode
            if getattr(self, 'dfa_inject_mode', 'head') == 'block' and self.gate_mode == 'debug':
                if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                    if getattr(self, 'debug_task_only', False):
                        w_pair_txt = [0.0, 1.0]
                    else:
                        if is_zeroshot and getattr(self, 'debug_zs_weights_matrix', None) is not None:
                            idx = int(getattr(self, 'zs_dataset_index_for_debug', -1))
                            if 0 <= idx < self.debug_zs_weights_matrix.size(0):
                                row = self.debug_zs_weights_matrix[idx].to(device=img.device, dtype=img.dtype)
                                w_pair_txt = [float(row[0].item()), float(row[1].item())]
                            else:
                                w_pair_txt = getattr(self, 'debug_zs_weights', [0.7, 0.3])
                        else:
                            w_pair_txt = getattr(self, 'debug_seen_weights', [0.7, 0.3])
                    w_top_vec_txt = torch.tensor(w_pair_txt, device=img.device, dtype=torch.float32)
                    for m in self.text_dfa_blocks:
                        m.debug_w_top = w_top_vec_txt
            txt = self.model.encode_text(text_tokens)
            txt = txt / txt.norm(dim=-1, keepdim=True)
            scale = self.model.logit_scale.exp()
            logits = scale * fused @ txt.t()
            # Clear text-side debug overrides
            if getattr(self, 'dfa_inject_mode', 'head') == 'block' and self.gate_mode == 'debug':
                if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
                    for m in self.text_dfa_blocks:
                        if hasattr(m, 'debug_w_top'):
                            m.debug_w_top = None
        if hasattr(self, 'text_dfa_blocks') and self.text_dfa_blocks:
            for m, s in zip(self.text_dfa_blocks, t_prev):
                m.train(s)
        for m, s in zip(getattr(self, 'dfa_blocks', []), b_prev):
            m.train(s)
        if getattr(self, 'dfa', None) is not None:
            self.dfa.train(a_prev)
        self.model.train(m_prev)
        return logits

    @torch.no_grad()
    def get_router_weights(self, inputs: torch.Tensor, is_zeroshot: bool = False) -> torch.Tensor:
        """Return router soft weights [B,2] for a batch of inputs.
        
        Args:
            inputs: image batch
            is_zeroshot: if True and debug_mode, return debug_zs_weights; else debug_seen_weights
        """
        self.dfa.eval()
        img = self.model.encode_image(inputs)
        img = img / img.norm(dim=-1, keepdim=True)
        if self.gate_mode == 'debug':
            b = img.size(0)
            if is_zeroshot and getattr(self, 'debug_zs_weights_matrix', None) is not None:
                idx = int(getattr(self, 'zs_dataset_index_for_debug', -1))
                if 0 <= idx < self.debug_zs_weights_matrix.size(0):
                    row = self.debug_zs_weights_matrix[idx].to(device=img.device, dtype=img.dtype)
                    w = row.view(1, 2).expand(b, 2)
                    return w
                # fallback to single pair
                w_list = getattr(self, 'debug_zs_weights', [0.7, 0.3])
                w = torch.tensor(w_list, device=img.device, dtype=img.dtype).view(1, 2).expand(b, 2)
                return w
            else:
                w_list = self.debug_seen_weights
                w = torch.tensor(w_list, device=img.device, dtype=img.dtype).view(1, 2).expand(b, 2)
                return w
        elif self.gate_mode == 'ood_confidence':
            seen_tokens = self.text_tokens
            if seen_tokens is None:
                # fallback equal weights if seen tokens are unavailable
                b = img.size(0)
                return torch.full((b, 2), 0.5, device=img.device, dtype=img.dtype)
            t_seen = self.model.encode_text(seen_tokens)
            t_seen = t_seen / t_seen.norm(dim=-1, keepdim=True)
            scale = self.model.logit_scale.exp()
            conf = torch.softmax(scale * (img @ t_seen.t()), dim=-1).max(dim=1).values
            w = torch.stack([1.0 - conf, conf], dim=1)
            return w
        else:
            w = self.dfa.router(img.float())
            w = torch.softmax(w, dim=-1)
            return w

    @torch.no_grad()
    def get_block_router_weights(self, inputs: torch.Tensor, is_zeroshot: bool = False) -> dict:
        """Return per-block router weights for block mode.
        
        Args:
            inputs: image batch [B, C, H, W]
            is_zeroshot: if True and debug mode, use debug_zs_weights; else use debug_seen_weights
            
        Returns:
            dict with keys:
                'w_top': list of [B, 2] tensors for each block's router_top (E1 vs E2)
                'w_spec': list of [B, E] tensors for each block's E2 moe-w router
        """
        if self.dfa_inject_mode != 'block' or not self.dfa_blocks:
            return {}
        
        # Set all blocks to eval mode
        for blk in self.dfa_blocks:
            blk.eval()
        
        # Debug mode: set debug_w_top for all blocks before forward
        if self.gate_mode == 'debug':
            if is_zeroshot:
                # Check for per-dataset ZS weights matrix
                if getattr(self, 'debug_zs_weights_matrix', None) is not None:
                    idx = int(getattr(self, 'zs_dataset_index_for_debug', -1))
                    if 0 <= idx < self.debug_zs_weights_matrix.size(0):
                        w_pair = self.debug_zs_weights_matrix[idx].tolist()
                    else:
                        w_pair = getattr(self, 'debug_zs_weights', [0.7, 0.3])
                else:
                    w_pair = getattr(self, 'debug_zs_weights', [0.7, 0.3])
            else:
                if getattr(self, 'debug_task_only', False):
                    w_pair = [0.0, 1.0]
                else:
                    w_pair = getattr(self, 'debug_seen_weights', [0.3, 0.7])
            w_top_vec = torch.tensor(w_pair, device=inputs.device, dtype=torch.float32)
            for blk in self.dfa_blocks:
                blk.debug_w_top = w_top_vec
        else:
            # Clear debug_w_top so blocks use their learned routers
            for blk in self.dfa_blocks:
                blk.debug_w_top = None
        
        # Forward pass to populate _last_w_top and _last_w_spec in each block
        _ = self.model.encode_image(inputs)
        
        # Collect weights from each block
        w_top_list = []
        w_spec_list = []
        for blk in self.dfa_blocks:
            if hasattr(blk, '_last_w_top') and blk._last_w_top is not None:
                w_top_list.append(blk._last_w_top.cpu())
            else:
                w_top_list.append(None)
            if hasattr(blk, '_last_w_spec') and blk._last_w_spec is not None:
                w_spec_list.append(blk._last_w_spec.cpu())
            else:
                w_spec_list.append(None)
        
        return {
            'w_top': w_top_list,    # list of [B, 2] for 12 blocks
            'w_spec': w_spec_list,  # list of [B, E] for 12 blocks
        }

    @torch.no_grad()
    def _dequeue_and_enqueue(self, img_feats: torch.Tensor, txt_feats: torch.Tensor):
        """Update momentum queue with new features (FIFO)."""
        if not hasattr(self, 'queue_count'):
            try:
                self.queue_count = 0
            except Exception:
                self.queue_count = 0
        batch_size = img_feats.size(0)
        ptr = int(self.queue_ptr)
        # Replace oldest entries
        if ptr + batch_size <= self.queue_size:
            self.queue_img[:, ptr:ptr + batch_size] = img_feats.t()
            self.queue_txt[:, ptr:ptr + batch_size] = txt_feats.t()
            ptr = (ptr + batch_size) % self.queue_size
        else:
            # Wrap around
            remain = self.queue_size - ptr
            self.queue_img[:, ptr:] = img_feats[:remain].t()
            self.queue_txt[:, ptr:] = txt_feats[:remain].t()
            overflow = batch_size - remain
            if overflow > 0:
                self.queue_img[:, :overflow] = img_feats[remain:].t()
                self.queue_txt[:, :overflow] = txt_feats[remain:].t()
            ptr = overflow
        self.queue_ptr[0] = ptr
        # Track how many valid entries the queue currently holds
        try:
            self.queue_count = min(self.queue_size, int(self.queue_count) + batch_size)
        except Exception:
            self.queue_count = self.queue_size

    def _inject_dfa_into_visual_blocks(self, cfg):
        """Replace visual transformer resblocks with wrappers that append DFABlock residual, mimicking moe-w insertion."""
        visual = getattr(self.model, 'visual', None)
        transformer = getattr(visual, 'transformer', None)
        resblocks = getattr(transformer, 'resblocks', None)
        if visual is None or transformer is None or resblocks is None:
            return
        # Build per-block DFABlock and wrap (FFN-segment injection, moe-w aligned)
        new_blocks = []
        self.dfa_blocks = []
        r1 = int(getattr(cfg, 'dfa_r1', 16))
        r2 = int(getattr(cfg, 'dfa_r2', 64))
        drop = float(getattr(cfg, 'adapter_dropout', 0.1))
        scalar = float(getattr(cfg, 'adapter_scalar', 0.1))
        top_k = int(getattr(cfg, 'e2_top_k', 2))
        router_hidden = int(getattr(cfg, 'dfa_router_hidden', 0))
        for b in resblocks:
            # infer width from ln_1
            width = None
            try:
                width = int(getattr(b, 'ln_1').weight.shape[0])
            except Exception:
                width = int(getattr(visual, 'width', 768))
            sr3 = bool(getattr(cfg, 'block_single_router_3way', False))
            dfab = DFA3Block(
                width,
                r1=r1,
                r2=r2,
                adapter_dropout=drop,
                adapter_scalar=scalar,
                e2_top_k=top_k,
                router_hidden=router_hidden,
                single_router_3way=sr3,
                num_task_experts=int(getattr(cfg, 'num_task_experts', 2)),
            ).to(self.device)
            wrap = ResidualAttentionBlockWithDFA(b, dfab)
            new_blocks.append(wrap)
            self.dfa_blocks.append(dfab)
        # Replace in-place: use nn.Sequential because encode_image() calls self.resblocks(x)
        transformer.resblocks = nn.Sequential(*new_blocks)

    def _inject_dfa_into_text_blocks(self, cfg):
        """Optionally wrap text transformer resblocks with DFABlock residuals (A1).
        If cfg.text_dfa_full is True, wrap all blocks; else wrap the last K blocks where K=text_dfa_last_k.
        """
        k = int(getattr(cfg, 'text_dfa_last_k', 0))
        full = bool(getattr(cfg, 'text_dfa_full', False))
        text_tf = getattr(self.model, 'transformer', None)
        resblocks = getattr(text_tf, 'resblocks', None)
        if text_tf is None or resblocks is None:
            return
        # Determine dims and hyperparams
        r1 = int(getattr(cfg, 'text_dfa_r1', getattr(cfg, 'dfa_r1', 16)))
        r2 = int(getattr(cfg, 'text_dfa_r2', 8))
        drop = float(getattr(cfg, 'text_adapter_dropout', getattr(cfg, 'adapter_dropout', 0.1)))
        scalar = float(getattr(cfg, 'text_adapter_scalar', getattr(cfg, 'adapter_scalar', 0.1)))
        top_k = int(getattr(cfg, 'text_e2_top_k', getattr(cfg, 'e2_top_k', 2)))
        router_hidden = int(getattr(cfg, 'text_dfa_router_hidden', getattr(cfg, 'dfa_router_hidden', 0)))

        # Wrap last K blocks (or all, when text_dfa_full=True) with FFN-segment injection
        new_blocks = []
        self.text_dfa_blocks = []
        blocks = list(resblocks)
        n = len(blocks)
        if not full and k <= 0:
            # nothing to inject
            text_tf.resblocks = nn.Sequential(*blocks)
            return
        start_idx = 0 if full else max(0, n - k)
        for i, b in enumerate(blocks):
            if i < start_idx:
                new_blocks.append(b)
                continue
            # infer width from ln_1
            try:
                width = int(getattr(b, 'ln_1').weight.shape[0])
            except Exception:
                # Fallback to model width if available
                width = int(getattr(text_tf, 'width', 512))
            sr3_t = bool(getattr(cfg, 'text_block_single_router_3way', getattr(cfg, 'block_single_router_3way', False)))
            dfab = DFA3Block(
                width,
                r1=r1,
                r2=r2,
                adapter_dropout=drop,
                adapter_scalar=scalar,
                e2_top_k=top_k,
                router_hidden=router_hidden,
                single_router_3way=sr3_t,
                num_task_experts=int(getattr(cfg, 'num_task_experts', 2)),
            ).to(self.device)
            wrap = ResidualAttentionBlockWithDFA(b, dfab)
            new_blocks.append(wrap)
            self.text_dfa_blocks.append(dfab)
        text_tf.resblocks = nn.Sequential(*new_blocks)

class DomainIncremental(nn.Module):
    pass


class TaskAgnostic(nn.Module):
    pass


def load_model(cfg: DictConfig, device: torch.device) -> nn.Module:
    r"""Load a CLIP model in different continual scenarios.

    Arguments:
        cfg (DictConfig): Experiment configurations.
        device (torch.device): Device to train (or) evaluate the model on.

    Returns:
        nn.Module: Return scenario specific CLIP model.
    """
    if cfg.scenario == "class":
        return ClassIncremental(cfg, device)
    elif cfg.scenario == "domain":
        return DomainIncremental(cfg, device)
    elif cfg.scenario == "task-aganostic":
        return TaskAgnostic(cfg, device)
    else:
        raise ValueError(f"""
            `{cfg.scenarios}` is not a valid scenario, 
            Please choose from ['class', "domain', 'task-agnostic']
        """)
