from temporal.inference import ACTION_CLASSES
from datasets.cater.enums import actions_order_dataset, MAX_FRAMES
from math import floor
import torch

def rule_examples():
    rules = actions_order_dataset(unique=False)
    examples = torch.zeros(len(rules), len(ACTION_CLASSES), MAX_FRAMES)
    
    # divide the timeline into thirds, and based on the enumerated rules, assign those values in the examples accordingly
    start = 0
    middle_start = floor(MAX_FRAMES/3)
    middle_end = 2 * middle_start
    end = MAX_FRAMES

    for sample_idx, rule in enumerate(rules):
        (subject, obj), relation  = rule
        relation = relation[0]
        if relation == 'before':
            subject_start = start
            subject_end = middle_start

            object_start = middle_end
            object_end = end
        elif relation == 'after':
            subject_start = middle_end
            subject_end = end

            object_start = start
            object_end = middle_start
        else:
            subject_start = object_start = middle_start
            subject_end = object_end = middle_end
        subject_idx = ACTION_CLASSES.index(subject)
        obj_idx = ACTION_CLASSES.index(obj)
        
        examples[sample_idx][subject_idx, subject_start:subject_end] = 1
        examples[sample_idx][obj_idx, object_start:object_end] = 1

    return examples

if __name__ == '__main__':
    rule_examples()