from collections import defaultdict
import os
import random
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.init as init
import sys
from argparse import ArgumentParser
from tqdm import tqdm
import numpy as np
from metrics import compute_metrics_top_k
# Import custom modules
model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../../models'))
sys.path.append(model_path)
from llava_clip_model_v3 import PredicateModel
from vidvrd_dataset import open_vidvrd_loader

# -----------------------
# 1) Distributed setup utilities
# -----------------------
def ddp_setup(rank, world_size):
    """
    Initialize the process group for DistributedDataParallel.
    Uses NCCL backend and sets local device to 'rank'.
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12446"
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

# -----------------------
# 2) Parse script arguments
# -----------------------
def parse_args(model_name=None, epoch_num=None):
    """
    Parse command line arguments for training/finetuning.
    Returns a namespace of parameters.
    """
    parser = ArgumentParser("Finetune/Train Script")

    parser.add_argument(
        "--dataset",
        type=str,
        default="vidvrd-dataset",
        choices=["vidvrd-dataset", "ActionGenome"],
        help="Which dataset to use (default: vidvrd-dataset)."
    )
    parser.add_argument(
        "--phase",
        type=str,
        default='train',
        help="Set to 'train' to run finetuning; 'test' or others for evaluation or custom phases."
    )

    parser.add_argument("--load-model", action="store_true", default=True,
                        help="Whether to attempt loading a previous model checkpoint.")
    parser.add_argument("--save-model", action="store_true", default=True,
                        help="Whether to save model checkpoints after each epoch.")
    parser.add_argument("--clip-model-name", type=str, default="openai/clip-vit-base-patch32",
                        help="Name of the CLIP model to load (huggingface style).")

    parser.add_argument("--test-num-top-pairs", type=int, default=30,
                        help="Max number of bounding-box pairs to consider for relationships.")
    parser.add_argument("--max-video-len", type=int, default=12,
                        help="Max length of a sampled video segment.")

    parser.add_argument("--train-num", type=int, default=5000,
                        help="Number of training samples (if your dataset loading logic supports it).")
    parser.add_argument("--val-num", type=int, default=1000,
                        help="Number of validation samples (if your dataset loading logic supports it).")
    parser.add_argument("--test-percentage", type=int, default=100,
                        help="Percentage of test data to use (if your dataset loading logic supports it).")

    # Basic training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1,
                        help="Batch size for both training and evaluation.")
    parser.add_argument("--seed", type=int, default=123,
                        help="Random seed for reproducibility.")
    parser.add_argument("--model-name", type=str, default=model_name,
                        help="Custom string to identify the model/checkpoint filenames.")
    parser.add_argument("--model-epoch", type=int, default=epoch_num,
                        help="If set, tries to load that specific epoch from a checkpoint.")
    parser.add_argument("--n_epochs", type=int, default=500,
                        help="Number of epochs to train/finetune for.")

    # Checkpoint and output directories
    parser.add_argument("--model-dir", type=str,
                        help="Path to the directory containing original/pretrained model checkpoints.")
    parser.add_argument("--finetune-save-dir", type=str,
                        help="Directory where finetuning checkpoints/logs will be saved.")

    parser.add_argument("--use-cuda", action="store_true", default=True,
                        help="Whether to use CUDA.")
    parser.add_argument("--use-half", action="store_true",
                        help="Whether to cast model to half precision (fp16).")
    parser.add_argument("--use-ddp", action="store_true",
                        help="Whether to use DistributedDataParallel multi-GPU training.")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="Specific GPU index if not using DDP or if single-GPU usage is desired.")
    parser.add_argument("--learning_rate", type=float, default=5e-5,
                        help="Learning rate for the optimizer.")

    # Data slicing / partial training
    parser.add_argument("--splice-start", type=int, default=0,
                        help="Start index for a data slice (for large dataset partial usage).")
    parser.add_argument("--splice-size", type=int, default=1,
                        help="Size of a data slice.")
    parser.add_argument("--ft_split", type=int, default=20,
                        help="Used to train on a fraction of the data. e.g. 10 => 10%, 5 => 5%, etc.")

    args = parser.parse_args()

    # Set random seeds
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # Construct data directory from script location
    script_dir = os.path.abspath(os.path.dirname(__file__))
    args.data_dir = os.path.abspath(os.path.join(script_dir, f"../../../../data_local/{args.dataset}"))

    # Log directory inside finetune-save-dir
    args.log_dir = os.path.join(args.finetune_save_dir, "logs")
    os.makedirs(args.log_dir, exist_ok=True)

    return args


import torch
import torch.nn.functional as F

def hinge_ranking_loss_pos_neg(pos_score, neg_score, margin=0.2):
    """
    Basic hinge-style ranking loss for a single positive and single negative score.
    L = max(0, margin + neg_score - pos_score).
    
    If you're using probabilities (pos_score, neg_score in [0,1]),
    choose a smaller margin like 0.1-0.3.
    """
    return F.relu(margin + neg_score - pos_score)


def compute_object_ranking_loss(
    cate_probs_dict,
    gt_oid_to_label,
    margin=0.2,
    top_neg=5
):
    """
    Object classification ranking loss:
    - For each object O, find the probability of the correct label (pos_score).
    - Then pick the top 'top_neg' highest-prob *incorrect* labels as negatives.
    - Accumulate hinge ranking losses for each negative vs. the positive.
    
    Arguments:
        cate_probs_dict: dict {(oid, cat_name): prob} from your model
        gt_oid_to_label: dict {oid: ground_truth_label_str}
        margin: margin for hinge
        top_neg: how many negative labels to sample for each object
        
    Returns:
        total_ranking_loss (torch scalar)
    """
    device = None
    all_losses = []
    
    # Step 1) Re-group predictions by OID
    pred_by_oid = {}
    for (oid, cat_str), prob_val in cate_probs_dict.items():
        if device is None and isinstance(prob_val, torch.Tensor):
            device = prob_val.device
        if oid not in pred_by_oid:
            pred_by_oid[oid] = []
        pred_by_oid[oid].append((cat_str, prob_val))
    
    # Step 2) For each OID, find pos_score and top negative scores
    for oid, preds in pred_by_oid.items():
        if oid not in gt_oid_to_label:
            # skip objects that don't have a GT label
            continue
        gt_label = gt_oid_to_label[oid]
        
        # Find the positive probability
        pos_score = None
        negative_list = []
        for (cat_str, prob_val) in preds:
            if cat_str == gt_label:
                pos_score = prob_val
            else:
                negative_list.append((cat_str, prob_val))
        
        # If we did not find a positive (the model never assigned the GT label),
        # let's define pos_score=0 or some small constant
        if pos_score is None:
            if device is not None:
                pos_score = torch.zeros(1, device=device)
            else:
                continue
        
        # Sort negatives by descending prob
        negative_list.sort(key=lambda x: x[1], reverse=True)
        # Take top_neg highest-prob negative labels
        hardest_negatives = negative_list[:top_neg]
        
        # Sum hinge losses
        for (_, neg_prob) in hardest_negatives:
            loss_ij = hinge_ranking_loss_pos_neg(pos_score, neg_prob, margin=margin)
            all_losses.append(loss_ij)
    
    if len(all_losses) == 0:
        return torch.tensor(0.0, device=device if device else "cpu")
    return torch.stack(all_losses, dim=0).mean()


def compute_relation_ranking_loss(
    bin_probs_dict,
    gt_relations_per_frame,
    gt_oid_to_label,
    rel_label_list,
    margin=0.2,
    top_neg=5
):
    """
    Multi-label ranking for relationships:
    
    bin_probs_dict: dict {(frame_id, (sid, oid), rel_name): prob}
    gt_relations_per_frame: e.g. for each frame f => list of (subj_id, obj_id, rel_str)
    gt_oid_to_label: {oid: object_label_str}, used to skip if missing
    rel_label_list: e.g. train_loader.dataset.predicates
    margin: hinge margin
    top_neg: pick top X negative relations for each pair
    """
    device = None
    all_losses = []
    
    # 1) Build a dict that merges all frames for each pair:
    #    pair_key = (fid, sid, oid). We'll accumulate { rel_name: prob }.
    pred_per_pair = {}
    for (fid, (sid, oid), rel_name), prob_val in bin_probs_dict.items():
        if device is None and isinstance(prob_val, torch.Tensor):
            device = prob_val.device
        pair_key = (fid, sid, oid)
        if pair_key not in pred_per_pair:
            pred_per_pair[pair_key] = {}
        pred_per_pair[pair_key][rel_name] = prob_val
    
    # 2) For each frame, we have a GT list of (sid, oid, rel_str).
    #    We'll accumulate a set of correct relations for each (fid, sid, oid).
    gt_per_pair = {}
    for fid, rels_in_f in enumerate(gt_relations_per_frame):
        for (sid, oid, rel_str) in rels_in_f:
            pair_key = (fid, sid, oid)
            if pair_key not in gt_per_pair:
                gt_per_pair[pair_key] = []
            gt_per_pair[pair_key].append(rel_str)
    
    # 3) For each pair_key, define the correct set of relation strings
    #    and get the predicted probabilities from pred_per_pair.
    for pair_key, gt_rels in gt_per_pair.items():
        # example: pair_key = (fid, sid, oid)
        # gt_rels could be multiple: ["on", "above"], etc.
        if pair_key not in pred_per_pair:
            # The model had no predictions for that pair => skip
            continue
        pred_rel_dict = pred_per_pair[pair_key]  # {rel_name: prob}
        
        # For each correct rel, get pos_score
        # For each incorrect rel, collect neg_score
        for correct_rel in gt_rels:
            pos_score = pred_rel_dict.get(correct_rel, None)
            if pos_score is None:
                # model assigned 0 prob for that correct label if missing
                if device is not None:
                    pos_score = torch.zeros(1, device=device)
                else:
                    continue
            
            # Build list of negative rels
            neg_rels = [r for r in rel_label_list if r not in gt_rels]
            # Sort them by predicted prob, descending
            neg_rels.sort(key=lambda r: pred_rel_dict.get(r, 0.0), reverse=True)
            hardest_negatives = neg_rels[:top_neg]
            a
            for nr in hardest_negatives:
                neg_score = pred_rel_dict.get(nr, None)
                if neg_score is None:
                    if device is not None:
                        neg_score = torch.zeros(1, device=device)
                    else:
                        continue
                loss_ij = hinge_ranking_loss_pos_neg(pos_score, neg_score, margin=margin)
                all_losses.append(loss_ij)
    
    if len(all_losses) == 0:
        return torch.tensor(0.0, device=device if device else "cpu")
    return torch.stack(all_losses, dim=0).mean()



# -----------------------
# 3) Trainer class
# -----------------------
class Trainer:
    def __init__(
        self,
        train_loader,
        device,
        dataset,
        learning_rate,
        ft_split,
        model_dir,
        finetune_save_dir,
        model_name,
        model_epoch,
        load_model,
        test_num_top_pairs,
        clip_model_name,
        use_half,
        world_size,
        use_ddp
    ):
        """
        This Trainer handles the creation/loading of the CLIP-based PredicateModel
        and implements a simple training loop over object classification + relationship classification.
        """
        self.dataset = dataset
        self.train_loader = train_loader
        self.device = device
        self.model_dir = model_dir
        self.finetune_save_dir = finetune_save_dir
        self.model_name = model_name
        self.world_size = world_size
        self.use_ddp = use_ddp
        self.ft_split = ft_split
        self.test_num_top_pairs = test_num_top_pairs
        self.epoch_ct = model_epoch
        self.learning_rate = learning_rate
    

        # Attempt to load a model checkpoint if requested
        predicate_model = None
        loaded_from_finetune_ckpt = False
        self.current_ft_epoch = 1
        if load_model:
            # 1) First look in finetune_save_dir for existing checkpoints
            if os.path.exists(self.finetune_save_dir):
                all_ckpt_files = os.listdir(self.finetune_save_dir)
                ft_ckpts = []

                # Gather .model files that contain model_name (if model_name is specified)
                for fn in all_ckpt_files:
                    if fn.endswith(".model"):
                        if self.model_name is not None and f"{self.model_name}_{self.epoch_ct}_fts{self.ft_split}_lr{self.learning_rate}" not in fn:
                            continue
                        # Attempt to parse epoch from e.g. "mymodel.3.model"
                        base = fn[:-6]  # remove ".model"
                        last_chunk = base.rsplit('.', 1)[-1]
                        if last_chunk.isdigit():
                            epoch_val = int(last_chunk)
                            ft_ckpts.append((epoch_val, fn))

                if ft_ckpts:
                    max_epoch, best_ckpt = max(ft_ckpts, key=lambda x: x[0])
                    ckpt_path = os.path.join(self.finetune_save_dir, best_ckpt)
                    print(f"[Info] Found fine-tuning checkpoints in {self.finetune_save_dir}")
                    print(f"[Info] Loading highest epoch {max_epoch} from {ckpt_path}")
                    model_info = torch.load(ckpt_path, map_location='cuda:'+str(self.device))
                    loaded_from_finetune_ckpt = True
                    self.current_ft_epoch = max_epoch + 1

                    if isinstance(model_info, PredicateModel):
                        predicate_model = model_info
                    elif isinstance(model_info, nn.parallel.distributed.DistributedDataParallel):
                        predicate_model = model_info.module
                    else:
                        predicate_model = PredicateModel(
                            hidden_dim=0,
                            num_top_pairs=test_num_top_pairs,
                            device=device,
                            model_name=clip_model_name,
                            multi_class = True
                        ).to(device)
                        predicate_model.load_state_dict(model_info)

            # 2) If no fine-tuning checkpoint was found, or not loaded, look in model_dir
            if (not loaded_from_finetune_ckpt) and os.path.exists(model_dir) and os.listdir(model_dir):
                print(f"[Info] Loading from pretrained directory: {model_dir}")
                current_model_names = [
                    nm for nm in os.listdir(model_dir)
                    if (model_name is not None and model_name in nm)
                ]
                model_ids = [nm.split('.')[-2] for nm in current_model_names]
                digital_model_ids = [int(mid) for mid in model_ids if str.isdigit(mid)]

                # Determine checkpoint epoch or "latest"
                if model_epoch is not None:
                    latest_model_id = model_epoch
                else:
                    if len(digital_model_ids) == 0:
                        latest_model_id = 'latest'
                    else:
                        latest_model_id = max(digital_model_ids)

                # Construct checkpoint path
                if isinstance(latest_model_id, int):
                    load_fname = f"{model_name}.{latest_model_id}.model"
                else:
                    load_fname = f"{model_name}.latest.model"
                ckpt_path = os.path.join(model_dir, load_fname)
                print(f"[Info] Loading checkpoint: {ckpt_path}")

                model_info = torch.load(ckpt_path, map_location='cuda:'+str(self.device))
                if isinstance(model_info, PredicateModel):
                    predicate_model = model_info
                elif isinstance(model_info, nn.parallel.distributed.DistributedDataParallel):
                    predicate_model = model_info.module
                else:
                    predicate_model = PredicateModel(
                        hidden_dim=0,
                        num_top_pairs=test_num_top_pairs,
                        device=device,
                        model_name=clip_model_name,
                        multi_class=True
                    ).to(device)
                    predicate_model.load_state_dict(model_info)

                if isinstance(latest_model_id, int):
                    self.epoch_ct = latest_model_id
                else:
                    self.epoch_ct = 0

        # If no model was successfully loaded above, create a fresh model
        if predicate_model is None:
            print("[Info] Constructing a fresh model (no checkpoint loaded).")
            predicate_model = PredicateModel(
                hidden_dim=0,
                num_top_pairs=test_num_top_pairs,
                device=device,
                model_name=clip_model_name
            ).to(device)

        # Optionally cast to half precision
        if use_half:
            predicate_model = predicate_model.half()

        # Ensure correct number of top pairs
        predicate_model.num_top_pairs = self.test_num_top_pairs
        self.predicate_model = predicate_model.to(device)
        
        if self.use_ddp:
            self.predicate_model = nn.parallel.DistributedDataParallel(
                self.predicate_model,
                device_ids=[device],  # or [rank] if you're using `device = rank`
                output_device=device,  # ensures output is on the same device
                find_unused_parameters=True
            )

        # Freeze everything except projection layers
        for name, param in self.predicate_model.named_parameters():
            if 'visual_projection' not in name and 'text_projection' not in name:
                param.requires_grad = False
            else:
                param.requires_grad = True
                print(f"[Trainable Layer]: {name}")
                # if 'clip_binary_model' in name:
                #     # For weight parameters (assumed to be at least 2D), use truncated normal
                #     if param.dim() > 1:
                #         init.trunc_normal_(param, std=0.02)
                #     else:
                #         # For bias or single-dimension parameters, initialize to zero
                #         init.zeros_(param)

        # Set up optimizer only on trainable parameters
        self.optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.predicate_model.parameters()),
            lr=self.learning_rate
        )

        # Print parameter counts
        total_params = sum(p.numel() for p in self.predicate_model.parameters())
        trainable_params = sum(p.numel() for p in self.predicate_model.parameters() if p.requires_grad)
        print(f"[Model Info] Total parameters: {total_params}")
        print(f"[Model Info] Trainable parameters: {trainable_params}")

        # Grab dataset label lists
        self.obj_labels = self.train_loader.dataset.objects
        self.rel_labels = self.train_loader.dataset.predicates
        # Mappings from label -> index
        self.obj_label_to_idx = {label: i for i, label in enumerate(self.obj_labels)}
        self.rel_label_to_idx = {label: i for i, label in enumerate(self.rel_labels)}
        
        
        #metric computation stuff
        # For accumulating epoch-level metrics across all batches
        self.precision_thres_list = [1, 5, 10]
        self.recall_thres_list = [1, 5, 10]
        
        # We'll keep all per-batch results in a list and average at epoch-end.
        # Each entry in epoch_metrics_storage will look like the dict returned by compute_metrics_top_k.
        self.epoch_metrics_storage = []

        # For running average display in tqdm (binary@1)
        self.running_bin_prec1_sum = 0.0
        self.running_bin_prec1_count = 0

        print("[Trainer] Initialization complete.")

    # -----------------------
    # 4) Train method
    # -----------------------
    def train(self, n_epochs=5):
        """
        Finetune the model on:
          1) Object classification (per bounding box): cross-entropy for correct category.
          2) Relationship classification (per bounding-box pair): multi-label BCE for predicates.
        Saves checkpoints each epoch if --save-model was used.
        """
        # Construct a base name for saving
        if self.model_name is None:
            base_savename = "finetune_epoch.model"
        else:
            base_savename = f"{self.model_name}_{self.epoch_ct}_fts{self.ft_split}_lr{self.learning_rate}.model"

        # Open log file
        log_file_path = os.path.join(self.finetune_save_dir, "logs", f"{base_savename}.log")
        os.makedirs(os.path.dirname(log_file_path), exist_ok=True)

        # Put model in train mode
        self.predicate_model.train()

        with open(log_file_path, "a") as f_log:
            for epoch in range(self.current_ft_epoch, n_epochs + 1):
                # If you're using a DistributedSampler, you must set the epoch
                if self.use_ddp:
                    # self.train_loader.sampler is a DistributedSampler
                    self.train_loader.sampler.set_epoch(epoch)
                epoch_loss = 0.0
                num_batches = 0

                loader_iter = tqdm(
                    self.train_loader,
                    desc=f"Epoch {epoch}/{n_epochs}, Avg Loss: -1",
                    leave=True
                )
                
                failed_bins = []

                for batch_data in loader_iter:
                    # Update progress bar description with running average loss

                    # Zero out gradients
                    self.optimizer.zero_grad()

                    # Unpack the batch data
                    batched_ids = batch_data["batched_ids"]
                    batched_reshaped_raw_videos = batch_data["batched_reshaped_raw_videos"]
                    batched_object_ids = batch_data["batched_object_ids"]
                    batched_gt_masks = batch_data["batched_gt_masks"]
                    batched_gt_bboxes = batch_data["batched_gt_bboxes"]
                    batched_gt_obj_names = batch_data["batched_gt_obj_names"]
                    batched_gt_object_rels = batch_data["batched_gt_object_rels"]
                    batched_video_splits = batch_data["batched_video_splits"]
                    batched_obj_pairs = batch_data["batched_obj_pairs"]

                    # Prepare category & relationship vocab for forward call
                    cate_kw = self.obj_labels
                    binary_kw = self.rel_labels
                    unary_kw = []  # not used, but model signature requires it

                    # For batch_size>1, replicate
                    repeated_cate = [cate_kw for _ in range(len(batched_ids))]
                    repeated_unary = [unary_kw for _ in range(len(batched_ids))]
                    repeated_binary = [binary_kw for _ in range(len(batched_ids))]

                    try:
                        (
                            batched_image_cate_probs,
                            _,
                            batched_image_binary_probs,
                            _
                        ) = self.predicate_model(
                            batched_video_ids=batched_ids,
                            batched_videos=batched_reshaped_raw_videos,
                            batched_masks=batched_gt_masks,
                            batched_bboxes=batched_gt_bboxes,
                            batched_names=repeated_cate,
                            batched_object_ids=batched_object_ids,
                            batched_unary_kws=repeated_unary,
                            batched_binary_kws=repeated_binary,
                            batched_obj_pairs=batched_obj_pairs,
                            batched_video_splits=batched_video_splits,
                            batched_binary_predicates=[None]*len(batched_ids),
                            multi_class=True
                        )
                    except Exception as e:
                        print(f"[Warning] Batch {batched_ids[0]} failed with error: {e}")
                        continue

                    # ------------------------
                    # (1) Object classification loss
                    #     We do a cross-entropy style loss for each bounding box.
                    #     Each bounding box has exactly 1 correct category among self.obj_labels.
                    # ------------------------
                    cate_probs_dict = batched_image_cate_probs[0]  # Dictionary: {(oid, cate_name) : probability}

                    # Build map: OID -> ground truth category name
                    oid_to_gt_label = {}
                    for ((vid_idx, fid, oid), (_, _, label_str)) in zip(batched_object_ids, batched_gt_obj_names):
                        if vid_idx == 0:  # for single-video batch
                            oid_to_gt_label[oid] = label_str

                    cate_loss = 0.0
                    cate_num = 0
                    epsilon = 1e-9
                    for (oid, cat_name), prob_val in cate_probs_dict.items():
                            gt_name = oid_to_gt_label[oid]
                            if gt_name == cat_name:
                                cate_loss += -torch.log(prob_val + epsilon)
                                cate_num += 1

                    # ------------------------
                    # (2) Relationship classification loss (multi-label BCE)
                    #     For each bounding-box pair in a frame, we may have multiple valid relations.
                    #     We sum the BCE across all predicates.
                    # ------------------------
                    bin_probs_dict = batched_image_binary_probs[0]  # {(frame_idx, (subj_oid, obj_oid), rel_name) : prob}
                    pair_to_relset = {}  # (frame_idx, (subj_oid, obj_oid)) -> set_of_relation_strings

                    # Gather GT relationships
                    for fid_idx, rel_list in enumerate(batched_gt_object_rels[0]):
                        for (sid, oid, rel_str) in rel_list:
                            pair_key = (fid_idx, (sid, oid))
                            if pair_key not in pair_to_relset:
                                pair_to_relset[pair_key] = set()
                            pair_to_relset[pair_key].add(rel_str)

                    rel_loss = 0.0
                    rel_num = 0
                    for pair_key, gt_relset in pair_to_relset.items():
                        # pair_key = (frame_idx, (subj_oid, obj_oid))
                        fid_ck, (sid_ck, oid_ck) = pair_key

                        # Build GT multi-hot vector
                        target_vec = torch.zeros(len(self.rel_labels), device=self.device)
                        for rel_str_gt in gt_relset:
    
                            idx = self.rel_label_to_idx[rel_str_gt]
                            target_vec[idx] = 1.0

                        # Gather predicted probabilities for each possible relation
                        pred_vec = torch.zeros(len(self.rel_labels), device=self.device)
                        for i, rel_lab in enumerate(self.rel_labels):
                            key_test = (fid_ck, (sid_ck, oid_ck), rel_lab)
                            pred_vec[i] = bin_probs_dict[key_test]

                        # Binary Cross Entropy for multi-label classification
                        rel_loss += nn.functional.binary_cross_entropy(
                            pred_vec.clamp(1e-9, 1.0 - 1e-9),
                            target_vec,
                            reduction='sum'
                        )
                        rel_num +=1

                    # 2) Ranking losses
                    obj_ranking_loss = compute_object_ranking_loss(
                        cate_probs_dict=cate_probs_dict,         # dict {(oid, cat_name) : prob}
                        gt_oid_to_label=oid_to_gt_label,         # e.g. {oid: label_str}
                        margin=0.2,
                        top_neg=5
                    )

                    rel_ranking_loss = compute_relation_ranking_loss(
                        bin_probs_dict=bin_probs_dict,           # dict {(fid,(sid,oid),rel_str): prob}
                        gt_relations_per_frame=batched_gt_object_rels[0],  # list-of-lists
                        gt_oid_to_label=oid_to_gt_label,
                        rel_label_list=self.rel_labels,          # e.g. train_loader.dataset.predicates
                        margin=0.2,
                        top_neg=5
                    )

                    alpha = 0.1  # how heavily you weight the ranking losses

                    ranking_loss = obj_ranking_loss + rel_ranking_loss
                    total_loss = (cate_loss + rel_loss) + alpha * ranking_loss
                    epoch_loss += total_loss.item()
                    num_batches += 1
                    total_loss.backward()
                    self.optimizer.step()
                    
                    # (2) Build the ground-truth structures for compute_metrics_top_k

                    # A. Build gt_object_dict: list of (video_id, object_id, gt_label_str)
                    #    We'll just take the first video id from batched_ids since batch_size=1
                    video_id = batched_ids[0] if len(batched_ids) > 0 else 0

                    gt_object_dict = []
                    for ((vid_idx, fid, oid), (_, _, label_str)) in zip(batched_object_ids, batched_gt_obj_names):
                        # If you want to use the "video_id" from your loader (rather than vid_idx=0):
                        gt_object_dict.append((video_id, oid, label_str))

                    # B. gt_object_rels:
                    #    For a single batch, batched_gt_object_rels is e.g. [ [ (sid, oid, rel_str), ...], [ ... ], ... ]
                    #    That is: a list per frame. We'll pass it directly.
                    gt_object_rels = batched_gt_object_rels[0]  # (since batch_size=1 => index 0)

                    # C. Convert the model outputs to the required dict structure:
                    #    cate_pred: { (oid, pred_label) : probability }
                    #    binary_pred: { (fid, (subj_id, obj_id), rel_str) : probability }
                    cate_pred = batched_image_cate_probs[0]  # e.g.  {(oid, cat_name): prob}
                    binary_pred = batched_image_binary_probs[0]  # e.g. {(fid, (sid, oid), rel_name): prob}

                    # (3) Compute metrics
                    batch_metrics, _, _ = compute_metrics_top_k(
                        gt_object_dict=gt_object_dict,         # list of (video_id, object_id, gt_label_str)
                        gt_object_rels=gt_object_rels,       # we wrap in a list-of-lists format if needed
                        cate_pred=cate_pred, 
                        binary_pred=binary_pred,
                        precision_thres_ls=self.precision_thres_list,
                        recall_thres_ls=self.recall_thres_list,
                        top_k_classes=1,
                    )

                    # (4) Update the running average for binary precision@1
                    bin_prec1_val = batch_metrics["precision"]["binary"][1]  # float
                    if bin_prec1_val<1:
                        failed_bins.append(video_id)
                    self.running_bin_prec1_sum += bin_prec1_val
                    self.running_bin_prec1_count += 1
                    bin_prec1_so_far = self.running_bin_prec1_sum / self.running_bin_prec1_count

                    # (5) Store batch metrics for epoch averaging
                    self.epoch_metrics_storage.append(batch_metrics)

                    # (6) Update tqdm description to include the running bin@1
                    loader_iter.set_description(
                        f"Epoch {epoch}/{n_epochs}, Avg Loss: {epoch_loss/max(1, num_batches):.4f}, bin@1: {bin_prec1_so_far:.4f}"
                    )

                    # Also log to file
                    f_log.write(f"Epoch {epoch}/{n_epochs}, Batch {num_batches}, "
                                f"AvgLossSoFar: {epoch_loss/max(1, num_batches):.4f}, "
                                f"bin@1_so_far: {bin_prec1_so_far:.4f}\n")
                    f_log.flush()
                
                # End-of-epoch
                avg_loss = epoch_loss / max(1, num_batches)
                
                #failed bins
                print(f"Failed bin@1 for this epoch:\n{failed_bins}")

                # Build accumulators for each threshold
                epoch_agg = {
                    "precision": {
                        "cate": {k: [] for k in self.precision_thres_list},
                        "binary": {k: [] for k in self.precision_thres_list}
                    },
                    "recall": {
                        "cate": {k: [] for k in self.recall_thres_list},
                        "binary": {k: [] for k in self.recall_thres_list}
                    }
                }

                # Accumulate per-batch
                for met in self.epoch_metrics_storage:
                    for metric_type in ["precision", "recall"]:
                        thres_list = self.precision_thres_list if metric_type == "precision" else self.recall_thres_list
                        for thres in thres_list:
                            epoch_agg[metric_type]["cate"][thres].append(met[metric_type]["cate"][thres])
                            epoch_agg[metric_type]["binary"][thres].append(met[metric_type]["binary"][thres])

                # Now compute the means
                final_epoch_metrics = {
                    "precision": {
                        "cate": {},
                        "binary": {}
                    },
                    "recall": {
                        "cate": {},
                        "binary": {}
                    }
                }
                for metric_type in ["precision", "recall"]:
                    thres_list = self.precision_thres_list if metric_type == "precision" else self.recall_thres_list
                    for thres in thres_list:
                        vals_cate = epoch_agg[metric_type]["cate"][thres]
                        vals_bin  = epoch_agg[metric_type]["binary"][thres]
                        mean_cate = sum(vals_cate) / len(vals_cate) if len(vals_cate) else 0.0
                        mean_bin  = sum(vals_bin)  / len(vals_bin)  if len(vals_bin)  else 0.0
                        final_epoch_metrics[metric_type]["cate"][thres]   = mean_cate
                        final_epoch_metrics[metric_type]["binary"][thres] = mean_bin

                # Log them
                print(f"[Epoch {epoch}/{n_epochs}] Avg Loss: {avg_loss:.4f} | "
                    f"Precision@1(cate): {final_epoch_metrics['precision']['cate'][1]:.4f}, "
                    f"Precision@1(binary): {final_epoch_metrics['precision']['binary'][1]:.4f}")

                f_log.write(f"End of Epoch {epoch}/{n_epochs} - Avg Loss: {avg_loss:.4f}\n")
                f_log.write(f"Epoch {epoch} metrics:\n{final_epoch_metrics}\n\n")
                f_log.flush()

                # Reset for next epoch
                self.epoch_metrics_storage.clear()
                self.running_bin_prec1_sum = 0.0
                self.running_bin_prec1_count = 0


                # Save checkpoint each epoch if --save-model is enabled
                # We replace the trailing "model" with e.g. "<epoch>.model"
                # e.g. mymodel.0.fts10.lr5e-4.model => mymodel.0.fts10.lr5e-4.<epoch>.model
                # so that we keep a separate file per epoch
                if num_batches > 0:
                    ckpt_name = base_savename.replace('model', f'{epoch}.model')
                    ckpt_path = os.path.join(self.finetune_save_dir, ckpt_name)
                    if os.path.isdir(self.finetune_save_dir):
                        torch.save(self.predicate_model.state_dict(), ckpt_path)
                        print(f"[Checkpoint Saved] {ckpt_path}")

# -----------------------
# 5) Main entry point
# -----------------------
def main(rank: int, world_size: int, args):
    """
    Main function that sets up distributed training (if needed),
    loads the dataset, instantiates the Trainer, and runs training (if phase=='train').
    """
    if args.use_ddp:
        ddp_setup(rank, world_size)
    device = rank if args.use_ddp else (args.gpu if args.gpu >= 0 else 0)

    # Potentially set up distributed sampler if we do multi-GPU
    from torch.utils.data.distributed import DistributedSampler
    sampler_class = DistributedSampler if args.use_ddp else None

    data_args = {
        "dataset_dir": args.data_dir,
        "batch_size": args.batch_size,
        "device": device,
        "training_percentage": 1,
        "testing_percentage": args.test_percentage,
        "max_video_len": args.max_video_len,
        "neg_kws": False,
        "neg_spec": False,
        "neg_example_ct": 0,
        "neg_example_file_name": "neg_examples.json",
        "backbone_model": "clip",
        "sampler": sampler_class,
        "splice_start": args.splice_start,
        "splice_size": args.splice_size,
        "ft_split": args.ft_split,
        "only_videos": []#'ILSVRC2015_train_00878001'#['ILSVRC2015_train_00085004', 'ILSVRC2015_train_00274001', 'ILSVRC2015_train_00878001'],
    }

    # Load data (train & test)
    # The loader returns (train_dataset, valid_dataset, train_loader, test_loader)
    # per your original dataset logic
    train_dataset, valid_dataset, train_loader, test_loader = open_vidvrd_loader(**data_args)

    if args.phase == "train":
        trainer = Trainer(
            train_loader=train_loader,
            device=device,
            dataset=args.dataset,
            learning_rate=args.learning_rate,
            ft_split=args.ft_split,
            model_dir=args.model_dir,
            finetune_save_dir=args.finetune_save_dir,
            model_name=args.model_name,
            model_epoch=args.model_epoch,
            load_model=args.load_model,
            test_num_top_pairs=args.test_num_top_pairs,
            clip_model_name=args.clip_model_name,
            use_half=args.use_half,
            world_size=world_size,
            use_ddp=args.use_ddp
        )
        trainer.train(n_epochs=args.n_epochs)
    else:
        print("[Info] Phase is not 'train'. Implement evaluation logic here if desired.")

    if args.use_ddp:
        dist.destroy_process_group()

# -----------------------
# 6) Script Runner
# -----------------------
if __name__ == "__main__":
    torch.multiprocessing.set_start_method('spawn', force=True)

    # Example overrides for model name & epoch
    model_name = "ensemble-2025-02-10-14-57-22"
    epoch_num = 0
    world_size = torch.cuda.device_count()

    # Parse arguments
    args = parse_args(model_name, epoch_num)

    # Run either single-GPU or multi-GPU
    if args.use_ddp and world_size > 1:
        mp.spawn(main, args=(world_size, args), nprocs=world_size)
    else:
        main(0, world_size, args)
    if args.use_ddp:
        dist.destroy_process_group()

    print("[Done] Finetuning complete.")
