import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault(cuda_visible_devices := "CUDA_VISIBLE_DEVICES", "5")
import math
import json
import hydra
import logging
try:
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
except Exception:
    plt = None
from omegaconf import DictConfig, ListConfig

from tqdm import tqdm

import torch
import numpy as np
import statistics
from torch.utils.data import DataLoader, ConcatDataset
from continuum.metrics import Logger

import clip.clip as clip
from mtil_datasets import get_dataset as get_mtil_dataset
from continual_clip.clip_original import clip as clip_orig
from MTIL_datasets.voc2007 import VOC2007 as MTILVOC2007
from PIL import Image

from continual_clip import utils
from continual_clip.models import load_model
from continual_clip.datasets import build_cl_scenarios, get_dataset

MTIL_INDEX_TO_NAME = {
    0: 'FGVCAircraft',
    1: 'Caltech101',
    2: 'CIFAR100',
    3: 'DescribableTextures',
    4: 'EuroSAT',
    5: 'OxfordFlowers',
    6: 'Food101',
    7: 'MNIST',
    8: 'OxfordPets',
    9: 'StanfordCars',
    10: 'SUN397',
    11: 'Country211',
    12: 'SST2',
    13: 'HatefulMemes',
    14: 'GTSRB',
    15: 'RESISC45',
    16: 'FER2013',
    17: 'UCF101',
    18: 'CIFAR10',
    19: 'STL10',
    20: 'VOC2007',
    21: 'ImageNetR',
    22: 'KittiDistance',
    23: 'PCam',
    24: 'CLEVRCount',
}

# Map index to dataset key used in our get_dataset() for ALL 25 MTIL datasets (0..24)
TRAIN_INDEX_TO_DATASET_KEY = {
    0: 'aircraft',
    1: 'caltech101',
    2: 'cifar100',
    3: 'dtd',
    4: 'eurosat',
    5: 'oxford_flowers',
    6: 'food101',
    7: 'mnist',
    8: 'oxford_pets',
    9: 'stanford_cars',
    10: 'sun397',
    11: 'country211',
    12: 'sst2',
    13: 'hatefulmemes',
    14: 'gtsrb',
    15: 'resisc45',
    16: 'fer2013',
    17: 'ucf101',
    18: 'cifar10',
    19: 'stl10',
    20: 'voc2007',
    21: 'imagenet_r',
    22: 'kitti_distance',
    23: 'pcam',
    24: 'clevr_count',
}

def evaluate_zero_shot(model, device, cfg, limit_datasets=None, use_original_clip=False):
    """Evaluate current model zero-shot on MTIL datasets excluding the downstream training dataset.

    Returns a dict of {dataset_name: acc} with accuracies in percentage.
    """
    # Build a minimal cfg compatible with mtil_datasets
    class _ZSCfg:
        pass
    zs_cfg = _ZSCfg()
    zs_cfg.dataset = 'MTIL'
    zs_cfg.dataset_root = cfg.dataset_root
    zs_cfg.seed = getattr(cfg, 'seed', 1)
    zs_cfg.use_validation = getattr(cfg, 'use_validation', False)
    zs_cfg.MTIL_order_2 = getattr(cfg, 'MTIL_order_2', False)
    # Always load all MTIL datasets for zero-shot; we'll skip the downstream dataset by name
    zs_cfg.train_one_dataset = -1

    # Choose backbone/tokenizer/transforms
    orig_model = None
    tokenizer = clip.tokenize
    zs_transforms = getattr(model, 'transforms', None)
    if use_original_clip:
        try:
            orig_model, _, zs_transforms = clip_orig.load(cfg.model_name, device=device, jit=False)
            orig_model.eval()
            tokenizer = clip_orig.tokenize
        except Exception as e:
            logging.error(f"Failed to load original CLIP for pre-task ZS: {e}")
            return {}

    try:
        zs_datasets, zs_classnames, zs_templates, zs_names = get_mtil_dataset(
            zs_cfg, split='test', transforms=zs_transforms
        )
    except Exception as e:
        logging.error(f"Zero-shot dataset loading failed: {e}")
        return {}

    # Optional filtering of zero-shot datasets via cfg.zs_mtil_indices (indices) or cfg.zero_shot_datasets (names)
    zs_filter = getattr(cfg, 'zero_shot_datasets', None)
    if isinstance(zs_filter, str):
        zs_filter = [s.strip() for s in zs_filter.split(',') if s.strip()]
    # Parse index-based filter
    def _parse_list(val):
        if isinstance(val, ListConfig):
            return [int(v) for v in val]
        if isinstance(val, (list, tuple)):
            return [int(v) for v in val]
        if isinstance(val, str):
            parts = [p.strip() for p in val.replace(';', ',').split(',') if p.strip()]
            return [int(p) for p in parts]
        if isinstance(val, (int,)):
            return [int(val)]
        return []
    zs_indices = _parse_list(getattr(cfg, 'zs_mtil_indices', []))
    allowed_by_indices = {MTIL_INDEX_TO_NAME[i] for i in zs_indices if i in MTIL_INDEX_TO_NAME}

    max_zs_samples = int(getattr(cfg, 'max_zs_samples', -1))  # -1 means all
    zs_bs = int(getattr(cfg, 'zs_batch_size', 32))
    num_workers = int(getattr(cfg, 'num_workers', 4))
    pin_memory = device.type == 'cuda'
    amp_zs = bool(getattr(cfg, 'eval_use_amp', True) and device.type == 'cuda')

    # Materialize list and apply filters and optional limit
    datasets_info = list(zip(zs_datasets, zs_classnames, zs_templates, zs_names))
    filtered = []
    # Determine which dataset(s) to skip for zero-shot: all selected downstream datasets
    def _parse_list(val):
        if isinstance(val, ListConfig):
            return [int(v) for v in val]
        if isinstance(val, (list, tuple)):
            return [int(v) for v in val]
        if isinstance(val, str):
            parts = [p.strip() for p in val.replace(';', ',').split(',') if p.strip()]
            return [int(p) for p in parts]
        if isinstance(val, (int,)):
            return [int(val)]
        return []
    train_indices = _parse_list(getattr(cfg, 'train_dataset', []))
    # Backward compatibility: if train_one_dataset is provided
    if not train_indices:
        toi = int(getattr(cfg, 'train_one_dataset', -1))
        if toi >= 0:
            train_indices = [toi]
    skip_names = {MTIL_INDEX_TO_NAME.get(i, 'StanfordCars') for i in train_indices}
    # If user didn't specify zs_mtil_indices, evaluate all remaining datasets by default
    if not allowed_by_indices:
        all_names = {name for (_, _, _, name) in datasets_info}
        allowed_by_indices = all_names - skip_names
    for ds, classnames, templates, name in datasets_info:
        if name in skip_names:
            continue  # training datasets; skip here
        if zs_filter and name not in zs_filter:
            continue
        if allowed_by_indices and name not in allowed_by_indices:
            continue
        filtered.append((ds, classnames, templates, name))
    # Respect explicit limit only if an allowlist wasn't set (we set allowlist to all remaining by default)
    if (not allowed_by_indices) and isinstance(limit_datasets, int) and limit_datasets > 0:
        filtered = filtered[:limit_datasets]

    # Map dataset names to MTIL indices for per-dataset debug weights
    name_to_idx = {v: k for k, v in MTIL_INDEX_TO_NAME.items()}
    # Aggregate zero-shot dataset statistics across all filtered datasets
    zs_total_samples_all = 0
    zs_total_classes_all = 0
    zs_total_eval_samples_all = 0

    results = {}
    for ds, classnames, templates, name in filtered:
        # Per-dataset zero-shot stats: total samples (test split) and num classes
        try:
            ds_total_samples = int(len(ds))
        except Exception:
            ds_total_samples = 0
        try:
            ds_num_classes = int(len(classnames))
        except Exception:
            ds_num_classes = 0
        zs_total_samples_all += ds_total_samples
        zs_total_classes_all += ds_num_classes
        # For debug gate mode, set per-dataset zs index on model
        try:
            if getattr(model, 'gate_mode', '') == 'debug':
                setattr(model, 'zs_dataset_index_for_debug', int(name_to_idx.get(name, -1)))
        except Exception:
            pass

        # choose a template
        tmpl = None
        if isinstance(templates, (list, tuple)) and len(templates) > 0:
            tmpl = templates[0]

        def render(c):
            if callable(tmpl):
                try:
                    return tmpl(c)
                except Exception:
                    return f"a photo of a {c}."
            if isinstance(tmpl, str):
                try:
                    return tmpl.format(c)
                except Exception:
                    return f"a photo of a {c}."
            # fallback to global prompt template if available
            try:
                return cfg.prompt_template.format(c)
            except Exception:
                return f"a photo of a {c}."

        # Print template info for zero-shot
        try:
            if isinstance(tmpl, str):
                desc = tmpl
            elif callable(tmpl):
                desc = f"callable:{getattr(tmpl, '__name__', str(tmpl))}"
            else:
                desc = f"fallback:{getattr(cfg, 'prompt_template', 'a photo of a {}.')}"
            sample_cls = classnames[0] if isinstance(classnames, (list, tuple)) and len(classnames) > 0 else 'object'
            example = render(sample_cls)
            logging.info(f"[ZS][Template] dataset={name} | tmpl={desc} | example='{example}'")
        except Exception:
            pass

        prompts = [render(c) for c in classnames]
        try:
            text_tokens = tokenizer(prompts).to(device)
        except Exception as e:
            logging.error(
                f"Tokenization failed for {name}: {e}. Prompts sample: "
                f"{prompts[:3] if len(prompts) > 3 else prompts}"
            )
            continue

        logging.info(f"ZS: {name} | classes={len(classnames)} | images={len(ds)}")
        # Collate: for VOC2007 keep multi-hot vectors; otherwise convert to scalar class id
        def _zs_collate(batch):
            xs = []
            ys = []
            if name == 'VOC2007':
                for xi, yi in batch:
                    xs.append(xi)
                    if isinstance(yi, torch.Tensor):
                        yv = yi.detach().cpu().numpy()
                    elif isinstance(yi, (list, tuple, np.ndarray)):
                        yv = np.asarray(yi)
                    else:
                        # if a scalar sneaks in, one-hot it
                        vec = np.zeros(len(classnames), dtype=np.int64)
                        try:
                            vec[int(yi)] = 1
                        except Exception:
                            pass
                        yv = vec
                    yv = np.asarray(yv).astype(np.int64).reshape(-1)
                    if len(yv) != len(classnames):
                        # try to coerce to the right length
                        vec = np.zeros(len(classnames), dtype=np.int64)
                        try:
                            vec[int(np.argmax(yv))] = 1
                        except Exception:
                            pass
                        yv = vec
                    ys.append(torch.tensor(yv, dtype=torch.long))
                x_batch = torch.stack(xs, dim=0)
                y_batch = torch.stack(ys, dim=0)  # [B, C] multi-hot
                return x_batch, y_batch
            else:
                for xi, yi in batch:
                    xs.append(xi)
                    # Per-sample label to scalar class id via argmax if vector-like
                    if isinstance(yi, torch.Tensor):
                        arr = yi.detach().cpu().numpy()
                    elif isinstance(yi, (list, tuple, np.ndarray)):
                        arr = np.asarray(yi)
                    else:
                        arr = yi
                    if isinstance(arr, (list, tuple, np.ndarray)):
                        arr = np.asarray(arr)
                        if arr.ndim == 0:
                            yi_scalar = int(arr.item())
                        else:
                            yi_scalar = int(arr.argmax())
                    else:
                        yi_scalar = int(arr)
                    ys.append(yi_scalar)
                x_batch = torch.stack(xs, dim=0)
                y_batch = torch.tensor(ys, dtype=torch.long)
                return x_batch, y_batch
        loader = DataLoader(ds, batch_size=zs_bs, num_workers=num_workers, pin_memory=pin_memory, collate_fn=_zs_collate)
        correct = 0
        total = 0
        processed = 0
        # For VOC2007: accumulate multi-label scores and targets
        voc_y_true = []
        voc_y_score = []
        with torch.inference_mode(), torch.cuda.amp.autocast(enabled=amp_zs):
            # Precompute text features for original CLIP once per dataset
            if use_original_clip and orig_model is not None:
                text_features = orig_model.encode_text(text_tokens)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            for x, y in tqdm(loader, desc=f"ZS {name}", leave=False):
                x = x.to(device, non_blocking=True)
                # Labels: for VOC2007 keep [B, C] on CPU for mAP; others -> [B] on device
                if name == 'VOC2007':
                    # Ensure shape [B, C]
                    if isinstance(y, torch.Tensor):
                        y_vec = y
                    else:
                        y_vec = torch.as_tensor(y)
                    if y_vec.ndim == 1 and y_vec.numel() == len(classnames):
                        y_vec = y_vec.view(1, -1)
                else:
                    # Robust label handling: convert to 1D Long tensor on device
                    def _to_label_tensor(y_any):
                        if isinstance(y_any, torch.Tensor):
                            if y_any.ndim > 1:
                                y_any = y_any.argmax(dim=1)
                            return y_any.to(device, non_blocking=True).long()
                        if isinstance(y_any, (list, tuple)):
                            proc = []
                            for elem in y_any:
                                if isinstance(elem, torch.Tensor):
                                    if elem.ndim == 0:
                                        proc.append(int(elem.item()))
                                    else:
                                        proc.append(int(elem.detach().cpu().numpy().argmax()))
                                elif isinstance(elem, (list, tuple, np.ndarray)):
                                    arr = np.asarray(elem)
                                    if arr.ndim == 0:
                                        proc.append(int(arr.item()))
                                    else:
                                        proc.append(int(arr.argmax()))
                                else:
                                    proc.append(int(elem))
                            return torch.tensor(proc, device=device, dtype=torch.long)
                        try:
                            return torch.tensor([int(y_any)], device=device, dtype=torch.long)
                        except Exception:
                            return torch.tensor(y_any, device=device, dtype=torch.long)
                    y = _to_label_tensor(y)
                    bsz_now = x.size(0)
                    if y.ndim == 1 and y.size(0) != bsz_now:
                        if y.size(0) == len(classnames):
                            y = y.argmax(dim=0).reshape(1).to(device).long()
                        elif (y.numel() % max(1, len(classnames))) == 0 and len(classnames) > 0:
                            try:
                                y = y.view(-1, len(classnames)).argmax(dim=1).to(device).long()
                            except Exception:
                                pass
                        if y.ndim == 1 and y.size(0) != bsz_now:
                            if y.numel() == 1:
                                y = y.view(1).repeat(bsz_now).to(device)
                            else:
                                y = y[:bsz_now].to(device)
                    if not (isinstance(y, torch.Tensor) and y.ndim == 1 and y.size(0) == bsz_now):
                        y = torch.as_tensor(y, device=device)
                        y = y.view(-1)
                        if len(classnames) > 0 and y.numel() == len(classnames):
                            y = y.argmax().view(1).repeat(bsz_now)
                        elif len(classnames) > 0 and y.numel() == bsz_now * len(classnames):
                            y = y.view(bsz_now, len(classnames)).argmax(dim=1)
                        elif y.numel() == 1:
                            y = y.view(1).repeat(bsz_now)
                        elif y.numel() > bsz_now:
                            y = y[:bsz_now]
                        else:
                            pad_val = int(y[0].item()) if y.numel() > 0 else 0
                            y = torch.nn.functional.pad(y.long(), (0, bsz_now - y.numel()), value=pad_val)
                        y = y.long()

                if use_original_clip and orig_model is not None:
                    image_features = orig_model.encode_image(x)
                    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                    logit_scale = getattr(orig_model, 'logit_scale', None)
                    if logit_scale is not None and hasattr(logit_scale, 'exp'):
                        scale = logit_scale.exp()
                    else:
                        scale = 1.0
                    logits = scale * image_features @ text_features.t()
                else:
                    # Prefer unified DFA path when available
                    if hasattr(model, 'compute_logits') and callable(getattr(model, 'compute_logits')):
                        logits = model.compute_logits(x, text_tokens, is_zeroshot=True)
                    else:
                        logits, _ = model.model(x, text_tokens, 0, is_train=False)

                if name == 'VOC2007':
                    # accumulate scores and ground truth for mAP computation
                    voc_y_score.append(logits.detach().cpu())
                    voc_y_true.append(y_vec.detach().cpu())
                    processed += x.size(0)
                    if max_zs_samples > 0 and processed >= max_zs_samples:
                        break
                    continue

                pred = logits.argmax(dim=1)
                correct += (pred == y).sum().item()
                bsz = y.size(0)
                total += bsz
                processed += bsz
                if max_zs_samples > 0 and processed >= max_zs_samples:
                    break
        # free per-dataset ZS tensors to avoid VRAM growth
        del text_tokens
        if use_original_clip and orig_model is not None:
            try:
                del text_features
            except Exception:
                pass
        torch.cuda.empty_cache()
        if name == 'VOC2007':
            # Compute 11-point mAP across classes
            def _ap11(y_true_cls: np.ndarray, y_score_cls: np.ndarray) -> float:
                # Sort by score desc
                order = np.argsort(-y_score_cls)
                y_true_sorted = y_true_cls[order]
                tp = (y_true_sorted == 1).astype(np.float32)
                fp = (y_true_sorted == 0).astype(np.float32)
                tp_cum = np.cumsum(tp)
                fp_cum = np.cumsum(fp)
                # avoid div by zero
                prec = tp_cum / np.maximum(tp_cum + fp_cum, 1e-12)
                # recall relative to total positives
                total_pos = max(1.0, float((y_true_cls == 1).sum()))
                rec = tp_cum / total_pos
                ap = 0.0
                for r in np.linspace(0.0, 1.0, 11):
                    mask = rec >= r
                    p_interp = np.max(prec[mask]) if np.any(mask) else 0.0
                    ap += p_interp
                return ap / 11.0

            if voc_y_true and voc_y_score:
                y_true_all = torch.cat(voc_y_true, dim=0).numpy()
                y_score_all = torch.cat(voc_y_score, dim=0).numpy()
                aps = []
                for ci in range(y_true_all.shape[1]):
                    aps.append(_ap11(y_true_all[:, ci].astype(np.int64), y_score_all[:, ci].astype(np.float32)))
                mAP = float(np.mean(aps)) if aps else 0.0
                results[name] = round(100.0 * mAP, 2)
            else:
                results[name] = 0.0
        else:
            acc = 100.0 * correct / total if total > 0 else 0.0
            results[name] = round(acc, 2)

        # Accumulate evaluated samples for this zero-shot dataset
        try:
            zs_total_eval_samples_all += int(processed)
        except Exception:
            pass

    # Zero-shot aggregate summary (only for main evaluation, not pre-task original CLIP with limit)
    if (not use_original_clip) and (not isinstance(limit_datasets, int) or limit_datasets <= 0) and zs_total_samples_all > 0:
        print(f"[ZS][Summary] datasets={len(filtered)} | total_samples={zs_total_samples_all} | total_classes={zs_total_classes_all} | eval_samples={zs_total_eval_samples_all}")

    return results


class TaskIdOffsetDataset(torch.utils.data.Dataset):
    """Wrap a dataset that yields (x, y, task_id) and override task_id with the provided global offset."""
    def __init__(self, ds, offset: int):
        self.ds = ds
        self.offset = int(offset)
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        x, y, t = self.ds[idx]
        # Override local task id with the global task id (offset) as a plain int
        return x, y, int(self.offset)


def _parse_int_list(val):
    if isinstance(val, ListConfig):
        return [int(v) for v in val]
    if isinstance(val, (list, tuple)):
        return [int(v) for v in val]
    if isinstance(val, str):
        parts = [p.strip() for p in val.replace(';', ',').split(',') if p.strip()]
        return [int(p) for p in parts]
    if isinstance(val, (int,)):
        return [int(val)]
    return []


@hydra.main(config_path=None, config_name=None, version_base="1.1") 
def continual_clip(cfg: DictConfig) -> None:

    cfg.workdir = utils.get_workdir(path=os.getcwd())
    # If dataset_root is absolute, use it as-is; else resolve relative to workdir
    try:
        if not os.path.isabs(str(getattr(cfg, 'dataset_root', ''))):
            cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root)
    except Exception:
        cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root)

    # Global seed for reproducibility
    utils.seed_all(int(getattr(cfg, 'seed', 1)))

    # Parse multi-dataset training plan
    train_indices = _parse_int_list(getattr(cfg, 'train_dataset', []))
    splits_list = _parse_int_list(getattr(cfg, 'cil_splits', []))
    # Backward compatibility: single dataset + single split
    if not train_indices:
        toi = int(getattr(cfg, 'train_one_dataset', -1))
        if toi >= 0:
            train_indices = [toi]
            cs = int(getattr(cfg, 'cil_splits', 0))
            splits_list = [cs] if cs > 0 else []
    if not train_indices:
        raise ValueError("Please provide train_dataset=[...] (indices 0..24), and cil_splits=[...] of the same length.")
    if not splits_list or len(splits_list) != len(train_indices):
        raise ValueError(f"cil_splits must be provided with the same length as train_dataset. Got {len(splits_list)} vs {len(train_indices)}")

    utils.save_config(cfg)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load class order only if a file is provided; otherwise use None
    if getattr(cfg, 'class_order', None):
        cfg.class_order = utils.get_class_order(os.path.join(cfg.workdir, cfg.class_order))
    else:
        cfg.class_order = None
    model  = load_model(cfg, device)

    # Build per-dataset scenarios according to the training plan
    train_scenarios = []  # list of Continuum ClassIncremental for train
    eval_scenarios = []   # list of Continuum ClassIncremental for eval
    classes_names_per = []
    templates_first_per = []  # per-phase first template (callable or string)
    templates_list_per = []   # per-phase full template list (list[str/callable])
    dataset_key_per = []      # per-phase dataset key string
    num_tasks_per_phase = []
    inc_per_phase = []
    # Aggregate CIL dataset statistics across all training datasets
    cil_train_samples_total = 0
    cil_total_samples_total = 0
    cil_total_classes_total = 0
    for di, splits in zip(train_indices, splits_list):
        if di not in TRAIN_INDEX_TO_DATASET_KEY:
            raise ValueError(
                f"train_dataset contains invalid index {di}. Supported indices: {list(TRAIN_INDEX_TO_DATASET_KEY.keys())}"
            )
        ds_key = TRAIN_INDEX_TO_DATASET_KEY[di]
        # Determine classes and increments for this dataset
        cfg.dataset = ds_key
        try:
            _tmp_dataset, _tmp_classes = get_dataset(cfg, is_train=True)
            num_classes = len(_tmp_classes)
        except Exception:
            fallback_classes = {
                'cifar100': 100,
                'stanford_cars': 196,
            }
            num_classes = fallback_classes.get(cfg.dataset, 100)
        inc = math.ceil(num_classes / int(splits))
        cfg.initial_increment = inc
        cfg.increment = inc
        # Ensure the scenario builder creates exactly `splits` tasks for this dataset phase
        try:
            cfg.cil_splits = int(splits)
        except Exception:
            cfg.cil_splits = splits
        # Build scenarios for this dataset
        eval_scn, eval_cls = build_cl_scenarios(cfg, is_train=False, transforms=model.transforms)
        train_scn, train_cls = build_cl_scenarios(cfg, is_train=True, transforms=model.transforms)
        eval_scenarios.append(eval_scn)
        train_scenarios.append(train_scn)
        classes_names_per.append(train_cls)
        dataset_key_per.append(ds_key)

        # Per-dataset CIL stats: train samples, total samples (train+eval), num classes
        try:
            train_samples_ds = 0
            for t_idx in range(len(train_scn)):
                try:
                    train_samples_ds += len(train_scn[t_idx])
                except Exception:
                    pass
            eval_samples_ds = 0
            for t_idx in range(len(eval_scn)):
                try:
                    eval_samples_ds += len(eval_scn[t_idx])
                except Exception:
                    pass
            total_samples_ds = train_samples_ds + eval_samples_ds
            num_classes_ds = len(train_cls)
            cil_train_samples_total += train_samples_ds
            cil_total_samples_total += total_samples_ds
            cil_total_classes_total += num_classes_ds
            print(f"[CIL][Dataset] {ds_key}: train_samples={train_samples_ds} | total_samples={total_samples_ds} | num_classes={num_classes_ds}")
        except Exception:
            pass
        # Also record dataset-specific templates for this phase (first + full list)
        try:
            class _TCfg:
                pass
            tcfg = _TCfg()
            tcfg.dataset = 'MTIL'
            tcfg.dataset_root = cfg.dataset_root
            tcfg.seed = getattr(cfg, 'seed', 1)
            tcfg.use_validation = getattr(cfg, 'use_validation', False)
            tcfg.MTIL_order_2 = getattr(cfg, 'MTIL_order_2', False)
            tcfg.train_one_dataset = int(di)
            _ds_tmp, _cls_tmp, _tmpl_tmp, _names_tmp = get_mtil_dataset(tcfg, split='test', transforms=model.transforms)
            first_tmpl = None
            full_list = None
            # _tmpl_tmp is a list per selected dataset; for single dataset, take index 0
            if isinstance(_tmpl_tmp, (list, tuple)) and len(_tmpl_tmp) > 0:
                per_ds_templates = _tmpl_tmp[0]
                # per_ds_templates is usually a list[str/callable]
                if isinstance(per_ds_templates, (list, tuple)) and len(per_ds_templates) > 0:
                    full_list = list(per_ds_templates)
                elif isinstance(per_ds_templates, str) or callable(per_ds_templates):
                    full_list = [per_ds_templates]
                # take its first element
                if isinstance(per_ds_templates, (list, tuple)) and len(per_ds_templates) > 0:
                    first_tmpl = per_ds_templates[0]
                elif isinstance(per_ds_templates, str) or callable(per_ds_templates):
                    first_tmpl = per_ds_templates
            templates_first_per.append(first_tmpl)
            templates_list_per.append(full_list)
        except Exception:
            templates_first_per.append(None)
            templates_list_per.append(None)
        num_tasks_per_phase.append(len(train_scn))
        inc_per_phase.append(inc)

    # Global CIL stats summary across all selected training datasets
    if cil_total_samples_total > 0:
        print(f"[CIL][Summary] train_samples={cil_train_samples_total} | total_samples={cil_total_samples_total} | total_classes={cil_total_classes_total}")

    # Prepare global plan of tasks across phases
    global_plan = []  # list of (phase_idx, local_task_idx)
    for pi, n_tasks in enumerate(num_tasks_per_phase):
        for lt in range(n_tasks):
            global_plan.append((pi, lt))

    with open(cfg.log_path, 'w+') as f: 
        pass

    acc_list = []
    metric_logger = Logger(list_subsets=["test"])  # kept for compatibility, not used for metrics
    # Histories for custom metrics
    best_acc_so_far = {}            # task_id -> best accuracy before current step
    acc_at_learn_time = {}          # task_id -> accuracy at the step it was learned
    pre_acc_before_training = {}    # task_id -> accuracy measured at previous step before learning task_id
    # Track per-global-task output block (start index, size) and absolute class ids order
    block_starts = []
    block_sizes = []
    block_abs_ids = []

    # Pre-task zero-shot evaluation on up to 10 non-training datasets using original CLIP
    if bool(getattr(cfg, 'pre_task_zero_shot_eval', True)) and bool(getattr(cfg, 'zero_shot_eval', True)):
        _skip_names_log = [MTIL_INDEX_TO_NAME.get(i, 'StanfordCars') for i in train_indices]
        logging.info(f"Pre-task zero-shot evaluation with ORIGINAL CLIP (excluding {_skip_names_log})...")
        zs_pre_results = evaluate_zero_shot(model, device, cfg, limit_datasets=10, use_original_clip=True)
        with open(cfg.log_path, 'a+') as f:
            f.write(json.dumps({
                'task': -1,
                'zs_pre': zs_pre_results,
            }) + '\n')

    # Training/evaluation over the flattened multi-dataset plan
    for global_task_id, (phase_idx, local_task_id) in enumerate(global_plan):
        logging.info(f"Evaluation for task {global_task_id} (phase {phase_idx}, local {local_task_id}) has started.")
        # Ensure model has correct class names for this phase
        model.classes_names = classes_names_per[phase_idx]
        # Ensure cfg has the right increment for this dataset phase so adaptation builds correct splits
        cfg.initial_increment = inc_per_phase[phase_idx]
        cfg.increment = inc_per_phase[phase_idx]
        # Force recomputation of class_ids_per_task for a new dataset phase
        model.class_ids_per_task = None
        # Train/adapt on the current local task within its phase
        model.adaptation(local_task_id, cfg, train_scenarios[phase_idx], classes_names_per[phase_idx])
        # Record the output block range and absolute ids appended by this task
        block_starts.append(int(getattr(model, 'last_task_start_index', 0)))
        abs_ids = list(getattr(model, 'last_task_real_ids', []))
        block_abs_ids.append(abs_ids)
        block_sizes.append(len(abs_ids))
        # free cached VRAM after finishing adaptation before evaluation
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Strict CIL over a single global class space (across datasets) for all seen tasks
        eval_bs = int(getattr(cfg, 'eval_batch_size', 32))
        amp_eval = bool(getattr(cfg, 'eval_use_amp', True) and torch.cuda.is_available())
        text_chunk = int(getattr(cfg, 'eval_text_chunk', 512))
        # Build global seen class list in fixed order: concatenate each seen task's (phase, class_id)
        global_seen = []  # list of (phase_idx, abs_class_id)
        for g_idx_seen, (p_seen, _l_seen) in enumerate(global_plan[:global_task_id + 1]):
            for cid in block_abs_ids[g_idx_seen]:
                global_seen.append((p_seen, int(cid)))
        # Build global tokens and mapping (phase,cid) -> global index
        def _render_with_phase_template(phase_idx, class_name):
            tmpl = templates_first_per[phase_idx] if phase_idx < len(templates_first_per) else None
            if callable(tmpl):
                try:
                    return tmpl(class_name)
                except Exception:
                    pass
            if isinstance(tmpl, str):
                try:
                    return tmpl.format(class_name)
                except Exception:
                    pass
            # fallback to global prompt template
            try:
                return cfg.prompt_template.format(class_name)
            except Exception:
                return f"a photo of a {class_name}."
        # Print templates used for each phase included so far
        try:
            phases_in_eval = sorted({p for (p, _cid) in global_seen})
            for p in phases_in_eval:
                ds_key = dataset_key_per[p] if p < len(dataset_key_per) else f"phase_{p}"
                tmpl = templates_first_per[p] if p < len(templates_first_per) else None
                if isinstance(tmpl, (list, tuple)) and len(tmpl) > 0:
                    desc = f"list_first:{tmpl[0]}"
                elif isinstance(tmpl, str):
                    desc = tmpl
                elif callable(tmpl):
                    desc = f"callable:{getattr(tmpl, '__name__', str(tmpl))}"
                else:
                    desc = f"fallback:{getattr(cfg, 'prompt_template', 'a photo of a {}.')}"
                sample_cls = None
                try:
                    sample_cls = classes_names_per[p][0]
                except Exception:
                    sample_cls = 'object'
                example = _render_with_phase_template(p, sample_cls)
                logging.info(f"[CIL][Template] phase={p} dataset={ds_key} | tmpl={desc} | example='{example}'")
        except Exception:
            pass
        # Build tokens using ALL templates per phase; map each token to its global class index
        prompts_all = []
        token_to_class_index = []  # each entry maps token position -> global class index [0..G-1]
        class_template_counts = [0 for _ in range(len(global_seen))]
        for g_idx, (p, cid) in enumerate(global_seen):
            name = classes_names_per[p][cid]
            tlist = templates_list_per[p] if p < len(templates_list_per) and templates_list_per[p] else None
            if not tlist:
                # fallback to single template
                prompts_all.append(_render_with_phase_template(p, name))
                token_to_class_index.append(g_idx)
                class_template_counts[g_idx] += 1
            else:
                for t in tlist:
                    # render with specific template
                    if callable(t):
                        try:
                            s = t(name)
                        except Exception:
                            s = _render_with_phase_template(p, name)
                    elif isinstance(t, str):
                        try:
                            s = t.format(name)
                        except Exception:
                            s = _render_with_phase_template(p, name)
                    else:
                        s = _render_with_phase_template(p, name)
                    prompts_all.append(s)
                    token_to_class_index.append(g_idx)
                    class_template_counts[g_idx] += 1
        tokens_all = clip.tokenize(prompts_all)  # keep on CPU; move chunks to device on demand
        global_index_of = {(p, cid): i for i, (p, cid) in enumerate(global_seen)}
        # Accumulate per-task counts for custom metrics (by global task id index)
        task_correct = {}
        task_total = {}
        for g_idx, (p_i, l_t) in enumerate(global_plan[:global_task_id + 1]):
            ds = eval_scenarios[p_i][l_t]
            loader = DataLoader(TaskIdOffsetDataset(ds, offset=g_idx), batch_size=eval_bs)
            with torch.no_grad():
                # Precompute class counts tensor for averaging
                counts_tensor = torch.tensor(class_template_counts, dtype=torch.float32, device=device).clamp_min(1.0)
                for inputs, targets, _task_ids in tqdm(loader):
                    inputs = inputs.to(device, non_blocking=True)
                    batch_size = inputs.shape[0]
                    # Aggregate logits over all templates per class
                    agg_logits = torch.zeros((batch_size, len(global_seen)), device=device)
                    with torch.cuda.amp.autocast(enabled=amp_eval):
                        for start in range(0, tokens_all.size(0), max(1, text_chunk)):
                            end = min(tokens_all.size(0), start + max(1, text_chunk))
                            chunk = tokens_all[start:end].to(device, non_blocking=True)
                            if hasattr(model, 'compute_logits') and callable(getattr(model, 'compute_logits')):
                                logits_chunk = model.compute_logits(inputs, chunk, is_zeroshot=False)
                            else:
                                logits_chunk, _ = model.model(inputs, chunk, 0, is_train=False)
                            # Map chunk token positions to global class indices and scatter-add
                            idx_chunk = torch.tensor(token_to_class_index[start:end], device=device).view(1, -1).expand(batch_size, -1)
                            if logits_chunk.dtype != agg_logits.dtype:
                                logits_chunk = logits_chunk.to(dtype=agg_logits.dtype)
                            agg_logits.scatter_add_(1, idx_chunk, logits_chunk)
                    # Average across templates per class
                    agg_logits = agg_logits / counts_tensor.view(1, -1)
                    preds_global = agg_logits.detach().cpu().argmax(dim=1).numpy()
                    # Map targets (abs ids from phase p_i) to global indices
                    if isinstance(targets, torch.Tensor):
                        t_np = targets.detach().cpu().numpy()
                    else:
                        t_np = np.asarray(targets)
                    mapped = np.array([global_index_of.get((p_i, int(v)), -1) for v in t_np], dtype=np.int64)
                    valid = mapped >= 0
                    corr = int((preds_global[valid] == mapped[valid]).sum())
                    tot = int(valid.sum())
                    task_correct[g_idx] = task_correct.get(g_idx, 0) + corr
                    task_total[g_idx] = task_total.get(g_idx, 0) + tot
        # free global tokens to avoid growth across steps
        del tokens_all
        torch.cuda.empty_cache()

        # Compute VOC2007 11-point mAP for CIL and use it as per-task accuracy for the VOC task
        voc_mAP = None
        voc_tid_override = None
        try:
            # Identify seen tasks belonging to VOC2007 phase(s)
            voc_tids = []
            for tid_seen, (p_i, _lt) in enumerate(global_plan[:global_task_id + 1]):
                if p_i < len(dataset_key_per) and dataset_key_per[p_i] == 'voc2007':
                    voc_tids.append(tid_seen)
            if voc_tids:
                # Use dataset-specific template for VOC phase
                p_voc = global_plan[voc_tids[-1]][0]
                # Use ALL templates for VOC phase
                tlist_voc = templates_list_per[p_voc] if p_voc < len(templates_list_per) and templates_list_per[p_voc] else None
                def _render_voc_all(cname: str):
                    outs = []
                    if tlist_voc:
                        for t in tlist_voc:
                            if callable(t):
                                try:
                                    outs.append(t(cname))
                                except Exception:
                                    continue
                            elif isinstance(t, str):
                                try:
                                    outs.append(t.format(cname))
                                except Exception:
                                    continue
                    if not outs:
                        try:
                            outs = [cfg.prompt_template.format(cname)]
                        except Exception:
                            outs = [f"a photo of a {cname}."]
                    return outs
                # Build VOC tokens over all templates per class (20 classes)
                voc_ds_multi = MTILVOC2007(root=cfg.dataset_root, seed=getattr(cfg, 'seed', 1), single_label=False)
                voc_prompts = []
                voc_token_to_class = []
                for ci, cname in enumerate(voc_ds_multi.classnames):
                    outs = _render_voc_all(cname)
                    voc_prompts.extend(outs)
                    voc_token_to_class.extend([ci] * len(outs))
                voc_tokens = clip.tokenize(voc_prompts).to(device)
                # counts per class for averaging
                voc_counts = torch.zeros(len(voc_ds_multi.classnames), dtype=torch.float32, device=device)
                for ci in voc_token_to_class:
                    voc_counts[ci] += 1.0
                # Iterate VOC test set in batches
                y_true = []
                y_score = []
                batch = []
                def _flush_batch(batch_list):
                    if not batch_list:
                        return
                    imgs = []
                    ys = []
                    for d in batch_list:
                        try:
                            img = Image.open(d.impath).convert('RGB')
                            if getattr(model, 'transforms', None) is not None:
                                img = model.transforms(img)
                            imgs.append(img)
                            ys.append(torch.tensor(d.label, dtype=torch.long))
                        except Exception:
                            continue
                    if not imgs:
                        return
                    x = torch.stack(imgs, dim=0).to(device, non_blocking=True)
                    with torch.no_grad(), torch.cuda.amp.autocast(enabled=amp_eval):
                        if hasattr(model, 'compute_logits') and callable(getattr(model, 'compute_logits')):
                            logits_full = model.compute_logits(x, voc_tokens, is_zeroshot=False)
                        else:
                            logits_full, _ = model.model(x, voc_tokens, 0, is_train=False)
                    # Aggregate over templates per class
                    B = logits_full.size(0)
                    Gv = len(voc_ds_multi.classnames)
                    agg = torch.zeros((B, Gv), device=logits_full.device)
                    idx_chunk = torch.tensor(voc_token_to_class, device=logits_full.device).view(1, -1).expand(B, -1)
                    if logits_full.dtype != agg.dtype:
                        logits_full = logits_full.to(dtype=agg.dtype)
                    agg.scatter_add_(1, idx_chunk, logits_full)
                    agg = agg / voc_counts.view(1, -1)
                    y_score.append(agg.detach().cpu())
                    y_true.append(torch.stack(ys, dim=0))
                bs_local = eval_bs
                for d in voc_ds_multi.test:
                    batch.append(d)
                    if len(batch) >= bs_local:
                        _flush_batch(batch)
                        batch = []
                if batch:
                    _flush_batch(batch)
                if y_true and y_score:
                    y_true_all = torch.cat(y_true, dim=0).numpy()
                    y_score_all = torch.cat(y_score, dim=0).numpy()
                    # 11-point AP per class
                    def _ap11(y_true_cls: np.ndarray, y_score_cls: np.ndarray) -> float:
                        order = np.argsort(-y_score_cls)
                        y_true_sorted = y_true_cls[order]
                        tp = (y_true_sorted == 1).astype(np.float32)
                        fp = (y_true_sorted == 0).astype(np.float32)
                        tp_cum = np.cumsum(tp)
                        fp_cum = np.cumsum(fp)
                        prec = tp_cum / np.maximum(tp_cum + fp_cum, 1e-12)
                        total_pos = max(1.0, float((y_true_cls == 1).sum()))
                        rec = tp_cum / total_pos
                        ap = 0.0
                        for r in np.linspace(0.0, 1.0, 11):
                            mask = rec >= r
                            p_interp = np.max(prec[mask]) if np.any(mask) else 0.0
                            ap += p_interp
                        return ap / 11.0
                    aps = []
                    for ci in range(y_true_all.shape[1]):
                        aps.append(_ap11(y_true_all[:, ci].astype(np.int64), y_score_all[:, ci].astype(np.float32)))
                    voc_mAP = 100.0 * float(np.mean(aps)) if aps else None
                    voc_tid_override = voc_tids[-1]
        except Exception as e:
            logging.error(f"VOC2007 mAP (CIL) failed: {e}")

        # zero-shot evaluation on auxiliary datasets
        zs_results = {}
        if getattr(cfg, 'zero_shot_eval', True):
            zs_results = evaluate_zero_shot(model, device, cfg)
            # Compute zero-shot mean accuracy
            if zs_results:
                zs_mean = round(sum(zs_results.values()) / len(zs_results), 2)
            else:
                zs_mean = 0.0

        # Router behavior statistics: seen group + zero-shot datasets
        # Detect block mode for per-block statistics
        is_block_mode = (str(getattr(cfg, 'dfa_inject_mode', 'head')).lower() == 'block')
        num_blocks = len(getattr(model, 'dfa_blocks', [])) if is_block_mode else 0
        router_stats = {}
        
        # Helper: compute stats from weight tensor
        def _compute_w_stats(w_all):
            if w_all is None or w_all.numel() == 0:
                return None
            w1 = w_all[:, 0].clamp(0, 1)
            w2 = w_all[:, 1].clamp(0, 1)
            def _hist(v):
                bins = int(getattr(cfg, 'router_hist_bins', 10))
                h = torch.histc(v, bins=bins, min=0.0, max=1.0)
                return {'bins': bins, 'range': [0.0, 1.0], 'counts': h.tolist(), 'n': int(v.numel())}
            return {
                'w1_mean': float(w1.mean().item()),
                'w2_mean': float(w2.mean().item()),
                'w1_hist': _hist(w1),
                'w2_hist': _hist(w2),
            }

        try:
            # Seen group: sample up to 50 across seen eval datasets
            sample_seen = int(getattr(cfg, 'router_stats_samples', 50))
            seen_imgs = []
            per_ds_quota = max(1, sample_seen // max(1, len(global_plan[:global_task_id + 1])))
            for p_i, l_t in global_plan[:global_task_id + 1]:
                ds = eval_scenarios[p_i][l_t]
                cnt = 0
                for bx, _, _ in DataLoader(ds, batch_size=eval_bs, shuffle=True):
                    for i in range(bx.size(0)):
                        seen_imgs.append(bx[i])
                        cnt += 1
                        if cnt >= per_ds_quota or len(seen_imgs) >= sample_seen:
                            break
                    if cnt >= per_ds_quota or len(seen_imgs) >= sample_seen:
                        break
                if len(seen_imgs) >= sample_seen:
                    break
            
            if seen_imgs:
                seen_tensor = torch.stack(seen_imgs, dim=0).to(device, non_blocking=True)
                
                if is_block_mode and num_blocks > 0:
                    # Block mode: collect per-block router weights
                    block_weights = model.get_block_router_weights(seen_tensor, is_zeroshot=False)
                    router_stats['seen'] = {'blocks': []}
                    for blk_idx in range(num_blocks):
                        w_top = block_weights['w_top'][blk_idx]
                        w_spec = block_weights['w_spec'][blk_idx]
                        blk_stats = {
                            'w_top': _compute_w_stats(w_top) if w_top is not None else None,
                            'w_spec': _compute_w_stats(w_spec) if w_spec is not None else None,
                        }
                        router_stats['seen']['blocks'].append(blk_stats)
                else:
                    # Head mode: single router
                    from torch.utils.data import TensorDataset
                    seen_loader = DataLoader(TensorDataset(seen_tensor.cpu()), batch_size=eval_bs)
                    w_list = []
                    for bx in seen_loader:
                        imgs = bx[0].to(device, non_blocking=True)
                        try:
                            w = model.get_router_weights(imgs, is_zeroshot=False)
                            w_list.append(w.detach().cpu())
                        except Exception:
                            continue
                    if w_list:
                        router_stats['seen'] = _compute_w_stats(torch.cat(w_list, dim=0))
        except Exception as e:
            logging.error(f"Router stats (seen) failed: {e}")

        # Zero-shot groups: 24 MTIL datasets filtered as in evaluation
        try:
            class _ZSCfg:
                pass
            zs_cfg = _ZSCfg()
            zs_cfg.dataset = 'MTIL'
            zs_cfg.dataset_root = cfg.dataset_root
            zs_cfg.seed = getattr(cfg, 'seed', 1)
            zs_cfg.use_validation = getattr(cfg, 'use_validation', False)
            zs_cfg.MTIL_order_2 = getattr(cfg, 'MTIL_order_2', False)
            zs_cfg.train_one_dataset = -1
            zs_transforms = getattr(model, 'transforms', None)
            zs_datasets, zs_classnames, zs_templates, zs_names = get_mtil_dataset(zs_cfg, split='test', transforms=zs_transforms)
            datasets_info = list(zip(zs_datasets, zs_classnames, zs_templates, zs_names))
            # parse training indices to skip
            def _parse_list(val):
                if isinstance(val, ListConfig):
                    return [int(v) for v in val]
                if isinstance(val, (list, tuple)):
                    return [int(v) for v in val]
                if isinstance(val, str):
                    parts = [p.strip() for p in val.replace(';', ',').split(',') if p.strip()]
                    return [int(p) for p in parts]
                if isinstance(val, (int,)):
                    return [int(val)]
                return []
            train_indices = _parse_list(getattr(cfg, 'train_dataset', []))
            if not train_indices:
                toi = int(getattr(cfg, 'train_one_dataset', -1))
                if toi >= 0:
                    train_indices = [toi]
            skip_names = {MTIL_INDEX_TO_NAME.get(i, 'StanfordCars') for i in train_indices}
            # filter
            filtered = [(ds, name) for (ds, _cn, _tm, name) in datasets_info if name not in skip_names]
            # sample 50 per dataset
            sample_zs = int(getattr(cfg, 'router_stats_samples', 50))
            router_stats['zs'] = {}
            name_to_idx = {v: k for k, v in MTIL_INDEX_TO_NAME.items()}
            
            for ds, name in filtered:
                imgs = []
                cnt = 0
                for bx, *_ in DataLoader(ds, batch_size=eval_bs, shuffle=True):
                    imgs.extend([bx[i] for i in range(bx.size(0))])
                    cnt += bx.size(0)
                    if cnt >= sample_zs:
                        break
                if imgs:
                    tensor = torch.stack(imgs[:sample_zs], dim=0).to(device, non_blocking=True)
                    
                    if is_block_mode and num_blocks > 0:
                        # Block mode: set zs_dataset_index for debug mode
                        try:
                            if getattr(model, 'gate_mode', '') == 'debug':
                                setattr(model, 'zs_dataset_index_for_debug', int(name_to_idx.get(name, -1)))
                        except Exception:
                            pass
                        block_weights = model.get_block_router_weights(tensor, is_zeroshot=True)
                        zs_stats = {'blocks': []}
                        for blk_idx in range(num_blocks):
                            w_top = block_weights['w_top'][blk_idx]
                            w_spec = block_weights['w_spec'][blk_idx]
                            blk_stats = {
                                'w_top': _compute_w_stats(w_top) if w_top is not None else None,
                                'w_spec': _compute_w_stats(w_spec) if w_spec is not None else None,
                            }
                            zs_stats['blocks'].append(blk_stats)
                        router_stats['zs'][name] = zs_stats
                    else:
                        # Head mode
                        try:
                            if getattr(model, 'gate_mode', '') == 'debug':
                                setattr(model, 'zs_dataset_index_for_debug', int(name_to_idx.get(name, -1)))
                        except Exception:
                            pass
                        from torch.utils.data import TensorDataset
                        dl = DataLoader(TensorDataset(tensor.cpu()), batch_size=eval_bs)
                        w_list = []
                        for bx in dl:
                            imgs_b = bx[0].to(device, non_blocking=True)
                            try:
                                w = model.get_router_weights(imgs_b, is_zeroshot=True)
                                w_list.append(w.detach().cpu())
                            except Exception:
                                continue
                        if w_list:
                            router_stats['zs'][name] = _compute_w_stats(torch.cat(w_list, dim=0))
        except Exception as e:
            logging.error(f"Router stats (ZS) failed: {e}")

        # Visualization: per-block plots in separate folders for block mode
        try:
            if plt is not None and router_stats:
                base_out_dir = os.path.join(os.path.dirname(cfg.log_path), 'router_plots')
                os.makedirs(base_out_dir, exist_ok=True)
                
                if is_block_mode and num_blocks > 0:
                    # Block mode: create subfolder for each block
                    for blk_idx in range(num_blocks):
                        blk_dir = os.path.join(base_out_dir, f'block_{blk_idx}')
                        os.makedirs(blk_dir, exist_ok=True)
                        
                        # Collect w_top stats for this block across all groups
                        groups = []
                        w1_means = []
                        w2_means = []
                        
                        # Seen group
                        if 'seen' in router_stats and router_stats['seen']:
                            seen_blk = router_stats['seen'].get('blocks', [])
                            if blk_idx < len(seen_blk) and seen_blk[blk_idx]:
                                w_top_stats = seen_blk[blk_idx].get('w_top')
                                if w_top_stats:
                                    groups.append('seen')
                                    w1_means.append(w_top_stats['w1_mean'])
                                    w2_means.append(w_top_stats['w2_mean'])
                        
                        # ZS groups
                        zs_dict = router_stats.get('zs', {}) or {}
                        for name in sorted(zs_dict.keys()):
                            zs_blks = zs_dict[name].get('blocks', []) if zs_dict[name] else []
                            if blk_idx < len(zs_blks) and zs_blks[blk_idx]:
                                w_top_stats = zs_blks[blk_idx].get('w_top')
                                if w_top_stats:
                                    groups.append(name)
                                    w1_means.append(w_top_stats['w1_mean'])
                                    w2_means.append(w_top_stats['w2_mean'])
                        
                        if groups:
                            x = np.arange(len(groups))
                            width = 0.4
                            fig, ax = plt.subplots(figsize=(max(8, len(groups) * 0.4), 4))
                            ax.bar(x - width/2, w1_means, width, label='w1 (E1)')
                            ax.bar(x + width/2, w2_means, width, label='w2 (E2)')
                            ax.set_xticks(x)
                            ax.set_xticklabels(groups, rotation=45, ha='right')
                            ax.set_ylim(0.0, 1.0)
                            ax.set_ylabel('mean weight')
                            ax.set_title(f'Block {blk_idx} router_top @ task {global_task_id}')
                            ax.legend()
                            fig.tight_layout()
                            fig.savefig(os.path.join(blk_dir, f'router_top_task_{global_task_id}.png'))
                            plt.close(fig)
                        
                        # Also plot w_spec (E2a vs E2b) if available
                        groups_spec = []
                        e2a_means = []
                        e2b_means = []
                        
                        if 'seen' in router_stats and router_stats['seen']:
                            seen_blk = router_stats['seen'].get('blocks', [])
                            if blk_idx < len(seen_blk) and seen_blk[blk_idx]:
                                w_spec_stats = seen_blk[blk_idx].get('w_spec')
                                if w_spec_stats:
                                    groups_spec.append('seen')
                                    e2a_means.append(w_spec_stats['w1_mean'])
                                    e2b_means.append(w_spec_stats['w2_mean'])
                        
                        for name in sorted(zs_dict.keys()):
                            zs_blks = zs_dict[name].get('blocks', []) if zs_dict[name] else []
                            if blk_idx < len(zs_blks) and zs_blks[blk_idx]:
                                w_spec_stats = zs_blks[blk_idx].get('w_spec')
                                if w_spec_stats:
                                    groups_spec.append(name)
                                    e2a_means.append(w_spec_stats['w1_mean'])
                                    e2b_means.append(w_spec_stats['w2_mean'])
                        
                        if groups_spec:
                            x = np.arange(len(groups_spec))
                            fig, ax = plt.subplots(figsize=(max(8, len(groups_spec) * 0.4), 4))
                            ax.bar(x - width/2, e2a_means, width, label='E2a')
                            ax.bar(x + width/2, e2b_means, width, label='E2b')
                            ax.set_xticks(x)
                            ax.set_xticklabels(groups_spec, rotation=45, ha='right')
                            ax.set_ylim(0.0, 1.0)
                            ax.set_ylabel('mean weight')
                            ax.set_title(f'Block {blk_idx} w_spec (E2a/E2b) @ task {global_task_id}')
                            ax.legend()
                            fig.tight_layout()
                            fig.savefig(os.path.join(blk_dir, f'router_spec_task_{global_task_id}.png'))
                            plt.close(fig)
                    
                    # Also create a summary plot across all blocks
                    if 'seen' in router_stats and router_stats['seen']:
                        seen_blocks = router_stats['seen'].get('blocks', [])
                        if seen_blocks:
                            blk_indices = list(range(len(seen_blocks)))
                            w1_by_blk = [b['w_top']['w1_mean'] if b and b.get('w_top') else 0.0 for b in seen_blocks]
                            w2_by_blk = [b['w_top']['w2_mean'] if b and b.get('w_top') else 0.0 for b in seen_blocks]
                            
                            fig, ax = plt.subplots(figsize=(10, 4))
                            x = np.arange(len(blk_indices))
                            ax.bar(x - 0.2, w1_by_blk, 0.4, label='w1 (E1)')
                            ax.bar(x + 0.2, w2_by_blk, 0.4, label='w2 (E2)')
                            ax.set_xticks(x)
                            ax.set_xticklabels([f'B{i}' for i in blk_indices])
                            ax.set_ylim(0.0, 1.0)
                            ax.set_xlabel('Block')
                            ax.set_ylabel('mean weight')
                            ax.set_title(f'Seen: router_top across blocks @ task {global_task_id}')
                            ax.legend()
                            fig.tight_layout()
                            fig.savefig(os.path.join(base_out_dir, f'seen_all_blocks_task_{global_task_id}.png'))
                            plt.close(fig)
                else:
                    # Head mode: single plot (original behavior)
                    groups = []
                    w1_means = []
                    w2_means = []
                    if 'seen' in router_stats and router_stats['seen']:
                        groups.append('seen')
                        w1_means.append(router_stats['seen']['w1_mean'])
                        w2_means.append(router_stats['seen']['w2_mean'])
                    zs_dict = router_stats.get('zs', {}) or {}
                    for name in sorted(zs_dict.keys()):
                        if zs_dict[name]:
                            groups.append(name)
                            w1_means.append(zs_dict[name]['w1_mean'])
                            w2_means.append(zs_dict[name]['w2_mean'])
                    if groups:
                        x = np.arange(len(groups))
                        width = 0.4
                        fig, ax = plt.subplots(figsize=(max(8, len(groups) * 0.4), 4))
                        ax.bar(x - width/2, w1_means, width, label='w1_mean')
                        ax.bar(x + width/2, w2_means, width, label='w2_mean')
                        ax.set_xticks(x)
                        ax.set_xticklabels(groups, rotation=45, ha='right')
                        ax.set_ylim(0.0, 1.0)
                        ax.set_ylabel('mean weight')
                        ax.set_title(f'Router means @ task {global_task_id}')
                        ax.legend()
                        fig.tight_layout()
                        fig.savefig(os.path.join(base_out_dir, f'router_stats_task_{global_task_id}.png'))
                        plt.close(fig)
        except Exception as e:
            logging.error(f"Router plotting failed: {e}")

        # Compute custom metrics from accumulated counts
        seen_task_ids = list(range(global_task_id + 1))
        acc_per_task = []
        for tid in seen_task_ids:
            tot = task_total.get(tid, 0)
            if voc_tid_override is not None and tid == voc_tid_override and voc_mAP is not None:
                acc = max(0.0, min(1.0, voc_mAP / 100.0))
            else:
                acc = (task_correct.get(tid, 0) / tot) if tot > 0 else 0.0
            acc_per_task.append(acc)
        # Overall accuracy: treat VOC mAP as accuracy weighted by its sample count
        if voc_tid_override is not None and voc_mAP is not None and voc_tid_override in task_total:
            total_samples = max(1, sum(task_total.values()))
            corrected_sum = 0.0
            for tid in seen_task_ids:
                if tid == voc_tid_override:
                    corrected_sum += (voc_mAP / 100.0) * task_total.get(tid, 0)
                else:
                    corrected_sum += task_correct.get(tid, 0)
            overall_acc = 100.0 * (corrected_sum / total_samples)
        else:
            overall_acc = 100.0 * (sum(task_correct.values()) / max(1, sum(task_total.values())))
        acc_list.append(overall_acc)
        # avg acc over seen tasks
        avg_acc_val = round(100.0 * (sum(acc_per_task) / max(1, len(acc_per_task))), 2)
        # forgetting over past tasks
        forgetting_vals = []
        for tid in seen_task_ids[:-1]:
            prev_best = best_acc_so_far.get(tid, 0.0)
            curr = acc_per_task[tid]
            forgetting_vals.append(max(0.0, prev_best - curr))
        forgetting_val = round(100.0 * (sum(forgetting_vals) / max(1, len(forgetting_vals))), 6)
        # update best-so-far after computing forgetting
        for tid in seen_task_ids:
            best_acc_so_far[tid] = max(best_acc_so_far.get(tid, 0.0), acc_per_task[tid])
        # diagonal accuracy at learn time
        if global_task_id not in acc_at_learn_time:
            acc_at_learn_time[global_task_id] = acc_per_task[global_task_id]
        # backward transfer: mean over past tasks of (curr - acc_at_learn_time)
        bwt_vals = []
        for tid in seen_task_ids[:-1]:
            base = acc_at_learn_time.get(tid, acc_per_task[tid])
            bwt_vals.append(acc_per_task[tid] - base)
        bwt_val = round(100.0 * (sum(bwt_vals) / max(1, len(bwt_vals))), 2)
        # forward transfer: evaluate next task pre-accuracy at this step (R_{t-1,t})
        fwt_val = None
        next_idx = global_task_id + 1
        if next_idx < len(global_plan):
            # evaluate only the next task for pre-training accuracy using temporary tokens for its classes
            p_i, l_t = global_plan[next_idx]
            next_ds = eval_scenarios[p_i][l_t]
            # First pass: collect absolute class ids present in the next task
            uniq = set()
            tmp_loader_ids = DataLoader(next_ds, batch_size=eval_bs)
            with torch.no_grad():
                for _x, _y, _t in tmp_loader_ids:
                    if isinstance(_y, torch.Tensor):
                        uniq.update([int(v) for v in _y.tolist()])
                    else:
                        uniq.update([int(v) for v in _y])
            abs_list = sorted(uniq)
            # Build temporary tokens for those classes
            class_names_next = [classes_names_per[p_i][cid] for cid in abs_list]
            tmp_tokens = clip.tokenize([model.prompt_template.format(c) for c in class_names_next]).to(device)
            mapping = {int(cid): i for i, cid in enumerate(abs_list)}
            # Second pass: compute pre-accuracy restricted to next task classes
            next_loader = DataLoader(next_ds, batch_size=eval_bs)
            next_correct, next_total = 0, 0
            with torch.no_grad():
                for inputs, targets, _task_ids in next_loader:
                    inputs = inputs.to(device)
                    with torch.cuda.amp.autocast(enabled=amp_eval):
                        if hasattr(model, 'compute_logits') and callable(getattr(model, 'compute_logits')):
                            logits = model.compute_logits(inputs, tmp_tokens, is_zeroshot=False)
                        else:
                            logits, _ = model.model(inputs, tmp_tokens, 0, is_train=False)
                    preds = logits.detach().cpu().argmax(dim=1).numpy()
                    t_np = targets.detach().cpu().numpy() if isinstance(targets, torch.Tensor) else np.asarray(targets)
                    mapped = np.array([mapping.get(int(v), -1) for v in t_np], dtype=np.int64)
                    valid = mapped >= 0
                    next_correct += int((preds[valid] == mapped[valid]).sum())
                    next_total += int(valid.sum())
            if next_total > 0:
                pre_acc_before_training[next_idx] = next_correct / next_total
            # free next-task tokens
            del tmp_tokens
            torch.cuda.empty_cache()
        # compute mean FWT observed so far (for tasks that we have pre-acc recorded and already learned)
        if global_task_id >= 1:
            fwt_terms = []
            for tid in range(1, global_task_id + 1):
                if tid in pre_acc_before_training:
                    fwt_terms.append(pre_acc_before_training[tid])
            if fwt_terms:
                fwt_val = round(100.0 * (sum(fwt_terms) / len(fwt_terms)), 2)
            else:
                fwt_val = None
        acc_per_task_list = [round(100.0 * a, 2) for a in acc_per_task]
        with open(cfg.log_path, 'a+') as f:
            f.write(json.dumps({
                'task': global_task_id,
                'acc': round(overall_acc, 2),
                'avg_acc': avg_acc_val,
                'forgetting': forgetting_val,
                'acc_per_task': acc_per_task_list,
                'bwt': bwt_val,
                'fwt': fwt_val,
                'zs': zs_results,
                'router_stats': router_stats,
            }) + '\n')
        
        # Write zero-shot mean to separate JSON file
        if getattr(cfg, 'zero_shot_eval', True) and zs_results:
            zs_mean_path = cfg.log_path.replace('.json', '_zs_mean.json')
            with open(zs_mean_path, 'a+') as f:
                f.write(json.dumps({
                    'task': global_task_id,
                    'zs_mean': zs_mean,
                    'zs_details': zs_results,
                }) + '\n')
        # assert 1 == 2
    with open(cfg.log_path, 'a+') as f:
        f.write(json.dumps({
            'last': round(acc_list[-1], 2), 
            'avg': round(statistics.mean(acc_list), 2)
        }) + '\n')

        



if __name__ == "__main__":
    continual_clip()