import os

import json
import random
import cv2
from argparse import ArgumentParser
import math
from torch import nn, optim
from tqdm import tqdm
import torch
import scallopy
import sys
import pickle
import gc
import heapq
import transformers

# torch.set_default_dtype(torch.float32)
model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../violet2_model'))
assert os.path.exists(model_path)

sys.path.append(model_path)
from violet_model_stsg_gt_rel import VIOLET_STSG
from openpvsg_dataset import *
from utils import *
    
class Trainer():
    def __init__(self, 
                 train_loader, test_loader, 
                 device,
                 caption2scl, common_scl_path,
                 tokenizer_base,
                 violet_config_dir,
                 ckpt_path, 
                 latent_dim = 64,
                 provenance="difftopkproofs", k=3, save_scl_dir=None,
                 use_neg_spec=False, use_neg_kws=False,
                 model_dir=None, 
                 model_name=None, 
                 learning_rate=None,
                 load_model=False, save_model=True,
                 video_save_dir=None,
                 train_num_top_pairs=100, 
                 test_num_top_pairs=300, 
                 report_dir=None,
                 neg_spec_weight=0.1,
                 neg_entity_kw_cate_weight=0,
                 neg_entity_kw_binary_weight=0.1,
                 max_size_frame=8,
                 max_grad_norm=1.0,
                 log_path=None
                 ):

         # Dataset and scallop file setup
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.caption2scl = caption2scl
        self.template2action = {}
        self.common_scl = open(common_scl_path).read()
        self.save_model = save_model
        self.report_dir = report_dir
        self.model_dir = model_dir
        self.model_name = model_name
        self.max_grad_norm = max_grad_norm
        self.log_path = log_path
        
        # Contrastive learning type
        self.use_neg_spec = use_neg_spec
        self.use_neg_kws = use_neg_kws
        self.neg_spec_weight = neg_spec_weight
        self.neg_entity_kw_cate_weight = neg_entity_kw_cate_weight 
        self.neg_entity_kw_binary_weight = neg_entity_kw_binary_weight
        
        # Hyperparameter controlling the number of binary pairs to consider for effiency
        self.train_num_top_pairs = train_num_top_pairs
        self.test_num_top_pairs = test_num_top_pairs

        # Scallop context and forwarding setup
        self.scallop_ctx = scallopy.ScallopContext(provenance=provenance, k=k)
        self.scallop_ctx.import_file(common_scl_path)
        self.scallop_ctx.set_non_probabilistic(non_prob_gpt_prog_str_preds)
        
        self.reason = self.scallop_ctx.forward_function(output_mappings={
            "aligned_t1": None,
            "aligned_t2": None, 
            "aligned_t3": None,
        # }, dispatch="single")
        }, retain_graph=True)
        
        self.reason.to(self.device)

        # Training continuation setups
        self.epoch_ct = 0
        
        # Setting up the STSG model
        if load_model and os.path.exists(model_dir) and len(os.listdir(model_dir)) > 0:
            
            # Load the latest model from given path
            current_model_names = [existing_model_name for existing_model_name in os.listdir(model_dir) if model_name in existing_model_name]
            model_ids = [model_name.split('.')[-2] for model_name in current_model_names]
            digital_model_ids = [int(model_id) for model_id in model_ids if str.isdigit(model_id)]
            if len(digital_model_ids) == 0 and 'latest' in digital_model_ids:
                latest_model_id = 'latest'
            else:
                latest_model_id = max(digital_model_ids)
            
            tokzr = transformers.AutoTokenizer.from_pretrained(tokenizer_base)
            model_name = model_name + f'.{latest_model_id}.model'
            predicate_model = VIOLET_STSG(config_dir=violet_config_dir, tokzr=tokzr, max_size_frame=max_size_frame)
            predicate_model.load_state_dict(torch.load(os.path.join(model_dir, model_name), map_location=self.device))
            predicate_model.to(self.device)
                      
            print(f"Loading: {model_name}")
            if type(latest_model_id) == int:
                self.epoch_ct = latest_model_id
        else:
            # Load a pretrained model 
            tokzr = transformers.AutoTokenizer.from_pretrained(tokenizer_base)
            predicate_model = VIOLET_STSG(config_dir=violet_config_dir, tokzr=tokzr, max_size_frame=max_size_frame)
            predicate_model.load_ckpt(ckpt_path)
            predicate_model = predicate_model.to(self.device)
            # predicate_model = torch.nn.DataParallel(predicate_model, device_ids=[0, 1])
            
        predicate_model.num_top_pairs = self.train_num_top_pairs
        self.predicate_model = predicate_model
        
        # Setting up learning parameters
        self.optimizer = optim.Adam(self.predicate_model.parameters(), lr=learning_rate)
        self.scaler = torch.amp.GradScaler()
        
        self.min_loss = 10000000000

        # self.loss_fn = nn.BCELoss(reduction='none')
        self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
        
        # Debugging utils
        if not save_scl_dir is None:
            self.save_scl_dir = save_scl_dir
            
    def get_valid_result(self, frame_ids, probs):
        result = []
        for frame_id, prob in zip(frame_ids, probs):
            assert len(frame_id) == 1
            frame_id = frame_id[0]
            
            if frame_id == -1:
                continue
            
            result.append((prob, frame_id))
                
        result = sorted(result, key=lambda x: x[1])
        return result
    
    def baseline_eval_batch(self, batch):
        # Obtained batched data
        batched_ids = batch['batched_ids']
        batched_captions = batch['batched_captions']
        batched_gt_bboxes = batch['batched_gt_bboxes'] 
        # batched_gt_masks = batch['batched_gt_masks']
        batched_obj_pairs = batch['batched_obj_pairs']
        batched_object_ids = batch['batched_object_ids']
        batched_video_splits = batch['batched_video_splits']
        batched_reshaped_raw_videos = batch['batched_videos']
        batched_gt_labels = batch['batched_gt_obj_names']
        batched_gt_object_rels = batch['batched_gt_object_rels']
        batched_gpt_specs = batch['batched_gpt_specs']
        
        if len(batched_object_ids) == 0:
            return {}, {}
        
        # Set the required keywords corresponding to the ground truth scene graph
        batch_size = len(batched_ids)
        batched_unary_kws = [[]] * batch_size
        batched_binary_kws = []
        batched_obj_labels = {}
        
        for (vid, fid, label), (vid, fid, oid) in zip(batched_gt_labels, batched_object_ids):
            
            if not vid in batched_obj_labels:
                batched_obj_labels[vid] = {}
            
            if oid in batched_obj_labels[vid]:
                assert label == batched_obj_labels[vid][oid]
                
            batched_obj_labels[vid][oid] = label
        
        batched_obj_names = []
        for vid, vid_info in batched_obj_labels.items():
            obj_names = set()
            for oid, o_name in vid_info.items():
                obj_names.add(o_name)
            batched_obj_names.append(list(obj_names))
            
        for vid, gt_object_rels in enumerate(batched_gt_object_rels):
            binary_kws = set()
            for fid, obj_pairs in enumerate(gt_object_rels):
                for  (sub, obj, binary_kw) in obj_pairs:
                    if (vid, fid, (sub, obj)) in batched_obj_pairs:
                        binary_kws.add(binary_kw)
                        
            batched_binary_kws.append(list(binary_kws))
            
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=batched_obj_names,
                                batched_object_ids = batched_object_ids,
                                batched_unary_kws=batched_unary_kws,
                                batched_binary_kws=batched_binary_kws,
                                batched_obj_pairs=batched_gt_object_rels, 
                                batched_video_splits=batched_video_splits)
            
        # Only categories live in the ground truth scene graph
        # Process the categories as predicates
        batched_cate_pred_scl = []
        for vid, (image_cate_probs, _) in enumerate(batched_image_unary_probs):
            
            cate_pred_scl = {}
            
            # object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            # assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
                
            for (oid, cate_name), prob  in zip(image_cate_probs[1], image_cate_probs[0]):
                # for prob, (fid, oid) in zip(probs, object_ids):
                    if not oid in cate_pred_scl:
                        cate_pred_scl[oid] = {}
                    if not cate_name in cate_pred_scl[oid]:
                        cate_pred_scl[oid][cate_name] = []
                    cate_pred_scl[oid][cate_name].append(prob)
                    
            new_cate_pred_scl = []
            for oid, object_cate_info in cate_pred_scl.items():
                for cate_name, prob in object_cate_info.items():
                    new_cate_pred_scl.append((torch.mean(torch.stack(prob)), (oid, cate_name)))
                    
            batched_cate_pred_scl.append(new_cate_pred_scl)
            
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            # object_pairs = [ (fid, pair) for (ovid, fid, pair) in selected_pairs if vid == ovid]
            # assert (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(object_pairs) == image_binary_probs[0].shape[1]
            
            for (vid, fid, from_id, to_id, binary_pred_name), prob in zip(image_binary_probs[1], image_binary_probs[0]):
 
                # for prob, (fid, pair) in zip(probs, object_pairs):
                    binary_pred_scl.append((prob, (binary_pred_name, fid, from_id, to_id)))

            batched_binary_pred_scl.append(binary_pred_scl)
        
        # Test the accuracy
        result_unary, result_binary = self.accu(batched_obj_labels, batched_gt_object_rels, batched_cate_pred_scl, batched_binary_pred_scl)
        
        return result_unary, result_binary
    
    # Accuracy Metric
    def accu(self, batched_obj_labels, batched_gt_object_rels, batched_image_cate_probs, batched_image_binary_probs, top_pair_num=100):
        result_unary = {}
        result_binary = {}
        top_binary_preds_heap = []
        
        for vid, (image_cate_probs, binary_pred, gt_object_rels) in enumerate(zip(batched_image_cate_probs, batched_image_binary_probs, batched_gt_object_rels)):
            # rela_target = []
           
            obj_labels = batched_obj_labels[vid]
            
            for rela_prob, (rela_name, fid, sub, obj) in binary_pred:
                heapq.heappush(top_binary_preds_heap, (rela_prob, (rela_name, fid, sub, obj)))
            top_binary_preds = heapq.nlargest(top_pair_num, top_binary_preds_heap)
            
            for rela_prob, (rela_name, fid, sub, obj) in top_binary_preds:
                
                if not rela_name in result_binary:
                    result_binary[rela_name] = {}
                    result_binary[rela_name]['gt'] = []
                    result_binary[rela_name]['pred'] = []
                    
                if (sub, obj, rela_name) in gt_object_rels[fid]:
                    result_binary[rela_name]['gt'].append(1)
                else:
                    result_binary[rela_name]['gt'].append(0)
                    
                if rela_prob > 0.5:
                    result_binary[rela_name]['pred'].append(1)
                else:
                    result_binary[rela_name]['pred'].append(0)
                                    
            obj_pred = {}
            for cate_prob, (oid, obj_name) in image_cate_probs:
                if not oid in obj_pred:
                    obj_pred[oid] = (cate_prob, obj_name)
                if cate_prob > obj_pred[oid][0]:
                    obj_pred[oid] = (cate_prob, obj_name)
                    
                    
            for oid, obj_name in obj_labels.items():
                
                if not obj_name in result_unary:
                    result_unary[obj_name] = {}
                    result_unary[obj_name]['gt'] = []
                    result_unary[obj_name]['pred'] = []
                    
                if  obj_labels[oid] == obj_name:
                    result_unary[obj_name]['gt'].append(1)
                else:
                    result_unary[obj_name]['gt'].append(0)
                
                if obj_pred[oid][1] == obj_name:
                    result_unary[obj_name]['pred'].append(1)
                else:
                    result_unary[obj_name]['pred'].append(0)
                
        return result_unary, result_binary
    
    def neg_sample_loss(self, batched_unary_pred, batched_binary_pred, batched_neg_examples):
        batched_neg_sample_loss = []
        for unary_pred, binary_pred, neg_sample in zip(batched_unary_pred, batched_binary_pred, batched_neg_examples):
            ((image_cate_probs, cates), (image_unary_probs, unary_kws)) = unary_pred
            if not len(binary_pred) == 0:
                image_binary_probs, binary_kws = binary_pred
            else:
                image_binary_probs, binary_kws = [], []
                
            neg_cate = neg_sample['neg_entity']
            neg_binary = neg_sample['neg_binary']
        
            cate_indexes = [cate_id for cate_id, cate in enumerate(cates) if cate in neg_cate]
            if not len(cate_indexes) == 0:
                neg_cate_probs = image_cate_probs[cate_indexes, :]
            else:
                neg_cate_probs = torch.tensor([]).to(self.device)
            target_cate_probs = torch.zeros(neg_cate_probs.shape).to(self.device)
            
            binary_indexes = [ binary_id for binary_id, binary_kw in enumerate(binary_kws) if binary_kw in neg_binary]
            if not len(binary_indexes) == 0:
                neg_binary_probs = image_binary_probs[binary_indexes, :]
            else:
                neg_binary_probs = torch.tensor([]).to(self.device)
            target_binary_probs = torch.zeros(neg_binary_probs.shape).to(self.device)
            
            cate_loss, binary_loss = torch.sum(self.loss_fn(neg_cate_probs, target_cate_probs)), torch.sum(self.loss_fn(neg_binary_probs, target_binary_probs))
            batched_neg_sample_loss.append((cate_loss, binary_loss))
        return batched_neg_sample_loss
    
    # Loss function
    def loss(self, t1_frame_ids, batched_t1probs, t2_frame_ids, batched_t2probs, t3_frame_ids, batched_t3probs, 
             batched_action_specs, batched_ys, batched_video_splits, encourage_prop = 0.3, eps = 1e-15):
        batched_loss = []
        batched_video_length = []
        
        current_vid_id = 0
        for video_splits in batched_video_splits:
            batched_video_length.append(video_splits - current_vid_id)
            current_vid_id = video_splits
            
        for t1probs, t2probs, t3probs, y, action_spec, video_length in \
            zip(batched_t1probs, batched_t2probs, batched_t3probs, batched_ys, batched_action_specs, batched_video_length):

            t1_result = self.get_valid_result(t1_frame_ids, t1probs)
            t2_result = self.get_valid_result(t2_frame_ids, t2probs)
            t3_result = self.get_valid_result(t3_frame_ids, t3probs)
            
            if len(t1_result) == 0 and len(t2_result) == 0 and len(t3_result) == 0:
                continue

            results = [t1_result, t2_result, t3_result]
            locations = action_spec['video location']
            assert len(locations) <= 3 and len(locations) > 0
            
            current_loss = []
            for result, location in zip(results, locations):
                if not location in location_consts:
                    continue
                
                encourage_len = math.ceil(video_length * encourage_prop)
                score_for_dist = 1 / encourage_len
                
                # encourage the first part of the total len
                if location == "early":
                    start = 0
                    end = start + encourage_len
                    
                    # TODO: optimize
                    weights = []
                    for i in range(video_length):
                        if i > end:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (end - i + 1)
                            weights.append(weight)
                            
                elif location == "mid":
                    mid = math.ceil(video_length / 2)
                    dis =  math.ceil(encourage_len / 2)
                    start = mid - dis
                    end = mid + dis
                    
                    weights = []
                    for i in range(video_length):
                        if i > end:
                            weights.append(0)
                        elif i < start:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (dis - abs(mid - i) + 1)
                            weights.append(weight)
                            
                else:
                    end = video_length
                    start = end - encourage_len
                    
                    weights = []
                    for i in range(video_length):
                        if i < start:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (i - start + 1)
                            weights.append(weight)
                
                valid_probs = []
                valid_weights = []
                target_y = []
                for prob, frame_id in result:
                    if frame_id >= video_length:
                        continue
                    weight = weights[frame_id]
                    if not weight == 0:
                        valid_probs.append(prob)
                        valid_weights.append(weight)
                        target_y.append(y)
                        
                valid_weights = torch.tensor(valid_weights, device=self.device)
                if len(valid_probs) == 0:
                    continue
                valid_probs = torch.stack(valid_probs)
                valid_prob_logits = torch.log(valid_probs / (1 - valid_probs))
                loss = self.loss_fn(valid_prob_logits, torch.tensor(target_y, dtype=valid_probs[0].dtype, device=self.device))
                loss = loss * valid_weights
                loss = (loss.sum() / valid_weights.sum())
                current_loss.append(loss)
            
            if len(current_loss) == 0:
                continue
            else:
                batched_loss.append(torch.mean(torch.stack(current_loss)))
            # For smaller window, it has lower likelihood of actually capturing the operation
            # We thus assign a weight function for the

        return batched_loss

    def forward(self, batch):
        # Load batch info
        batched_ids = batch['batched_ids']
        batched_captions = batch['batched_captions']
        batched_gt_bboxes = batch['batched_gt_bboxes'] 
        # batched_gt_masks = batch['batched_gt_masks']
        batched_obj_pairs = batch['batched_obj_pairs']
        batched_object_ids = batch['batched_object_ids']
        batched_video_splits = batch['batched_video_splits']
        # batched_reshaped_raw_videos = batch['batched_videos']
        batched_reshaped_raw_videos = batch['batched_reshaped_raw_videos']
        batched_gt_labels = batch['batched_gt_obj_names']
        batched_gt_object_rels = batch['batched_gt_object_rels']
        batched_gpt_specs = batch['batched_gpt_specs']
        
        if len(batched_object_ids) == 0:
            return []
                
        batch_size = len(batched_ids)
        
        # Fetch constants
        batched_unary_kws = []
        batched_binary_kws = []
        batched_consts = []
        batched_pos_consts = []
        batched_neg_consts = []
            
        # Contrastive Learning Setup
        for spec in batched_gpt_specs:
            batched_unary_kws.append(spec['unary_kws'])
            batched_binary_kws.append(spec['binary_kws'])
            batched_consts.append(spec['consts'])
            batched_pos_consts.append(spec['consts'])
        
        if self.use_neg_spec:
            # TODO: fix this
            batched_neg_gpt_specs = batch['batched_neg_gpt_specs']
            for batch_id, (spec, neg_spec) in enumerate(zip(batched_gpt_specs, batched_neg_gpt_specs)):
                deduped_neg_binary = list(set(neg_spec['binary_kws']) - set(batched_binary_kws[batch_id]))
                deduped_neg_entity = list(set(neg_spec['consts']) - set(batched_consts[batch_id]))
                deduped_neg_unary = list(set(neg_spec['unary_kws']) - set(batched_unary_kws[batch_id]))

                batched_unary_kws[batch_id] += (deduped_neg_unary)
                batched_binary_kws[batch_id] += (deduped_neg_binary)
                batched_consts[batch_id] += (deduped_neg_entity)
                batched_neg_consts.append(neg_spec['consts'])
                
        if self.use_neg_kws:
            batched_neg_kws = batch['batched_neg_kws']
            for batch_id, (spec, negative_examples) in enumerate(zip(batched_gpt_specs, batched_neg_kws)):
                deduped_neg_binary = list(set(negative_examples['neg_binary']) - set(batched_binary_kws[batch_id]))
                deduped_neg_entity = list(set(negative_examples['neg_entity']) - set(batched_consts[batch_id]))

                batched_binary_kws[batch_id] += (deduped_neg_binary)
                batched_consts[batch_id] += (deduped_neg_entity)
        
        # Get probabilities
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=batched_consts,
                                batched_object_ids = batched_object_ids,
                                batched_unary_kws=batched_unary_kws,
                                batched_binary_kws=batched_binary_kws,
                                batched_obj_pairs=batched_gt_object_rels, 
                                batched_video_splits=batched_video_splits)
        
        consts = [e for c in batched_pos_consts for e in c]
        const_lookup = {}
        cids = []
        for k, v in enumerate(consts):
            const_lookup[v] = -k
            const_lookup[v.upper()] = -k
            const_lookup[v.lower()] = -k
            cids.append(-k)
        
        if self.use_neg_spec:       
            neg_consts = [e for c in batched_neg_consts for e in c]
            neg_const_lookup = {}
            neg_cids = []
            for k, v in enumerate(neg_consts):
                neg_const_lookup[v] = -k
                neg_const_lookup[v.upper()] = -k
                neg_const_lookup[v.lower()] = -k
                neg_cids.append(-k)

        # batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_object_ids)
        
        # Process unary predicates
        batched_unary_pred_scl = []
        batched_cate_pred_scl = []
        if self.use_neg_spec:
            batched_neg_cate_pred_scl = []
            
        for vid, (image_cate_probs, image_unary_probs) in enumerate(batched_image_unary_probs):
            
            unary_pred_scl = []
            cate_pred_scl = {}
                            
            new_cate_pred_scl = []
            for (object_id, cate_name), prob in zip(image_cate_probs[1], image_cate_probs[0]):
                new_cate_pred_scl.append((prob, (object_id, const_lookup[cate_name] - 1)))
         
            for (unary_name, object_id, frame_id), prob in zip(image_unary_probs[1], image_unary_probs[0]):
                unary_pred_scl.append((prob, (unary_name, frame_id, object_id)))
            
            batched_cate_pred_scl.append(new_cate_pred_scl)
            batched_unary_pred_scl.append(unary_pred_scl)
            
            if self.use_neg_spec:
                new_neg_cate_pred_scl = []
                for oid, object_cate_info in cate_pred_scl.items():
                    for cate_name, prob in object_cate_info.items():
                        if cate_name in neg_const_lookup:
                            new_neg_cate_pred_scl.append((torch.mean(torch.stack(prob)), (oid, neg_const_lookup[cate_name] - 1)))
                batched_neg_cate_pred_scl.append(new_neg_cate_pred_scl)
           
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            # object_pairs = [ (fid, pair) for (ovid, fid, pair) in selected_pairs if vid == ovid]
            # assert (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(object_pairs) == image_binary_probs[0].shape[1]
            for (vid, fid, from_id, to_id, binary_pred_name), prob in zip(image_binary_probs[1], image_binary_probs[0]):
                binary_pred_scl.append((prob, (binary_pred_name, fid, from_id, to_id)))
                
            batched_binary_pred_scl.append(binary_pred_scl)

        # formatted_batched_scl_input_facts = process_batched_facts(batched_scl_input_facts)
        formatted_batched_scl_input_facts = format_batched_facts(batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs)
        
        # Ground truth is 1 as the batch size is always 1
        # TODO: Update this for contrastive setup
        batched_ys = [1] * batch_size

        output = self.reason(**formatted_batched_scl_input_facts)
    
        t1s, t1probs = output['aligned_t1']
        t2s, t2probs = output['aligned_t2']
        t3s, t3probs = output['aligned_t3']

        has_no_answer = (len(output['aligned_t1'][0]) == 1 and output['aligned_t1'][0][0][0] == -1)
        if has_no_answer:
            print(f'Warning: No anwer: {batched_ids}')
                
        batched_align_loss = self.loss(t1s, t1probs, t2s, t2probs, t3s, t3probs, batched_gpt_specs, batched_ys, batched_video_splits)
        
        batched_loss = batched_align_loss
        
        if self.use_neg_spec:
            formatted_batched_neg_scl_input_facts = format_batched_facts(batched_scl_tps, batched_neg_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_neg_gpt_specs)
            output = self.reason(**formatted_batched_neg_scl_input_facts)
            
            t1s, t1probs = output['aligned_t1']
            t2s, t2probs = output['aligned_t2']
            t3s, t3probs = output['aligned_t3']
            neg_batched_ys = [0] * batch_size
            
            batched_neg_spec_loss =  self.loss(t1s, t1probs, t2s, t2probs, t3s, t3probs, batched_neg_gpt_specs, neg_batched_ys, batched_video_splits)
            for batch_id, neg_spec_loss in enumerate(batched_neg_spec_loss):
                if len(batched_loss) > 0:
                    batched_loss[batch_id] += self.neg_spec_weight * neg_spec_loss
                else:
                    batched_loss.append(self.neg_spec_weight * neg_spec_loss)
            
        if self.use_neg_kws:

            batched_neg_sample_loss = self.neg_sample_loss(batched_image_unary_probs, batched_image_binary_probs, batched_neg_kws)            
            for batch_id, (cate_loss, binary_loss) in enumerate(batched_neg_sample_loss):
                if len(batched_loss) > 0:
                    batched_loss[batch_id] += self.neg_entity_kw_cate_weight * cate_loss + self.neg_entity_kw_binary_weight * binary_loss
                else:
                    batched_loss.append(self.neg_entity_kw_cate_weight * cate_loss + self.neg_entity_kw_binary_weight * binary_loss)

        return batched_loss

    def baseline_eval(self):
        self.predicate_model.eval()
        self.predicate_model.num_top_pairs = self.test_num_top_pairs

        total_results_unary = []
        total_results_binary = []
        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):

                result_unary, result_binary = self.baseline_eval_batch(dp_list)
                total_results_unary.append(result_unary)
                total_results_binary.append(result_binary)

            merge_unary_results = combine_baseline_pred_dict_ls(total_results_unary)
            merge_binary_results = combine_baseline_pred_dict_ls(total_results_binary)
            unary_accu, unary_stats = obtain_stats(merge_unary_results)
            binary_accu, binary_stats = obtain_stats(merge_binary_results)
        
        report_str = ["unary stats:"]
        report_str += get_report(unary_stats)
        report_str.append(f"Accuracy: {unary_accu}")
        
        report_str += ["binary stats"]
        report_str += get_report(binary_stats)
        report_str.append(f"Accuracy: {binary_accu}")
        report = "\n".join(report_str)
        
        if not self.report_dir is None:
            report_path = os.path.join(self.report_dir, self.model_name + f'.{self.epoch_ct}.report.txt')
            with open(report_path, 'w') as file:
                file.write(report)  
        
        return(report_str, unary_accu, binary_accu)
        
    def train_epoch(self, n):
            
        self.predicate_model.train()
        
        all_losses = []
        process_failures = []
        all_dps = 0

        iterator = tqdm(self.train_loader)
        for ct, dp_list in enumerate(iterator):
            # if ct > 15:
            #     exit()
            all_dps += 1
            self.optimizer.zero_grad()
            try:
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    loss_ls = self.forward(dp_list)
                    loss = sum(loss_ls)

                    if type(loss) == int or not loss.requires_grad:
                        del loss_ls
                        del loss
                        del dp_list
                        
                        gc.collect()
                        torch.cuda.empty_cache()
                        continue
                
                    self.scaler.scale(loss).backward(retain_graph=True,)
                    if self.max_grad_norm > 0:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(self.predicate_model.parameters(), self.max_grad_norm)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    
                    all_losses += [loss.item() for loss in loss_ls]
                    avg_loss = sum(all_losses)/len(all_losses)
                    iterator.set_description(f'[Train {n}] Loss: {avg_loss}')
            
            except torch.cuda.OutOfMemoryError as e:
                batched_ids = dp_list['batched_ids']
                batched_captions = dp_list['batched_captions']
                process_failures.append((batched_ids, batched_captions))
                with open(self.log_path, 'a') as f:
                    f.write(str((batched_ids, batched_captions)))
                    
                print(f"current out of memory ct: {len(process_failures)} out of {all_dps}")
                
                del dp_list
                gc.collect()
                torch.cuda.empty_cache()
                print()
                continue
            
            del loss_ls
            del loss
            del dp_list
            
            gc.collect()
            torch.cuda.empty_cache()
            # self.train_loader.dataset.shuffle()

        return avg_loss

    def test_epoch(self):
        report_str, unary_accu, binary_accu = self.baseline_eval()

        if self.save_model: 
            torch.save(self.predicate_model.state_dict(), os.path.join(self.model_dir, f"{self.model_name}.{self.epoch_ct}.model"))

        return

    def test(self):
        self.test_epoch()
        
    def train(self, num_epochs):
        start_ct = 1
        if not self.epoch_ct == 0:
            start_ct = self.epoch_ct + 1
            
        for i in range(start_ct, num_epochs + 1):
            self.epoch_ct = i
            self.train_epoch(i)
            self.test_epoch()
            

    def save_scl_file(self, datapoint, object_tps, current_constraint):
        scl_file_content = obtain_scl_file(object_tps, current_constraint, self.common_scl)
        scl_file_name = datapoint['id'] + '.scl'
        if not self.save_scl_dir is None:
            scl_path = os.path.join(self.save_scl_dir, scl_file_name)
            with open(scl_path, 'w') as scl_file:
                scl_file.write(scl_file_content)


if __name__ == "__main__":
    
    # Set up data directories and paths
    dataset = "open_pvsg"
    cache_file_name = f"gpt_specs_prog_str.json"
    # cache_file_name = f"gpt_specs_videollamav2_prog_str.json"
    data_file_name = 'pvsg.json'
    checkpoint_file_name = 'ckpt_pretrain_3d-init_webvid-cc_2d_rm-bm_0.3_ep10.pt'

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    # checkpoint_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../../data/violet2_checkpoints"))
    
    # ckpt_path = os.path.join(data_dir, "violet2_checkpoints", checkpoint_file_name)
    ckpt_path = os.path.abspath(os.path.join(data_dir, "../violet2_checkpoints", checkpoint_file_name))

    assert os.path.exists(data_dir)
    assert os.path.exists(ckpt_path)
    
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    caption2scl = json.load(open(cache_path, 'r'))
    
    scl_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../scl'))
    assert os.path.exists(scl_dir)
    common_scl_path = os.path.join(scl_dir, f'ltl_disj.scl')
    assert os.path.exists(common_scl_path)
    
    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')
    assert (os.path.exists(model_dir))
    
    # Setup argument parser
    parser = ArgumentParser(dataset)
    parser.add_argument("--phase", type=str, default='test')
    parser.add_argument("--n-epochs", type=int, default=51)
    parser.add_argument("--load-model", default=True)
    parser.add_argument("--save-model", default=False)
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")
    parser.add_argument("--use-neg-spec",  type=bool, default=False)
    parser.add_argument("--use-neg-kws", type=bool, default=False)
    parser.add_argument("--neg-example-ct", type=int, default=2)
    parser.add_argument("--neg-spec-weight", type=float, default=0.1)
    parser.add_argument("--neg_entity_kw_binary_weight", type=float, default=0.1)
    parser.add_argument("--neg_entity_kw_cate_weight", type=float, default=0)

    parser.add_argument("--parallel",  action='store_true')
    parser.add_argument("--report-dir", type=str, default="")
    
    parser.add_argument("--train-num-top-pairs", type=int, default=5)
    parser.add_argument("--test-num-top-pairs", type=int, default=30)
    parser.add_argument("--max-video-len", type=int, default=8)
    
    # setup question path
    parser.add_argument("--train-num", type=int, default=5000)
    parser.add_argument("--val-num", type=int, default=1000)
    parser.add_argument("--train-percentage", type=int, default=10)
    parser.add_argument("--test-percentage", type=int, default=10)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.00001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--norm-x", type=int, default=224)
    parser.add_argument("--norm-y", type=int, default=224)

    parser.add_argument("--ckpt-path", type=str, default=ckpt_path)
    parser.add_argument("--tokenizer", type=str, default="bert-base-uncased")
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-dir", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_false")
    parser.add_argument("--gpu", type=int, default=-1)
    
    
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    config_dir = os.path.join(model_path, "visbackbone")
    model_name = f"test_laser_vllama_violet_{dataset}" + \
                  f"_training_{args.train_percentage}" +\
                  f"_lr_{args.learning_rate}" + \
                  f"_negspec_{args.use_neg_spec}" + \
                  f"_kwweight_{args.neg_spec_weight}" + \
                  f"_negkw_{args.use_neg_kws}" + \
                  f"_negcate_{args.neg_entity_kw_cate_weight}" + \
                  f"_negbin_{args.neg_entity_kw_binary_weight}" + \
                  f"_mvl_{args.max_video_len}" + \
                  f"_seed_{args.seed}" + \
                  f"_batch_size_{args.batch_size}" + \
                  f"_prov_{args.provenance}" + \
                  f"_tpk_{args.train_top_k}"
    model_name = "laser_vllama_violet_open_pvsg_training_100_seed_1234_batch_size_1_lr_1e-06_prov_difftopkproofs_tpk_3_negspec_True_negkw_True_kwweight_0.1_mvl_3"
    log_name = f"{model_name}.log"
    log_path = os.path.join(data_dir, log_name)
    if os.path.exists(log_path):
        os.remove(log_path)
        
    if args.model_name is None:
        args.model_name = model_name

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    with open(data_path, 'r') as f:
        anno = json.load(f)
    
    # See video id in anno['split'].
    data = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    train_dataset, valid_dataset, train_loader, test_loader = open_pvsg_loader(
            cache_path=cache_path,
            dataset_dir=data_dir, 
            dataset_name=data_file_name, 
            batch_size=args.batch_size, 
            device=device, 
            training_percentage=args.train_percentage, 
            testing_percentage=args.test_percentage, 
            max_video_len=args.max_video_len,
            neg_kws=args.use_neg_kws,
            neg_spec=args.use_neg_spec,
            neg_example_ct=args.neg_example_ct,
            neg_example_file_name="neg_examples.json",
            set_norm_x=args.norm_x,
            set_norm_y=args.norm_y,
            )
    
    trainer = Trainer(train_loader=train_loader,
                      test_loader=train_loader, 
                      device=device, 
                      caption2scl=caption2scl, 
                      save_scl_dir=scl_dir, 
                      common_scl_path=common_scl_path,
                      tokenizer_base=args.tokenizer,
                      violet_config_dir=config_dir,
                      ckpt_path=args.ckpt_path, 
                      latent_dim=args.latent_dim,
                      model_dir=args.model_dir, 
                      model_name=args.model_name,
                      learning_rate=args.learning_rate, 
                      load_model=args.load_model,
                      provenance=args.provenance,
                      save_model=args.save_model,
                      video_save_dir= args.video_save_dir,
                      train_num_top_pairs=args.train_num_top_pairs,
                      test_num_top_pairs=args.test_num_top_pairs,
                      report_dir=args.report_dir,
                      use_neg_kws=args.use_neg_kws,
                      use_neg_spec=args.use_neg_spec,
                      neg_spec_weight=args.neg_spec_weight,
                      neg_entity_kw_cate_weight=args.neg_entity_kw_cate_weight,
                      neg_entity_kw_binary_weight=args.neg_entity_kw_binary_weight,
                      max_size_frame=args.max_video_len,
                      log_path=log_path,
                      )
    
    print(args.model_name)
    if args.phase == "train":
        print("train")
        trainer.train(args.n_epochs)
    elif args.phase == "test":
        print("baseline eval")
        trainer.baseline_eval()
    
    print("end")