# train.py

import os
import json
import datetime
from typing import Optional, Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import average_precision_score

# Keep wildcard import for compatibility with provided model file.
from CGMN import *
from dataset import EmoticDataset
from loss import DiscreteLoss, LabelLoss
from prepare_models import prep_models
from util import gen_label_adj


class TrainingMetricsCollector:
    """Collects and persists training metrics across epochs."""
    def __init__(self) -> None:
        self.reset()

    def reset(self) -> None:
        self.epoch_metrics: List[dict] = []
        self.best_map: float = 0.0
        self.best_epoch: int = 0
        self.final_map: float = 0.0
        self.final_epoch: int = 0
        self.total_epochs: int = 0
        self.fusion_strategy: Optional[str] = None
        self.augmentation_strategy: Optional[str] = None
        self.scheduler_type: Optional[str] = None
        self.best_detailed_ap: Optional[dict] = None
        self.final_detailed_ap: Optional[dict] = None

    def add_epoch_metrics(
        self,
        epoch: int,
        running_loss: float,
        cat_loss: float,
        label_loss_val: float,
        map_score: float,
        detailed_ap: Optional[dict] = None,
    ) -> None:
        metrics = {
            "epoch": epoch,
            "running_loss": float(running_loss),
            "categorical_loss": float(cat_loss),
            "label_loss": float(label_loss_val),
            "map_score": float(map_score),
        }
        if detailed_ap is not None:
            metrics["detailed_ap"] = detailed_ap
        self.epoch_metrics.append(metrics)

        if map_score >= self.best_map:
            self.best_map = float(map_score)
            self.best_epoch = epoch
            if detailed_ap is not None:
                self.best_detailed_ap = detailed_ap

        self.final_map = float(map_score)
        self.final_epoch = epoch
        self.total_epochs = epoch + 1
        if detailed_ap is not None:
            self.final_detailed_ap = detailed_ap

    def set_strategy_info(self, fusion_strategy: str, augmentation_strategy: str, scheduler_type: str) -> None:
        self.fusion_strategy = fusion_strategy
        self.augmentation_strategy = augmentation_strategy
        self.scheduler_type = scheduler_type

    def save_to_json(self, output_dir: str = "./outputs", run_tag: Optional[str] = None) -> str:
        os.makedirs(output_dir, exist_ok=True)
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        suffix = f"_{run_tag}" if run_tag else ""
        json_file = os.path.join(output_dir, f"training_metrics_{timestamp}{suffix}.json")

        metrics_data = {
            "summary": {
                "best_map": self.best_map,
                "best_epoch": self.best_epoch,
                "final_map": self.final_map,
                "final_epoch": self.final_epoch,
                "total_epochs": self.total_epochs,
                "training_completed": True,
                "timestamp": timestamp,
                "best_detailed_ap": self.best_detailed_ap,
                "final_detailed_ap": self.final_detailed_ap,
            },
            "epoch_details": self.epoch_metrics,
            "strategy_info": {
                "fusion_strategy": self.fusion_strategy,
                "augmentation_strategy": self.augmentation_strategy,
                "scheduler_type": self.scheduler_type,
                "fusion_strategy_description": {
                    "basic": "Concatenate multimodal features (context/body/depth/object/head)."
                }.get(self.fusion_strategy, "unknown"),
                "augmentation_strategy_description": {
                    "standard": "Standard augmentations (flip, color jitter).",
                    "aggressive": "Stronger augmentations (rotation/affine).",
                    "emotion_focused": "Mild, stable augmentations tuned for affect recognition.",
                    "minimal": "ToTensor only."
                }.get(self.augmentation_strategy, "unknown"),
                "scheduler_type_description": {
                    "step": "StepLR decay.",
                    "cosine": "Cosine annealing.",
                    "cosine_restart": "Cosine annealing with warm restarts.",
                }.get(self.scheduler_type, "unknown"),
            },
        }

        with open(json_file, "w", encoding="utf-8") as f:
            json.dump(metrics_data, f, indent=2, ensure_ascii=False)

        print(f"\nMetrics saved to: {json_file}")
        return json_file


class DataAugmentationStrategy:
    STANDARD = "standard"
    AGGRESSIVE = "aggressive"
    EMOTION_FOCUSED = "emotion_focused"
    MINIMAL = "minimal"


def get_data_augmentation_transforms(strategy: str = "standard", mode: str = "train"):
    if mode == "test" or strategy == DataAugmentationStrategy.MINIMAL:
        return transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])

    if strategy == DataAugmentationStrategy.STANDARD:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
        ])

    if strategy == DataAugmentationStrategy.AGGRESSIVE:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.RandomRotation(degrees=10),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ToTensor(),
        ])

    if strategy == DataAugmentationStrategy.EMOTION_FOCUSED:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.3),
            transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15),
            transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)),
            transforms.ToTensor(),
        ])

    return transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])


def get_learning_rate_scheduler(optimizer, scheduler_type: str = "step", **kwargs):
    if scheduler_type == "step":
        step_size = kwargs.get("step_size", 7)
        gamma = kwargs.get("gamma", 0.5)
        return StepLR(optimizer, step_size=step_size, gamma=gamma)
    if scheduler_type == "cosine":
        T_max = kwargs.get("T_max", 50)
        eta_min = kwargs.get("eta_min", 1e-6)
        return CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
    if scheduler_type == "cosine_restart":
        T_0 = kwargs.get("T_0", 10)
        T_mult = kwargs.get("T_mult", 2)
        eta_min = kwargs.get("eta_min", 1e-6)
        return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_min=eta_min)
    return StepLR(optimizer, step_size=7, gamma=0.5)


def test_scikit_ap(cat_preds: np.ndarray, cat_labels: np.ndarray, return_detailed: bool = False):
    emotion_categories = [
        "Anger","Affection","Annoyance", "Anticipation", "Aversion", "Confidence",
        "Disapproval", "Disconnection", "Disquietment", "Doubt/Confusion", "Embarrassment",
        "Engagement", "Esteem", "Excitement", "Fatigue", "Fear", "Happiness", "Pain",
        "Peace", "Pleasure", "Sadness", "Sensitivity", "Suffering", "Surprise", "Sympathy", "Yearning",
    ]
    ap = np.zeros(26, dtype=np.float32)
    detailed_ap = {}
    for i in range(26):
        ap[i] = average_precision_score(cat_labels[i, :], cat_preds[i, :])
        detailed_ap[emotion_categories[i]] = float(ap[i])
    mean_ap = ap.mean()
    if return_detailed:
        return mean_ap, detailed_ap
    return mean_ap


def test(
    models,
    device: torch.device,
    data_loader: DataLoader,
    num_images: int,
    label_emb: torch.Tensor,
    edge_sem: torch.Tensor,
    edge_cooccur: torch.Tensor,
    return_detailed: bool = False,
    ablate: str = "none",
    in_obj: Optional[int] = None,
):
    """
    Evaluation loop using extracted backbone features and the CGMN head.
    """
    if len(models) == 6:
        model_context, model_body, model_head, model_depth, model_object, model_caer = models
    else:
        model_context, model_body, model_depth, model_object, model_caer = models
        model_head = None

    cat_preds = np.zeros((num_images, 26))
    cat_labels = np.zeros((num_images, 26))

    model_caer.to(device)
    model_context.to(device)
    model_body.to(device)
    model_depth.to(device)
    model_object.to(device)
    if model_head is not None:
        model_head.to(device)

    model_caer.eval()
    model_context.eval()
    model_body.eval()
    model_depth.eval()
    model_object.eval()
    if model_head is not None:
        model_head.eval()

    DROP_CONTEXT = (ablate == "drop_context")
    DROP_BODY = (ablate == "drop_body")
    DROP_DEPTH = (ablate == "drop_depth")
    DROP_OBJECT = (ablate == "drop_object")
    DROP_HEAD = (ablate == "drop_head")

    IN_CONTEXT = getattr(model_caer, "in_context", 2048)
    IN_BODY = getattr(model_caer, "in_body", 2048)
    IN_DEPTH = getattr(model_caer, "in_depth", 2048)
    IN_OBJ = getattr(model_caer, "in_obj", None)
    if IN_OBJ is None:
        IN_OBJ = in_obj if in_obj is not None else 2048

    with torch.no_grad():
        indx = 0
        for batch_data in iter(data_loader):
            if len(batch_data) == 8:
                images_context, images_body, images_head, images_depth, images_object, obj_label, obj_dist, labels_cat = batch_data
                images_head = images_head.to(device) if (model_head is not None and not DROP_HEAD) else images_head
            else:
                images_context, images_body, images_depth, images_object, obj_label, obj_dist, labels_cat = batch_data
                images_head = None

            labels_cat = labels_cat.to(device)

            images_context = images_context.to(device) if not DROP_CONTEXT else images_context
            images_body = images_body.to(device) if not DROP_BODY else images_body
            images_depth = images_depth.to(device) if not DROP_DEPTH else images_depth
            images_object = images_object.to(device) if not DROP_OBJECT else images_object
            obj_label = obj_label.to(device)
            obj_dist = obj_dist.to(device)

            def _to_vec(x):
                return F.adaptive_avg_pool2d(x, 1).flatten(1) if (hasattr(x, "dim") and x.dim() == 4) else x

            B = labels_cat.size(0)

            if DROP_CONTEXT:
                f_context = torch.zeros(B, IN_CONTEXT, device=device, dtype=torch.float32)
            else:
                f_context = _to_vec(model_context(images_context))

            if DROP_BODY:
                f_body = torch.zeros(B, IN_BODY, device=device, dtype=torch.float32)
            else:
                f_body = _to_vec(model_body(images_body))

            f_head = None
            if not DROP_HEAD and (model_head is not None) and (images_head is not None):
                f_head = _to_vec(model_head(images_head))

            if DROP_DEPTH:
                f_depth = torch.zeros(B, IN_DEPTH, device=device, dtype=torch.float32)
            else:
                f_depth = _to_vec(model_depth(images_depth))

            if DROP_OBJECT:
                f_obj = torch.zeros(B, 4, IN_OBJ, device=device, dtype=torch.float32)
            else:
                f_obj_ = []
                for i in range(images_object.shape[1]):
                    object_i = images_object[:, i, :, :, :]
                    f_object_i = _to_vec(model_object(object_i))
                    f_obj_.append(f_object_i)
                f_obj = torch.stack(f_obj_, dim=1)

            if f_head is not None:
                out = model_caer(
                    f_context, f_body, f_depth, f_obj, obj_label, obj_dist,
                    label_emb, edge_sem, edge_cooccur, f_head=f_head
                )
            else:
                out = model_caer(
                    f_context, f_body, f_depth, f_obj, obj_label, obj_dist,
                    label_emb, edge_sem, edge_cooccur
                )

            if isinstance(out, (list, tuple)):
                if len(out) == 4:
                    pred_cat, label_emb_sem, label_emb_cooccur, _ = out
                elif len(out) == 3:
                    pred_cat, label_emb_sem, label_emb_cooccur = out
                else:
                    raise ValueError(f"model_caer returned {len(out)} outputs, expected 3 or 4.")
            else:
                raise ValueError("model_caer must return a tuple/list.")

            cat_preds[indx:(indx + pred_cat.shape[0]), :] = pred_cat.to("cpu").data.numpy()
            cat_labels[indx:(indx + labels_cat.shape[0]), :] = labels_cat.to("cpu").data.numpy()
            indx = indx + pred_cat.shape[0]

    cat_preds = cat_preds.transpose()
    cat_labels = cat_labels.transpose()

    if return_detailed:
        mAP, detailed_ap = test_scikit_ap(cat_preds, cat_labels, return_detailed=True)
        return mAP, detailed_ap
    mAP = test_scikit_ap(cat_preds, cat_labels)
    return mAP


def train(norm, args) -> float:
    """
    Train loop. `args` is expected to expose:
      - data_path (str): directory containing *_arr_new.npy files.
      - output_dir (str): directory to save metrics; default "./outputs".
      - label_dir (Optional[str]): directory containing label files (emo_def.npy, emo_sim.npy/label_sem.npy, cooccur.npy/label_occur.npy).
      - epochs (int), batch_size (int), gpu (int or str), use_head (bool), ablate (str), loss_ratio (float).
      - augmentation_strategy (str): one of {"standard","aggressive","emotion_focused","minimal"}.
      - scheduler_type (str): one of {"step","cosine","cosine_restart"} and related params below if needed.
      - lr (float), lr_ratio (float), wd (float), step_size (int), gamma (float),
        T_max (int), T_0 (int), T_mult (int), eta_min (float).
      - t_sem (float), t_cooccur (float), p (float): thresholds for label graph.
    """
    metrics_collector = TrainingMetricsCollector()

    # --------------------
    # Data loading
    # --------------------
    train_context = np.load(os.path.join(args.data_path, "train_context_arr_new.npy"))
    print(f"Loaded train context: {train_context.shape}")
    train_body = np.load(os.path.join(args.data_path, "train_body_arr_new.npy"))
    print(f"Loaded train body: {train_body.shape}")
    train_head = None
    if getattr(args, "use_head", False):
        try:
            train_head = np.load(os.path.join(args.data_path, "train_head_arr_new.npy"))
            print(f"Loaded train head: {train_head.shape}")
        except FileNotFoundError:
            print("train_head_arr_new.npy not found. Head modality will be zero-filled.")
            train_head = None
    train_depth = np.load(os.path.join(args.data_path, "train_depth_arr_new.npy"))
    print(f"Loaded train depth: {train_depth.shape}")
    train_object = np.load(os.path.join(args.data_path, "train_obj_arr_new.npy"))
    print(f"Loaded train object: {train_object.shape}")
    train_obj_label = np.load(os.path.join(args.data_path, "train_obj_label_arr_new.npy"))
    print(f"Loaded train object labels: {train_obj_label.shape}")
    train_obj_dist = np.load(os.path.join(args.data_path, "train_obj_dist_arr_new.npy"))
    print(f"Loaded train object distances: {train_obj_dist.shape}")
    train_cat = np.load(os.path.join(args.data_path, "train_cat_arr_new.npy"))
    print(f"Loaded train categories: {train_cat.shape}")

    test_context = np.load(os.path.join(args.data_path, "test_context_arr_new.npy"))
    print(f"Loaded test context: {test_context.shape}")
    test_body = np.load(os.path.join(args.data_path, "test_body_arr_new.npy"))
    print(f"Loaded test body: {test_body.shape}")
    test_head = None
    if getattr(args, "use_head", False):
        try:
            test_head = np.load(os.path.join(args.data_path, "test_head_arr_new.npy"))
            print(f"Loaded test head: {test_head.shape}")
        except FileNotFoundError:
            print("test_head_arr_new.npy not found. Head modality will be zero-filled.")
            test_head = None
    test_depth = np.load(os.path.join(args.data_path, "test_depth_arr_new.npy"))
    print(f"Loaded test depth: {test_depth.shape}")
    test_object = np.load(os.path.join(args.data_path, "test_obj_arr_new.npy"))
    print(f"Loaded test object: {test_object.shape}")
    test_obj_label = np.load(os.path.join(args.data_path, "test_obj_label_arr_new.npy"))
    print(f"Loaded test object labels: {test_obj_label.shape}")
    test_obj_dist = np.load(os.path.join(args.data_path, "test_obj_dist_arr_new.npy"))
    print(f"Loaded test object distances: {test_obj_dist.shape}")
    test_cat = np.load(os.path.join(args.data_path, "test_cat_arr_new.npy"))
    print(f"Loaded test categories: {test_cat.shape}")

    # --------------------
    # Augmentations
    # --------------------
    augmentation_strategy = getattr(args, "augmentation_strategy", DataAugmentationStrategy.STANDARD)
    print(f"Using augmentation strategy: {augmentation_strategy}")
    train_transform = get_data_augmentation_transforms(strategy=augmentation_strategy, mode="train")
    test_transform = get_data_augmentation_transforms(strategy=augmentation_strategy, mode="test")

    # --------------------
    # Datasets / Loaders
    # --------------------
    train_dataset = EmoticDataset(
        train_context, train_body, train_depth, train_object, train_obj_label, train_obj_dist,
        train_cat, train_transform, norm, x_head=train_head
    )
    test_dataset = EmoticDataset(
        test_context, test_body, test_depth, test_object, test_obj_label, test_obj_dist,
        test_cat, test_transform, norm, x_head=test_head
    )

    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=8, pin_memory=True)

    # --------------------
    # Backbones (feature extractors)
    # --------------------
    model_context, model_body, model_depth, model_object, model_head = prep_models()
    if getattr(args, "use_head", False) and (train_head is not None):
        in_head = list(model_head.children())[-1].in_features
    else:
        model_head = None
        in_head = 0

    in_context = list(model_context.children())[-1].in_features
    in_body = list(model_body.children())[-1].in_features
    in_depth = list(model_depth.children())[-1].in_features
    in_obj = list(model_object.children())[-1].in_features
    print(f"Feature dims - Context: {in_context}, Body: {in_body}, Depth: {in_depth}, Object: {in_obj}, Head: {in_head}")

    # --------------------
    # Label embeddings & adjacency
    # --------------------
    # Search directories for label files: args.label_dir (if provided), args.data_path, current dir, and file dir.
    label_search_dirs = []
    if hasattr(args, "label_dir") and args.label_dir:
        label_search_dirs.append(args.label_dir)
    label_search_dirs.extend([args.data_path, ".", os.path.dirname(__file__)])
    # de-duplicate while preserving order
    label_search_dirs = list(dict.fromkeys(label_search_dirs))

    def _load_from_dirs(filenames: List[str], dirs: List[str]) -> np.ndarray:
        for d in dirs:
            for fn in filenames:
                p = os.path.join(d, fn)
                if os.path.isfile(p):
                    print(f"Loaded: {p}")
                    return np.load(p)
        raise FileNotFoundError(f"Files not found in {dirs}: {filenames}")

    label_emb = _load_from_dirs(["emo_def.npy"], label_search_dirs)
    label_sem = _load_from_dirs(["emo_sim.npy", "label_sem.npy"], label_search_dirs)
    label_occur = _load_from_dirs(["cooccur.npy", "label_occur.npy"], label_search_dirs)

    t_sem = getattr(args, "t_sem", 0.8)
    t_cooccur = getattr(args, "t_cooccur", 0.3)
    p = getattr(args, "p", 0.5)
    edge_sem = gen_label_adj(label_sem, t_sem, p).astype(np.float32)
    edge_cooccur = gen_label_adj(label_occur, t_cooccur, p).astype(np.float32)

    # --------------------
    # Model head (ablation-aware)
    # --------------------
    model_caer = None
    if getattr(args, "ablate", "none") == "drop_context":
        model_caer = CGMN_fgs(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: drop context")
    elif args.ablate == "drop_body":
        model_caer = CGMN_fs(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: drop body")
    elif args.ablate == "drop_object":
        model_caer = CGMN_fo(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: drop object")
    elif args.ablate == "drop_depth":
        model_caer = CGMN_fgd(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: drop depth")
    elif args.ablate == "drop_head":
        print("Ablation: drop head (handled by zero features)")

    elif args.ablate == "wodasor":
        model_caer = CGMN_darn(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: w/o DASOR")
    elif args.ablate == "wohegr":
        model_caer = CGMN_hegr(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: w/o HEGR (MLP only)")
    elif args.ablate == "womlp":
        model_caer = CGMN_mvec(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: w/o MLP (graph scores only)")
    elif args.ablate == "wo_sem_1024":
        model_caer = CGMN_wo_sem_1024_noloss(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: w/o label similarity (only co-occur GCN; no label loss)")
    elif args.ablate == "wo_cooccur_1024":
        model_caer = CGMN_wo_cooccur_1024_noloss(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Ablation: w/o label co-occur (only similarity GCN; no label loss)")

    fusion_strategy = "basic"
    print(f"Fusion strategy: {fusion_strategy}")
    if model_caer is None:
        model_caer = CGMN_Basic(in_context, in_body, in_depth, in_obj, d_obj=300, in_label=label_emb.shape[1], in_head=in_head)
        print("Loaded baseline CGMN_Basic head.")

    # Remove backbone classification heads; keep feature extractors only.
    model_context = nn.Sequential(*(list(model_context.children())[:-1]))
    model_body = nn.Sequential(*(list(model_body.children())[:-1]))
    if model_head is not None:
        model_head = nn.Sequential(*(list(model_head.children())[:-1]))
    model_depth = nn.Sequential(*(list(model_depth.children())[:-1]))
    model_object = nn.Sequential(*(list(model_object.children())[:-1]))

    for m in [model_context, model_body, model_depth, model_object, model_caer]:
        for p_ in m.parameters():
            p_.requires_grad = True
    if model_head is not None:
        for p_ in model_head.parameters():
            p_.requires_grad = True

    # --------------------
    # Optimizer / Scheduler
    # --------------------
    lr = getattr(args, "lr", 1e-4)
    lr_ratio = getattr(args, "lr_ratio", 1e-2)
    weight_decay = getattr(args, "wd", 1e-4)

    scheduler_type = getattr(args, "scheduler_type", "step")
    step_size = getattr(args, "step_size", 7)
    gamma = getattr(args, "gamma", 0.5)
    T_max = getattr(args, "T_max", max(10, args.epochs // 2))
    T_0 = getattr(args, "T_0", 10)
    T_mult = getattr(args, "T_mult", 2)
    eta_min = getattr(args, "eta_min", 1e-6)

    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")

    resnet_params = (
        list(model_context.parameters())
        + list(model_body.parameters())
        + list(model_depth.parameters())
        + list(model_object.parameters())
    )
    if model_head is not None:
        resnet_params += list(model_head.parameters())

    opt = optim.Adam(
        [
            {"params": model_caer.parameters(), "lr": lr},
            {"params": resnet_params, "lr": lr * lr_ratio},
        ],
        weight_decay=weight_decay,
    )

    if scheduler_type == "step":
        scheduler = get_learning_rate_scheduler(opt, "step", step_size=step_size, gamma=gamma)
        print(f"Using StepLR: step_size={step_size}, gamma={gamma}")
    elif scheduler_type == "cosine":
        scheduler = get_learning_rate_scheduler(opt, "cosine", T_max=T_max, eta_min=eta_min)
        print(f"Using CosineAnnealingLR: T_max={T_max}, eta_min={eta_min}")
    elif scheduler_type == "cosine_restart":
        scheduler = get_learning_rate_scheduler(opt, "cosine_restart", T_0=T_0, T_mult=T_mult, eta_min=eta_min)
        print(f"Using CosineAnnealingWarmRestarts: T_0={T_0}, T_mult={T_mult}, eta_min={eta_min}")
    else:
        scheduler = StepLR(opt, step_size=7, gamma=0.5)
        print("Using default StepLR.")

    # --------------------
    # Losses & tensors
    # --------------------
    disc_loss = DiscreteLoss(device)
    label_loss_fn = LabelLoss(device)

    model_caer.to(device)
    model_context.to(device)
    model_body.to(device)
    if model_head is not None:
        model_head.to(device)
    model_depth.to(device)
    model_object.to(device)
    label_emb_t = torch.from_numpy(label_emb).to(device)
    edge_sem_t = torch.from_numpy(edge_sem).to(device)
    edge_cooccur_t = torch.from_numpy(edge_cooccur).to(device)
    label_occur_t = torch.from_numpy(label_occur).to(device)
    label_sem_t = torch.from_numpy(label_sem).to(device)

    print("Start training!")

    best_epoch = 0
    best_map = 0.0
    num_batches = len(train_loader)

    def _to_vec(x):
        return F.adaptive_avg_pool2d(x, 1).flatten(1) if (hasattr(x, "dim") and x.dim() == 4) else x

    DROP_CONTEXT = (args.ablate == "drop_context")
    DROP_BODY = (args.ablate == "drop_body")
    DROP_DEPTH = (args.ablate == "drop_depth")
    DROP_OBJECT = (args.ablate == "drop_object")
    DROP_HEAD = (args.ablate == "drop_head") or (not getattr(args, "use_head", False))

    for e in range(args.epochs):
        running_loss = 0.0
        running_cat_loss = 0.0
        running_label_loss = 0.0

        model_context.train()
        model_body.train()
        if model_head is not None:
            model_head.train()
        model_object.train()
        model_depth.train()
        model_caer.train()

        for batch_data in iter(train_loader):
            if len(batch_data) == 8:
                images_context, images_body, images_head, images_depth, images_object, obj_label, obj_dist, labels_cat = batch_data
                images_head = images_head.to(device) if (model_head is not None and not DROP_HEAD) else images_head
            else:
                images_context, images_body, images_depth, images_object, obj_label, obj_dist, labels_cat = batch_data
                images_head = None

            labels_cat = labels_cat.to(device)
            images_context = images_context.to(device) if not DROP_CONTEXT else images_context
            images_body = images_body.to(device) if not DROP_BODY else images_body
            images_depth = images_depth.to(device) if not DROP_DEPTH else images_depth
            images_object = images_object.to(device) if not DROP_OBJECT else images_object
            obj_label = obj_label.to(device)
            obj_dist = obj_dist.to(device)

            opt.zero_grad()
            B = labels_cat.size(0)

            if DROP_CONTEXT:
                f_context = torch.zeros(B, in_context, device=device, dtype=torch.float32)
            else:
                f_context = _to_vec(model_context(images_context))

            if DROP_BODY:
                f_body = torch.zeros(B, in_body, device=device, dtype=torch.float32)
            else:
                f_body = _to_vec(model_body(images_body))

            f_head = None
            if (not DROP_HEAD) and (model_head is not None) and (images_head is not None):
                f_head = _to_vec(model_head(images_head))

            if DROP_DEPTH:
                f_depth = torch.zeros(B, in_depth, device=device, dtype=torch.float32)
            else:
                f_depth = _to_vec(model_depth(images_depth))

            if DROP_OBJECT:
                f_obj = torch.zeros(B, 4, in_obj, device=device, dtype=torch.float32)
            else:
                f_obj_ = []
                for i in range(images_object.shape[1]):
                    object_i = images_object[:, i, :, :, :]
                    f_object_i = _to_vec(model_object(object_i))
                    f_obj_.append(f_object_i)
                f_obj = torch.stack(f_obj_, dim=1)

            if f_head is not None:
                out = model_caer(
                    f_context, f_body, f_depth, f_obj, obj_label, obj_dist,
                    label_emb_t, edge_sem_t, edge_cooccur_t, f_head=f_head
                )
            else:
                out = model_caer(
                    f_context, f_body, f_depth, f_obj, obj_label, obj_dist,
                    label_emb_t, edge_sem_t, edge_cooccur_t
                )

            if isinstance(out, (list, tuple)):
                if len(out) == 4:
                    pred_cat, label_emb_sem, label_emb_cooccur, _ = out
                elif len(out) == 3:
                    pred_cat, label_emb_sem, label_emb_cooccur = out
                else:
                    raise ValueError(f"model_caer returned {len(out)} outputs, expected 3 or 4.")
            else:
                raise ValueError("model_caer must return a tuple/list.")

            cat_loss_batch = disc_loss(pred_cat, labels_cat)

            ablate_no_label_loss = args.ablate in ["wohegr", "wo_sem_1024", "wo_cooccur_1024"]
            if (not ablate_no_label_loss) and (label_emb_sem is not None) and (label_emb_cooccur is not None):
                label_embedding_loss_batch = label_loss_fn(label_emb_sem, label_emb_cooccur, label_sem_t, label_occur_t)
            else:
                label_embedding_loss_batch = torch.tensor(0.0, device=device)

            loss_ratio = getattr(args, "loss_ratio", 0.8)
            if ablate_no_label_loss:
                loss = cat_loss_batch
            else:
                loss = loss_ratio * cat_loss_batch + (1.0 - loss_ratio) * label_embedding_loss_batch

            running_loss += loss.item()
            running_cat_loss += cat_loss_batch.item()
            running_label_loss += float(label_embedding_loss_batch.item()) if hasattr(label_embedding_loss_batch, "item") else 0.0

            loss.backward()
            opt.step()

        scheduler.step()

        avg_running_loss = running_loss / num_batches
        avg_cat_loss = running_cat_loss / num_batches
        avg_label_loss = running_label_loss / num_batches

        get_detailed = (e == args.epochs - 1) or (e % 10 == 9)
        if model_head is not None and not DROP_HEAD:
            pack = [model_context, model_body, model_head, model_depth, model_object, model_caer]
        else:
            pack = [model_context, model_body, model_depth, model_object, model_caer]

        if get_detailed:
            mAP, detailed_ap = test(
                pack, device, test_loader, len(test_dataset),
                label_emb_t, edge_sem_t, edge_cooccur_t,
                return_detailed=True, ablate=args.ablate, in_obj=in_obj
            )
        else:
            mAP = test(
                pack, device, test_loader, len(test_dataset),
                label_emb_t, edge_sem_t, edge_cooccur_t,
                return_detailed=False, ablate=args.ablate, in_obj=in_obj
            )
            detailed_ap = None

        metrics_collector.add_epoch_metrics(e, avg_running_loss, avg_cat_loss, avg_label_loss, mAP, detailed_ap)

        print(f"Epoch {e+1}/{args.epochs}:")
        print(f"  total loss: {avg_running_loss:.6f}")
        print(f"  classification loss: {avg_cat_loss:.6f}")
        print(f"  label loss: {avg_label_loss:.6f}")
        print(f"  mAP: {mAP:.6f}")

        if detailed_ap is not None:
            sorted_ap = sorted(detailed_ap.items(), key=lambda x: x[1], reverse=True)
            print("  Top-5 classes by AP:")
            for i, (category, ap_val) in enumerate(sorted_ap[:5]):
                print(f"    {i+1}. {category}: {ap_val:.4f}")
            print("  Bottom-5 classes by AP:")
            for i, (category, ap_val) in enumerate(sorted_ap[-5:]):
                print(f"    {len(sorted_ap)-4+i}. {category}: {ap_val:.4f}")

        if mAP >= best_map:
            best_map = mAP
            best_epoch = e
            print(f"  *** new best mAP: {best_map:.6f} (epoch {best_epoch+1}) ***")
        print("-" * 50)

    print("=" * 60)
    print("Best epoch:", best_epoch + 1, ", Best mAP:", best_map)
    print("=" * 60)

    metrics_collector.set_strategy_info(fusion_strategy, augmentation_strategy, scheduler_type)
    metrics_collector.save_to_json(output_dir=getattr(args, "output_dir", "./outputs"), run_tag=None)

    return best_map
