import os

import json
import random
import cv2
from argparse import ArgumentParser
import math
from torch import nn, optim
from tqdm import tqdm
import torch
import scallopy
import sys
import pickle
import gc
import heapq
from datetime import datetime
import time

model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../models'))
assert os.path.exists(model_path)
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group


sys.path.append(model_path)
from open_pvsg_predicate_clip_model import PredicateModel
from openpvsg_dataset import *
from utils import *
import cProfile
# import warnings
# warnings.filterwarnings('always') 

def get_print_hook(name):
    def print_hook(grad):
        print(f"{name}: \n {grad} \n")
        return grad
    return print_hook
    
def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12357"
    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def filter_grad_thres(instances, names, probs, grads, device, top_k_grads=-1, grad_thres=0):
    under_grad_thres = {}
    above_grad_thres = {}
    
    assert probs.shape == torch.Size([0]) or probs.shape[0] == len(names)
    assert probs.shape == torch.Size([0]) or probs.shape[1] == len(instances)

    if probs is None or len(names) == 0 or grads is None:
        top_grad_thres = torch.tensor(0, device=device)
    else:   
        if not top_k_grads <= 0:
            num_grad = len(probs.reshape(-1))
            top_grad_thres = torch.min(torch.topk(torch.abs(grads.reshape(-1)), min(top_k_grads + 1, num_grad))[0])
        else:
            top_grad_thres = torch.tensor(0, device=device)
            
        for name, name_probs, name_grads in zip(names, probs, grads):
            for prob, grad, instance in zip( name_probs, name_grads, instances):
                if grad is not None and torch.abs(grad) > max(top_grad_thres, grad_thres):
                    above_grad_thres[(name, instance)] = prob, grad
                else: 
                    under_grad_thres[(name, instance)] = prob, grad
                    
    return above_grad_thres, under_grad_thres

class Trainer():
    
    def __init__(self, train_loader, test_loader, device, 
                 args,
                 caption2scl, common_scl_path,
                 latent_dim = 64,
                 provenance="difftopkproofsdebug", k=3, save_scl_dir=None,
                 use_neg_spec=False, use_neg_kws=False,
                 model_dir=None, model_name=None, 
                 learning_rate=None,
                 load_model=False, save_model=True,
                 video_save_dir=None,
                 train_num_top_pairs=100, 
                 test_num_top_pairs=300, 
                 report_dir=None,
                 neg_spec_weight=0.1,
                 neg_entity_kw_cate_weight=0,
                 neg_entity_kw_binary_weight=0.1,
                 clip_model_name="google/clip-base-patch16-224", 
                 use_half=True,
                 top_k_grads=30, 
                 segment_size=5,
                 world_size=1):

         # Dataset and scallop file setup
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.caption2scl = caption2scl
        self.template2action = {}
        self.common_scl = open(common_scl_path).read()
        self.save_model = save_model
        self.report_dir = report_dir
        self.model_dir = model_dir
        self.model_name = model_name
        self.top_k_grads = top_k_grads
        self.segment_size = segment_size
        self.args = args
        self.use_ddp = args.use_ddp
        self.use_sparse = args.use_sparse
        self.world_size = world_size
        
        # Contrastive learning type
        self.use_neg_spec = use_neg_spec
        self.use_neg_kws = use_neg_kws
        self.neg_spec_weight = neg_spec_weight
        self.neg_entity_kw_cate_weight = neg_entity_kw_cate_weight 
        self.neg_entity_kw_binary_weight = neg_entity_kw_binary_weight
        
        # Hyperparameter controlling the number of binary pairs to consider for effiency
        self.train_num_top_pairs = train_num_top_pairs
        self.test_num_top_pairs = test_num_top_pairs

        # Scallop context and forwarding setup
        self.scallop_ctx = scallopy.ScallopContext(provenance=provenance, k=k)
        self.scallop_ctx.import_file(common_scl_path)
        self.scallop_ctx.set_non_probabilistic(non_prob_gpt_prog_str_preds)
        
        self.reason = self.scallop_ctx.forward_function(output_mappings={
            "aligned_t1": None,
            "aligned_t2": None, 
            "aligned_t3": None,
        # }, dispatch="single")
        })
        
        self.reason.to(self.device)

        # Training continuation setups
        self.epoch_ct = 0
        
        # Setting up the STSG model
        if load_model and os.path.exists(model_dir) and len(os.listdir(model_dir)) > 0:
            
            # Load the latest model from given path
            current_model_names = [existing_model_name for existing_model_name in os.listdir(model_dir) if model_name in existing_model_name]
            model_ids = [model_name.split('.')[-2] for model_name in current_model_names]
            digital_model_ids = [int(model_id) for model_id in model_ids if str.isdigit(model_id)]
            
            if len(digital_model_ids) == 0 and 'latest' in digital_model_ids:
                latest_model_id = 'latest'
            else:
                latest_model_id = max(digital_model_ids)
            model_name = model_name + f'.{latest_model_id}.model'
            
            model_info = torch.load(os.path.join(model_dir, model_name), map_location="cuda:" + str(self.device))
            
            if type(model_info) == PredicateModel:
                predicate_model = model_info
            elif type(model_info) == torch.nn.parallel.distributed.DistributedDataParallel:
                predicate_model = model_info.module
            else:
                predicate_model = PredicateModel(hidden_dim = latent_dim, num_top_pairs=train_num_top_pairs, device=device, model_name=clip_model_name, use_sparse=args.use_sparse).to(device)
                predicate_model.load_state_dict(model_info)
             
            predicate_model.use_sparse = self.use_sparse
            predicate_model.device = self.device
            print(f"Loading: {model_name}")
            if type(latest_model_id) == int:
                self.epoch_ct = latest_model_id
        else:
            
            # Initialize a new predicate model
            predicate_model = PredicateModel(hidden_dim = latent_dim, num_top_pairs=train_num_top_pairs, device=device, model_name=clip_model_name, use_sparse=args.use_sparse).to(device)
        
        predicate_model.num_top_pairs = self.train_num_top_pairs
        
        if args.use_ddp:
            predicate_model = DDP(predicate_model, device_ids=[device], find_unused_parameters=True, static_graph=True)
            # predicate_model = DDP(predicate_model, device_ids=[device], find_unused_parameters=True)
            
        self.predicate_model = predicate_model
        
        # Setting up learning parameters
        self.optimizer = optim.Adam(self.predicate_model.parameters(), lr=learning_rate)
        self.min_loss = 10000000000
        self.use_half = use_half
        
        if use_half:
            self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
        else:
            self.loss_fn = nn.BCELoss(reduction='none')
                    
        # Debugging utils
        if not save_scl_dir is None:
            self.save_scl_dir = save_scl_dir
    
    def get_valid_result(self, frame_ids, probs):
        result = []
        for frame_id, prob in zip(frame_ids, probs):
            assert len(frame_id) == 1
            frame_id = frame_id[0]
            
            if frame_id == -1:
                continue
            
            result.append((prob, frame_id))
                
        result = sorted(result, key=lambda x: x[1])
        return result
    
    def baseline_eval_batch(self, batch):
        
        # Obtained batched data
        batched_ids = batch['batched_ids']
        batched_captions = batch['batched_captions']
        batched_gt_bboxes = batch['batched_gt_bboxes'] 
        # batched_gt_masks = batch['batched_gt_masks']
        batched_obj_pairs = batch['batched_obj_pairs']
        batched_object_ids = batch['batched_object_ids']
        batched_video_splits = batch['batched_video_splits']
        batched_reshaped_raw_videos = batch['batched_reshaped_raw_videos']
        batched_gt_labels = batch['batched_gt_obj_names']
        batched_gt_object_rels = batch['batched_gt_object_rels']
        batched_gpt_specs = batch['batched_gpt_specs']
        
        if len(batched_object_ids) == 0:
            return {}, {}
        
        # Set the required keywords corresponding to the ground truth scene graph
        batch_size = len(batched_ids)
        batched_unary_kws = [[] for _ in range(batch_size)] 
        batched_binary_kws = []
        batched_obj_labels = {i: {} for i in range(batch_size)}
        
        for (vid, fid, label), (vid, fid, oid) in zip(batched_gt_labels, batched_object_ids):
            
            if not vid in batched_obj_labels:
                batched_obj_labels[vid] = {}
            
            if oid in batched_obj_labels[vid]:
                assert label == batched_obj_labels[vid][oid]
            batched_obj_labels[vid][oid] = label
        
        batched_obj_names = []
        for vid, vid_info in batched_obj_labels.items():
            obj_names = set()
            for oid, o_name in vid_info.items():
                obj_names.add(o_name)
            batched_obj_names.append(list(obj_names))
            
        for vid, gt_object_rels in enumerate(batched_gt_object_rels):
            binary_kws = set()
            for fid, obj_pairs in enumerate(gt_object_rels):
                for  (sub, obj, binary_kw) in obj_pairs:
                    if (vid, fid, (sub, obj)) in batched_obj_pairs:
                        binary_kws.add(binary_kw)
                        
            batched_binary_kws.append(list(binary_kws))
            
        # Obtain the predictions with pretrained model
       
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=batched_obj_names,
                                batched_object_ids=batched_object_ids,
                                batched_unary_kws=batched_unary_kws,
                                batched_binary_kws=batched_binary_kws,
                                batched_gt_obj_pairs=batched_gt_object_rels, 
                                batched_video_splits=batched_video_splits,
                                unary_segment_size=self.segment_size,
                                binary_segment_size=self.segment_size)
        
        # Only categories live in the ground truth scene graph
        # Process the categories as predicates
        batched_cate_pred_scl = []
        for vid, (image_cate_probs, _) in enumerate(batched_image_unary_probs):
            
            cate_pred_scl = {}
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
                
            for cate_name, probs in zip(image_cate_probs[1], image_cate_probs[0]):
                for prob, (fid, oid) in zip(probs, object_ids):
                    if not oid in cate_pred_scl:
                        cate_pred_scl[oid] = {}
                    if not cate_name in cate_pred_scl[oid]:
                        cate_pred_scl[oid][cate_name] = []
                    cate_pred_scl[oid][cate_name].append(prob)
                    
            new_cate_pred_scl = []
            
            for oid, object_cate_info in cate_pred_scl.items():
                for cate_name, prob in object_cate_info.items():
                    new_cate_pred_scl.append((torch.mean(torch.stack(prob)), (oid, cate_name)))
                    
            batched_cate_pred_scl.append(new_cate_pred_scl)
            
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            # object_pairs = [ pair for pair in selected_pairs[vid]]
            object_pairs = [ (fid, pair) for (ovid, fid, pair) in selected_pairs if vid == ovid]

            assert (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(object_pairs) == image_binary_probs[0].shape[1]
            
            for binary_pred_name, probs in zip(image_binary_probs[1], image_binary_probs[0]):
 
                for prob, (fid, pair) in zip(probs, object_pairs):
                    binary_pred_scl.append((prob, (binary_pred_name, fid, pair[0], pair[1])))

            batched_binary_pred_scl.append(binary_pred_scl)
        
        # Test the accuracy
        result_unary, result_binary = self.accu(batched_obj_labels, batched_gt_object_rels, batched_cate_pred_scl, batched_binary_pred_scl)
        
        return result_unary, result_binary
    
    # Accuracy Metric
    def accu(self, batched_obj_labels, batched_gt_object_rels, batched_image_cate_probs, batched_image_binary_probs, top_pair_num=100):
        result_unary = {}
        result_binary = {}
        
        for vid, (image_cate_probs, binary_pred, gt_object_rels) in enumerate(zip(batched_image_cate_probs, batched_image_binary_probs, batched_gt_object_rels)):

            top_binary_preds_heap = []
            obj_labels = batched_obj_labels[vid]
            
            for rela_prob, (rela_name, fid, sub, obj) in binary_pred:
                heapq.heappush(top_binary_preds_heap, (rela_prob, (rela_name, fid, sub, obj)))
            top_binary_preds = heapq.nlargest(top_pair_num, top_binary_preds_heap)
            
            for rela_prob, (rela_name, fid, sub, obj) in top_binary_preds:
                
                if not rela_name in result_binary:
                    result_binary[rela_name] = {}
                    result_binary[rela_name]['gt'] = []
                    result_binary[rela_name]['pred'] = []
                    
                if (sub, obj, rela_name) in gt_object_rels[fid]:
                    result_binary[rela_name]['gt'].append(1)
                else:
                    result_binary[rela_name]['gt'].append(0)
                    
                if rela_prob > 0.5:
                    result_binary[rela_name]['pred'].append(1)
                else:
                    result_binary[rela_name]['pred'].append(0)
                                    
            obj_pred = {}
            for cate_prob, (oid, obj_name) in image_cate_probs:
                if not oid in obj_pred:
                    obj_pred[oid] = (cate_prob, obj_name)
                if cate_prob > obj_pred[oid][0]:
                    obj_pred[oid] = (cate_prob, obj_name)
                    
            for oid, obj_name in obj_labels.items():
                
                if not obj_name in result_unary:
                    result_unary[obj_name] = {}
                    result_unary[obj_name]['gt'] = []
                    result_unary[obj_name]['pred'] = []
                    
                if  obj_labels[oid] == obj_name:
                    result_unary[obj_name]['gt'].append(1)
                else:
                    result_unary[obj_name]['gt'].append(0)
                
                if obj_pred[oid][1] == obj_name:
                    result_unary[obj_name]['pred'].append(1)
                else:
                    result_unary[obj_name]['pred'].append(0)
                
        return result_unary, result_binary
    
    # Recall Metric
    def recall(self, gt_object_dict, unary_pred, gt_object_rels, binary_pred, recall_thres_ls = [1, 5, 10]):
        result_unary = {}
        result_binary = {}
        
        for recall_thres in recall_thres_ls:
            result_unary[recall_thres] = []
            result_binary[recall_thres] = []
            
        # rela_target = []
        rela_pred = []
        cate_target = []
        cate_pred = [] 
        top_binary_preds_heap = []
        
        for oid, gt_label in gt_object_dict.items():
            pred = unary_pred[oid]
            sorted_pred = sorted([(v, k) for k, v in pred.items()], reverse=True) 
            
            for recall_thres in recall_thres_ls:
                top_pred_ls = [int(cid) for p, cid in sorted_pred[:recall_thres]]
                if gt_label in top_pred_ls:
                    result_unary[recall_thres].append(1)
                else:
                    result_unary[recall_thres].append(0)
        
        # aggr_result_unary = {}
        # for recall_thres, vals in result_unary.items():
        #     aggr_result_unary[recall_thres] = sum(vals) / len(vals)
        
        gt_rel_dict = {}
        for fid, rel_ls in gt_object_rels.items():
            for (from_id, to_id, rel) in rel_ls:
                gt_rel_dict[(fid, from_id, to_id)] = rel

        for (fid, from_id, to_id), gt_label in gt_rel_dict.items():
        # for (fid, from_id, to_id), binary_pred_ls in binary_pred.items():
            if not (fid, from_id, to_id) in binary_pred:
                binary_pred_ls = []
            else:
                binary_pred_ls = binary_pred[(fid, from_id, to_id)]
            # gt_label = gt_rel_dict[(fid, from_id, to_id)]
            for recall_thres in recall_thres_ls:
                top_pred_ls = [bid for p, bid in binary_pred_ls[:recall_thres]]
                
                if gt_label in top_pred_ls:
                    result_binary[recall_thres].append(1)
                else:
                    result_binary[recall_thres].append(0)

        return result_unary, result_binary

    
    def neg_sample_loss(self, batched_unary_pred, batched_binary_pred, batched_neg_examples):
        batched_neg_sample_loss = []
        for unary_pred, binary_pred, neg_sample in zip(batched_unary_pred, batched_binary_pred, batched_neg_examples):
            ((image_cate_probs, cates), (image_unary_probs, unary_kws)) = unary_pred
            image_binary_probs, binary_kws =  binary_pred
            neg_cate = neg_sample['neg_entity']
            neg_binary = neg_sample['neg_binary']
        
            cate_indexes = [cate_id for cate_id, cate in enumerate(cates) if cate in neg_cate]
            if not len(cate_indexes) == 0:
                neg_cate_probs = image_cate_probs[cate_indexes, :]
            else:
                neg_cate_probs = torch.tensor([]).to(self.device)
            target_cate_probs = torch.zeros(neg_cate_probs.shape).to(self.device)
            
            binary_indexes = [ binary_id for binary_id, binary_kw in enumerate(binary_kws) if binary_kw in neg_binary]
            if not len(binary_indexes) == 0 and not len(image_binary_probs) == 0:
                neg_binary_probs = image_binary_probs[binary_indexes, :]
            else:
                neg_binary_probs = torch.tensor([]).to(self.device)
            target_binary_probs = torch.zeros(neg_binary_probs.shape).to(self.device)
            
            cate_loss, binary_loss = torch.sum(self.loss_fn(neg_cate_probs, target_cate_probs)), torch.sum(self.loss_fn(neg_binary_probs, target_binary_probs))
            batched_neg_sample_loss.append((cate_loss, binary_loss))
        return batched_neg_sample_loss
    
    # Loss function
    def loss(self, 
             batched_t1s,
             batched_t2s,
             batched_t3s,
             batched_action_specs, batched_ys, batched_video_splits, encourage_prop = 0.3, eps = 1e-15):
        batched_loss = []
        batched_video_length = []
        
        # batched_t1_frame_ids, batched_t1probs, 
        # batched_t2_frame_ids, batched_t2probs, 
        # batched_t3_frame_ids, batched_t3probs, 
        t1_frame_ids = batched_t1s[0]
        t2_frame_ids = batched_t2s[0]
        t3_frame_ids = batched_t3s[0]
        current_vid_id = 0
        for video_splits in batched_video_splits:
            batched_video_length.append(video_splits - current_vid_id)
            current_vid_id = video_splits
            
        for t1probs, t2probs, t3probs, y, action_spec, video_length in \
            zip(batched_t1s[1], 
                batched_t2s[1], 
                batched_t3s[1], 
                batched_ys, batched_action_specs, batched_video_length):

            t1_result = self.get_valid_result(t1_frame_ids, t1probs)
            t2_result = self.get_valid_result(t2_frame_ids, t2probs)
            t3_result = self.get_valid_result(t3_frame_ids, t3probs)
            
            if len(t1_result) == 0 and len(t2_result) == 0 and len(t3_result) == 0:
                batched_loss.append(torch.tensor(0.0, device=self.device))
                continue

            results = [t1_result, t2_result, t3_result]
            locations = action_spec['video location'][:3]
            
            assert len(locations) <= 3 and len(locations) > 0
            
            current_loss = []
            for result, location in zip(results, locations):
                if not location in location_consts:
                    continue
                
                encourage_len = math.ceil(video_length * encourage_prop)
                score_for_dist = 1 / encourage_len
                
                # encourage the first part of the total len
                if location == "early":
                    start = 0
                    end = start + encourage_len
                    
                    # TODO: optimize
                    weights = []
                    for i in range(video_length):
                        if i > end:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (end - i + 1)
                            weights.append(weight)
                            
                elif location == "mid":
                    mid = math.ceil(video_length / 2)
                    dis =  math.ceil(encourage_len / 2)
                    start = mid - dis
                    end = mid + dis
                    
                    weights = []
                    for i in range(video_length):
                        if i > end:
                            weights.append(0)
                        elif i < start:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (dis - abs(mid - i) + 1)
                            weights.append(weight)
                            
                else:
                    end = video_length
                    start = end - encourage_len
                    
                    weights = []
                    for i in range(video_length):
                        if i < start:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (i - start + 1)
                            weights.append(weight)
                
                valid_probs = []
                valid_weights = []
                target_y = []
                for prob, frame_id in result:
                    if frame_id >= video_length:
                        continue
                    weight = weights[frame_id]
                    if not weight == 0:
                        valid_probs.append(prob)
                        valid_weights.append(weight)
                        target_y.append(y)
                        
                valid_weights = torch.tensor(valid_weights, device=self.device)
                if len(valid_probs) == 0:
                    continue
                
                if self.use_half:
                    valid_probs = torch.stack(valid_probs)
                    valid_prob_logits = torch.log( valid_probs / (1 - valid_probs))
                    loss = self.loss_fn(valid_prob_logits, torch.tensor(target_y, dtype=valid_probs[0].dtype, device=self.device))
                else:
                    loss = self.loss_fn(torch.stack(valid_probs), torch.tensor(target_y, dtype=torch.float32, device=self.device))
                
                loss = loss * valid_weights
                loss = (loss.sum() / valid_weights.sum())
                current_loss.append(loss)
            
            if len(current_loss) == 0:
                batched_loss.append(torch.tensor(0.0, device=self.device))
            else:
                batched_loss.append(torch.mean(torch.stack(current_loss)))
            # For smaller window, it has lower likelihood of actually capturing the operation
            # We thus assign a weight function for the

        return batched_loss

    def forward(self, batch):
        
        # print(f"start forwarding: {self.device}")
        # Load batch info
        batched_ids = batch['batched_ids']
        batched_captions = batch['batched_captions']
        batched_gt_bboxes = batch['batched_gt_bboxes'] 
        # batched_gt_masks = batch['batched_gt_masks']
        batched_obj_pairs = batch['batched_obj_pairs']
        batched_object_ids = batch['batched_object_ids']
        batched_video_splits = batch['batched_video_splits']
        batched_reshaped_raw_videos = batch['batched_reshaped_raw_videos']
        batched_gt_labels = batch['batched_gt_obj_names']
        batched_gt_object_rels = batch['batched_gt_object_rels']
        batched_gpt_specs = batch['batched_gpt_specs']
        
        if len(batched_object_ids) == 0:
            return []
                
        batch_size = len(batched_ids)
        
        # Fetch constants
        batched_unary_kws = []
        batched_binary_kws = []
        batched_consts = []
        batched_pos_consts = []
        batched_neg_consts = []
            
        # Contrastive Learning Setup
        for spec in batched_gpt_specs:
            batched_unary_kws.append(spec['unary_kws'])
            batched_binary_kws.append(spec['binary_kws'])
            batched_consts.append(spec['consts'])
            batched_pos_consts.append(spec['consts'])
        
        if self.use_neg_spec:
            # TODO: fix this
            batched_neg_gpt_specs = batch['batched_neg_gpt_specs']
            for batch_id, (spec, neg_spec) in enumerate(zip(batched_gpt_specs, batched_neg_gpt_specs)):
                deduped_neg_binary = list(set(neg_spec['binary_kws']) - set(batched_binary_kws[batch_id]))
                deduped_neg_entity = list(set(neg_spec['consts']) - set(batched_consts[batch_id]))
                deduped_neg_unary = list(set(neg_spec['unary_kws']) - set(batched_unary_kws[batch_id]))

                batched_unary_kws[batch_id] += (deduped_neg_unary)
                batched_binary_kws[batch_id] += (deduped_neg_binary)
                batched_consts[batch_id] += (deduped_neg_entity)
                batched_neg_consts.append(neg_spec['consts'])
                
        if self.use_neg_kws:
            batched_neg_kws = batch['batched_neg_kws']
            for batch_id, (spec, negative_examples) in enumerate(zip(batched_gpt_specs, batched_neg_kws)):
                deduped_neg_binary = list(set(negative_examples['neg_binary']) - set(batched_binary_kws[batch_id]))
                deduped_neg_entity = list(set(negative_examples['neg_entity']) - set(batched_consts[batch_id]))

                batched_binary_kws[batch_id] += (deduped_neg_binary)
                batched_consts[batch_id] += (deduped_neg_entity)
        
        # print(f"calling predicting model: {self.device}")

        # Get probabilities
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=batched_consts,
                                batched_object_ids = batched_object_ids,
                                batched_unary_kws=batched_unary_kws,
                                batched_binary_kws=batched_binary_kws,
                                batched_gt_obj_pairs=batched_gt_object_rels, 
                                batched_video_splits=batched_video_splits, 
                                unary_segment_size=self.segment_size,
                                binary_segment_size=self.segment_size,)
        
        # consts = [e for c in batched_pos_consts for e in c]
        const_lookup = [{} for _ in range(batch_size)]
        neg_const_lookup = [{} for _ in range(batch_size)]
        cids = [[] for _ in range(batch_size)]
        batched_loss = [[] for _ in range(batch_size)]
        
        for vid, consts in enumerate(batched_pos_consts):
            for k, v in enumerate(consts):
                const_lookup[vid][v] = -k
                const_lookup[vid][v.upper()] = -k
                const_lookup[vid][v.lower()] = -k
                cids[vid].append(-k)
        
        if self.use_neg_spec:       
            # neg_consts = [e for c in batched_neg_consts for e in c]
            neg_cids = [[] for _ in range(batch_size)]
            for vid, neg_consts in enumerate(batched_neg_consts):
                for k, v in enumerate(neg_consts):
                    neg_const_lookup[vid][v] = -k
                    neg_const_lookup[vid][v.upper()] = -k
                    neg_const_lookup[vid][v.lower()] = -k
                    neg_cids[vid].append(-k)

        # batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_object_ids)
        
        # Process unary predicates
        batched_unary_pred_scl = []
        batched_cate_pred_scl = []
        if self.use_neg_spec:
            batched_neg_cate_pred_scl = []
            
        for vid, (image_cate_probs, image_unary_probs) in enumerate(batched_image_unary_probs):
            
            unary_pred_scl = []
            cate_pred_scl = {}
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
                
            for cate_name, probs in zip(image_cate_probs[1], image_cate_probs[0]):
                
                for prob, (fid, oid) in zip(probs, object_ids):
                    if not oid in cate_pred_scl:
                        cate_pred_scl[oid] = {}
                    if not cate_name in cate_pred_scl[oid]:
                        cate_pred_scl[oid][cate_name] = []
                    cate_pred_scl[oid][cate_name].append(prob)
                    
            new_cate_pred_scl = []
            for oid, object_cate_info in cate_pred_scl.items():
                for cate_name, prob in object_cate_info.items():
                    if cate_name in const_lookup[vid]:
                        prob = torch.mean(torch.stack(prob))
                        new_cate_pred_scl.append((prob, (oid, const_lookup[vid][cate_name] - 1)))
            
            for unary_pred_name, probs in zip(image_unary_probs[1], image_unary_probs[0]):
                
                for prob, (fid, oid) in zip(probs, object_ids):
                    unary_pred_scl.append((prob, (unary_pred_name, fid, oid)))
            
            batched_cate_pred_scl.append(new_cate_pred_scl)
            batched_unary_pred_scl.append(unary_pred_scl)
            
            if self.use_neg_spec:
                new_neg_cate_pred_scl = []
                for oid, object_cate_info in cate_pred_scl.items():
                    for cate_name, prob in object_cate_info.items():
                        if cate_name in neg_const_lookup:
                            prob = torch.mean(torch.stack(prob))
                            new_neg_cate_pred_scl.append((prob, (oid, neg_const_lookup[vid][cate_name] - 1)))
                batched_neg_cate_pred_scl.append(new_neg_cate_pred_scl)
           
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            object_pairs = [ (fid, pair) for (ovid, fid, pair) in selected_pairs if vid == ovid]
            assert (type(image_binary_probs[0]) == torch.Tensor and len(image_binary_probs[0]) == 0) or (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(object_pairs) == image_binary_probs[0].shape[1]
            
            for binary_pred_name, probs in zip(image_binary_probs[1], image_binary_probs[0]):
                for prob, (fid, pair) in zip(probs, object_pairs):
                    binary_pred_scl.append((prob, (binary_pred_name, fid, pair[0], pair[1])))

            batched_binary_pred_scl.append(binary_pred_scl)

        formatted_batched_scl_input_facts = format_batched_facts(batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs)
        
        # Ground truth is 1 as the batch size is always 1
        # TODO: Update this for contrastive setup
        batched_ys = [1] * batch_size
        # print(f"calling scallop: {self.device}")

        output = self.reason(**formatted_batched_scl_input_facts)
        batched_t1 = []
        batched_t1probs = []
        batched_loss = []
        batched_t1s = output['aligned_t1']
        batched_t2s = output['aligned_t2']
        batched_t3s = output['aligned_t3']

        has_no_answer = (len(output['aligned_t1'][0]) == 1 and output['aligned_t1'][0][0][0] == -1)
        if has_no_answer:
            print(f'Warning: No anwer: {batched_ids}')

        batched_loss = self.loss(batched_t1s, batched_t2s, batched_t3s, batched_gpt_specs, batched_ys, batched_video_splits)
        formatted_batched_neg_scl_input_facts = [[] for _ in range(batch_size)]
        
        if self.use_neg_spec:
            formatted_batched_neg_scl_input_facts = format_batched_facts(batched_scl_tps, batched_neg_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_neg_gpt_specs)
            output = self.reason(**formatted_batched_neg_scl_input_facts)
            
            batched_t1s = output['aligned_t1']
            batched_t2s = output['aligned_t2']
            batched_t3s = output['aligned_t3']
            
            neg_batched_ys = [0] * batch_size
            
            batched_neg_spec_loss =  self.loss(batched_t1s, batched_t2s, batched_t3s, batched_neg_gpt_specs, neg_batched_ys, batched_video_splits)
            for batch_id, neg_spec_loss in enumerate(batched_neg_spec_loss):
                if len(batched_loss) > 0:
                    batched_loss[batch_id] += self.neg_spec_weight * neg_spec_loss
                else:
                    batched_loss.append(self.neg_spec_weight * neg_spec_loss)
            
        if self.use_neg_kws:

            batched_neg_sample_loss = self.neg_sample_loss(batched_image_unary_probs, batched_image_binary_probs, batched_neg_kws)            
            for batch_id, (cate_loss, binary_loss) in enumerate(batched_neg_sample_loss):
                if len(batched_loss) > 0:
                    batched_loss[batch_id] += self.neg_entity_kw_cate_weight * cate_loss + self.neg_entity_kw_binary_weight * binary_loss
                else:
                    batched_loss.append(self.neg_entity_kw_cate_weight * cate_loss + self.neg_entity_kw_binary_weight * binary_loss)

        # hook_fn = self.get_register_backward_hook_fn(batch, batched_image_unary_probs, batched_image_binary_probs, batched_selected_pairs, 
            # raw_cate_logits_per_text, raw_unary_logits_per_text, raw_binary_logits_per_text, grad_thres)
        # print(f"finished forwarding: {self.device}")
        
        return batched_loss

    def backward_v2(self, loss, batch, batched_image_unary_probs, batched_image_binary_probs, batched_selected_pairs, 
        raw_cate_logits_per_text, raw_unary_logits_per_text, raw_binary_logits_per_text, 
        batched_consts, batched_pos_consts, batched_neg_consts, const_lookup, neg_const_lookup, grad_thres = 0):
    
        # print(f"start trainer backward: {self.device}")
        # Load batch info
        batched_ids = batch['batched_ids']
        batched_captions = batch['batched_captions']
        batched_gt_bboxes = batch['batched_gt_bboxes'] 
        # batched_gt_masks = batch['batched_gt_masks']
        batched_obj_pairs = batch['batched_obj_pairs']
        batched_object_ids = batch['batched_object_ids']
        batched_video_splits = batch['batched_video_splits']
        batched_reshaped_raw_videos = batch['batched_reshaped_raw_videos']
        batched_gt_labels = batch['batched_gt_obj_names']
        batched_gt_object_rels = batch['batched_gt_object_rels']
        batched_gpt_specs = batch['batched_gpt_specs']
        
        batch_size = len(batched_ids)
        batched_cate_pred = [{} for _ in range(batch_size)]
        batched_unary_pred = [{} for _ in range(batch_size)]
        
        batched_d_loss_cate_grads = [[] for _ in range(batch_size)]
        batched_d_loss_unary_grads = [[] for _ in range(batch_size)]
        batched_d_loss_binary_grads = [[] for _ in range(batch_size)]
        
        for vid in range(batch_size):
            
            if batched_image_unary_probs[vid][0][0].requires_grad:
                batched_d_loss_cate_grads[vid] = torch.autograd.grad(outputs=loss, inputs=batched_image_unary_probs[vid][0][0], allow_unused=True, retain_graph=True)[0]
            else:
                batched_d_loss_cate_grads[vid] = None
                
            if batched_image_unary_probs[vid][1][0].requires_grad:
                batched_d_loss_unary_grads[vid] = torch.autograd.grad(outputs=loss, inputs=batched_image_unary_probs[vid][1][0], allow_unused=True, retain_graph=True)[0]
            else:
                batched_d_loss_unary_grads[vid] = None
                
            if batched_image_binary_probs[vid][0].requires_grad:
                batched_d_loss_binary_grads[vid] = torch.autograd.grad(outputs=loss, inputs=batched_image_binary_probs[vid][0], allow_unused=True, retain_graph=True)[0]
            else:
                batched_d_loss_binary_grads[vid] = None
                
        for vid, (image_cate_probs, image_unary_probs) in enumerate(batched_image_unary_probs):
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
            
            above_grad_cates, under_grad_cates = filter_grad_thres(object_ids, image_cate_probs[1], image_cate_probs[0], batched_d_loss_cate_grads[vid], self.device, self.top_k_grads, grad_thres)
            batched_cate_pred[vid] = {"above": above_grad_cates, "under": under_grad_cates}
 
            above_grad_unary, under_grad_unary = filter_grad_thres(object_ids, image_unary_probs[1], image_unary_probs[0], batched_d_loss_unary_grads[vid], self.device, self.top_k_grads, grad_thres)
            batched_unary_pred[vid] = {"above": above_grad_unary, "under": under_grad_unary}

                
        # Process binary predicates
        batched_binary_pred = [{} for _ in range(batch_size)]
        new_batched_selected_pairs = {i: [] for i in range(batch_size)}
        
        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            
            object_pairs = batched_selected_pairs[vid]
            assert (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(object_pairs) == image_binary_probs[0].shape[1]
            
            above_grad_binary, under_grad_binary = filter_grad_thres(object_pairs, image_binary_probs[1], image_binary_probs[0], batched_d_loss_binary_grads[vid], self.device, self.top_k_grads, grad_thres)
            batched_binary_pred[vid] = {"above": above_grad_binary, "under": under_grad_binary}

        for vid, binary_pred in enumerate(batched_binary_pred):
            for _, pair in binary_pred['above']:
                if not pair in new_batched_selected_pairs[vid]:
                    new_batched_selected_pairs[vid].append(pair)
                            
        if self.use_ddp:
            model = self.predicate_model.module
        else:
            model = self.predicate_model
            
        batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_neg_cate_pred_scl = model.backward_v2(
                batched_data_ids = batched_ids,
                batched_names = batched_consts,
                batched_used_obj_pairs = new_batched_selected_pairs,
                batched_origin_used_obj_pairs = batched_selected_pairs,
                batched_image_unary_probs_orig = batched_image_unary_probs, 
                batched_image_binary_probs_orig = batched_image_binary_probs,
                batched_cate_pred=batched_cate_pred,
                batched_unary_pred=batched_unary_pred,
                batched_binary_pred=batched_binary_pred,
                batched_videos=batched_reshaped_raw_videos,
                batched_bboxes=batched_gt_bboxes, 
                batched_object_ids = batched_object_ids,
                batched_video_splits=batched_video_splits,
                const_lookup=const_lookup,
                neg_const_lookup=neg_const_lookup,
                raw_cate_logits_per_text=raw_cate_logits_per_text, 
                raw_unary_logits_per_text=raw_unary_logits_per_text, 
                raw_binary_logits_per_text=raw_binary_logits_per_text,)
        
        
        batched_scl_tps = construct_batched_scl_tps(batched_object_ids)
        formatted_batched_scl_input_facts = format_batched_facts(batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs)
        
        # Ground truth is 1 as the batch size is always 1
        # TODO: Update this for contrastive setup
        batched_ys = [1] * batch_size
        # print(f"calling scallop: {self.device}")

        output = self.reason(**formatted_batched_scl_input_facts)
        batched_t1 = []
        batched_t1probs = []
        batched_loss = []
        batched_t1s = output['aligned_t1']
        batched_t2s = output['aligned_t2']
        batched_t3s = output['aligned_t3']

        has_no_answer = (len(output['aligned_t1'][0]) == 1 and output['aligned_t1'][0][0][0] == -1)
        if has_no_answer:
            print(f'Warning: No anwer: {batched_ids}')

        batched_loss = self.loss(batched_t1s, batched_t2s, batched_t3s, batched_gpt_specs, batched_ys, batched_video_splits)
        
        formatted_batched_neg_scl_input_facts = [[] for _ in range(batch_size)]
        if self.use_neg_spec:
            batched_neg_gpt_specs = batch['batched_neg_gpt_specs']
            formatted_batched_neg_scl_input_facts = format_batched_facts(batched_scl_tps, batched_neg_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_neg_gpt_specs)
            output = self.reason(**formatted_batched_neg_scl_input_facts)
            
            batched_t1s = output['aligned_t1']
            batched_t2s = output['aligned_t2']
            batched_t3s = output['aligned_t3']
            neg_batched_ys = [0] * batch_size
            
            batched_neg_spec_loss =  self.loss(batched_t1s, batched_t2s, batched_t3s, batched_neg_gpt_specs, neg_batched_ys, batched_video_splits)
            for batch_id, neg_spec_loss in enumerate(batched_neg_spec_loss):
                if len(batched_loss) > 0:
                    batched_loss[batch_id] += self.neg_spec_weight * neg_spec_loss
                else:
                    batched_loss.append(self.neg_spec_weight * neg_spec_loss)
            
        if self.use_neg_kws:
            batched_neg_kws = batch['batched_neg_kws']
            batched_neg_sample_loss = self.neg_sample_loss(batched_image_unary_probs, batched_image_binary_probs, batched_neg_kws)            
            for batch_id, (cate_loss, binary_loss) in enumerate(batched_neg_sample_loss):
                if len(batched_loss) > 0:
                    batched_loss[batch_id] += self.neg_entity_kw_cate_weight * cate_loss + self.neg_entity_kw_binary_weight * binary_loss
                else:
                    batched_loss.append(self.neg_entity_kw_cate_weight * cate_loss + self.neg_entity_kw_binary_weight * binary_loss)
        
        loss = sum(batched_loss)
        # if loss.requires_grad:
        #     print_hook = get_print_hook("loss")
        #     loss.register_hook(print_hook)
        
        if type(loss) == int or not loss.requires_grad:
            return
        loss.backward(retain_graph=True)
        
    def backward_v1(self, loss, batch, batched_image_unary_probs, batched_image_binary_probs, batched_selected_pairs, 
        raw_cate_logits_per_text, raw_unary_logits_per_text, raw_binary_logits_per_text, 
        batched_consts, batched_pos_consts, batched_neg_consts, const_lookup, neg_const_lookup, grad_thres = 0):
    
        # print(f"start trainer backward: {self.device}")
        # Load batch info
        batched_ids = batch['batched_ids']
        batched_captions = batch['batched_captions']
        batched_gt_bboxes = batch['batched_gt_bboxes'] 
        # batched_gt_masks = batch['batched_gt_masks']
        batched_obj_pairs = batch['batched_obj_pairs']
        batched_object_ids = batch['batched_object_ids']
        batched_video_splits = batch['batched_video_splits']
        batched_reshaped_raw_videos = batch['batched_reshaped_raw_videos']
        batched_gt_labels = batch['batched_gt_obj_names']
        batched_gt_object_rels = batch['batched_gt_object_rels']
        batched_gpt_specs = batch['batched_gpt_specs']
        
        batch_size = len(batched_ids)
        batched_cate_pred = [{} for _ in range(batch_size)]
        batched_unary_pred = [{} for _ in range(batch_size)]
        
        batched_d_loss_cate_grads = [None for _ in range(batch_size)]
        batched_d_loss_unary_grads = [None for _ in range(batch_size)]
        batched_d_loss_binary_grads = [None for _ in range(batch_size)]
        
        ## Test efficiency
        get_grad_start_time_2 = time.time()
        to_back = []
        to_back_indices = []
        
        for vid in range(batch_size):
            
            if batched_image_unary_probs[vid][0][0].requires_grad:
                to_back_indices.append((vid, "cate"))
                to_back.append(batched_image_unary_probs[vid][0][0])
                
            if batched_image_unary_probs[vid][1][0].requires_grad:
                to_back_indices.append((vid, "unary"))
                to_back.append(batched_image_unary_probs[vid][1][0])
                
            if batched_image_binary_probs[vid][0].requires_grad:
                to_back_indices.append((vid, "binary"))
                to_back.append(batched_image_binary_probs[vid][0])
            else:
                batched_d_loss_binary_grads[vid] = None
        
        to_back_res = torch.autograd.grad(outputs=loss, inputs=to_back, allow_unused=True, retain_graph=True)
        for (vid, tp), val in zip(to_back_indices, to_back_res):
            if tp == "cate":
                batched_d_loss_cate_grads[vid] = val
            if tp == "unary":
                batched_d_loss_unary_grads[vid] = val
            if tp == "binary":
                batched_d_loss_binary_grads[vid] = val
                
        get_grad_end_time_2 = time.time()
        get_grad_time_2 = get_grad_end_time_2 - get_grad_start_time_2
        
        for vid, (image_cate_probs, image_unary_probs) in enumerate(batched_image_unary_probs):
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
            
            above_grad_cates, under_grad_cates = filter_grad_thres(object_ids, image_cate_probs[1], image_cate_probs[0], batched_d_loss_cate_grads[vid], self.device, self.top_k_grads, grad_thres)
            batched_cate_pred[vid] = {"above": above_grad_cates, "under": under_grad_cates}
 
            above_grad_unary, under_grad_unary = filter_grad_thres(object_ids, image_unary_probs[1], image_unary_probs[0], batched_d_loss_unary_grads[vid], self.device, self.top_k_grads, grad_thres)
            batched_unary_pred[vid] = {"above": above_grad_unary, "under": under_grad_unary}

                
        # Process binary predicates
        batched_binary_pred = [{} for _ in range(batch_size)]
        new_batched_selected_pairs = {i: [] for i in range(batch_size)}
        
        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            
            object_pairs = batched_selected_pairs[vid]
            assert (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(image_binary_probs[1]) == 0 or len(object_pairs) == image_binary_probs[0].shape[1]
            
            above_grad_binary, under_grad_binary = filter_grad_thres(object_pairs, image_binary_probs[1], image_binary_probs[0], batched_d_loss_binary_grads[vid], self.device, self.top_k_grads, grad_thres)
            batched_binary_pred[vid] = {"above": above_grad_binary, "under": under_grad_binary}

        for vid, binary_pred in enumerate(batched_binary_pred):
            for _, pair in binary_pred['above']:
                if not pair in new_batched_selected_pairs[vid]:
                    new_batched_selected_pairs[vid].append(pair)
                            
        if self.use_ddp:
            model = self.predicate_model.module
        else:
            model = self.predicate_model
            
        model.backward_v1(
                batched_data_ids = batched_ids,
                batched_names = batched_consts,
                batched_used_obj_pairs = new_batched_selected_pairs,
                batched_origin_used_obj_pairs = batched_selected_pairs,
                batched_image_unary_probs_orig = batched_image_unary_probs, 
                batched_image_binary_probs_orig = batched_image_binary_probs,
                batched_cate_pred=batched_cate_pred,
                batched_unary_pred=batched_unary_pred,
                batched_binary_pred=batched_binary_pred,
                batched_videos=batched_reshaped_raw_videos,
                batched_bboxes=batched_gt_bboxes, 
                batched_object_ids = batched_object_ids,
                batched_video_splits=batched_video_splits,
                const_lookup=const_lookup,
                neg_const_lookup=neg_const_lookup,
                raw_cate_logits_per_text=raw_cate_logits_per_text, 
                raw_unary_logits_per_text=raw_unary_logits_per_text, 
                raw_binary_logits_per_text=raw_binary_logits_per_text,)
        
    def baseline_eval(self):
        self.predicate_model.eval()
        self.predicate_model.num_top_pairs = self.test_num_top_pairs

        total_results_unary = [[] for _ in range(self.world_size)]
        total_results_binary = [[] for _ in range(self.world_size)]
        
        results_unary = []
        results_binary = []
        # merge_unary_results = [dict ]
        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):
                
                # if not ct >= 72:
                #     continue
                
                batch_result_unary, batch_result_binary = self.baseline_eval_batch(dp_list)
                results_unary.append(batch_result_unary)
                results_binary.append(batch_result_binary)
                
            torch.distributed.all_gather_object(total_results_unary, results_unary)
            torch.distributed.all_gather_object(total_results_binary, results_binary)
        
        if self.device == 0:
            total_results_unary = [i for results_unary in total_results_unary for i in results_unary]
            total_results_binary = [i for results_binary in total_results_binary for i in results_binary]

            merge_unary_results = combine_baseline_pred_dict_ls(total_results_unary)
            merge_binary_results = combine_baseline_pred_dict_ls(total_results_binary)
            unary_accu, unary_stats = obtain_stats(merge_unary_results)
            binary_accu, binary_stats = obtain_stats(merge_binary_results)
            
            report_str = ["unary stats:"]
            report_str += get_report(unary_stats)
            report_str.append(f"Accuracy: {unary_accu}")
            
            report_str += ["binary stats"]
            report_str += get_report(binary_stats)
            report_str.append(f"Accuracy: {binary_accu}")
            report = "\n".join(report_str)
            
            if not self.report_dir is None:
                report_path = os.path.join(self.report_dir, self.model_name + f'.{self.epoch_ct}.report.txt')
                with open(report_path, 'w') as file:
                    file.write(report)  
        
        return
    
    def train_epoch(self, n):
            
        self.predicate_model.train()
        
        all_losses = []
        process_failures = []
        all_dps = 0

        iterator = tqdm(self.train_loader)
        for ct, dp_list in enumerate(iterator):
            
            all_dps += 1
            self.optimizer.zero_grad()
            try:
                loss_ls = self.forward(dp_list)
                loss = sum(loss_ls)
                
                if type(loss) == int or not loss.requires_grad:
                    continue
                                
                loss.backward(retain_graph=True)
                
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                all_losses += [loss.item() for loss in loss_ls]
                avg_loss = sum(all_losses)/len(all_losses)
                iterator.set_description(f'[Train {n}] Loss: {avg_loss}')
                del loss_ls
                del loss
                del dp_list
                
                gc.collect()
                torch.cuda.empty_cache()

            except torch.cuda.OutOfMemoryError as e:
                batched_ids = dp_list['batched_ids']
                batched_captions = dp_list['batched_captions']
                process_failures.append((batched_ids, batched_captions))
                
                print(f"current out of memory ct: {len(process_failures)} out of {all_dps}")
                
                del dp_list
                gc.collect()
                torch.cuda.empty_cache()
                print()
                continue
            
        return avg_loss

    def test_epoch(self):
        
        self.baseline_eval()

        if self.save_model and self.device == 0:
            if type(self.predicate_model) == PredicateModel:  
                torch.save(self.predicate_model, os.path.join(self.model_dir, f"{self.model_name}.{self.epoch_ct}.model"))
            else:
                torch.save(self.predicate_model.module, os.path.join(self.model_dir, f"{self.model_name}.{self.epoch_ct}.model"))

        return

    def test(self):
        self.test_epoch()
        
    def train(self, num_epochs):
        start_ct = 1
        if not self.epoch_ct == 0:
            start_ct = self.epoch_ct + 1
            
        # self.test_epoch()
        for i in range(start_ct, num_epochs + 1):
            now = datetime.now()
            current_time = now.strftime("%H:%M:%S")
            print(f"Epoch {i} train: {current_time}")
            self.epoch_ct = i
            self.train_epoch(i)
            now = datetime.now()
            current_time = now.strftime("%H:%M:%S")
            print(f"Epoch {i} test: {current_time}")
            self.test_epoch()
            
    def save_scl_file(self, datapoint, object_tps, current_constraint):
        scl_file_content = obtain_scl_file(object_tps, current_constraint, self.common_scl)
        scl_file_name = datapoint['id'] + '.scl'
        if not self.save_scl_dir is None:
            scl_path = os.path.join(self.save_scl_dir, scl_file_name)
            with open(scl_path, 'w') as scl_file:
                scl_file.write(scl_file_content)


def main(rank: int, 
         world_size: int, 
         args):
    
    # print(f"start main: {rank}")
    
    dataset = "open_pvsg"
    cache_file_name = f"gpt_specs_prog_str.json"
    data_file_name = 'pvsg.json'
    
    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    cache_path = os.path.join(data_nl_dir, cache_file_name)
    caption2scl = json.load(open(cache_path, 'r'))
    
    scl_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../scl'))
    assert os.path.exists(scl_dir)
    common_scl_path = os.path.join(scl_dir, f'ltl_disj.scl')
    assert os.path.exists(common_scl_path)
    
    if args.use_ddp:
        ddp_setup(rank, world_size)
        # print(f"finish ddp setup: {rank}")
        sampler = DistributedSampler
    else:
        sampler = None
    
    device = rank
    
    train_dataset, valid_dataset, train_loader, test_loader = open_pvsg_loader(
            cache_path=cache_path,
            dataset_dir=data_dir, 
            dataset_name=data_file_name, 
            batch_size=args.batch_size, 
            device=device, 
            training_percentage=args.train_percentage, 
            testing_percentage=args.test_percentage, 
            max_video_len=args.max_video_len,
            neg_kws=args.use_neg_kws,
            neg_spec=args.use_neg_spec,
            neg_example_ct=args.neg_example_ct,
            neg_example_file_name="neg_examples.json",
            backbone_model="clip",
            sampler=sampler,
            dataloader_worker_ct=args.dataloader_worker_ct,
            )
    
    # print(f"finish loader setup: {rank}")
    
    trainer = Trainer(train_loader=train_loader,
                      test_loader=test_loader, 
                      device=device, caption2scl=caption2scl, 
                      save_scl_dir=scl_dir, common_scl_path=common_scl_path,
                      latent_dim=args.latent_dim,
                      model_dir=args.model_dir, model_name=args.model_name,
                      learning_rate=args.learning_rate, load_model=args.load_model,
                      provenance=args.provenance,
                      save_model=args.save_model,
                      video_save_dir= args.video_save_dir,
                      train_num_top_pairs=args.train_num_top_pairs,
                      test_num_top_pairs=args.test_num_top_pairs,
                      report_dir=args.report_dir,
                      use_neg_kws=args.use_neg_kws,
                      use_neg_spec=args.use_neg_spec,
                      neg_spec_weight=args.neg_spec_weight,
                      neg_entity_kw_cate_weight=args.neg_entity_kw_cate_weight,
                      neg_entity_kw_binary_weight=args.neg_entity_kw_binary_weight,
                      clip_model_name=args.clip_model_name,
                      use_half=args.use_half,
                      top_k_grads=args.top_k_grads,
                      segment_size=args.segment_size,
                      args = args, 
                      world_size=world_size,
                      )
        
    print(args.model_name)
    print(f"start train: {rank}")
    if args.phase == "train":
        print("train")
        trainer.train(args.n_epochs)
    elif args.phase == "test":
        print("baseline eval")
        trainer.baseline_eval()
    
    if args.use_ddp:
        destroy_process_group()
    
if __name__ == "__main__":
    
    torch.multiprocessing.set_start_method('spawn', force=True)

    # Set up data directories and paths
    dataset = "open_pvsg"
    # cache_file_name = f"gpt_specs_videollamav2_prog_str.json"
    cache_file_name = f"gpt_specs_prog_str.json"
    data_file_name = 'pvsg.json'

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../../data/{dataset}"))
    assert os.path.exists(data_dir)
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    caption2scl = json.load(open(cache_path, 'r'))
    
    scl_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../scl'))
    assert os.path.exists(scl_dir)
    common_scl_path = os.path.join(scl_dir, f'ltl_disj.scl')
    assert os.path.exists(common_scl_path)
    
    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model_sparse')

    # Setup argument parser
    parser = ArgumentParser(dataset)
    parser.add_argument("--phase", type=str, default='test')
    parser.add_argument("--n-epochs", type=int, default=50)
    parser.add_argument("--load-model", default=False)
    parser.add_argument("--save-model", default=True)
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")
    parser.add_argument("--use-neg-spec",  type=bool, default=True)
    parser.add_argument("--use-neg-kws", type=bool, default=True)
    parser.add_argument("--neg-example-ct", type=int, default=2)
    parser.add_argument("--neg-spec-weight", type=float, default=0.1)
    parser.add_argument("--neg_entity_kw_binary_weight", type=float, default=0.1)
    parser.add_argument("--neg_entity_kw_cate_weight", type=float, default=0)
    parser.add_argument("--clip-model-name", type=str, default="openai/clip-vit-base-patch32")

    parser.add_argument("--parallel",  action='store_true')
    parser.add_argument("--report-dir", type=str, default="")
    
    parser.add_argument("--train-num-top-pairs", type=int, default=10)
    parser.add_argument("--test-num-top-pairs", type=int, default=30)
    parser.add_argument("--max-video-len", type=int, default=12)
    parser.add_argument("--segment-size", type=int, default=3)
    parser.add_argument("--top-k-grads", type=int, default=-1)
    
    # setup question path
    parser.add_argument("--train-num", type=int, default=5000)
    parser.add_argument("--val-num", type=int, default=1000)
    parser.add_argument("--train-percentage", type=int, default=100)
    parser.add_argument("--test-percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.000001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-dir", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_false")
    parser.add_argument("--use-half", action="store_true")
    parser.add_argument("--use-sparse", default=False)
    parser.add_argument("--use-ddp", default=True)
    parser.add_argument("--gpu", type=int, default=-1)
    
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    # args.dataloader_worker_ct = args.batch_size if args.batch_size < 6 else 6
    args.dataloader_worker_ct = 0

    # model_name = f"laser_vllama_{dataset}_training_{args.train_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_negspec_{args.use_neg_spec}_negkw_{args.use_neg_kws}_specweight_{args.neg_spec_weight}_cateweight{args.neg_entity_kw_cate_weight}_binweight{args.neg_entity_kw_binary_weight}"
    record_clip_model_name = args.clip_model_name.split('/')[-1]
    current_time = datetime.now()
    current_time = current_time.replace(microsecond=0)
    current_time_str = str(current_time).replace(' ', '-').replace(':', '-')
    
    model_name =  f"laser_clip_{dataset}" + '_' + current_time_str +\
                  f"_training_{args.train_percentage}" +\
                  f"_lr_{args.learning_rate}" + \
                  f"_negspec_{args.use_neg_spec}" + \
                  f"_negkw_{args.use_neg_kws}" + \
                  f"_mvl_{args.max_video_len}" + \
                  f"_bs_{args.batch_size}" + \
                  f"_sparse_{args.use_sparse}" + \
                  f"_tpg_{args.top_k_grads}" + \
                  f"_sparse_{args.use_ddp}" 

    if args.model_name is None:
        args.model_name = model_name

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    with open(data_path, 'r') as f:
        anno = json.load(f)

    world_size = torch.cuda.device_count()
    
    if args.use_ddp:
        mp.spawn(main, args=(world_size, args), nprocs=world_size)
    else:
        main(0, world_size, args)
    
    print("end")
   