import os

import json
import random
import cv2
from argparse import ArgumentParser
import math
from torch import nn, optim
from tqdm import tqdm
import torch
import scallopy
import sys
import pickle
import gc
import heapq

model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../models'))
assert os.path.exists(model_path)

sys.path.append(model_path)
from open_pvsg_predicate_model_gt_rel import PredicateModel
from openpvsg_dataset import *
from utils import *

class Trainer():
    def __init__(self, train_loader, test_loader, device,
                 caption2scl, common_scl_path,
                 latent_dim = 64,
                 provenance="difftopkproofs", k=3, save_scl_dir=None,
                 use_neg_spec=False, use_neg_kws=False,
                 model_dir=None, model_name=None, 
                 learning_rate=None,
                 load_model=False, save_model=True,
                 video_save_dir=None,
                 train_num_top_pairs=100, 
                 test_num_top_pairs=300, 
                 report_dir=None,
                 neg_spec_weight=0.1,
                 neg_entity_kw_cate_weight=0,
                 neg_entity_kw_binary_weight=0.1,
                 siglip_model_name="google/siglip-base-patch16-224", 
                 use_half=True):

         # Dataset and scallop file setup
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.caption2scl = caption2scl
        self.template2action = {}
        self.common_scl = open(common_scl_path).read()
        self.save_model = save_model
        self.report_dir = report_dir
        self.model_dir = model_dir
        self.model_name = model_name
        
        # Contrastive learning type
        self.use_neg_spec = use_neg_spec
        self.use_neg_kws = use_neg_kws
        self.neg_spec_weight = neg_spec_weight
        self.neg_entity_kw_cate_weight = neg_entity_kw_cate_weight 
        self.neg_entity_kw_binary_weight = neg_entity_kw_binary_weight
        
        # Hyperparameter controlling the number of binary pairs to consider for effiency
        self.train_num_top_pairs = train_num_top_pairs
        self.test_num_top_pairs = test_num_top_pairs

        # Scallop context and forwarding setup
        self.scallop_ctx = scallopy.ScallopContext(provenance=provenance, k=k)
        self.scallop_ctx.import_file(common_scl_path)
        self.scallop_ctx.set_non_probabilistic(non_prob_gpt_prog_str_preds)
        
        self.reason = self.scallop_ctx.forward_function(output_mappings={
            "aligned_t1": None,
            "aligned_t2": None, 
            "aligned_t3": None,
        # }, dispatch="single")
        }, retain_graph=True)
        
        self.reason.to(self.device)

        # Training continuation setups
        self.epoch_ct = 0
        
        # Setting up the STSG model
        if load_model and os.path.exists(model_dir) and len(os.listdir(model_dir)) > 0:
            
            # Load the latest model from given path
            current_model_names = [existing_model_name for existing_model_name in os.listdir(model_dir) if model_name in existing_model_name]
            model_ids = [model_name.split('.')[-2] for model_name in current_model_names]
            digital_model_ids = [int(model_id) for model_id in model_ids if str.isdigit(model_id)]
            
            if len(digital_model_ids) == 0 and 'latest' in digital_model_ids:
                latest_model_id = 'latest'
            else:
                latest_model_id = max(digital_model_ids)
            model_name = model_name + f'.{latest_model_id}.model'
            
            predicate_model = PredicateModel(hidden_dim = latent_dim, num_top_pairs=train_num_top_pairs, device=device, model_name=siglip_model_name).to(device)
            predicate_model.load_state_dict(torch.load(os.path.join(model_dir, model_name), map_location=self.device))
            predicate_model.device = self.device
            print(f"Loading: {model_name}")
            if type(latest_model_id) == int:
                self.epoch_ct = latest_model_id
        else:
            
            # Initialize a new predicate model
            predicate_model = PredicateModel(hidden_dim = latent_dim, num_top_pairs=train_num_top_pairs, device=device, model_name=siglip_model_name).to(device)
        
        predicate_model.num_top_pairs = self.train_num_top_pairs
        self.predicate_model = predicate_model
        
        # Setting up learning parameters
        self.optimizer = optim.Adam(self.predicate_model.parameters(), lr=learning_rate)
        self.min_loss = 10000000000
        self.use_half = use_half
        
        if use_half:
            self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
        else:
            self.loss_fn = nn.BCELoss(reduction='none')
                    
        # Debugging utils
        if not save_scl_dir is None:
            self.save_scl_dir = save_scl_dir
            
    def get_valid_result(self, frame_ids, probs):
        result = []
        for frame_id, prob in zip(frame_ids, probs):
            assert len(frame_id) == 1
            frame_id = frame_id[0]
            
            if frame_id == -1:
                continue
            
            result.append((prob, frame_id))
                
        result = sorted(result, key=lambda x: x[1])
        return result
    
    def baseline_eval_batch(self, batch):
        # Obtained batched data
        batched_ids = batch['batched_ids']
        batched_captions = batch['batched_captions']
        batched_gt_bboxes = batch['batched_gt_bboxes'] 
        # batched_gt_masks = batch['batched_gt_masks']
        batched_obj_pairs = batch['batched_obj_pairs']
        batched_object_ids = batch['batched_object_ids']
        batched_video_splits = batch['batched_video_splits']
        batched_reshaped_raw_videos = batch['batched_reshaped_raw_videos']
        batched_gt_labels = batch['batched_gt_obj_names']
        batched_gt_object_rels = batch['batched_gt_object_rels']
        batched_gpt_specs = batch['batched_gpt_specs']
        
        if len(batched_object_ids) == 0:
            return {}, {}
        
        # Set the required keywords corresponding to the ground truth scene graph
        batch_size = len(batched_ids)
        batched_unary_kws = [[]] * batch_size
        batched_binary_kws = []
        batched_obj_labels = {}
        
        for (vid, fid, label), (vid, fid, oid) in zip(batched_gt_labels, batched_object_ids):
            
            if not vid in batched_obj_labels:
                batched_obj_labels[vid] = {}
            
            if oid in batched_obj_labels[vid]:
                assert label == batched_obj_labels[vid][oid]
            batched_obj_labels[vid][oid] = label
        
        batched_obj_names = []
        for vid, vid_info in batched_obj_labels.items():
            obj_names = set()
            for oid, o_name in vid_info.items():
                obj_names.add(o_name)
            batched_obj_names.append(list(obj_names))
            
        for vid, gt_object_rels in enumerate(batched_gt_object_rels):
            binary_kws = set()
            for fid, obj_pairs in enumerate(gt_object_rels):
                for  (sub, obj, binary_kw) in obj_pairs:
                    if (vid, fid, (sub, obj)) in batched_obj_pairs:
                        binary_kws.add(binary_kw)
                        
            batched_binary_kws.append(list(binary_kws))
            
        # Obtain the predictions with pretrained model
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs  = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=batched_obj_names,
                                batched_object_ids=batched_object_ids,
                                batched_unary_kws=batched_unary_kws,
                                batched_binary_kws=batched_binary_kws,
                                batched_gt_obj_pairs=batched_gt_object_rels, 
                                batched_video_splits=batched_video_splits)
        
        # Only categories live in the ground truth scene graph
        # Process the categories as predicates
        batched_cate_pred_scl = []
        for vid, (image_cate_probs, _) in enumerate(batched_image_unary_probs):
            
            cate_pred_scl = {}
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
                
            for cate_name, probs in zip(image_cate_probs[1], image_cate_probs[0]):
                for prob, (fid, oid) in zip(probs, object_ids):
                    if not oid in cate_pred_scl:
                        cate_pred_scl[oid] = {}
                    if not cate_name in cate_pred_scl[oid]:
                        cate_pred_scl[oid][cate_name] = []
                    cate_pred_scl[oid][cate_name].append(prob)
                    
            new_cate_pred_scl = []
            for oid, object_cate_info in cate_pred_scl.items():
                for cate_name, prob in object_cate_info.items():
                    new_cate_pred_scl.append((torch.mean(torch.stack(prob)), (oid, cate_name)))
                    
            batched_cate_pred_scl.append(new_cate_pred_scl)
            
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            object_pairs = [ (fid, pair) for (ovid, fid, pair) in selected_pairs if vid == ovid]
            assert (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(object_pairs) == image_binary_probs[0].shape[1]
            
            for binary_pred_name, probs in zip(image_binary_probs[1], image_binary_probs[0]):
 
                for prob, (fid, pair) in zip(probs, object_pairs):
                    binary_pred_scl.append((prob, (binary_pred_name, fid, pair[0], pair[1])))

            batched_binary_pred_scl.append(binary_pred_scl)
        
        # Test the accuracy
        result_unary, result_binary = self.accu(batched_obj_labels, batched_gt_object_rels, batched_cate_pred_scl, batched_binary_pred_scl)
        
        return result_unary, result_binary
    
    # Accuracy Metric
    def accu(self, batched_obj_labels, batched_gt_object_rels, batched_image_cate_probs, batched_image_binary_probs, top_pair_num=100):
        result_unary = {}
        result_binary = {}
        top_binary_preds_heap = []
        
        for vid, (image_cate_probs, binary_pred, gt_object_rels) in enumerate(zip(batched_image_cate_probs, batched_image_binary_probs, batched_gt_object_rels)):
            # rela_target = []
           
            obj_labels = batched_obj_labels[vid]
            
            for rela_prob, (rela_name, fid, sub, obj) in binary_pred:
                heapq.heappush(top_binary_preds_heap, (rela_prob, (rela_name, fid, sub, obj)))
            top_binary_preds = heapq.nlargest(top_pair_num, top_binary_preds_heap)
            
            for rela_prob, (rela_name, fid, sub, obj) in top_binary_preds:
                
                if not rela_name in result_binary:
                    result_binary[rela_name] = {}
                    result_binary[rela_name]['gt'] = []
                    result_binary[rela_name]['pred'] = []
                    
                if (sub, obj, rela_name) in gt_object_rels[fid]:
                    result_binary[rela_name]['gt'].append(1)
                else:
                    result_binary[rela_name]['gt'].append(0)
                    
                if rela_prob > 0.5:
                    result_binary[rela_name]['pred'].append(1)
                else:
                    result_binary[rela_name]['pred'].append(0)
                                    
            obj_pred = {}
            for cate_prob, (oid, obj_name) in image_cate_probs:
                if not oid in obj_pred:
                    obj_pred[oid] = (cate_prob, obj_name)
                if cate_prob > obj_pred[oid][0]:
                    obj_pred[oid] = (cate_prob, obj_name)
                    
            for oid, obj_name in obj_labels.items():
                
                if not obj_name in result_unary:
                    result_unary[obj_name] = {}
                    result_unary[obj_name]['gt'] = []
                    result_unary[obj_name]['pred'] = []
                    
                if  obj_labels[oid] == obj_name:
                    result_unary[obj_name]['gt'].append(1)
                else:
                    result_unary[obj_name]['gt'].append(0)
                
                if obj_pred[oid][1] == obj_name:
                    result_unary[obj_name]['pred'].append(1)
                else:
                    result_unary[obj_name]['pred'].append(0)
                
        return result_unary, result_binary
    
    def neg_sample_loss(self, batched_unary_pred, batched_binary_pred, batched_neg_examples):
        batched_neg_sample_loss = []
        for unary_pred, binary_pred, neg_sample in zip(batched_unary_pred, batched_binary_pred, batched_neg_examples):
            ((image_cate_probs, cates), (image_unary_probs, unary_kws)) = unary_pred
            image_binary_probs, binary_kws =  binary_pred
            neg_cate = neg_sample['neg_entity']
            neg_binary = neg_sample['neg_binary']
        
            cate_indexes = [cate_id for cate_id, cate in enumerate(cates) if cate in neg_cate]
            if not len(cate_indexes) == 0:
                neg_cate_probs = image_cate_probs[cate_indexes, :]
            else:
                neg_cate_probs = torch.tensor([]).to(self.device)
            target_cate_probs = torch.zeros(neg_cate_probs.shape).to(self.device)
            
            binary_indexes = [ binary_id for binary_id, binary_kw in enumerate(binary_kws) if binary_kw in neg_binary]
            if not len(binary_indexes) == 0:
                neg_binary_probs = image_binary_probs[binary_indexes, :]
            else:
                neg_binary_probs = torch.tensor([]).to(self.device)
            target_binary_probs = torch.zeros(neg_binary_probs.shape).to(self.device)
            
            cate_loss, binary_loss = torch.sum(self.loss_fn(neg_cate_probs, target_cate_probs)), torch.sum(self.loss_fn(neg_binary_probs, target_binary_probs))
            batched_neg_sample_loss.append((cate_loss, binary_loss))
        return batched_neg_sample_loss
    
    # Loss function
    def loss(self, t1_frame_ids, batched_t1probs, t2_frame_ids, batched_t2probs, t3_frame_ids, batched_t3probs, 
             batched_action_specs, batched_ys, batched_video_splits, encourage_prop = 0.3, eps = 1e-15):
        batched_loss = []
        batched_video_length = []
        
        current_vid_id = 0
        for video_splits in batched_video_splits:
            batched_video_length.append(video_splits - current_vid_id)
            current_vid_id = video_splits
            
        for t1probs, t2probs, t3probs, y, action_spec, video_length in \
            zip(batched_t1probs, batched_t2probs, batched_t3probs, batched_ys, batched_action_specs, batched_video_length):

            t1_result = self.get_valid_result(t1_frame_ids, t1probs)
            t2_result = self.get_valid_result(t2_frame_ids, t2probs)
            t3_result = self.get_valid_result(t3_frame_ids, t3probs)
            
            if len(t1_result) == 0 and len(t2_result) == 0 and len(t3_result) == 0:
                continue

            results = [t1_result, t2_result, t3_result]
            locations = action_spec['video location']
            assert len(locations) <= 3 and len(locations) > 0
            
            current_loss = []
            for result, location in zip(results, locations):
                if not location in location_consts:
                    continue
                
                encourage_len = math.ceil(video_length * encourage_prop)
                score_for_dist = 1 / encourage_len
                
                # encourage the first part of the total len
                if location == "early":
                    start = 0
                    end = start + encourage_len
                    
                    # TODO: optimize
                    weights = []
                    for i in range(video_length):
                        if i > end:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (end - i + 1)
                            weights.append(weight)
                            
                elif location == "mid":
                    mid = math.ceil(video_length / 2)
                    dis =  math.ceil(encourage_len / 2)
                    start = mid - dis
                    end = mid + dis
                    
                    weights = []
                    for i in range(video_length):
                        if i > end:
                            weights.append(0)
                        elif i < start:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (dis - abs(mid - i) + 1)
                            weights.append(weight)
                            
                else:
                    end = video_length
                    start = end - encourage_len
                    
                    weights = []
                    for i in range(video_length):
                        if i < start:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (i - start + 1)
                            weights.append(weight)
                
                valid_probs = []
                valid_weights = []
                target_y = []
                for prob, frame_id in result:
                    if frame_id >= video_length:
                        continue
                    weight = weights[frame_id]
                    if not weight == 0:
                        valid_probs.append(prob)
                        valid_weights.append(weight)
                        target_y.append(y)
                        
                valid_weights = torch.tensor(valid_weights, device=self.device)
                if len(valid_probs) == 0:
                    continue
                
                if self.use_half:
                    valid_probs = torch.stack(valid_probs)
                    valid_prob_logits = torch.log( valid_probs / (1 - valid_probs))
                    loss = self.loss_fn(valid_prob_logits, torch.tensor(target_y, dtype=valid_probs[0].dtype, device=self.device))
                else:
                    loss = self.loss_fn(torch.stack(valid_probs), torch.tensor(target_y, dtype=torch.float32, device=self.device))
                
                loss = loss * valid_weights
                loss = (loss.sum() / valid_weights.sum())
                current_loss.append(loss)
            
            if len(current_loss) == 0:
                continue
            else:
                batched_loss.append(torch.mean(torch.stack(current_loss)))
            # For smaller window, it has lower likelihood of actually capturing the operation
            # We thus assign a weight function for the

        return batched_loss

    def forward(self, batch):
        # Load batch info
        batched_ids = batch['batched_ids']
        batched_captions = batch['batched_captions']
        batched_gt_bboxes = batch['batched_gt_bboxes'] 
        # batched_gt_masks = batch['batched_gt_masks']
        batched_obj_pairs = batch['batched_obj_pairs']
        batched_object_ids = batch['batched_object_ids']
        batched_video_splits = batch['batched_video_splits']
        batched_reshaped_raw_videos = batch['batched_reshaped_raw_videos']
        batched_gt_labels = batch['batched_gt_obj_names']
        batched_gt_object_rels = batch['batched_gt_object_rels']
        batched_gpt_specs = batch['batched_gpt_specs']
        
        if len(batched_object_ids) == 0:
            return []
                
        batch_size = len(batched_ids)
        
        # Fetch constants
        batched_unary_kws = []
        batched_binary_kws = []
        batched_consts = []
        batched_pos_consts = []
        batched_neg_consts = []
            
        # Contrastive Learning Setup
        for spec in batched_gpt_specs:
            batched_unary_kws.append(spec['unary_kws'])
            batched_binary_kws.append(spec['binary_kws'])
            batched_consts.append(spec['consts'])
            batched_pos_consts.append(spec['consts'])
        
        if self.use_neg_spec:
            # TODO: fix this
            batched_neg_gpt_specs = batch['batched_neg_gpt_specs']
            for batch_id, (spec, neg_spec) in enumerate(zip(batched_gpt_specs, batched_neg_gpt_specs)):
                deduped_neg_binary = list(set(neg_spec['binary_kws']) - set(batched_binary_kws[batch_id]))
                deduped_neg_entity = list(set(neg_spec['consts']) - set(batched_consts[batch_id]))
                deduped_neg_unary = list(set(neg_spec['unary_kws']) - set(batched_unary_kws[batch_id]))

                batched_unary_kws[batch_id] += (deduped_neg_unary)
                batched_binary_kws[batch_id] += (deduped_neg_binary)
                batched_consts[batch_id] += (deduped_neg_entity)
                batched_neg_consts.append(neg_spec['consts'])
                
        if self.use_neg_kws:
            batched_neg_kws = batch['batched_neg_kws']
            for batch_id, (spec, negative_examples) in enumerate(zip(batched_gpt_specs, batched_neg_kws)):
                deduped_neg_binary = list(set(negative_examples['neg_binary']) - set(batched_binary_kws[batch_id]))
                deduped_neg_entity = list(set(negative_examples['neg_entity']) - set(batched_consts[batch_id]))

                batched_binary_kws[batch_id] += (deduped_neg_binary)
                batched_consts[batch_id] += (deduped_neg_entity)
        
        # Get probabilities
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=batched_consts,
                                batched_object_ids = batched_object_ids,
                                batched_unary_kws=batched_unary_kws,
                                batched_binary_kws=batched_binary_kws,
                                batched_gt_obj_pairs=batched_gt_object_rels, 
                                batched_video_splits=batched_video_splits)
        
        consts = [e for c in batched_pos_consts for e in c]
        const_lookup = {}
        cids = []
        for k, v in enumerate(consts):
            const_lookup[v] = -k
            const_lookup[v.upper()] = -k
            const_lookup[v.lower()] = -k
            cids.append(-k)
        
        if self.use_neg_spec:       
            neg_consts = [e for c in batched_neg_consts for e in c]
            neg_const_lookup = {}
            neg_cids = []
            for k, v in enumerate(neg_consts):
                neg_const_lookup[v] = -k
                neg_const_lookup[v.upper()] = -k
                neg_const_lookup[v.lower()] = -k
                neg_cids.append(-k)

        # batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_object_ids)
        
        # Process unary predicates
        batched_unary_pred_scl = []
        batched_cate_pred_scl = []
        if self.use_neg_spec:
            batched_neg_cate_pred_scl = []
            
        for vid, (image_cate_probs, image_unary_probs) in enumerate(batched_image_unary_probs):
            
            unary_pred_scl = []
            cate_pred_scl = {}
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
                
            for cate_name, probs in zip(image_cate_probs[1], image_cate_probs[0]):
                
                for prob, (fid, oid) in zip(probs, object_ids):
                    if not oid in cate_pred_scl:
                        cate_pred_scl[oid] = {}
                    if not cate_name in cate_pred_scl[oid]:
                        cate_pred_scl[oid][cate_name] = []
                    cate_pred_scl[oid][cate_name].append(prob)
                    
            new_cate_pred_scl = []
            for oid, object_cate_info in cate_pred_scl.items():
                for cate_name, prob in object_cate_info.items():
                    if cate_name in const_lookup:
                        new_cate_pred_scl.append((torch.mean(torch.stack(prob)), (oid, const_lookup[cate_name] - 1)))
            
            for unary_pred_name, probs in zip(image_unary_probs[1], image_unary_probs[0]):
                
                for prob, (fid, oid) in zip(probs, object_ids):
                    unary_pred_scl.append((prob, (unary_pred_name, fid, oid)))
            
            batched_cate_pred_scl.append(new_cate_pred_scl)
            batched_unary_pred_scl.append(unary_pred_scl)
            
            if self.use_neg_spec:
                new_neg_cate_pred_scl = []
                for oid, object_cate_info in cate_pred_scl.items():
                    for cate_name, prob in object_cate_info.items():
                        if cate_name in neg_const_lookup:
                            new_neg_cate_pred_scl.append((torch.mean(torch.stack(prob)), (oid, neg_const_lookup[cate_name] - 1)))
                batched_neg_cate_pred_scl.append(new_neg_cate_pred_scl)
           
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            object_pairs = [ (fid, pair) for (ovid, fid, pair) in selected_pairs if vid == ovid]
            assert (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(object_pairs) == image_binary_probs[0].shape[1]
            
            for binary_pred_name, probs in zip(image_binary_probs[1], image_binary_probs[0]):
 
                for prob, (fid, pair) in zip(probs, object_pairs):
                    binary_pred_scl.append((prob, (binary_pred_name, fid, pair[0], pair[1])))

            batched_binary_pred_scl.append(binary_pred_scl)

        formatted_batched_scl_input_facts = format_batched_facts(batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs)
        
        # Ground truth is 1 as the batch size is always 1
        # TODO: Update this for contrastive setup
        batched_ys = [1] * batch_size

        output = self.reason(**formatted_batched_scl_input_facts)
    
        t1s, t1probs = output['aligned_t1']
        t2s, t2probs = output['aligned_t2']
        t3s, t3probs = output['aligned_t3']

        has_no_answer = (len(output['aligned_t1'][0]) == 1 and output['aligned_t1'][0][0][0] == -1)
        if has_no_answer:
            print(f'Warning: No anwer: {batched_ids}')
                
        batched_align_loss = self.loss(t1s, t1probs, t2s, t2probs, t3s, t3probs, batched_gpt_specs, batched_ys, batched_video_splits)
        
        batched_loss = batched_align_loss
        
        if self.use_neg_spec:
            formatted_batched_neg_scl_input_facts = format_batched_facts(batched_scl_tps, batched_neg_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_neg_gpt_specs)
            output = self.reason(**formatted_batched_neg_scl_input_facts)
            
            t1s, t1probs = output['aligned_t1']
            t2s, t2probs = output['aligned_t2']
            t3s, t3probs = output['aligned_t3']
            neg_batched_ys = [0] * batch_size
            
            batched_neg_spec_loss =  self.loss(t1s, t1probs, t2s, t2probs, t3s, t3probs, batched_neg_gpt_specs, neg_batched_ys, batched_video_splits)
            for batch_id, neg_spec_loss in enumerate(batched_neg_spec_loss):
                if len(batched_loss) > 0:
                    batched_loss[batch_id] += self.neg_spec_weight * neg_spec_loss
                else:
                    batched_loss.append(self.neg_spec_weight * neg_spec_loss)
            
        if self.use_neg_kws:

            batched_neg_sample_loss = self.neg_sample_loss(batched_image_unary_probs, batched_image_binary_probs, batched_neg_kws)            
            for batch_id, (cate_loss, binary_loss) in enumerate(batched_neg_sample_loss):
                if len(batched_loss) > 0:
                    batched_loss[batch_id] += self.neg_entity_kw_cate_weight * cate_loss + self.neg_entity_kw_binary_weight * binary_loss
                else:
                    batched_loss.append(self.neg_entity_kw_cate_weight * cate_loss + self.neg_entity_kw_binary_weight * binary_loss)

        return batched_loss

    def baseline_eval(self):
        self.predicate_model.eval()
        self.predicate_model.num_top_pairs = self.test_num_top_pairs

        total_results_unary = []
        total_results_binary = []
        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):

                result_unary, result_binary = self.baseline_eval_batch(dp_list)
                total_results_unary.append(result_unary)
                total_results_binary.append(result_binary)

            merge_unary_results = combine_baseline_pred_dict_ls(total_results_unary)
            merge_binary_results = combine_baseline_pred_dict_ls(total_results_binary)
            unary_accu, unary_stats = obtain_stats(merge_unary_results)
            binary_accu, binary_stats = obtain_stats(merge_binary_results)
        
        report_str = ["unary stats:"]
        report_str += get_report(unary_stats)
        report_str.append(f"Accuracy: {unary_accu}")
        
        report_str += ["binary stats"]
        report_str += get_report(binary_stats)
        report_str.append(f"Accuracy: {binary_accu}")
        report = "\n".join(report_str)
        
        if not self.report_dir is None:
            report_path = os.path.join(self.report_dir, self.model_name + f'.{self.epoch_ct}.report.txt')
            with open(report_path, 'w') as file:
                file.write(report)  
        
        return(report_str, unary_accu, binary_accu)
        
    def train_epoch(self, n):
            
        self.predicate_model.train()
        
        all_losses = []
        process_failures = []
        all_dps = 0

        iterator = tqdm(self.train_loader)
        for ct, dp_list in enumerate(iterator):
            # if ct > 15:
            #     exit()
            all_dps += 1
            self.optimizer.zero_grad()
            try:
                loss_ls = self.forward(dp_list)

                loss = sum(loss_ls)
                if type(loss) == int:
                    continue
                loss.backward(retain_graph=True)
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                all_losses += [loss.item() for loss in loss_ls]
                avg_loss = sum(all_losses)/len(all_losses)
                iterator.set_description(f'[Train {n}] Loss: {avg_loss}')
                del loss_ls
                del loss
                del dp_list
                
                gc.collect()
                torch.cuda.empty_cache()

            except torch.cuda.OutOfMemoryError as e:
                batched_ids = dp_list['batched_ids']
                batched_captions = dp_list['batched_captions']
                process_failures.append((batched_ids, batched_captions))
                # with open(self.log_path, 'a') as f:
                #     f.write(str((batched_ids, batched_captions)))
                    
                print(f"current out of memory ct: {len(process_failures)} out of {all_dps}")
                
                del dp_list
                gc.collect()
                torch.cuda.empty_cache()
                print()
                continue
            
        return avg_loss

    def test_epoch(self):
        report_str, unary_accu, binary_accu = self.baseline_eval()

        if self.save_model: 
            torch.save(self.predicate_model, os.path.join(self.model_dir, f"{self.model_name}.{self.epoch_ct}.model"))

        return

    def test(self):
        self.test_epoch()
        
    def train(self, num_epochs):
        start_ct = 1
        if not self.epoch_ct == 0:
            start_ct = self.epoch_ct + 1
            
        # self.test_epoch()
        for i in range(start_ct, num_epochs + 1):
            self.epoch_ct = i
            self.train_epoch(i)
            self.test_epoch()
            
    def save_scl_file(self, datapoint, object_tps, current_constraint):
        scl_file_content = obtain_scl_file(object_tps, current_constraint, self.common_scl)
        scl_file_name = datapoint['id'] + '.scl'
        if not self.save_scl_dir is None:
            scl_path = os.path.join(self.save_scl_dir, scl_file_name)
            with open(scl_path, 'w') as scl_file:
                scl_file.write(scl_file_content)

if __name__ == "__main__":
    
    # Set up data directories and paths
    dataset = "open_pvsg"
    # cache_file_name = f"gpt_specs_videollamav2_prog_str.json"
    cache_file_name = f"gpt_specs_prog_str.json"
    data_file_name = 'pvsg.json'

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    assert os.path.exists(data_dir)
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    caption2scl = json.load(open(cache_path, 'r'))
    
    scl_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../scl'))
    assert os.path.exists(scl_dir)
    common_scl_path = os.path.join(scl_dir, f'ltl_disj.scl')
    assert os.path.exists(common_scl_path)
    
    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')

    # Setup argument parser
    parser = ArgumentParser(dataset)
    parser.add_argument("--phase", type=str, default='test')
    parser.add_argument("--n-epochs", type=int, default=50)
    parser.add_argument("--load-model", default=False)
    parser.add_argument("--save-model", default=True)
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")
    parser.add_argument("--use-neg-spec",  type=bool, default=True)
    parser.add_argument("--use-neg-kws", type=bool, default=True)
    parser.add_argument("--neg-example-ct", type=int, default=2)
    parser.add_argument("--neg-spec-weight", type=float, default=0.1)
    parser.add_argument("--neg_entity_kw_binary_weight", type=float, default=0.1)
    parser.add_argument("--neg_entity_kw_cate_weight", type=float, default=0)
    parser.add_argument("--siglip-model-name", type=str, default="google/siglip-base-patch16-224")

    parser.add_argument("--parallel",  action='store_true')
    parser.add_argument("--report-dir", type=str, default="")
    
    parser.add_argument("--train-num-top-pairs", type=int, default=10)
    parser.add_argument("--test-num-top-pairs", type=int, default=30)
    parser.add_argument("--max-video-len", type=int, default=12)
    
    # setup question path
    parser.add_argument("--train-num", type=int, default=5000)
    parser.add_argument("--val-num", type=int, default=1000)
    parser.add_argument("--train-percentage", type=int, default=100)
    parser.add_argument("--test-percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.000001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-dir", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_false")
    parser.add_argument("--use-half", action="store_true")
    parser.add_argument("--gpu", type=int, default=-1)
    
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # model_name = f"laser_vllama_{dataset}_training_{args.train_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_negspec_{args.use_neg_spec}_negkw_{args.use_neg_kws}_specweight_{args.neg_spec_weight}_cateweight{args.neg_entity_kw_cate_weight}_binweight{args.neg_entity_kw_binary_weight}"
    record_siglip_model_name = args.siglip_model_name.split('/')[-1]
    model_name = f"laser_{dataset}" + \
                  f"_training_{args.train_percentage}" +\
                  f"_lr_{args.learning_rate}" + \
                  f"_model_{record_siglip_model_name}" + \
                  f"_negspec_{args.use_neg_spec}" + \
                  f"_kwweight_{args.neg_spec_weight}" + \
                  f"_negkw_{args.use_neg_kws}" + \
                  f"_negcate_{args.neg_entity_kw_cate_weight}" + \
                  f"_negbin_{args.neg_entity_kw_binary_weight}" + \
                  f"_mvl_{args.max_video_len}" + \
                  f"_seed_{args.seed}" + \
                  f"_batch_size_{args.batch_size}" + \
                  f"_prov_{args.provenance}" + \
                  f"_tpk_{args.train_top_k}"
                    
    # model_name = "laser_vllama_violet_open_pvsg_training_100_seed_1234_batch_size_1_lr_1e-06_prov_difftopkproofs_tpk_3_negspec_True_negkw_True_kwweight_0.1_mvl_8"   
    if args.model_name is None:
        args.model_name = model_name

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    with open(data_path, 'r') as f:
        anno = json.load(f)
    
    # See video id in anno['split'].
    data = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    train_dataset, valid_dataset, train_loader, test_loader = open_pvsg_loader(
            cache_path=cache_path,
            dataset_dir=data_dir, 
            dataset_name=data_file_name, 
            batch_size=args.batch_size, 
            device=device, 
            training_percentage=args.train_percentage, 
            testing_percentage=args.test_percentage, 
            max_video_len=args.max_video_len,
            neg_kws=args.use_neg_kws,
            neg_spec=args.use_neg_spec,
            neg_example_ct=args.neg_example_ct,
            neg_example_file_name="neg_examples.json",
            backbone_model="siglip"
            )
    
    trainer = Trainer(train_loader=train_loader,
                      test_loader=test_loader, 
                      device=device, caption2scl=caption2scl, 
                      save_scl_dir=scl_dir, common_scl_path=common_scl_path,
                      latent_dim=args.latent_dim,
                      model_dir=args.model_dir, model_name=args.model_name,
                      learning_rate=args.learning_rate, load_model=args.load_model,
                      provenance=args.provenance,
                      save_model=args.save_model,
                      video_save_dir= args.video_save_dir,
                      train_num_top_pairs=args.train_num_top_pairs,
                      test_num_top_pairs=args.test_num_top_pairs,
                      report_dir=args.report_dir,
                      use_neg_kws=args.use_neg_kws,
                      use_neg_spec=args.use_neg_spec,
                      neg_spec_weight=args.neg_spec_weight,
                      neg_entity_kw_cate_weight=args.neg_entity_kw_cate_weight,
                      neg_entity_kw_binary_weight=args.neg_entity_kw_binary_weight,
                      siglip_model_name=args.siglip_model_name,
                      use_half=args.use_half,
                      )
    
    print(args.model_name)
    if args.phase == "train":
        print("train")
        trainer.train(args.n_epochs)
    elif args.phase == "test":
        print("baseline eval")
        trainer.baseline_eval()
    
    print("end")