import os
import json
import random
import cv2
from argparse import ArgumentParser
import math
from torch.utils.data import DataLoader
from torch import nn, optim
from tqdm import tqdm
import torch
import scallopy
import spacy
import numpy as np

from pddl_parser import PDDLParser
from model import PredicateModel
from common import static_preds, unary_preds, binary_preds, non_prob_preds, transform, norm_x, norm_y, \
    action_arg_num, num_to_actions, read_video, var2vid, const2cid, eps
import spacy_cleaner
from spacy_cleaner.processing import removers
from data_checker import annotate_pred_video
from sklearn import metrics

nlp = spacy.load("en_core_web_lg")
pipeline = spacy_cleaner.Pipeline(
    nlp,
    removers.remove_stopword_token,
)

def get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size):
    batched_objects = {}
    current_objects = []
    last_vid = 0

    for i in range(batch_size):
        batched_objects[i] = []

    for (vid1, fid1, n), (vid2, fid2, oid) in zip(batched_object_names, batched_object_ids):
        assert(vid1 == vid2)
        assert(fid1 == fid2)
        batched_objects[vid1].append((fid1, oid, const_lookup[n]))

    batched_objects = list(batched_objects.values())
    return batched_objects

def to_scl_file(common_scl, rules, tuples, file_path):
    scl_file = common_scl
    for rule in rules:
        scl_file += ('rel ' + rule)
        scl_file += '\n'

    for tuple_name, tps in tuples.items():
        if tuple_name in non_prob_preds:
            scl_file += ('rel ' + tuple_name + ' = {' + ', '.join([str(t).replace("'", '"') for t in tps]) + '}')
        else:
            scl_file += ('rel ' + tuple_name + ' = {' + ', '.join([(str(t[0].item()) + '::' + str(t[1])).replace("'", '"') for t in tps]) + '}')
        scl_file += '\n'

    with open(file_path, 'w') as file:
        file.write(scl_file)

    return scl_file

def is_same_name(n1, n2, threshold=0.8):
    d1 = nlp(n1)
    d2 = nlp(n2)
    if d1.similarity(d2) > threshold:
        return True
    return False

def preprocess_dataset(dataset):
    new_dataset = []

    for datapoint in dataset:
        placeholders = []
        placeholders.append(datapoint['placeholders'])
        bounding_box_info = datapoint['bboxes']

        all_obj_names = set()
        for frame_id, frame in enumerate(bounding_box_info):
            for label in frame['labels']:
                label['category'] = label['category'].replace('another', '').strip()
                all_obj_names.add(label['category'])

        if 'another' in datapoint['placeholders']:
            print('here')

        place_holder_sps = {ph: nlp(' '.join([str(t) for t in nlp(ph) if not t.is_stop or str(t) == 'can' or not str(t) == 'another'])) for ph in datapoint['placeholders']}
        name_sps = {name: nlp(' '.join([str(t) for t in nlp(name) if not t.is_stop or str(t) == 'can' or not str(t) == 'another'])) for name in all_obj_names}
        ph2name = {}

        for ph, place_holder_sp in place_holder_sps.items():
            for name, name_sp in name_sps.items():
                sim = place_holder_sp.similarity(name_sp)
                if sim > 0.97:
                    if not ph in ph2name:
                        ph2name[ph] = []
                    ph2name[ph].append(name)

        new_ph2name = {}
        for ph, names in ph2name.items():
            assert len(names) == 1
            new_ph2name[ph] = names[0]

        new_placeholder = []
        for ph in datapoint['placeholders']:
            new_name = new_ph2name[ph] if ph in new_ph2name else ph
            new_placeholder.append(new_name)
        datapoint['placeholders'] = new_placeholder
        new_dataset.append(datapoint)

    return new_dataset

class SSDataset():

    def __init__(self, dataset_path, video_dir, device, data_percentage, paired_actions) -> None:
        dataset = json.load(open(dataset_path, 'r'))

        dp_count = math.floor(data_percentage / 100 * len(dataset))
        dataset = dataset[:dp_count]

        self.device = device
        self.paired_actions = paired_actions

        all_dps = {}
        for dp in dataset:
            dp_action = template2action[dp['template']]
            if not dp_action in all_dps:
                all_dps[dp_action] = []
            all_dps[dp_action].append(dp)

        mixed_dps = {}
        all_nums = set(num_to_actions.keys())
        for num in all_nums:
            mixed_dps[num] = []
            for action in num_to_actions[num]:
                if action in all_dps:
                    mixed_dps[num] += (all_dps[action])


        self.mixed_dps = mixed_dps
        self.dataset = dataset
        self.video_dir = video_dir
        self.shuffle()

    def shuffle(self):

        rearraged_dps = list(range(len(self.mixed_dps)))
        random.shuffle(rearraged_dps)
        new_dataset = []
        for idx in rearraged_dps:
            dps = self.mixed_dps[idx]
            random.shuffle(dps)
            new_dataset += dps
        self.dataset = new_dataset
        self.dataset_id_lookup = {dp['id']: did for did, dp in enumerate(self.dataset)}

    def  __len__(self):
        return len(self.dataset)

    def get_item_by_id(self, data_id):
        did = self.dataset_id_lookup[data_id]
        datapoint = self.dataset[did]

        video_path = os.path.join(self.video_dir, f"{data_id}.webm")
        video = read_video(datapoint['bboxes'], video_path)
        reshaped_video = []
        norm_reshaped_video = []

        for frame in video:
            new_frame = cv2.resize(frame, (norm_x, norm_y))
            reshaped_video.append(new_frame)
            shape = new_frame.shape
            new_frame = np.moveaxis(new_frame, -1, 0)
            new_frame = transform(torch.tensor(new_frame, dtype=torch.float32))
            norm_reshaped_video.append(new_frame.to(self.device))

        batched_dp = self.collate_fn(([[datapoint, reshaped_video, norm_reshaped_video]]))
        return batched_dp

    def __getitem__(self, i):
        datapoint = self.dataset[i]
        vid_id = datapoint['id']

        video_path = os.path.join(self.video_dir, f"{vid_id}.webm")
        video = read_video(datapoint['bboxes'], video_path)
        reshaped_video = []
        norm_reshaped_video = []

        for frame in video:
            new_frame = cv2.resize(frame, (norm_x, norm_y))
            reshaped_video.append(new_frame)
            shape = new_frame.shape
            new_frame = np.moveaxis(new_frame, -1, 0)
            new_frame = transform(torch.tensor(new_frame, dtype=torch.float32))
            norm_reshaped_video.append(new_frame.to(self.device))

        return datapoint, reshaped_video, norm_reshaped_video

    @staticmethod
    def collate_fn(batch):
        batched_videos = []
        batched_bboxes = []
        batched_object_names = []
        batched_obj_pairs = []
        batched_ids = []
        batched_video_splits = []
        batched_labels = []
        batched_object_ids = []
        batched_actions = []
        batched_placeholders = []
        batched_reshaped_raw_videos = []

        frame_ct_in_video = 0
        for data_id, (datapoint, reshaped_raw_video, video) in enumerate(batch):

            batched_reshaped_raw_videos.append(reshaped_raw_video)
            batched_videos += (video)
            batched_ids.append(datapoint['id'])
            batched_labels.append(datapoint['label'])
            bounding_box_info = datapoint['bboxes']
            batched_actions.append(datapoint['template'])
            batched_placeholders.append(datapoint['placeholders'])

            all_obj_ids = set()
            for frame_id, frame in enumerate(bounding_box_info):
                for label in frame['labels']:
                    all_obj_ids.add(label['id'])

            for frame_id, frame in enumerate(bounding_box_info):
                object_ct_in_frame = len(frame['labels'])
                obj_ids_in_frame = []

                for label in frame['labels']:
                    assert label['gt_annotation'][:6] == 'object'
                    batched_bboxes.append(label['box2d'])

                    name = label['category']
                    batched_object_names.append((data_id, frame_id, name))
                    batched_object_ids.append((data_id, frame_id, label['id']))
                    obj_ids_in_frame.append(label['id'])


                for oid1 in all_obj_ids:
                    for oid2 in all_obj_ids:
                        if oid1 in obj_ids_in_frame and oid2 in obj_ids_in_frame and not oid1 == oid2:
                            batched_obj_pairs.append((data_id, frame_id, (oid1, oid2)))

            frame_ct_in_video += len(video)
            batched_video_splits.append(frame_ct_in_video)

        return batched_ids, batched_actions, batched_placeholders, torch.stack(batched_videos), batched_bboxes, \
            batched_obj_pairs, batched_object_ids, batched_video_splits, batched_object_names, batched_reshaped_raw_videos

    def collect_templates(self):
        all_templates = set()
        all_objects = set()
        for dp in self.dataset:
            template = dp['template']
            objs = dp['placeholders']
            obj_set = set(objs)
            all_templates.add(template)
            all_objects.update(obj_set)
        return all_templates, all_objects

def ss_loader(train_dataset_path, test_dataset_path, video_dir, batch_size, device, paired_actions, training_percentage=100, testing_percentage=100, ):
  train_dataset = SSDataset(train_dataset_path, video_dir, device, paired_actions=paired_actions, data_percentage = training_percentage)
  train_loader = DataLoader(train_dataset, batch_size, collate_fn=SSDataset.collate_fn, shuffle=False, drop_last=True)
  test_dataset = SSDataset(test_dataset_path, video_dir, device, paired_actions=paired_actions, data_percentage=testing_percentage)
  test_loader = DataLoader(test_dataset, batch_size, collate_fn=SSDataset.collate_fn, shuffle=False, drop_last=True)
  return (train_dataset, test_dataset, train_loader, test_loader)

def construct_batched_scl_tps(batched_object_tps):
    batched_scl_tps = []
    for object_tps in batched_object_tps:
        batched_scl_tps.append(construct_scl_tps(object_tps))
    return batched_scl_tps

def construct_scl_tps(object_tps):
    frame_tps = []
    all_objects_tps = set()
    all_frames_tps = set()

    for tp in object_tps:
        fid, oid, name = tp
        all_objects_tps.add(tuple([oid]))
        all_frames_tps.add(tuple([fid]))
        frame_tps.append((fid, oid, name))

    scl_tps = {'frame': frame_tps, 'all_objects': list(all_objects_tps),
                'all_frames': list(all_frames_tps)}
    return scl_tps

def process_batched_facts(fact_dict_ls):
    batched_fact_dict = {}

    for fact_dict in fact_dict_ls:
        for k, v in fact_dict.items():
            if not k in batched_fact_dict:
                batched_fact_dict[k] = []

    for fact_dict in fact_dict_ls:
        for k in batched_fact_dict:
            if not k in fact_dict:
                batched_fact_dict[k].append([])
            else:
                v = fact_dict[k]
                new_v = []
                if len(v) > 0 and type(v[0]) != tuple:
                    for v_tp in v:
                        new_v.append(tuple([v_tp]))
                else:
                    new_v = v
                batched_fact_dict[k].append(new_v)
    return batched_fact_dict

def construct_scl_facts(object_tps):

    scl_tps = construct_scl_tps(object_tps)
    frame_tps, all_objects_tps, all_frames_tps, all_names_tps = scl_tps['frame'], scl_tps['all_objects'], scl_tps['all_frames'], scl_tps['all_names'],

    frame_tps = [ '(' + str(oid) + ', ' + str(fid) + ', ' + name + ')' for oid, fid, name in frame_tps]
    all_objects_tps = [str(tp) for tp in all_objects_tps]
    all_frames_tps = [str(tp) for tp in all_frames_tps]
    all_names_tps = [tp for tp in all_names_tps]

    frame = 'rel frame = {' + ',\n'.join(frame_tps) + '}'
    all_objects = 'rel all_objects = {' + ', '.join(all_objects_tps) + '}'
    all_frames = 'rel all_frames = {' + ', '.join(all_frames_tps) + '}'
    all_names = 'rel all_names = {' + ', '.join(all_names_tps) + '}'

    scl_facts = '\n\n'.join([frame, all_objects, all_frames, all_names])
    return scl_facts

def substitute_context(placeholders, constraints, action):
    objects = placeholders
    variables = ['?a', '?b', '?c']
    v_dict = {}
    for v, o in zip(variables, objects):
        v_dict[v] = o
    constraints = constraints.actions[action].substitute(v_dict)
    return constraints

def obtain_scl_file(object_tps, constraints, common_scl):
    scl_facts = construct_scl_facts(object_tps)
    constraint_scl = constraints.to_scl()

    scl_content = [common_scl, scl_facts, constraint_scl]
    scl_content = '\n'.join(scl_content)
    return scl_content

class Trainer():
    def __init__(self, train_loader, test_loader, device,
                 action2scl, action2template, common_scl_path,
                 paired_actions,
                 provenance="difftopkproofs", k=3, save_scl_dir=None,
                 model_path=None, model_name=None, learning_rate=None,
                 latent_dim=64, model_layer=2, load_model=False, save_model=True,
                 save_video=False, video_save_dir=None, violation_weight=0.1,
                 with_violation=True, use_contrast=False):

        self.save_video = save_video
        self.video_save_dir = video_save_dir

        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.action2scl = action2scl
        self.action2template = action2template
        self.template2action = {}
        self.common_scl = open(common_scl_path).read()
        self.paired_actions = paired_actions
        self.save_model = save_model
        self.with_violation = with_violation
        self.use_contrast = use_contrast

        self.scallop_ctx = scallopy.ScallopContext(provenance=provenance, k=k)
        self.scallop_ctx.import_file(common_scl_path)
        self.scallop_ctx.set_non_probabilistic(non_prob_preds)
        # self.reason = self.scallop_ctx.forward_function("constraint")
        if with_violation:
            self.reason = self.scallop_ctx.forward_function(output_mappings={
                "constraint": None,
                "violation": (),
            }, retain_graph=True)
        else:
            self.reason = self.scallop_ctx.forward_function(output_mappings={
                "constraint": None,
            })

        if load_model:
            self.predicate_model = torch.load(os.path.join(model_dir, model_name + '.best.model'))
        else:
            self.predicate_model = PredicateModel(latent_dim, model_layer, device).to(device)


        self.optimizer = optim.Adam(self.predicate_model.parameters(), lr=learning_rate)
        self.model_dir = model_dir
        self.model_name = model_name
        self.max_accu = -1
        self.violation_weight = violation_weight

        for action, template in action2template.items():
            self.template2action[template] = action
        if not save_scl_dir is None:
            self.save_scl_dir = save_scl_dir
        self.loss_fn = nn.BCELoss(reduction='none')

    def loss(self, tps, batched_preds, batched_violations, batched_ys, sat_rate = 0.9):
        batched_loss = []
        for pred_probs, y, violation in zip(batched_preds, batched_ys, batched_violations):

            distances = []
            for tid, (dist, prob) in enumerate(zip(tps, pred_probs)):
                if prob > eps:
                    distances.append((tid, dist[0]))

            if len(distances) == 0:
                continue

            # For smaller window, it has lower likelihood of actually capturing the operation
            # We thus assign a weight function for the
            max_length = max([d for _, d in distances])
            min_encourage_len = math.ceil(max_length * sat_rate)
            score_for_dist = 1 / (max_length - min_encourage_len + 1)

            weights = []
            valid_probs = []
            target_y = []
            for tid, dist in distances:

                if dist < min_encourage_len:
                    continue

                prob = pred_probs[tid]
                weight = score_for_dist * (dist - min_encourage_len + 1)
                weights.append(weight)
                valid_probs.append(prob)
                target_y.append(y)

            weights = torch.tensor(weights)
            loss = self.loss_fn( torch.stack(valid_probs), torch.tensor(target_y, dtype=torch.float32))
            loss = loss * weights
            loss = (loss.sum() / weights.sum()) + self.violation_weight * violation
            batched_loss.append(loss)

        return batched_loss


    def correct(self, tps, batched_preds, batched_formatted_ys):

        correct_ct = []
        sims = []
        current_preds_id = 0

        for ys in batched_formatted_ys:
            ys = torch.tensor(ys)
            outputs = batched_preds[current_preds_id: current_preds_id + len(ys)]
            current_preds_id += len(ys)
            current_sims = []

            for probs in outputs:
                sim = max(probs)
                if type(sim) == int:
                    sim = torch.tensor(sim)
                else:
                    sim = sim.detach()
                current_sims.append(sim)

            sims.append(current_sims)
            num_of_gt_y = sum(ys)
            if num_of_gt_y == 0:
                print('here')
            else:
                v, i = torch.topk(torch.stack(current_sims), k=num_of_gt_y)
                correct = ys[i].sum() / num_of_gt_y
                correct_ct.append(correct.item())

        return correct_ct

    def pred_eff_likelyhood(self, prec_fids, prec_probs, eff_fids, eff_probs, start_end_fids):
        total_probs = []
        prec_fids = [int(i[0]) for i in prec_fids]
        eff_fids = [int(i[0]) for i in eff_fids]

        for vid, (prec_prob, eff_prob) in enumerate(zip(prec_probs, eff_probs)):
            start_fid = start_end_fids[vid]['start']
            end_fid = start_end_fids[vid]['end']
            if start_fid in prec_fids:
                start_idx = prec_fids.index(start_fid)
                start_prob = prec_prob[start_idx]
                total_probs.append(start_prob.item())
            else:
                total_probs.append(0)

            if end_fid in eff_fids:
                end_idx = eff_fids.index(end_fid)
                end_prob = eff_prob[end_idx]
                total_probs.append(end_prob.item())
            else:
                total_probs.append(0)

        return total_probs

    def label_batch(self, batch):

        batched_ids, batched_actions, batched_placeholders, batched_videos, batched_bboxes, \
            batched_object_pairs, batched_object_ids, batched_video_splits, batched_object_names, \
              batched_reshaped_raw_videos = batch
        batch_size = len(batched_ids)
        consts = list(set([p for ps in batched_placeholders for p in ps] + [tp[2] for tp in batched_object_names]))

        const_lookup = {}
        cids = []
        for k, v in const2cid.items():
            const_lookup[k] = v
            const_lookup[k.upper()] = v
            const_lookup[k.lower()] = v

            if k.lower() in consts:
                consts.remove(k.lower())

            cids.append(v)

        current_cid = min(cids) - 1
        for k in consts:
            const_lookup[k] = current_cid
            const_lookup[k.upper()] = current_cid
            const_lookup[k.lower()] = current_cid
            current_cid -= 1

        batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_object_tps)
        batched_occurred_objs = []
        occured_objs = []
        last_vid = 0

        for vid, fid, oid in batched_object_ids:
            if vid > last_vid:
                batched_occurred_objs.append(list(set(occured_objs)))
                occured_objs = []
                last_vid = vid
            occured_objs.append(oid)
        if not len(occured_objs) == 0:
            batched_occurred_objs.append(list(set(occured_objs)))

        batched_unary_pred_prob, batched_binary_pred_prob, batched_static_pred_prob = \
            self.predicate_model(batched_videos, batched_bboxes, batched_object_ids, batched_object_pairs, batched_occurred_objs, batched_video_splits)

        # TODO: support sampling
        # Process static predicates
        batched_static_pred_scl = []
        current_obj_ct = 0
        for occured_objs in batched_occurred_objs:
            static_pred_scl = []
            for oid in occured_objs:
                probs = batched_static_pred_prob[current_obj_ct]
                for prob, pred in zip(probs, static_preds):
                    static_pred_scl.append((prob, tuple([pred, oid])))
                current_obj_ct += 1
            batched_static_pred_scl.append(static_pred_scl)

        # Process unary predicates
        current_vid = 0
        batched_unary_pred_scl = {}

        for vid in range(batch_size):
            batched_unary_pred_scl[vid] = []

        for ((vid, fid, obj_id), unary_probs) in zip(batched_object_ids, batched_unary_pred_prob):
            for prob, pred in zip(unary_probs, unary_preds):
                batched_unary_pred_scl[vid].append((prob, (pred, fid, obj_id)))
        batched_unary_pred_scl = list([batched_unary_pred_scl[vid] for vid in range(batch_size)])

        # Process binary predicates
        batched_binary_pred_scl = {}
        binary_pred_scl = []

        for vid in range(batch_size):
            batched_binary_pred_scl[vid] = []

        for ((vid, fid, (obj1, obj2)), binary_probs) in zip(batched_object_pairs, batched_binary_pred_prob):
            for prob, pred in zip(binary_probs, binary_preds):
                batched_binary_pred_scl[vid].append((prob, (pred, fid, obj1, obj2)))
        batched_binary_pred_scl = list([batched_binary_pred_scl[vid] for vid in range(batch_size)])

        batched_outputs =[]
        batched_ys = []
        batched_formatted_ys = []
        batched_scl_input_facts = []
        batched_video_pred_labels = []

        bboxes_info = {}
        batched_obj_info = {}
        for bboxes, obj_info in zip(batched_bboxes, batched_object_ids):
            video_id, frame_id, obj_id = obj_info
            if not video_id in bboxes_info:
                bboxes_info[video_id] = {}
                batched_obj_info[video_id] = {}
            if not frame_id in bboxes_info[video_id]:
                bboxes_info[video_id][frame_id] = {}
                batched_obj_info[video_id][frame_id] = []
            bboxes_info[video_id][frame_id][obj_id] = bboxes
            batched_obj_info[video_id][frame_id].append(obj_id)
        batched_obj_info = list(batched_obj_info.values())

        interesting_dp = {}
        for scl_data_id, scl_template, static_pred_scl, unary_pred_scl, binary_pred_scl, obj_info in zip(batched_ids, batched_actions, batched_static_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_obj_info):
            action_name = self.template2action[scl_template]
            current_outputs = []
            current_ys = []
            required_pred_info = {}
            required_preds = constraints.actions[action_name].collect_preds()
            all_frames = set()



            for prob, (pred, frame, obj) in unary_pred_scl:
                if pred in required_preds:
                    if not frame in required_pred_info:
                        required_pred_info[frame] = []
                    required_pred_info[frame].append((pred, prob, obj, frame))
                all_frames.add(frame)

            static_related_preds = {}
            for prob, (pred, obj) in static_pred_scl:
                if not obj in static_related_preds:
                    static_related_preds[obj] = []
                if pred in required_preds:
                    static_related_preds[obj].append((prob, pred))

            for frame_id, obj_ids in obj_info.items():
                for obj_id in obj_ids:
                    # for frame in all_frames:
                    pred_ls = static_related_preds[obj_id]
                    if not frame_id in required_pred_info:
                        required_pred_info[frame_id] = []

                    for prob, pred in pred_ls:
                        required_pred_info[frame_id].append((pred, prob, obj_id, frame_id))

            for prob, (pred, frame, from_obj, to_obj) in binary_pred_scl:
                if pred in required_preds:
                    if not (pred == 'in' or pred == 'touching'):
                        if not pred in interesting_dp:
                            interesting_dp[pred] = set()
                        interesting_dp[pred].add(scl_data_id)
                    required_pred_info[frame].append((pred, prob, from_obj, to_obj, frame))

            batched_video_pred_labels.append(required_pred_info)

        for vid, (data_id, video, pred_labels, scl_template) in enumerate(zip(batched_ids, batched_reshaped_raw_videos, batched_video_pred_labels, batched_actions)):
            action_name = self.template2action[scl_template]
            out_vid_path = os.path.join(self.video_save_dir, data_id + '.mp4')
            annotate_pred_video(video, out_vid_path, bboxes_info[vid], pred_labels, action_name)

        return interesting_dp

    def baseline_metric_eval_batch(self, batch):

        batched_ids, batched_actions, batched_placeholders, batched_videos, batched_bboxes, \
            batched_object_pairs, batched_object_ids, batched_video_splits, batched_object_names, \
              batched_reshaped_raw_videos = batch
        batch_size = len(batched_ids)
        consts = list(set([p for ps in batched_placeholders for p in ps] + [tp[2] for tp in batched_object_names]))

        const_lookup = {}
        cids = []
        for k, v in const2cid.items():
            const_lookup[k] = v
            const_lookup[k.upper()] = v
            const_lookup[k.lower()] = v

            if k.lower() in consts:
                consts.remove(k.lower())

            cids.append(v)

        current_cid = min(cids) - 1
        for k in consts:
            const_lookup[k] = current_cid
            const_lookup[k.upper()] = current_cid
            const_lookup[k.lower()] = current_cid
            current_cid -= 1

        batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_object_tps)
        occured_objs = []
        batched_occurred_objs = []
        batched_start_end_oids = {}
        batched_start_end_fids = {}
        batched_obj_name_map = {}

        for (vid, fid, oid), (_, _, oname) in zip(batched_object_ids, batched_object_names):
            if not vid in batched_start_end_oids:
                batched_start_end_oids[vid] = {}
                batched_start_end_fids[vid] = {}
                batched_start_end_fids[vid]['start'] = 10000
                batched_start_end_fids[vid]['end'] = 0
                batched_obj_name_map[vid] = {}

            batched_obj_name_map[vid][oid] = oname

            if not fid in batched_start_end_oids[vid]:
                batched_start_end_oids[vid][fid] = []

            if fid > batched_start_end_fids[vid]['end']:
                batched_start_end_fids[vid]['end'] = fid
            elif fid < batched_start_end_fids[vid]['start']:
                batched_start_end_fids[vid]['start'] = fid

            batched_start_end_oids[vid][fid].append(oid)

        for seo_info in batched_start_end_oids.values():
            all_oids = set()
            for fid, ols in seo_info.items():
                all_oids.update(ols)
            batched_occurred_objs.append(list(set(all_oids)))

        batched_unary_pred_prob, batched_binary_pred_prob, batched_static_pred_prob = \
            self.predicate_model(batched_videos, batched_bboxes, batched_object_ids, batched_object_pairs, batched_occurred_objs, batched_video_splits)

        batched_static_pred_scl = []
        current_obj_ct = 0
        for occured_objs in batched_occurred_objs:
            static_pred_scl = []
            for oid in occured_objs:
                probs = batched_static_pred_prob[current_obj_ct]
                for prob, pred in zip(probs, static_preds):
                    static_pred_scl.append((prob, tuple([pred, oid])))
                current_obj_ct += 1
            batched_static_pred_scl.append(static_pred_scl)

        # Process unary predicates
        current_vid = 0
        batched_unary_pred_scl = {}

        for vid in range(batch_size):
            batched_unary_pred_scl[vid] = []

        for ((vid, fid, obj_id), unary_probs) in zip(batched_object_ids, batched_unary_pred_prob):
            for prob, pred in zip(unary_probs, unary_preds):
                batched_unary_pred_scl[vid].append((prob, (pred, fid, obj_id)))
        batched_unary_pred_scl = list([batched_unary_pred_scl[vid] for vid in range(batch_size)])

        # Process binary predicates
        batched_binary_pred_scl = {}
        binary_pred_scl = []

        for vid in range(batch_size):
            batched_binary_pred_scl[vid] = []

        for ((vid, fid, (obj1, obj2)), binary_probs) in zip(batched_object_pairs, batched_binary_pred_prob):
            for prob, pred in zip(binary_probs, binary_preds):
                batched_binary_pred_scl[vid].append((prob, (pred, fid, obj1, obj2)))
        batched_binary_pred_scl = list([batched_binary_pred_scl[vid] for vid in range(batch_size)])

        result = []

        bboxes_info = {}
        batched_obj_info = {}
        for bboxes, obj_info in zip(batched_bboxes, batched_object_ids):
            video_id, frame_id, obj_id = obj_info
            if not video_id in bboxes_info:
                bboxes_info[video_id] = {}
                batched_obj_info[video_id] = {}
            if not frame_id in bboxes_info[video_id]:
                bboxes_info[video_id][frame_id] = {}
                batched_obj_info[video_id][frame_id] = []
            bboxes_info[video_id][frame_id][obj_id] = bboxes
            batched_obj_info[video_id][frame_id].append(obj_id)
        batched_obj_info = list(batched_obj_info.values())

        for scl_data_id, scl_template, static_pred_scl, unary_pred_scl, binary_pred_scl, obj_info, place_holder,\
             start_end_oids, start_end_fids, obj_name_map\
             in zip(batched_ids, batched_actions, batched_static_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl,
                    batched_obj_info, batched_placeholders, batched_start_end_oids.values(), batched_start_end_fids.values(),
                    batched_obj_name_map.values()):

            result.append(self.baseline_accu(unary_pred_scl, static_pred_scl, binary_pred_scl, scl_template, \
                start_end_fids, const_lookup, place_holder, obj_name_map, start_end_oids))

        new_result = combine_baseline_pred_dict_ls(result)
        return new_result

    def baseline_accu(self, unary_pred_scl, static_pred_scl, binary_pred_scl, scl_template, \
        start_end_fids, const_lookup, place_holder, obj_name_map, obj_info):

        batched_video_pred_labels = []
        action_name = self.template2action[scl_template]
        current_outputs = []
        current_ys = []
        required_pred_info = {}
        required_preds = constraints.actions[action_name].collect_preds()

        prec_required_tuples = constraints.actions[action_name].precondition.collect_tuples()
        eff_required_tuples = constraints.actions[action_name].effect.collect_tuples()

        all_frames = set()

        for prob, (pred, frame, obj) in unary_pred_scl:
            if pred in required_preds:
                if not frame in required_pred_info:
                    required_pred_info[frame] = []
                required_pred_info[frame].append((pred, prob, obj, frame))
            all_frames.add(frame)

        static_related_preds = {}
        for prob, (pred, obj) in static_pred_scl:
            if not obj in static_related_preds:
                static_related_preds[obj] = []
            if pred in required_preds:
                static_related_preds[obj].append((prob, pred))

        for frame_id, obj_ids in obj_info.items():
            for obj_id in obj_ids:
                # for frame in all_frames:
                pred_ls = static_related_preds[obj_id]
                if not frame_id in required_pred_info:
                    required_pred_info[frame_id] = []

                for prob, pred in pred_ls:
                    required_pred_info[frame_id].append((pred, prob, obj_id, frame_id))

        for prob, (pred, frame, from_obj, to_obj) in binary_pred_scl:
            if pred in required_preds:
                required_pred_info[frame].append((pred, prob, from_obj, to_obj, frame))

        start_pred_info = required_pred_info[start_end_fids['start']]
        end_pred_info =  required_pred_info[start_end_fids['end']]
        var2placeholder = [(vid, const_lookup[p]) for vid, p in zip(var2vid.values(), place_holder)]
        batched_video_pred_labels.append(required_pred_info)
        new_prec_required_preds = substitute_args(prec_required_tuples, place_holder, const_lookup)
        new_eff_required_preds = substitute_args(eff_required_tuples, place_holder, const_lookup)

        prec_pred_probs = obtain_pred_probs(new_prec_required_preds, start_pred_info, const_lookup, obj_name_map)
        eff_pred_probs = obtain_pred_probs(new_eff_required_preds, end_pred_info, const_lookup, obj_name_map)

        comb_prec_pred_probs = {}
        comb_eff_pred_probs = {}
        for pred_name, pred_prob_info in prec_pred_probs.items():
            for pos_neg, pred_prob_ls in pred_prob_info.items():
                comb_prec_pred_probs[pred_name] = {}
                comb_prec_pred_probs[pred_name][pos_neg] = 1 - math.prod([(1-i) for i in pred_prob_ls])

        for pred_name, pred_prob_info in eff_pred_probs.items():
            for pos_neg, pred_prob_ls in pred_prob_info.items():
                comb_eff_pred_probs[pred_name] = {}
                comb_eff_pred_probs[pred_name][pos_neg] = 1 - math.prod([(1-i) for i in pred_prob_ls])

        return self.get_single_results(comb_prec_pred_probs, comb_eff_pred_probs)

    def get_single_results(self, comb_prec_pred_probs, comb_eff_pred_probs):

        output = {}
        for name, accu_info in comb_prec_pred_probs.items():
            if not name in output:
                output[name] = {}
                output[name]['gt'] = []
                output[name]['pred'] = []
            for pos_neg, pred_prob in accu_info.items():
                output[name]['gt'].append(pos_neg)
                if pred_prob > 0.5:
                    output[name]['pred'].append(pos_neg)
                else:
                    output[name]['pred'].append(not pos_neg)

        for name, accu_info in comb_eff_pred_probs.items():
            if not name in output:
                output[name] = {}
                output[name]['gt'] = []
                output[name]['pred'] = []
            for pos_neg, pred_prob in accu_info.items():
                output[name]['gt'].append(pos_neg)
                if pred_prob > 0.5:
                    output[name]['pred'].append(pos_neg)
                else:
                    output[name]['pred'].append(not pos_neg)

        return output

    def get_overall_comp(self, single_result):
        all_gt = []
        all_preds = []
        for stat_info in single_result.values():
            all_gt += (stat_info['gt'])
            all_preds += (stat_info['pred'])
        return all_gt, all_preds

    def collect_conditions(self, action_scl, var2placeholder):
        precondition_root = action_scl['precondition'][0]

        pass

    def forward_stats(self, batch):
        correct_ls = []
        loss_ls = []
        missing_name = []

        batched_ids, batched_actions, batched_placeholders, batched_videos, batched_bboxes, \
            batched_object_pairs, batched_object_ids, batched_video_splits, batched_object_names, _ = batch
        batch_size = len(batched_ids)
        consts = list(set([p for ps in batched_placeholders for p in ps] + [tp[2] for tp in batched_object_names]))

        const_lookup = {}
        cids = []
        for k, v in const2cid.items():
            const_lookup[k] = v
            const_lookup[k.upper()] = v
            const_lookup[k.lower()] = v

            if k.lower() in consts:
                consts.remove(k.lower())

            cids.append(v)

        current_cid = min(cids) - 1
        for k in consts:
            const_lookup[k] = current_cid
            const_lookup[k.upper()] = current_cid
            const_lookup[k.lower()] = current_cid
            current_cid -= 1

        batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_object_tps)
        batched_occurred_objs = []
        occured_objs = []
        start_end_fids = {}
        start_end_oids = {}
        obj_name_map = {}
        last_vid = 0
        start_vid = 10000

        for (vid, fid, oid), (_, _, oname) in zip(batched_object_ids, batched_object_names):
            if not vid in start_end_fids:
                start_end_fids[vid] = {}
                start_end_oids[vid] = {}
                start_end_fids[vid]['start'] = 10000
                start_end_fids[vid]['end'] = 0
                obj_name_map[vid] = {}

            obj_name_map[vid][oid] = oname

            if not fid in start_end_oids[vid]:
                start_end_oids[vid][fid] = []

            if fid > start_end_fids[vid]['end']:
                start_end_fids[vid]['end'] = fid
            if fid < start_end_fids[vid]['start']:
                start_end_fids[vid]['start'] = fid

            if vid > last_vid:
                batched_occurred_objs.append(list(set(occured_objs)))
                occured_objs = []
                last_vid = vid

            start_end_oids[vid][fid].append(oid)
            occured_objs.append(oid)

        if not len(occured_objs) == 0:
            batched_occurred_objs.append(list(set(occured_objs)))

        batched_unary_pred_prob, batched_binary_pred_prob, batched_static_pred_prob = \
            self.predicate_model(batched_videos, batched_bboxes, batched_object_ids, batched_object_pairs, batched_occurred_objs, batched_video_splits)

        # TODO: support sampling
        # Process static predicates
        batched_static_pred_scl = []
        current_obj_ct = 0
        for occured_objs in batched_occurred_objs:
            static_pred_scl = []
            for oid in occured_objs:
                probs = batched_static_pred_prob[current_obj_ct]
                for prob, pred in zip(probs, static_preds):
                    static_pred_scl.append((prob, tuple([pred, oid])))
                current_obj_ct += 1
            batched_static_pred_scl.append(static_pred_scl)

        # Process unary predicates
        current_vid = 0
        batched_unary_pred_scl = {}

        for vid in range(batch_size):
            batched_unary_pred_scl[vid] = []

        for ((vid, fid, obj_id), unary_probs) in zip(batched_object_ids, batched_unary_pred_prob):
            for prob, pred in zip(unary_probs, unary_preds):
                batched_unary_pred_scl[vid].append((prob, (pred, fid, obj_id)))
        batched_unary_pred_scl = list([batched_unary_pred_scl[vid] for vid in range(batch_size)])


        # Process binary predicates
        batched_binary_pred_scl = {}
        binary_pred_scl = []

        for vid in range(batch_size):
            batched_binary_pred_scl[vid] = []

        for ((vid, fid, (obj1, obj2)), binary_probs) in zip(batched_object_pairs, batched_binary_pred_prob):
            for prob, pred in zip(binary_probs, binary_preds):
                batched_binary_pred_scl[vid].append((prob, (pred, fid, obj1, obj2)))
        batched_binary_pred_scl = list([batched_binary_pred_scl[vid] for vid in range(batch_size)])

        batched_outputs =[]
        batched_ys = []
        batched_formatted_ys = []
        batched_scl_input_facts = []
        all_single_gt = []
        all_single_pred = []


        for vid, (data_id, template, scl_tp, static_pred_tp, unary_pred_tp, binary_pred_tp, place_holder) \
            in enumerate(zip(batched_ids, batched_actions, batched_scl_tps, batched_static_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_placeholders)):

            # Give an ID to all required placeholders and object names
            scl_input_facts = {}
            place_holder_ct = len(place_holder)
            gt_action = template2action[template]

            scl_input_facts.update(scl_tp)
            scl_input_facts['sg_static_atom'] = (static_pred_tp)
            scl_input_facts['sg_unary_atom'] = (unary_pred_tp)
            scl_input_facts['sg_binary_atom'] = (binary_pred_tp)
            scl_input_facts['num_variables'] = [tuple([place_holder_ct])]
            scl_input_facts['variable_name'] = [(vid, const_lookup[p]) for vid, p in zip(var2vid.values(), place_holder)]
            action_scl = action2scl[gt_action]
            scl_input_facts.update(action_scl)

            result = self.baseline_accu(unary_pred_tp, static_pred_tp, binary_pred_tp, template, \
                start_end_fids[vid], const_lookup, place_holder, obj_name_map[vid], start_end_oids[vid])
            gt, pred = self.get_overall_comp(result)
            batched_scl_input_facts.append(scl_input_facts)
            all_single_gt += gt
            all_single_pred += pred

        # batched_ys = [1] * batch_size
        # batched_formatted_ys = [[1]] * batch_size

        # formatted_batched_scl_input_facts = process_batched_facts(batched_scl_input_facts)
        # output = self.reason(**formatted_batched_scl_input_facts)
        
        fact_ct = []
        for batched_scl_input_fact in batched_scl_input_facts:
            ct = 0
            for v in batched_scl_input_fact.values():
                ct += len(v)
            fact_ct.append(ct)

        return fact_ct

    def forward(self, batch):
        correct_ls = []
        loss_ls = []
        missing_name = []

        batched_ids, batched_actions, batched_placeholders, batched_videos, batched_bboxes, \
            batched_object_pairs, batched_object_ids, batched_video_splits, batched_object_names, _ = batch
        batch_size = len(batched_ids)
        consts = list(set([p for ps in batched_placeholders for p in ps] + [tp[2] for tp in batched_object_names]))

        const_lookup = {}
        cids = []
        for k, v in const2cid.items():
            const_lookup[k] = v
            const_lookup[k.upper()] = v
            const_lookup[k.lower()] = v

            if k.lower() in consts:
                consts.remove(k.lower())

            cids.append(v)

        current_cid = min(cids) - 1
        for k in consts:
            const_lookup[k] = current_cid
            const_lookup[k.upper()] = current_cid
            const_lookup[k.lower()] = current_cid
            current_cid -= 1

        batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_object_tps)
        batched_occurred_objs = []
        occured_objs = []
        start_end_fids = {}
        start_end_oids = {}
        obj_name_map = {}
        last_vid = 0
        start_vid = 10000

        for (vid, fid, oid), (_, _, oname) in zip(batched_object_ids, batched_object_names):
            if not vid in start_end_fids:
                start_end_fids[vid] = {}
                start_end_oids[vid] = {}
                start_end_fids[vid]['start'] = 10000
                start_end_fids[vid]['end'] = 0
                obj_name_map[vid] = {}

            obj_name_map[vid][oid] = oname

            if not fid in start_end_oids[vid]:
                start_end_oids[vid][fid] = []

            if fid > start_end_fids[vid]['end']:
                start_end_fids[vid]['end'] = fid
            if fid < start_end_fids[vid]['start']:
                start_end_fids[vid]['start'] = fid

            if vid > last_vid:
                batched_occurred_objs.append(list(set(occured_objs)))
                occured_objs = []
                last_vid = vid

            start_end_oids[vid][fid].append(oid)
            occured_objs.append(oid)

        if not len(occured_objs) == 0:
            batched_occurred_objs.append(list(set(occured_objs)))

        batched_unary_pred_prob, batched_binary_pred_prob, batched_static_pred_prob = \
            self.predicate_model(batched_videos, batched_bboxes, batched_object_ids, batched_object_pairs, batched_occurred_objs, batched_video_splits)

        # TODO: support sampling
        # Process static predicates
        batched_static_pred_scl = []
        current_obj_ct = 0
        for occured_objs in batched_occurred_objs:
            static_pred_scl = []
            for oid in occured_objs:
                probs = batched_static_pred_prob[current_obj_ct]
                for prob, pred in zip(probs, static_preds):
                    static_pred_scl.append((prob, tuple([pred, oid])))
                current_obj_ct += 1
            batched_static_pred_scl.append(static_pred_scl)

        # Process unary predicates
        current_vid = 0
        batched_unary_pred_scl = {}

        for vid in range(batch_size):
            batched_unary_pred_scl[vid] = []

        for ((vid, fid, obj_id), unary_probs) in zip(batched_object_ids, batched_unary_pred_prob):
            for prob, pred in zip(unary_probs, unary_preds):
                batched_unary_pred_scl[vid].append((prob, (pred, fid, obj_id)))
        batched_unary_pred_scl = list([batched_unary_pred_scl[vid] for vid in range(batch_size)])


        # Process binary predicates
        batched_binary_pred_scl = {}
        binary_pred_scl = []

        for vid in range(batch_size):
            batched_binary_pred_scl[vid] = []

        for ((vid, fid, (obj1, obj2)), binary_probs) in zip(batched_object_pairs, batched_binary_pred_prob):
            for prob, pred in zip(binary_probs, binary_preds):
                batched_binary_pred_scl[vid].append((prob, (pred, fid, obj1, obj2)))
        batched_binary_pred_scl = list([batched_binary_pred_scl[vid] for vid in range(batch_size)])

        batched_outputs =[]
        batched_ys = []
        batched_formatted_ys = []
        batched_scl_input_facts = []
        all_single_gt = []
        all_single_pred = []


        for vid, (data_id, template, scl_tp, static_pred_tp, unary_pred_tp, binary_pred_tp, place_holder) \
            in enumerate(zip(batched_ids, batched_actions, batched_scl_tps, batched_static_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_placeholders)):

            # Give an ID to all required placeholders and object names
            scl_input_facts = {}
            place_holder_ct = len(place_holder)
            gt_action = template2action[template]

            scl_input_facts.update(scl_tp)
            scl_input_facts['sg_static_atom'] = (static_pred_tp)
            scl_input_facts['sg_unary_atom'] = (unary_pred_tp)
            scl_input_facts['sg_binary_atom'] = (binary_pred_tp)
            scl_input_facts['num_variables'] = [tuple([place_holder_ct])]
            scl_input_facts['variable_name'] = [(vid, const_lookup[p]) for vid, p in zip(var2vid.values(), place_holder)]
            action_scl = action2scl[gt_action]
            scl_input_facts.update(action_scl)

            result = self.baseline_accu(unary_pred_tp, static_pred_tp, binary_pred_tp, template, \
                start_end_fids[vid], const_lookup, place_holder, obj_name_map[vid], start_end_oids[vid])
            gt, pred = self.get_overall_comp(result)
            batched_scl_input_facts.append(scl_input_facts)
            all_single_gt += gt
            all_single_pred += pred

        batched_ys = [1] * batch_size
        batched_formatted_ys = [[1]] * batch_size

        formatted_batched_scl_input_facts = process_batched_facts(batched_scl_input_facts)
        output = self.reason(**formatted_batched_scl_input_facts)
        if self.with_violation:
            tps, probs = output['constraint']
            violations = output['violation']
        else:
            tps, probs = output
            violations = [0] * batch_size

        loss = self.loss(tps, probs, violations, batched_ys)
        correct_ls = [1 if p == g else 0 for p, g in zip(all_single_pred, all_single_gt)]

        return loss, correct_ls

    def forward_contrast(self, batch):

        batched_ids, batched_actions, batched_placeholders, batched_videos, batched_bboxes, \
            batched_object_pairs, batched_object_ids, batched_video_splits, batched_object_names, _ = batch
        batch_size = len(batched_ids)
        consts = list(set([p for ps in batched_placeholders for p in ps] + [tp[2] for tp in batched_object_names]))

        const_lookup = {}
        cids = []
        for k, v in const2cid.items():
            const_lookup[k] = v
            const_lookup[k.upper()] = v
            const_lookup[k.lower()] = v

            if k.lower() in consts:
                consts.remove(k.lower())

            cids.append(v)

        current_cid = min(cids) - 1
        for k in consts:
            const_lookup[k] = current_cid
            const_lookup[k.upper()] = current_cid
            const_lookup[k.lower()] = current_cid
            current_cid -= 1

        batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_object_tps)
        batched_occurred_objs = []
        occured_objs = []
        last_vid = 0

        for vid, fid, oid in batched_object_ids:
            if vid > last_vid:
                batched_occurred_objs.append(list(set(occured_objs)))
                occured_objs = []
                last_vid = vid
            occured_objs.append(oid)
        if not len(occured_objs) == 0:
            batched_occurred_objs.append(list(set(occured_objs)))

        batched_unary_pred_prob, batched_binary_pred_prob, batched_static_pred_prob = \
            self.predicate_model(batched_videos, batched_bboxes, batched_object_ids, batched_object_pairs, batched_occurred_objs, batched_video_splits)

        # TODO: support sampling
        # Process static predicates
        batched_static_pred_scl = []
        current_obj_ct = 0
        for occured_objs in batched_occurred_objs:
            static_pred_scl = []
            for oid in occured_objs:
                probs = batched_static_pred_prob[current_obj_ct]
                for prob, pred in zip(probs, static_preds):
                    static_pred_scl.append((prob, tuple([pred, oid])))
                current_obj_ct += 1
            batched_static_pred_scl.append(static_pred_scl)

        # Process unary predicates
        current_vid = 0
        batched_unary_pred_scl = {}

        for vid in range(batch_size):
            batched_unary_pred_scl[vid] = []

        for ((vid, fid, obj_id), unary_probs) in zip(batched_object_ids, batched_unary_pred_prob):
            for prob, pred in zip(unary_probs, unary_preds):
                batched_unary_pred_scl[vid].append((prob, (pred, fid, obj_id)))
        batched_unary_pred_scl = list([batched_unary_pred_scl[vid] for vid in range(batch_size)])


        # Process binary predicates
        batched_binary_pred_scl = {}
        binary_pred_scl = []

        for vid in range(batch_size):
            batched_binary_pred_scl[vid] = []

        for ((vid, fid, (obj1, obj2)), binary_probs) in zip(batched_object_pairs, batched_binary_pred_prob):
            for prob, pred in zip(binary_probs, binary_preds):
                batched_binary_pred_scl[vid].append((prob, (pred, fid, obj1, obj2)))
        batched_binary_pred_scl = list([batched_binary_pred_scl[vid] for vid in range(batch_size)])

        batched_outputs =[]
        batched_ys = []
        batched_formatted_ys = []
        batched_scl_input_facts = []

        for scl_data_id, scl_template in zip(batched_ids, batched_actions):
            action = self.template2action[scl_template]
            current_outputs = []
            current_ys = []

            for data_id, template, scl_tp, static_pred_tp, unary_pred_tp, binary_pred_tp, place_holder \
                in zip(batched_ids, batched_actions, batched_scl_tps, batched_static_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_placeholders):

                # Give an ID to all required placeholders and object names
                scl_input_facts = {}

                # Not the same placeholder count, cannot process
                action_args_ct = action_arg_num[action]
                place_holder_ct = len(place_holder)
                if not action_args_ct == place_holder_ct:
                    continue

                gt_action = template2action[template]

                if not gt_action == action and action not in self.paired_actions[gt_action]:
                    continue

                scl_input_facts.update(scl_tp)
                scl_input_facts['sg_static_atom'] = (static_pred_tp)
                scl_input_facts['sg_unary_atom'] = (unary_pred_tp)
                scl_input_facts['sg_binary_atom'] = (binary_pred_tp)
                scl_input_facts['num_variables'] = [tuple([place_holder_ct])]
                scl_input_facts['variable_name'] = [(vid, const_lookup[p]) for vid, p in zip(var2vid.values(), place_holder)]
                action_scl = action2scl[action]
                scl_input_facts.update(action_scl)

                batched_scl_input_facts.append(scl_input_facts)

                if gt_action == action:
                    y = 1
                else:
                    y = 0

                batched_ys.append(y)
                current_ys.append(y)

            if sum(current_ys) == 0:
                print('Expect at least one correct outcome')
                print(scl_data_id)

            batched_formatted_ys.append(current_ys)

        formatted_batched_scl_input_facts = process_batched_facts(batched_scl_input_facts)
        output = self.reason(**formatted_batched_scl_input_facts)
        if self.with_violation:
            tps, probs = output['constraint']
            violations = output['violation']
        else:
            tps, probs = output
            violations = [0] * batch_size

        if len(tps) == 0 and scl_data_id == data_id:
            to_scl_file(self.common_scl, [], scl_input_facts, os.path.join(self.save_scl_dir, data_id + '.scl'))
            print("wrong prog")
            return loss, correct

        loss = self.loss(tps, probs, violations, batched_ys)
        correct = self.correct(tps, probs, batched_formatted_ys)

        return loss, correct

    def train_epoch(self, n):
        self.predicate_model.train()
        all_losses = []

        iterator = tqdm(self.train_loader)
        for ct, dp_list in enumerate(iterator):

            self.optimizer.zero_grad()
            if self.use_contrast:
                loss_ls, _ = self.forward_contrast(dp_list)
            else:
                loss_ls, _ = self.forward(dp_list)

            loss = sum(loss_ls)
            loss.backward(retain_graph=True)
            self.optimizer.step()

            all_losses += [loss.item() for loss in loss_ls]
            avg_loss = sum(all_losses)/len(all_losses)
            iterator.set_description(f'[Train {n}] Loss: {avg_loss}')
            self.train_loader.dataset.shuffle()

        return avg_loss

    def test_epoch(self, n):
        self.predicate_model.eval()
        all_losses = []

        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):
                if self.use_contrast:
                    loss_ls, _ = self.forward_contrast(dp_list)
                else:
                    loss_ls, _ = self.forward(dp_list)

                all_losses += [loss if type(loss) == int else loss.item() for loss in loss_ls]
                avg_loss = sum(all_losses)/len(all_losses)
                iterator.set_description(f'[Test {n}] Loss: {avg_loss})')

        # Save model
        if avg_loss > self.max_accu and self.save_model:
            self.max_accu = avg_loss
            torch.save(self.predicate_model, os.path.join(self.model_dir, f"{self.model_name}.best.model"))
        if self.save_model:
            torch.save(self.predicate_model, os.path.join(self.model_dir, f"{self.model_name}.latest.model"))

        return avg_loss
    
    def test_stats(self):
        self.predicate_model.eval()
        all_losses = []
        all_stats = []

        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):
                stat_ls = self.forward_stats(dp_list)
                all_stats += stat_ls
                

        # Save model
        return

    def test(self):
        self.test_epoch(0)

    def label(self, n):
        self.predicate_model.eval()
        all_interesting_dp = {}

        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):
                interesting_dps = self.label_batch(dp_list)
                for pred, dp_ids in interesting_dps.items():
                    if not pred in all_interesting_dp:
                        all_interesting_dp[pred] = dp_ids
                    else:
                        all_interesting_dp[pred] = all_interesting_dp[pred].union(dp_ids)

                if ct > n:
                    break
        return

    def label_by_id(self, did, phase="train"):
        self.predicate_model.eval()
        if phase == "train":
            dataset = self.train_loader.dataset
        else:
            dataset = self.test_loader.dataset
        data = dataset.get_item_by_id(did)
        with torch.no_grad():
            self.label_batch(data)

        return

    def baseline_eval(self):
        self.predicate_model.eval()
        total_results = []
        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):
                result = self.baseline_metric_eval_batch(dp_list)
                total_results.append(result)

        merge_results = combine_baseline_pred_dict_ls(total_results)
        stats = obtain_stats(merge_results)
        pretty_print(stats)
        return stats

    def train(self, num_epochs):
        for i in range(1, num_epochs + 1):
            self.train_epoch(i)
            self.test_epoch(i)

    def save_scl_file(self, datapoint, object_tps, current_constraint):
        scl_file_content = obtain_scl_file(object_tps, current_constraint, self.common_scl)
        scl_file_name = datapoint['id'] + '.scl'
        if not self.save_scl_dir is None:
            scl_path = os.path.join(self.save_scl_dir, scl_file_name)
            with open(scl_path, 'w') as scl_file:
                scl_file.write(scl_file_content)

def combine_baseline_pred_dict_ls(pred_dict_ls):
    new_result = {}
    for pred_dict in pred_dict_ls:
        for pred_name, pred_info in pred_dict.items():
            if not pred_name in new_result:
                new_result[pred_name] = {}
                new_result[pred_name]['gt'] = []
                new_result[pred_name]['pred'] = []
            new_result[pred_name]['gt'] += pred_info['gt']
            new_result[pred_name]['pred'] += pred_info['pred']
    return new_result


def obtain_stats(pred_dict):
    new_result = {}
    for pred_name, pred_info in pred_dict.items():
        if pred_name == "touching":
            # By manually inspecting the touching videos, 80% of the heuristic is incorrect
            # the first frame is not reflective of the touching preconditions
            continue
        if not pred_name in new_result:
            new_result[pred_name] = {}
        new_result[pred_name]['recall'] = metrics.recall_score(pred_info['gt'], pred_info['pred'])
        new_result[pred_name]['precision'] = metrics.precision_score(pred_info['gt'], pred_info['pred'])
        new_result[pred_name]['f1'] = metrics.f1_score(pred_info['gt'], pred_info['pred'])
        new_result[pred_name]['count'] = len(pred_info['gt'])

    return new_result

def pretty_print(stats):
    total_number = 0

    for name, stats_info in stats.items():
        total_number += stats_info['count']

    for name, stats_info in stats.items():
        print(f"{name}, {stats_info['count']/total_number}, {stats_info['precision']}, {stats_info['recall']},{stats_info['f1']}")

def pair_sim_actions(action2template, constraints, threshold = 0.6):
    action_pred_map = {action_name: set(prog.collect_preds()) for (action_name, prog) in constraints.actions.items()}
    all_actions = list(action2template.keys())

    sim_actions = {}
    for action_1 in all_actions:
        for action_2 in all_actions:
            if action_2 == action_1:
                continue
            pred_1 = action_pred_map[action_1]
            pred_2 = action_pred_map[action_2]
            same_preds = pred_1.intersection(pred_2)
            all_preds = pred_1.union(pred_2)
            sim_score = len(same_preds) / len(all_preds)
            if sim_score > threshold:
                if not action_1 in sim_actions:
                    sim_actions[action_1] = []
                sim_actions[action_1].append(action_2)

    return sim_actions

def substitute_args(required_preds, place_holder, const_lookup):
    new_required_preds = []
    for pos_neg, pred_name, pred_args in required_preds:
        replaced_args = []
        for arg in pred_args:
            if '?' in arg:
                vid = var2vid[arg[1:]]-1
                if vid < len(place_holder):
                    name = place_holder[vid]
                    replaced_args.append(const_lookup[name])
                else:
                    replaced_args.append("ANY")
            else:
                replaced_args.append(const_lookup[arg])
        new_required_preds.append((pos_neg, pred_name, replaced_args))
    return new_required_preds

def obtain_pred_probs(new_prec_required_preds, start_pred_info, const_lookup, obj_name_map):
    prec_satis_pos = {}
    for pos_neg, pred_name, args in new_prec_required_preds:
        for tp in start_pred_info:
            # binary
            if len(tp) == 5:
                context_pred_name, context_prob, from_obj_id, to_obj_id, _ = tp
                from_obj_name = const_lookup[obj_name_map[from_obj_id]]
                to_obj_name = const_lookup[obj_name_map[to_obj_id]]

                if pred_name == context_pred_name and from_obj_name == args[0] and to_obj_name == args[1]:
                    if not pred_name in prec_satis_pos:
                        prec_satis_pos[pred_name] = {}
                    if not pos_neg in prec_satis_pos:
                        prec_satis_pos[pred_name][pos_neg] = []
                    if pos_neg:
                        prec_satis_pos[pred_name][pos_neg].append(context_prob)
                    else:
                        prec_satis_pos[pred_name][pos_neg].append(1 - context_prob)

            else:
                context_pred_name, context_prob, obj_id, _ = tp
                tp_name = const_lookup[obj_name_map[obj_id]]

                if pred_name == context_pred_name and tp_name == args[0]:
                    if not pred_name in prec_satis_pos:
                        prec_satis_pos[pred_name] = {}
                    if not pos_neg in prec_satis_pos:
                        prec_satis_pos[pred_name][pos_neg] = []
                    if pos_neg:
                        prec_satis_pos[pred_name][pos_neg].append(context_prob)
                    else:
                        prec_satis_pos[pred_name][pos_neg].append(1 - context_prob)

    return prec_satis_pos

def pair_non_eq_actions(action2template, constraints):
    action_pred_map = {action_name: set(prog.collect_preds()) for (action_name, prog) in constraints.actions.items()}
    action_num_map = {action_name: sum([1 if'?' in arg else 0 for arg in prog.param[0]['name']]) for (action_name, prog) in constraints.actions.items()}

    all_actions = list(action2template.keys())

    action2cat = {a: a.split('-')[0] for a in all_actions}
    cat2actions = {}
    diff_actions = {}

    for action, cat in action2cat.items():
        if not cat in cat2actions:
            cat2actions[cat] = []
        cat2actions[cat].append(action)

    for action_1 in all_actions:
        for action_2 in all_actions:
            if action_2 == action_1:
                continue
            cat1 = action2cat[action_1]
            cat2 = action2cat[action_2]
            if cat1 == cat2:
                continue

            num1 = action_num_map[action_1]
            num2 = action_num_map[action_2]
            if not num1 == num2:
                continue

            if not action_1 in diff_actions:
                diff_actions[action_1] = []
            diff_actions[action_1].append(action_2)

    return diff_actions

if __name__ == "__main__":

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../data'))
    model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../model'))
    video_save_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../data/pred_video'))

    parser = ArgumentParser("sth_sth")
    parser.add_argument("--phase", type=str, default='train')
    parser.add_argument("--n-epochs", type=int, default=100)
    parser.add_argument("--load_model", action='store_true')
    parser.add_argument("--save_model",  action='store_true')
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")

    # setup question path
    parser.add_argument("--train_num", type=int, default=100)
    parser.add_argument("--val_num", type=int, default=100)
    parser.add_argument("--training_percentage", type=int, default=100)
    parser.add_argument("--test_percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=3)
    parser.add_argument("--learning-rate", type=float, default=0.0001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--with-violation",  action='store_true')
    parser.add_argument("--violation-weight", type=float, default=0.05)
    parser.add_argument("--use-contrast",  action='store_true')

    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-path", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_true")
    parser.add_argument("--gpu", type=int, default=0)

    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # name = f"sth_contrast_{args.use_contrast}_wv_{args.with_violation}_{args.training_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_vw_{args.violation_weight}_b3s2"
    if args.model_type == "plain" and args.phase == "test":
        name = f"sth_contrast_False_wv_False_100_seed_1234_batch_size_8_lr_0.0001_prov_difftopkproofs_tpk_3_vw_0.0_b3s2"
    elif args.model_type == "contrast" and args.phase == "test":
        name = "sth_contrast_True_wv_False_100_seed_1234_batch_size_3_lr_0.0001_prov_difftopkproofs_tpk_3_vw_0.0_b3s2"
    elif args.model_type == "weighted" and args.phase == "test":
        name = "sth_contrast_True_wv_True_100_seed_1234_batch_size_3_lr_0.0001_prov_difftopkproofs_tpk_3_vw_0.05_b3s2"
    else:
        name = f"sth_contrast_{args.use_contrast}_wv_{args.with_violation}_{args.training_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_vw_{args.violation_weight}_b3s2"

    print(name)

    if args.model_name is None:
        args.model_name = name

    label_dir = os.path.join(args.data_dir, 'labels')
    video_dir = os.path.join(args.data_dir, '20bn-something-something-v2')
    bbox_dir = os.path.join(args.data_dir, 'bboxes')
    constraint_path = os.path.join(args.data_dir, 'constraints.pddl')
    manual_mapping_path = os.path.join(args.data_dir, 'manual_mapping.json')
    manual_mapping = json.load(open(manual_mapping_path, 'r'))
    template_mapping_path = os.path.join(args.data_dir, "template_mapping.json")
    scl_dir = os.path.join(args.data_dir, 'scl')
    if not args.with_violation:
        common_scl_path = os.path.join(args.data_dir, 'scl/eval_actions.scl')
    else:
        common_scl_path = os.path.join(args.data_dir, 'scl/eval_actions_with_vio.scl')

    parser = PDDLParser()
    constraints = parser.parse(''.join(open(constraint_path, 'r').readlines()))
    action2scl = constraints.to_scl()

    train_dataset_path = os.path.join(data_dir, f"mini_train_{args.train_num}.json")
    valid_dataset_path = os.path.join(data_dir, f"mini_valid_{args.val_num}.json")

    if torch.cuda.is_available():
        device = "cuda:0"
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        device = "cpu"
        torch.set_default_tensor_type(torch.FloatTensor)


    action2template = json.load(open(template_mapping_path, 'r'))
    template2action = {}
    for action, template in action2template.items():
        template2action[template] = action

    paired_actions = pair_non_eq_actions(action2template=action2template, constraints=constraints)

    train_dataset, test_dataset, train_loader, test_loader = \
        ss_loader(train_dataset_path, valid_dataset_path, video_dir, args.batch_size, device, \
            training_percentage=args.training_percentage, testing_percentage=args.test_percentage, paired_actions=paired_actions)

    trainer = Trainer(train_loader=train_loader, test_loader=test_loader, device=device, action2scl=action2scl,
                      action2template=action2template, save_scl_dir=scl_dir, common_scl_path=common_scl_path,
                      latent_dim=args.latent_dim, model_layer=args.model_layer,
                      model_path=args.model_path, model_name=args.model_name,
                      learning_rate=args.learning_rate, load_model=args.load_model,
                      paired_actions=paired_actions,
                      provenance=args.provenance,
                      save_model=args.save_model,
                      video_save_dir= args.video_save_dir,
                      violation_weight=args.violation_weight,
                      with_violation=args.with_violation,
                      use_contrast=args.use_contrast)

    # if args.phase == "train":
    #     trainer.train(args.n_epochs)
    # elif args.phase == "test":
    #     trainer.baseline_eval()
    # elif args.phase == "label":
    #     if not os.path.exists(video_save_dir):
    #         os.mkdir(video_save_dir)
    #     trainer.label(60)

    # trainer.baseline_eval()
    # trainer.decompose_label(args.n_epochs)
    # for lid in to_label_ids:
    #     trainer.label_by_id(lid, phase="test")
    trainer.test_stats()

    print('finished')
