from datasets.cater.enums import actions_order_dataset, actions_order_mapping, reverse
from vision.visionutils import load_labels, save_labels
from collections import defaultdict

# classes for one predicate, ie the base CATER task 2
one_pred_classes = actions_order_dataset(n=2)
# classes for one predicate, ie the base CATER task 2
two_pred_classes = actions_order_dataset(n=3, as_list=False, samples=12)


def build_two_pred(obj_preds, subj_preds):
    if subj_preds[0][0] != obj_preds[0][1]:
        return None
    actions = (obj_preds[0][0], obj_preds[0][1], subj_preds[0][1])  # here subj_preds[0][0] == obj_preds[0][1]
    relations = (obj_preds[1][0], subj_preds[1][0])
    two_pred_event = (actions, relations)
    return two_pred_event

def get_labels(actions):
    all_labels = []
    for samp_actions in actions:
        # find length 2 binary predicates of form r(x, y) ^ r(y, z)
        # need to find a common action y between individual predicates
        obj_subj_matches = defaultdict(lambda: [[], []])
        for obj_action in samp_actions:
            obj_action = one_pred_classes[obj_action]
            obj = obj_action[0][1]  # here y is the object event
            obj_subj_matches[obj][0].append(obj_action)

        for subj_action in samp_actions:
            subj_action = one_pred_classes[subj_action]
            subj = subj_action[0][0]  # here y is the subject event
            obj_subj_matches[subj][1].append(subj_action)

        two_pred_labels = []
        for common_action, common_predicates in obj_subj_matches.items():
            for obj_preds in common_predicates[0]:
                for subj_preds in common_predicates[1]:
                    two_pred_event = build_two_pred(obj_preds, subj_preds)
                    label_num = two_pred_classes.get(two_pred_event, two_pred_classes.get(reverse(two_pred_event), None))
                    if label_num is not None and label_num not in two_pred_labels:
                        two_pred_labels.append(label_num)
        all_labels.append(two_pred_labels)
    return all_labels


if __name__ == '__main__':
    data_dir = '/localscratch/cater_dataset/max2action'
    splits = ['train_subsetT', 'train_subsetV', 'val']

    limit = None
    for split in splits:
        files, actions = load_labels(data_dir, split=split, folder='actions_order_uniq')
        labels = get_labels(actions[:limit])
        save_labels(files[:limit], labels, data_dir, split, folder='actions_present_two_pred_1k')
