import os

import json
import random
import cv2
from argparse import ArgumentParser
import math
from torch import nn, optim
from tqdm import tqdm
import torch
import scallopy
import sys
import pickle
import gc

import re
import numpy as np
import torch.distributed.autograd as dist_autograd
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed.rpc as rpc
from torch.distributed.rpc import RRef
import torch.multiprocessing as mp

model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../models'))
assert os.path.exists(model_path)

sys.path.append(model_path)
from open_pvsg_predicate_model import PredicateModel
from dataset import *
from utils import *


class Trainer():
    def __init__(self, train_loader, test_loader, device,
                 caption2scl, common_scl_path,
                 provenance="difftopkproofs", k=3, save_scl_dir=None,
                 model_dir=None, model_name=None, learning_rate=None,
                 latent_dim=64, model_layer=2, load_model=False, save_model=True,
                 save_video=False, video_save_dir=None, violation_weight=0.1,
                 with_violation=True, use_contrast=False, world_size=4, parallel=False, 
                 use_neg_sampling=True, num_top_pairs=20):

        self.save_video = save_video
        self.video_save_dir = video_save_dir

        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.caption2scl = caption2scl
        self.template2action = {}
        self.common_scl = open(common_scl_path).read()
        self.save_model = save_model
        self.with_violation = with_violation
        self.use_contrast = use_contrast
        self.use_neg_sampling = use_neg_sampling

        self.scallop_ctx = scallopy.ScallopContext(provenance=provenance, k=k)
        self.scallop_ctx.import_file(common_scl_path)
        self.scallop_ctx.set_non_probabilistic(non_prob_gpt_preds)
        
        # self.reason = self.scallop_ctx.forward_function("constraint")
        if with_violation:
            self.reason = self.scallop_ctx.forward_function(output_mappings={
                "temporal_diff": None,
                "violation": (),
            }, retain_graph=True)
        else:
            self.reason = self.scallop_ctx.forward_function(output_mappings={
                "aligned_t1": None,
                "aligned_t2": None, 
                "aligned_t3": None,
            # }, dispatch="single")
            }, retain_graph=True)
            self.reason.to(self.device)

        if load_model:
            # predicate_model = torch.load(os.path.join(model_dir, model_name + '.latest.model'))
            predicate_model = torch.load(os.path.join(model_dir, model_name + '.2.model'))
        else:
            predicate_model = PredicateModel(hidden_dim = latent_dim, num_top_pairs=num_top_pairs, device=device).to(device)

        predicate_model.num_top_pairs = num_top_pairs
        
        if parallel:
            torch.distributed.init_process_group(backend='nccl', world_size=world_size, init_method='...')
            self.predicate_model = DDP(predicate_model)
        else:
            self.predicate_model = predicate_model
        
        self.optimizer = optim.Adam(self.predicate_model.parameters(), lr=learning_rate)
            
        self.model_dir = model_dir
        self.model_name = model_name
        self.min_loss = 10000000000
        self.violation_weight = violation_weight

        if not save_scl_dir is None:
            self.save_scl_dir = save_scl_dir
            
        self.loss_fn = nn.BCELoss(reduction='none')
        

    def get_valid_result(self, frame_ids, probs):
        result = []
        for frame_id, prob in zip(frame_ids, probs):
            assert len(frame_id) == 1
            frame_id = frame_id[0]
            
            if frame_id == -1:
                continue
            
            result.append((prob, frame_id))
                
        result = sorted(result, key=lambda x: x[1])
        return result
    
    def baseline_eval_batch(self, batch):
        
        batched_ids, batched_captions, batched_gt_bboxes, batched_gt_masks, \
            batched_obj_pairs, batched_object_ids, batched_video_splits, \
            batched_reshaped_raw_videos, batched_gt_labels, \
            batched_gt_object_rels, batched_gpt_specs, _, _ = batch
            
        if len(batched_object_ids) == 0:
            return {}, {}
                
        batch_size = len(batched_ids)
        
        batched_start_id = {}
        batched_end_id = {}
        for vid, fid, _ in batched_object_ids:
            if not vid in batched_start_id:
                batched_start_id[vid] = fid
            if fid < batched_start_id[vid]:
                batched_start_id[vid] = fid
                
            if not vid in batched_end_id:
                batched_end_id[vid] = fid
            if fid > batched_end_id[vid]:
                batched_end_id[vid] = fid
        
        cleaned_batched_object_pairs = []
        batched_unary_kws = [[]] * batch_size
        batched_binary_kws = []
        batched_consts = []
        batched_obj_labels = {}
        
        for (vid, fid, label), (vid, fid, oid) in zip(batched_gt_labels, batched_object_ids):
            
            if not vid in batched_obj_labels:
                batched_obj_labels[vid] = {}
            
            if oid in batched_obj_labels[vid]:
                assert label == batched_obj_labels[vid][oid]
            batched_obj_labels[vid][oid] = label
        
        batched_obj_names = []
        for vid, vid_info in batched_obj_labels.items():
            obj_names = set()
            for oid, o_name in vid_info.items():
                obj_names.add(o_name)
            batched_obj_names.append(list(obj_names))
            
        for vid, gt_object_rels in enumerate(batched_gt_object_rels):
            binary_kws = set()
            for fid, obj_pairs in enumerate(gt_object_rels):
                for  (sub, obj, binary_kw) in obj_pairs:
                    if (vid, fid, (sub, obj)) in batched_obj_pairs:
                        cleaned_batched_object_pairs.append((vid, fid, (sub, obj)))
                        binary_kws.add(binary_kw)
                        
            batched_binary_kws.append(list(binary_kws))
        
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs  = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=batched_obj_names,
                                batched_object_ids = batched_object_ids,
                                batched_unary_kws=batched_unary_kws,
                                batched_binary_kws=batched_binary_kws,
                                batched_obj_pairs=cleaned_batched_object_pairs, 
                                batched_video_splits=batched_video_splits)
        
        assert len(selected_pairs) == len(cleaned_batched_object_pairs)
        
        batched_unary_pred_scl = []
        batched_cate_pred_scl = []
        for vid, (image_cate_probs, image_unary_probs) in enumerate(batched_image_unary_probs):
            
            unary_pred_scl = []
            cate_pred_scl = {}
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
            # if not len(image_unary_probs[1]) == 0:
            #     assert len(object_ids) == image_unary_probs[0].shape[1]
                
            for cate_name, probs in zip(image_cate_probs[1], image_cate_probs[0]):
                
                for prob, (fid, oid) in zip(probs, object_ids):
                    if not oid in cate_pred_scl:
                        cate_pred_scl[oid] = {}
                    if not cate_name in cate_pred_scl[oid]:
                        cate_pred_scl[oid][cate_name] = []
                    cate_pred_scl[oid][cate_name].append(prob)
                    
            new_cate_pred_scl = []
            for oid, object_cate_info in cate_pred_scl.items():
                for cate_name, prob in object_cate_info.items():
                    new_cate_pred_scl.append((torch.mean(torch.stack(prob)), (oid, cate_name)))
                    
            batched_cate_pred_scl.append(new_cate_pred_scl)
            
         # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            object_pairs = [ (fid, pair) for (ovid, fid, pair) in selected_pairs if vid == ovid]
            assert len(object_pairs) == image_binary_probs[0].shape[1]
            
            for binary_pred_name, probs in zip(image_binary_probs[1], image_binary_probs[0]):
 
                for prob, (fid, pair) in zip(probs, object_pairs):
                    binary_pred_scl.append((prob, (binary_pred_name, fid, pair[0], pair[1])))

            batched_binary_pred_scl.append(binary_pred_scl)
            
        result_unary, result_binary = self.accu(batched_obj_labels, batched_gt_object_rels, batched_cate_pred_scl, batched_binary_pred_scl, selected_pairs)
        
        return result_unary, result_binary
    
    def accu(self, batched_obj_labels, batched_gt_object_rels, batched_image_cate_probs, batched_image_binary_probs, selected_pairs):
        result_unary = {}
        result_binary = {}
        
        for vid, (image_cate_probs, binary_pred, gt_object_rels) in enumerate(zip(batched_image_cate_probs, batched_image_binary_probs, batched_gt_object_rels)):
            # rela_target = []
            rela_pred = []
            cate_target = []
            cate_pred = [] 
            obj_labels = batched_obj_labels[vid]
            
            for rela_prob, (rela_name, fid, sub, obj) in binary_pred:
                if not rela_name in result_binary:
                    result_binary[rela_name] = {}
                    result_binary[rela_name]['gt'] = []
                    result_binary[rela_name]['pred'] = []
                if (sub, obj, rela_name) in gt_object_rels[fid]:
                    result_binary[rela_name]['gt'].append(1)
                else:
                    result_binary[rela_name]['gt'].append(0)
                    
                if rela_prob > 0.5:
                    result_binary[rela_name]['pred'].append(1)
                else:
                    result_binary[rela_name]['pred'].append(0)
                    
            obj_pred = {}
            for cate_prob, (oid, obj_name) in image_cate_probs:
                if not oid in obj_pred:
                    obj_pred[oid] = (cate_prob, obj_name)
                if cate_prob > obj_pred[oid][0]:
                    obj_pred[oid] = (cate_prob, obj_name)
                    
            for oid, (cate_prob, obj_name) in obj_pred.items():
                if not obj_name in result_unary:
                    result_unary[obj_name] = {}
                    result_unary[obj_name]['gt'] = []
                    result_unary[obj_name]['pred'] = []
                   
                # Use recall@1 
                if  obj_labels[oid] == obj_name:
                    result_unary[obj_name]['gt'].append(1)
                else:
                    result_unary[obj_name]['gt'].append(0)
                
                if obj_pred[oid][1] == obj_name:
                    result_unary[obj_name]['pred'].append(1)
                else:
                    result_unary[obj_name]['pred'].append(0)
        
        return result_unary, result_binary
    
    def neg_sample_loss(self, batched_unary_pred, batched_binary_pred, batched_neg_examples):
        batched_loss = []
        for unary_pred, binary_pred, neg_sample in zip(batched_unary_pred, batched_binary_pred, batched_neg_examples):
            ((image_cate_probs, cates), (image_unary_probs, unary_kws)) = unary_pred
            image_binary_probs, binary_kws =  binary_pred
            neg_cate = neg_sample['neg_entity']
            neg_binary = neg_sample['neg_binary']
        
            cate_indexes = [ 1 if cate in neg_cate else 0 for cate in cates]
            neg_cate_probs = image_cate_probs[cate_indexes, :]
            target_cate_probs = torch.zeros(neg_cate_probs.shape).to(self.device)
            
            binary_indexes = [ 1 if binary_kw in binary_kw else 0 for binary_kw in binary_kws]
            neg_binary_probs = image_binary_probs[binary_indexes, :]
            target_binary_probs = torch.zeros(neg_binary_probs.shape).to(self.device)
            
            loss = torch.sum(self.loss_fn(neg_cate_probs, target_cate_probs)) + torch.sum(self.loss_fn(neg_binary_probs, target_binary_probs))
            batched_loss.append(loss)
        return batched_loss


    def loss(self, t1_frame_ids, batched_t1probs, t2_frame_ids, batched_t2probs, t3_frame_ids, batched_t3probs, 
             batched_action_specs, batched_ys, batched_video_splits, encourage_prop = 0.3, eps = 1e-15):
        batched_loss = []
        batched_video_length = []
        
        current_vid_id = 0
        for video_splits in batched_video_splits:
            batched_video_length.append(video_splits - current_vid_id)
            current_vid_id = video_splits
            
        for t1probs, t2probs, t3probs, y, action_spec, video_length in \
            zip(batched_t1probs, batched_t2probs, batched_t3probs, batched_ys, batched_action_specs, batched_video_length):

            t1_result = self.get_valid_result(t1_frame_ids, t1probs)
            t2_result = self.get_valid_result(t2_frame_ids, t2probs)
            t3_result = self.get_valid_result(t3_frame_ids, t3probs)
            
            
            if len(t1_result) == 0 and len(t2_result) == 0 and len(t3_result) == 0:
                continue

            results = [t1_result, t2_result, t3_result]
            locations = action_spec['video location']
            assert len(locations) <= 3 and len(locations) > 0
            
            current_loss = []
            for result, location in zip(results, locations):
                if not location in location_consts:
                    continue
                # assert location in location_consts
                
                encourage_len = math.ceil(video_length * encourage_prop)
                score_for_dist = 1 / encourage_len
                
                
                # encourage the first part of the total len
                if location == "early":
                    start = 0
                    end = start + encourage_len
                    
                    # TODO: optimize
                    weights = []
                    for i in range(video_length):
                        if i > end:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (end - i + 1)
                            weights.append(weight)
                            
                elif location == "mid":
                    mid = math.ceil(video_length / 2)
                    dis =  math.ceil(encourage_len / 2)
                    start = mid - dis
                    end = mid + dis
                    
                    weights = []
                    for i in range(video_length):
                        if i > end:
                            weights.append(0)
                        elif i < start:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (dis - abs(mid - i) + 1)
                            weights.append(weight)
                            
                else:
                    end = video_length
                    start = end - encourage_len
                    
                    weights = []
                    for i in range(video_length):
                        if i < start:
                            weights.append(0)
                        else:
                            weight = score_for_dist * (i - start + 1)
                            weights.append(weight)
                
                valid_probs = []
                valid_weights = []
                target_y = []
                for prob, frame_id in result:
                    if frame_id >= video_length:
                        continue
                    weight = weights[frame_id]
                    if not weight == 0:
                        valid_probs.append(prob)
                        valid_weights.append(weight)
                        target_y.append(y)
                        
                valid_weights = torch.tensor(valid_weights, device=self.device)
                if len(valid_probs) == 0:
                    continue
                loss = self.loss_fn(torch.stack(valid_probs), torch.tensor(target_y, dtype=torch.float32, device=self.device))
                loss = loss * valid_weights
                loss = (loss.sum() / valid_weights.sum())
                current_loss.append(loss)
            
            if len(current_loss) == 0:
                continue
            else:
                batched_loss.append(torch.mean(torch.stack(current_loss)))
            # For smaller window, it has lower likelihood of actually capturing the operation
            # We thus assign a weight function for the

        return batched_loss


    def forward(self, batch):
        correct_ls = []
        loss_ls = []
        missing_name = []

        batched_ids, batched_captions, batched_gt_bboxes, batched_gt_masks, \
            batched_obj_pairs, batched_object_ids, batched_video_splits, \
            batched_reshaped_raw_videos, _, \
            batched_gt_object_rels, batched_gpt_specs, batched_neg_gpt_specs, \
            batched_negative_examples = batch
        
        if len(batched_object_ids) == 0:
            return []
                
        batch_size = len(batched_ids)
        
        batched_unary_kws = []
        batched_binary_kws = []
        batched_consts = []
        
        if self.use_contrast:
            for spec, neg_spec in zip(batched_gpt_specs, batched_neg_gpt_specs):
                batched_unary_kws.append(list(set(spec['unary_kws'] + neg_spec['unary_kws'])))
                batched_binary_kws.append(list(set(spec['binary_kws'] + neg_spec['binary_kws'])))
                batched_consts.append(list(set(spec['consts'] + neg_spec['consts'])))
                
        elif self.use_neg_sampling:
            for spec, negative_examples in zip(batched_gpt_specs, batched_negative_examples):
                batched_unary_kws.append(list(set(spec['unary_kws'])))
                batched_binary_kws.append(list(set(spec['binary_kws'] + negative_examples['neg_binary'])))
                batched_consts.append(list(set(spec['consts'] + negative_examples['neg_entity'])))
            
        else: 
            for spec, negative_examples in zip(batched_gpt_specs, batched_negative_examples):
                batched_unary_kws.append(list(set(spec['unary_kws'])))
                batched_binary_kws.append(list(set(spec['binary_kws'])))
                batched_consts.append(list(set(spec['consts'])))
            
        batch_size = len(batched_ids)
        
        batched_start_id = {}
        batched_end_id = {}
        for vid, fid, _ in batched_object_ids:
            if not vid in batched_start_id:
                batched_start_id[vid] = fid
            if fid < batched_start_id[vid]:
                batched_start_id[vid] = fid
                
            if not vid in batched_end_id:
                batched_end_id[vid] = fid
            if fid > batched_end_id[vid]:
                batched_end_id[vid] = fid
        
        cleaned_batched_object_pairs = []
        for vid, fid, pair in batched_obj_pairs:
            if batched_start_id[vid] == fid or batched_end_id[vid] == fid:
                cleaned_batched_object_pairs.append((vid, fid, pair))
        
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs  = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=batched_consts,
                                batched_object_ids = batched_object_ids,
                                batched_unary_kws=batched_unary_kws,
                                batched_binary_kws=batched_binary_kws,
                                batched_obj_pairs=cleaned_batched_object_pairs, 
                                batched_video_splits=batched_video_splits)

        batched_object_pairs = selected_pairs
        
        consts = [e for c in batched_consts for e in c]

        const_lookup = {}
        cids = []
        for k, v in enumerate(consts):
            const_lookup[v] = -k
            const_lookup[v.upper()] = -k
            const_lookup[v.lower()] = -k

            cids.append(-k)
        
        # batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_consts, batched_object_ids)
        
        # Process unary predicates
        batched_unary_pred_scl = []
        batched_cate_pred_scl = []
        for vid, (image_cate_probs, image_unary_probs) in enumerate(batched_image_unary_probs):
            
            unary_pred_scl = []
            cate_pred_scl = {}
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
            # if not len(image_unary_probs[1]) == 0:
            #     assert len(object_ids) == image_unary_probs[0].shape[1]
                
            for cate_name, probs in zip(image_cate_probs[1], image_cate_probs[0]):
                
                for prob, (fid, oid) in zip(probs, object_ids):
                    if not oid in cate_pred_scl:
                        cate_pred_scl[oid] = {}
                    if not cate_name in cate_pred_scl[oid]:
                        cate_pred_scl[oid][cate_name] = []
                    cate_pred_scl[oid][cate_name].append(prob)
                    
            new_cate_pred_scl = []
            for oid, object_cate_info in cate_pred_scl.items():
                for cate_name, prob in object_cate_info.items():
                    new_cate_pred_scl.append((torch.mean(torch.stack(prob)), (oid, const_lookup[cate_name] - 1)))
                    
            for unary_pred_name, probs in zip(image_unary_probs[1], image_unary_probs[0]):
                
                for prob, (fid, oid) in zip(probs, object_ids):
                    unary_pred_scl.append((prob, (unary_pred_name, fid, oid)))
            
            batched_cate_pred_scl.append(new_cate_pred_scl)
            batched_unary_pred_scl.append(unary_pred_scl)
            
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            
            object_pairs = [ (fid, pair) for (ovid, fid, pair) in batched_object_pairs if vid == ovid]
            assert len(object_pairs) == image_binary_probs[0].shape[1]
            
            for binary_pred_name, probs in zip(image_binary_probs[1], image_binary_probs[0]):
 
                for prob, (fid, pair) in zip(probs, object_pairs):
                    binary_pred_scl.append((prob, (binary_pred_name, fid, pair[0], pair[1])))

            batched_binary_pred_scl.append(binary_pred_scl)

        batched_outputs =[]
        batched_ys = []
        batched_formatted_ys = []
        batched_scl_input_facts = []
        batched_neg_scl_input_facts = []
        all_single_gt = []
        all_single_pred = []

        for vid, (data_id, caption, scl_tp, cate_pred_tp, unary_pred_tp, binary_pred_tp, gpt_spec, neg_gpt_spec) \
            in enumerate(zip(batched_ids, batched_captions, batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs, batched_neg_gpt_specs)):

            # Give an ID to all required placeholders and object names
            scl_input_facts = {}
            neg_scl_input_facts = {}
            neg_caption = neg_gpt_spec['caption']

            scl_input_facts.update(scl_tp)
            scl_input_facts['name'] = (cate_pred_tp)
            scl_input_facts['sg_unary_atom'] = (unary_pred_tp)
            scl_input_facts['sg_binary_atom'] = (binary_pred_tp)
            scl_input_facts['variable'] = [tuple([i + 1]) for i in range(len(gpt_spec['args']))]
            # scl_input_facts['name'] = [(vid, const_lookup[p]) for vid, p in zip(var2vid.values(), gpt_spec['args'])]
            time_stamp_ids, time_stamp_fact_ls = self.caption2scl[caption]['time_stamp_ids'], self.caption2scl[caption]['time_stamp_facts']
            
            neg_scl_input_facts.update(scl_tp)
            neg_scl_input_facts['name'] = (cate_pred_tp)
            neg_scl_input_facts['sg_unary_atom'] = (unary_pred_tp)
            neg_scl_input_facts['sg_binary_atom'] = (binary_pred_tp)
            neg_scl_input_facts['variable'] = [tuple([i + 1]) for i in range(len(neg_gpt_spec['args']))]
            # scl_input_facts['name'] = [(vid, const_lookup[p]) for vid, p in zip(var2vid.values(), gpt_spec['args'])]
            neg_time_stamp_ids, neg_time_stamp_fact_ls = self.caption2scl[neg_caption]['time_stamp_ids'], self.caption2scl[neg_caption]['time_stamp_facts']
            
            
            scl_input_facts['time_stamp'] = []
            for tid, teid in enumerate(time_stamp_ids):
                scl_input_facts['time_stamp'].append((tid + 1, teid))
            
            scl_input_facts['time_stamp_ct'] = [tuple([len(time_stamp_ids)])]
            for time_stamp_facts in time_stamp_fact_ls:
                for k, vs in time_stamp_facts.items():
                    if k not in scl_input_facts:
                        scl_input_facts[k] = []
                    scl_input_facts[k] += [tuple(v) for v in vs]

            batched_scl_input_facts.append(scl_input_facts)
            
            neg_scl_input_facts['time_stamp'] = []
            for tid, teid in enumerate(neg_time_stamp_ids):
                neg_scl_input_facts['time_stamp'].append((tid + 1, teid))
            
            neg_scl_input_facts['time_stamp_ct'] = [tuple([len(neg_time_stamp_ids)])]
            for neg_time_stamp_facts in neg_time_stamp_fact_ls:
                for k, vs in neg_time_stamp_facts.items():
                    if k not in neg_scl_input_facts:
                        neg_scl_input_facts[k] = []
                    neg_scl_input_facts[k] += [tuple(v) for v in vs]

            batched_neg_scl_input_facts.append(neg_scl_input_facts)

        batched_ys = [1] * batch_size
        neg_batched_ys = [0] * batch_size

        formatted_batched_scl_input_facts = process_batched_facts(batched_scl_input_facts)
        formatted_batched_neg_scl_input_facts = process_batched_facts(batched_neg_scl_input_facts)
        
        output = self.reason(**formatted_batched_scl_input_facts)
    
        t1s, t1probs = output['aligned_t1']
        t2s, t2probs = output['aligned_t2']
        t3s, t3probs = output['aligned_t3']

        align_loss = self.loss(t1s, t1probs, t2s, t2probs, t3s, t3probs, batched_gpt_specs, batched_ys, batched_video_splits)
        
        if self.use_contrast:
            output = self.reason(**formatted_batched_neg_scl_input_facts)
        
            t1s, t1probs = output['aligned_t1']
            t2s, t2probs = output['aligned_t2']
            t3s, t3probs = output['aligned_t3']

            contrast_loss = self.loss(t1s, t1probs, t2s, t2probs, t3s, t3probs, batched_neg_gpt_specs, neg_batched_ys, batched_video_splits)
            loss = contrast_loss + align_loss
            
        elif self.use_neg_sampling:
            neg_sample_loss = self.neg_sample_loss(batched_image_unary_probs, batched_image_binary_probs, batched_negative_examples)            
            loss = neg_sample_loss + align_loss
        else:
            loss = align_loss
            
        return loss

    def baseline_eval(self):
        self.predicate_model.eval()

        total_results_unary = []
        total_results_binary = []
        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):
                # if ct > 5:
                #     break
                
                result_unary, result_binary = self.baseline_eval_batch(dp_list)
                total_results_unary.append(result_unary)
                total_results_binary.append(result_binary)

            merge_unary_results = combine_baseline_pred_dict_ls(total_results_unary)
            merge_binary_results = combine_baseline_pred_dict_ls(total_results_binary)
            unary_accu, unary_stats = obtain_stats(merge_unary_results)
            binary_accu, binary_stats = obtain_stats(merge_binary_results)
        
        print("unary stats:")
        pretty_print(unary_stats)
        print(f"Accuracy: {unary_accu}")
        print("binary stass")
        pretty_print(binary_stats)
        print(f"Accuracy: {binary_accu}")
    
    def train_epoch(self, n):
        
        # mp.spawn(fn,
        #      args=gpu_ids,
        #      nprocs=len(gpu_ids),
        #      join=True)
            
        self.predicate_model.train()
        
        all_losses = []

        iterator = tqdm(self.train_loader)
        for ct, dp_list in enumerate(iterator):
            # if ct > 15:
            #     exit()
            
            self.optimizer.zero_grad()
            loss_ls = self.forward(dp_list)

            loss = sum(loss_ls)
            if type(loss) == int:
                continue
            loss.backward(retain_graph=True)
            self.optimizer.step()
            self.optimizer.zero_grad()
            
            all_losses += [loss.item() for loss in loss_ls]
            avg_loss = sum(all_losses)/len(all_losses)
            iterator.set_description(f'[Train {n}] Loss: {avg_loss}')
            del loss_ls
            del loss
            del dp_list
            
            gc.collect()
            torch.cuda.empty_cache()
            # self.train_loader.dataset.shuffle()

        return avg_loss

    def test_epoch(self, n):
        self.predicate_model.eval()
        all_losses = []
        avg_loss = self.min_loss

        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):
                if self.use_contrast:
                    loss_ls = self.forward_contrast(dp_list)
                else:
                    loss_ls = self.forward(dp_list)

                if len(loss_ls) == 0:
                    continue
                
                all_losses += [loss.item() for loss in loss_ls]
                avg_loss = sum(all_losses)/len(all_losses)
                iterator.set_description(f'[Test {n}] Loss: {avg_loss}')
                del loss_ls
                del dp_list
                
                gc.collect()
                torch.cuda.empty_cache()

        # Save model
        if avg_loss < self.min_loss and self.save_model:
            self.min_loss = avg_loss
            torch.save(self.predicate_model, os.path.join(self.model_dir, f"{self.model_name}.best.model"))
        if self.save_model:
            torch.save(self.predicate_model, os.path.join(self.model_dir, f"{self.model_name}.{n}.model"))
            # torch.save(self.predicate_model, os.path.join(self.model_dir, f"{self.model_name}.latest.model"))

        return avg_loss

    def test(self):
        self.test_epoch(0)

    def train(self, num_epochs):
        for i in range(1, num_epochs + 1):
            self.train_epoch(i)
            self.test_epoch(i)

    def save_scl_file(self, datapoint, object_tps, current_constraint):
        scl_file_content = obtain_scl_file(object_tps, current_constraint, self.common_scl)
        scl_file_name = datapoint['id'] + '.scl'
        if not self.save_scl_dir is None:
            scl_path = os.path.join(self.save_scl_dir, scl_file_name)
            with open(scl_path, 'w') as scl_file:
                scl_file.write(scl_file_content)

if __name__ == "__main__":
    
    dataset = "open_pvsg"
    # cache_file_name = f"{dataset}_v2_gpt4_cache.json"
    cache_file_name = f"gpt_specs_scl.json"
    # action_scl_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs_scl.json')
    data_file_name = 'pvsg.json'

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    
    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/{dataset}'))
    assert os.path.exists(data_dir)
    scl_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../scl'))
    assert os.path.exists(scl_dir)
    scl_path = os.path.join(scl_dir, f'{dataset}_ltl.scl')
    
    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')

    parser = ArgumentParser(dataset)
    parser.add_argument("--phase", type=str, default='train')
    parser.add_argument("--n-epochs", type=int, default=50)
    parser.add_argument("--load_model", default=False)
    parser.add_argument("--save_model", default=True)
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")
    parser.add_argument("--negative-sampling", type=bool, default=False)
    parser.add_argument("--violation-weight", type=float, default=0.05)
    parser.add_argument("--with-violation",  action='store_true')
    parser.add_argument("--use-contrast", type=bool, default=True)
    parser.add_argument("--use-neg-sampling", type=bool, default=False)
    parser.add_argument("--parallel",  action='store_true')
    parser.add_argument("--num_top_pairs", type=int, default=300)
    parser.add_argument("--max_video_len", type=int, default=15)
    
        # setup question path
    parser.add_argument("--train_num", type=int, default=5000)
    parser.add_argument("--val_num", type=int, default=1000)
    parser.add_argument("--training_percentage", type=int, default=100)
    parser.add_argument("--test_percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.000001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-dir", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_true")
    parser.add_argument("--gpu", type=int, default=0)
    

    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # name = f"sth_contrast_{args.use_contrast}_wv_{args.with_violation}_{args.training_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_vw_{args.violation_weight}_b3s2"
    model_name = f"laser_{dataset}_training_{args.training_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_contrast_{args.use_contrast}_ns_{args.use_neg_sampling}"
    # model_name = f"laser_{dataset}_training_{args.training_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_b3s2"

    if args.model_name is None:
        args.model_name = model_name
    # print(name)

    common_scl_path = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../scl/{dataset}_ltl.scl'))
    assert os.path.exists(common_scl_path)

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    # device = "cpu"
    
    with open(data_path, 'r') as f:
        anno = json.load(f)

    caption2scl = json.load(open(cache_path, 'r'))
    
    # See video id in anno['split'].
    data = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    train_dataset, valid_dataset, train_loader, test_loader = open_pvsg_loader(
            cache_path=cache_path,
            dataset_dir=data_dir, 
            dataset_name=data_file_name, 
            batch_size=args.batch_size, 
            device=device, 
            training_percentage=args.training_percentage, 
            testing_percentage=args.test_percentage, 
            max_video_len=args.max_video_len)
    
    trainer = Trainer(train_loader=train_loader, test_loader=test_loader, device=device, caption2scl=caption2scl, 
                      save_scl_dir=scl_dir, common_scl_path=common_scl_path,
                      latent_dim=args.latent_dim, model_layer=args.model_layer,
                      model_dir=args.model_dir, model_name=args.model_name,
                      learning_rate=args.learning_rate, load_model=args.load_model,
                      provenance=args.provenance,
                      save_model=args.save_model,
                      video_save_dir= args.video_save_dir,
                      violation_weight=args.violation_weight,
                      with_violation=args.with_violation,
                      use_contrast=args.use_contrast, 
                      parallel=args.parallel,
                      num_top_pairs=args.num_top_pairs,
                      use_neg_sampling=args.use_neg_sampling)
    
    print(args.model_name)
    if args.phase == "train":
        print("train")
        trainer.train(args.n_epochs)
    elif args.phase == "test":
        print("baseline eval")
        trainer.baseline_eval()
        # trainer.test()
    elif args.phase == "label":
        print("label")
        if not os.path.exists(video_save_dir):
            os.mkdir(video_save_dir)
        trainer.label(60)
        
    print("end")