import os
import re
import cv2
import json
import torchvision.transforms as T
from sklearn import metrics

norm_x = 480
norm_y = 360

mean=[123.675, 116.28, 103.53]
std=[58.395, 57.12, 57.375]
location_consts = ["early", "mid", "late"]

all_entities = ["adult",
            "baby",
            "bag",
            "ball",
            "ballon",
            "basket",
            "bat",
            "bed",
            "bench",
            "beverage",
            "bike",
            "bird",
            "blanket",
            "board",
            "book",
            "bottle",
            "bowl",
            "box",
            "bread",
            "brush",
            "bucket",
            "cabinet",
            "cake",
            "camera",
            "can",
            "candle",
            "car",
            "card",
            "carpet",
            "cart",
            "cat",
            "cellphone",
            "ceiling",
            "chair",
            "child",
            "chopstick",
            "cloth",
            "computer",
            "condiment",
            "cookie",
            "countertop",
            "cover",
            "cup",
            "curtain",
            "dog",
            "door",
            "drawer",
            "dustbin",
            "egg",
            "fan",
            "faucet",
            "fence",
            "flower",
            "fork",
            "fridge",
            "fruit",
            "gift",
            "glass",
            "glasses",
            "glove",
            "grain",
            "guitar",
            "hat",
            "helmet",
            "horse",
            "iron",
            "knife",
            "light",
            "lighter",
            "mat",
            "meat",
            "microphone",
            "microwave",
            "mop",
            "net",
            "noodle",
            "others",
            "oven",
            "pan",
            "paper",
            "piano",
            "pillow",
            "pizza",
            "plant",
            "plate",
            "pot",
            "powder",
            "rack",
            "racket",
            "rag",
            "ring",
            "scissor",
            "shelf",
            "shoe",
            "simmering",
            "sink",
            "slide",
            "sofa",
            "spatula",
            "sponge",
            "spoon",
            "spray",
            "stairs",
            "stand",
            "stove",
            "switch",
            "table",
            "teapot",
            "towel",
            "toy",
            "tray",
            "tv",
            "vaccum",
            "vegetable",
            "washer",
            "window"
            "ceiling",
            "floor",
            "grass",
            "ground",
            "rock",
            "sand",
            "sky",
            "snow",
            "tree",
            "wall",
            "water",
        ]

all_binary_preds = [
        "beside",
        "biting",
        "blowing",
        "brushing",
        "caressing",
        "carrying",
        "catching",
        "chasing",
        "cleaning",
        "closing",
        "cooking",
        "cutting",
        "drinking from",
        "eating",
        "entering",
        "feeding",
        "grabbing",
        "guiding",
        "hanging from",
        "hitting",
        "holding",
        "hugging",
        "in",
        "in front of",
        "jumping from",
        "jumping over",
        "kicking",
        "kissing",
        "licking",
        "lighting",
        "looking at",
        "lying on",
        "next to",
        "on",
        "opening",
        "over",
        "picking",
        "playing",
        "playing with",
        "pointing to",
        "pulling",
        "pushing",
        "riding",
        "running on",
        "shaking hand with",
        "sitting on",
        "standing on",
        "stepping on",
        "stirring",
        "swinging",
        "talking to",
        "throwing",
        "touching",
        "toward",
        "walking on",
        "watering",
        "wearing"
    ]

all_gpt_binary_kws = ["standing up", 'coming back', "looking ahead", 'walking to the side', 'crying', "pets", "pinching", "come over", "appreciating", "smiling at", "brushing", "dipping", "size", "follow", "setting down", "mount", "reach out to touch", "playing together", "discovering", "chase", "places in middle of", "flicking", "praising", "retreating", "pointing to", "away", "putting under", "closer", "type", "standing next to", "walking with", "together", "patting", "feed", "drawing", "come closer", "throw", "running around", "peeking", "emotion", "stirring", "climbing onto", "slicing", "puts on", "take", "losing", "brings", "with", "passing under", "sealing", "handing", "pulls head out of", "by", "watching", "helps up", "lifting", "take out", "biting", "playing", "supporting", "assisting", "kissing", "number", "laughing together", "waving at", "running to", "shaking hand with", "laying", "resetting", "scanning", "dancing", "swimming", "aligning", "kicking", "tap", "painting", "thinking", "climbing up", "taking photos of self", "sliding", "swing", "flipping through", "switch", "zoom in on", "moving closer", "crawling", "light", "leaning against", "keeping", "bathing", "brush", "going", "greeting", "turning", "igniting", "breaking", "showing affection", "sitting in front of", "getting down", "glance", "nearby", "tapping", "pointing", "walking into", "around", "watch", "taking", "embrace", "blowing", "holding", "hugging", "walking on", "away from", "join", "right side of", "receiving", "focus", "get off", "turning off", "talking to", "picking", "stand on", "give high five", "forward", "pick up", "raising", "sit", "rinse", "increasing temperature", "placing in front", "lighting", "trying", "capture", "play", "open", "below", "color", "on", "prepare", "falling to ground", "temporarily place on left", "folding", "walk past", "observe", "looking at", "in arms", "kiss", "removes", "repositioning", "filming", "sticks head into", "holding hands", "returning", "communicating", "rotating", "hit", "continues passing", "running through", "facing", "wiping", "feeding", "examine", "gazing", "concluding contact", "back to", "grab", "removing", "laughing with", "running on", "washing", "talking", "placing in front of", "swimming in", "move away", "mopping", "falls down", "transferring", "wrestling", "jumping on", "approaches", "past", "jumping over", "help", "retrieving", "with force", "dealing", "spinning", "change", "onto lap", "nuzzling", "touch", "inside", "place", "riding on", "inner side of", "jumping", "off", "gazing at", "in hand", "jumping onto", "sitting down on", "behind", "placing", "beside", "leaving", "taking picture", "performing", "stroke", "covering", "taking photo", "into", "sitting", "positioning in front of", "crouching on", "count", "walking up to", "push", "in front of", "walks up to", "using", "adjust sitting position", "shape", "under", "standing on", "cover", "turn on", "conversing", "jumping into", "leaning head closer", "bouncing around with", "shift", "hold", "knocking down", "climbing", "jumping with", "falling", "wash", "crawling on", "next to", "ride", "getting up", "at feet of", "inviting", "showing", "hanging", "let walk", "see", "pulling away", "encouraging", "walk out", "cutting", "jumping from", "reading", "switching", "throwing away", "bring", "fluffing", "running up", "get back up on", "turning attention", "go to", "spreading", "zoom in", "go into", "eating", "reaches out", "helping hold", "in embrace", "panning", "toward", "sets temperature", "falling from", "surrounding", "putting on", "fetching", "standing near", "sharing", "not hitting", "putting down", "separating", "chasing", "bigger", "wishing", "direction", "gluing", "touching", "reaching out", "contains", "examining", "walking", "searching", "chasing after", "tearing open", "putting back", "lowering temperature", "blocking", "running after", "stepping", "rubbing", "playing with", "dancing in circle", "following", "eating from", "trying to catch", "using hand", "add", "holding up", "age", "contain", "shuffling", "return to", "smearing", "wringing", "disposing", "caressing", "rolling", "packaging", "scrubbing", "hitting", "to side", "organizing", "finishing", "taking away", "has", "interacting", "prompting", "turning on", "ironing", "exiting", "picking up again", "part", "holding hands with", "water", "tuning", "setting", "heating", "gathering around", "put", "reaching", "serving", "unwrap", "filling", "checking", "looking for", "adjusting", "play with", "leaning on", "not enter", "left hand", "giving kiss", "embracing", "trimming", "tumbling", "stir", "letting", "returning to", "of", "moving aside", "for", "resting on", "flipping", "dismount", "continuing", "shifting", "lying", "dropping", "rocking", "observing", "going to", "guiding", "leading", "swinging", "releasing", "sleeve length", "put back", "running", "glancing at", "approaching", "drying", "blowing out", "drinking from", "look at", "going upstairs", "stopping", "stand", "sing", "standing up from", "standing up", "rolling up", "aiming", "other side", "pointing towards", "point", "far", "taking out", "hair color", "middle of", "to", "grabbing by hand", "adjust", "walking to", "pushing", "speaking to", "looking", "passing with", "bounce off", "sweeping", "cut", "amusing", "playfully", "near", "tearing", "pretending throw", "bringing back", "snuggling up with", "asking", "put down", "knocking", "boiling", "high five", "through", "right of", "stare at", "holding hand", "pressing", "sitting in", "singing", "adjusting position", "throwing through legs", "helping", "running toward", "carrying on back", "turn over", "move", "waiting", "competing", "burying head", "completely", "jumping off", "leaning", "celebrating", "left", "letting go", "putting", "moving forward", "watching together", "giving push", "taking picture of", "moving by hand", "scooping", "paint", "iron", "playful", "setting aside", "collecting", "riding", "running towards", "appear", "moving toward", "out of", "puts head back into", "moving", "side", "back into", "crawling toward", "blow out", "control", "lowering", "discarding", "placing in", "other side of", "taking photos", "sound", "crouching behind", "close", "bringing", "photographing", "offering wishes", "spraying", "climbing down", "above", "handling", "closing", "cleaning", "across", "left side", "onto", "clean", "tidying", "switch back", "falling off", "correcting direction of", "joyfully unwrapping", "recording", "miss", "unpack", "pushing in", "placing aside", "speaking", "saut\u00e9ing", "stepping on", "not pass over", "wearing", "walking toward", "use", "outer side of", "stroking", "dragging", "chopping", "place back on", "carrying", "backing up", "sitting next to", "running to fetch", "walking back", "check", "unfolding", "bursting", "teaching", "opening", "passing", "passing by", "continue sitting", "sitting on", "tasting", "pet", "right", "drops", "pulling", "entering", "part of", "wipe", "preparing", "rinsing", "picking up", "unwrapping", "enter", "owned by", "pouring", "teasing", "petting", "adding", "tipping over", "adorned with", "placing on", "in", "sanding", "dismounting", "walking towards", "pick", "turn off", "back", "going into", "failing", "shaking", "pruning", "taking pictures", "grabbing", "climbing on", "lying on", "gently pushes away", "standing", "clapping hands with", "making laugh", "stop", "involves", "putting away", "standing with", "wringing out", "fall", "walk to", "over", "moving away from", "up onto", "dip", "up to", "rolling together", "snatching", "crawling over", "study", "from", "throwing", "at", "repeating", "catching"]
re_num = "\([0-9a-zA-Z\,\- ]+\)"

non_prob_gpt_preds = ["frame", "all_frames", "all_objects", "num_variables", "variable", "time_stamp_ct", "time_stamp", "positive_unary_atom", "negative_unary_atom", "positive_binary_atom", "negative_binary_atom", "inequality_constraint", "object", "time"]
non_prob_gpt_prog_str_preds = ["variable", "spec", "object", "time"]


bool_token = ["and", "or", "not"]
kw_preds = {
    '=': "=="
}
not_kw_preds = {
    '=': "!="
}


var2vid = {chr(i): ord(chr(i)) - 96 for i in range(97, 97 + 26)}

const2cid = {
    'HAND': -1,
    'hand': -1
}


def get_start_end_frame(start_time, end_time, video_path):
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print("Error opening video file")
        return None, None

    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    start_frame = int(start_time * fps)
    end_frame = int(end_time * fps)

    cap.release()

    start_frame = max(0, start_frame)
    end_frame = min(end_frame, total_frames - 1)

    if end_frame < start_frame:
        end_frame = start_frame

    return start_frame, end_frame


def format_batched_facts(batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs):
    batched_scl_input_facts = []

    for vid, (scl_tp, cate_pred_tp, unary_pred_tp, binary_pred_tp, gpt_spec) \
        in enumerate(zip(batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs)):

        # Give an ID to all required placeholders and object names
        scl_input_facts = {}

        scl_input_facts.update(scl_tp)
        scl_input_facts['name'] = (cate_pred_tp)
        scl_input_facts['sg_unary_atom'] = (unary_pred_tp)
        scl_input_facts['sg_binary_atom'] = (binary_pred_tp)
        scl_input_facts['variable'] = [tuple([i + 1]) for i in range(len(gpt_spec['args']))]
        scl_input_facts['spec'] = [gpt_spec['prog']]
        batched_scl_input_facts.append(scl_input_facts)

    formatted_batched_scl_input_facts = process_batched_facts(batched_scl_input_facts)
    return formatted_batched_scl_input_facts


def resize_bboxes(orig_bboxes, orig_vid_width, orig_vid_height, new_vid_width, new_vid_height):
    new_bboxes = []
    width_sizing_ratio = new_vid_width / orig_vid_width
    height_sizing_ratio = new_vid_height / orig_vid_height

    for x1, y1, x2, y2 in orig_bboxes:
        new_x1, new_y1, new_x2, new_y2 = int(x1 // width_sizing_ratio), int(y1 // height_sizing_ratio), int(x2 // width_sizing_ratio), int(y2 // height_sizing_ratio)
        new_bboxes.append([new_x1, new_y1, new_x2, new_y2])

    return new_bboxes


def combine_jsons(path1, path2, annotations=1):
    with open(path1, 'r') as file1, open(path2, 'r') as file2:
        json1 = json.load(file1)
        json2 = json.load(file2)

    def merge_dicts(dict1, dict2):
        for key in dict2:
            if key in dict1:
                if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
                    merge_dicts(dict1[key], dict2[key])
                else:
                    dict1[key] = dict2[key]
            else:
                dict1[key] = dict2[key]
        return dict1

    if annotations == 1:
        combined_json = merge_dicts(json1['annotations'], json2['database'])
    else:
        combined_json = merge_dicts(json1['database'], json2['annotations'])

    return combined_json


def check_is_valid(all_bbox_data, video_path, video_id):
    if not os.path.exists(video_path):
        # print("init function: video path does not exist.")
        return False
    if get_video_data(all_bbox_data, video_id) == []:
        # print("init function: no bounding boxes data")
        return False

    return True

def check_valid_video(video_path, video_id):
    if not os.path.exists(video_path):
        # print("init function: video path does not exist.")
        return False
    return True



def get_video_data(all_bbox_data, video_id):
    """
    Retrieves the bounding box coordinates for the specified video ID from all_bbox_data.

    Args:
        video_id (str): The unique identifier for the video.
    """
    try:
        video_data = all_bbox_data[video_id]
        resized_width = video_data['rwidth']
        resized_height = video_data['rheight']
        return resized_width, resized_height, list(video_data['segments'].values())

    except KeyError:
        # print(f"Bounding box data not found for video_id: {video_id}")
        return []



def combine_baseline_pred_dict_ls(pred_dict_ls):
    new_result = {}
    for pred_dict in pred_dict_ls:
        for pred_name, pred_info in pred_dict.items():
            if not pred_name in new_result:
                new_result[pred_name] = {}
                new_result[pred_name]['gt'] = []
                new_result[pred_name]['pred'] = []
            new_result[pred_name]['gt'] += pred_info['gt']
            new_result[pred_name]['pred'] += pred_info['pred']
    return new_result

def rec_sub_val(ls, val_subst_dict):
    new_ls = []
    for element in ls:
        if type(element) == list:
            new_element = rec_sub_val(element, val_subst_dict)
            new_ls.append(new_element)
        else:
            if element in val_subst_dict:
                new_ls.append(val_subst_dict[element])
            else:
                new_ls.append(element)
    return new_ls

def get_start_end(caption):
    output = caption['time'].split('-')
    if len(output) == 1:
        start_time = int(output[0])
        end_time = int(output[0]) + 1
    else:
        start_time, end_time = output

        if len(start_time) > 0:
            start_time = int(start_time)
        else:
            start_time = 0

        if len(end_time) > 0:
            end_time = int(end_time)
        else:
            end_time = 10000

    return start_time, end_time


def get_start_end_activity_net(timestamp):
    start_time = int(timestamp[0])
    end_time = int(timestamp[1]) + 1
    return start_time, end_time

def get_pred_mask_paths(mask_dir, start_time, end_time):
    mask_paths = []
    for frame_id in range(start_time, end_time):
        mask_path = os.path.join(mask_dir, f'{str(frame_id)}.pkl')
        if not os.path.exists(mask_path):
            mask_path = None
        # assert os.path.exists(mask_path)
        mask_paths.append(mask_path)
    return mask_paths

def get_mask_paths(mask_dir, start_time, end_time):
    mask_paths = []
    for frame_id in range(start_time, end_time):
        mask_path = os.path.join(mask_dir, f'{str(frame_id).zfill(4)}.png')
        assert os.path.exists(mask_path)
        mask_paths.append(mask_path)
    return mask_paths


def clean_cap(caption):
    current_var_id = 0
    description_ls = caption.split(' ')
    new_description = []
    to_ignore = re.findall(re_num, caption)
    new_cap = caption
    for tk in to_ignore:
        new_cap = new_cap.replace(tk, '')
    new_cap = new_cap.replace('  ', ' ')
    new_cap = new_cap.replace(' .', '.')
    new_cap = new_cap.replace(' ,', ',')
    new_cap = new_cap.strip()

    return new_cap

transform = T.Normalize(mean=mean, std=std)

def get_overlap(x, y):
    new_x = max(x[0],y[0])
    new_y = min(x[1],y[1])
    if new_x < new_y:
        return (new_x, new_y)
    else:
        return (-1, -1)

def construct_batched_scl_tps(batched_object_ids):
    batched_scl_tps = construct_scl_tps(batched_object_ids)
    return list(batched_scl_tps.values())

def construct_scl_tps(batched_object_ids):
    frame_tps = []
    name_tps = set()
    batchs = {}
    all_objects_tps = set()
    all_frames_tps = set()
    max_time = -1

    for tp in batched_object_ids:
        vid, fid, oid = tp
        if fid > max_time:
            max_time = fid

        if not vid in batchs:
            batchs[vid] = {}
            batchs[vid]['object'] = set()
            batchs[vid]['time'] = set()

        batchs[vid]['object'].add(tuple([oid]))

    for fid in range(max_time):
        batchs[vid]['time'].add(tuple([fid]))

    for vid in batchs:
        # batchs[vid]['object'] = [- i - 1 for i in range(len(batched_consts[vid]))]
        batchs[vid]['object'] = list(batchs[vid]['object'])
        batchs[vid]['time'] = list(batchs[vid]['time'])
    # scl_tps = {'frame': frame_tps, 'object': list(all_objects_tps),
                # 'time': list(all_frames_tps), 'name': list(name_tps)}
    # scl_tps = {'object': list(all_objects_tps),
    #             'time': list(all_frames_tps)}
    return batchs

def construct_scl_facts(scl_tuples):

    scl_prog = []
    for rel_name, rel_tps in scl_tuples.items():
        assert len(rel_tps) == 1
        rel_tps = rel_tps[0]

        if rel_name in non_prob_gpt_prog_str_preds:
            tps = []
            for rel_tp in rel_tps:
                current_tp = '(' + ','.join([str(i) if type(i) == int else f"\"{i}\"" for i in rel_tp]) + ')'
                tps.append(current_tp)
            scl_prog.append("rel " + rel_name + " = {"  + ', '.join([str(i) for i in rel_tps]) + "}")
        else:
            tps = []
            for prob, rel_tp in rel_tps:
                current_tp = ""
                current_tp += f"{prob.item()}::"
                current_tp += '(' + ','.join([str(i) if type(i) == int else f"\"{i}\"" for i in rel_tp]) + ')'
                tps.append(current_tp)

            scl_prog.append("rel " + rel_name + " = {"  + ', '.join(tps) + "}")

    return '\n\n'.join(scl_prog)


def process_batched_facts(fact_dict_ls):
    batched_fact_dict = {}

    for fact_dict in fact_dict_ls:
        for k, v in fact_dict.items():
            if not k in batched_fact_dict:
                batched_fact_dict[k] = []

    for fact_dict in fact_dict_ls:
        for k in batched_fact_dict:
            if not k in fact_dict:
                batched_fact_dict[k].append([])
            else:
                v = fact_dict[k]
                new_v = []
                if len(v) > 0 and type(v[0]) != tuple:
                    for v_tp in v:
                        new_v.append(tuple([v_tp]))
                else:
                    new_v = v
                batched_fact_dict[k].append(new_v)

    return batched_fact_dict


def format_batched_facts(batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs):
    batched_scl_input_facts = []

    for vid, (scl_tp, cate_pred_tp, unary_pred_tp, binary_pred_tp, gpt_spec) \
        in enumerate(zip(batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs)):

        # Give an ID to all required placeholders and object names
        scl_input_facts = {}

        scl_input_facts.update(scl_tp)
        scl_input_facts['name'] = (cate_pred_tp)
        scl_input_facts['sg_unary_atom'] = (unary_pred_tp)
        scl_input_facts['sg_binary_atom'] = (binary_pred_tp)
        scl_input_facts['variable'] = [tuple([i + 1]) for i in range(len(gpt_spec['args']))]
        scl_input_facts['spec'] = [gpt_spec['prog']]
        batched_scl_input_facts.append(scl_input_facts)

    formatted_batched_scl_input_facts = process_batched_facts(batched_scl_input_facts)
    return formatted_batched_scl_input_facts


def to_scl_file(common_scl, rules, tuples, file_path):
    scl_file = common_scl
    for rule in rules:
        scl_file += ('rel ' + rule)
        scl_file += '\n'

    for tuple_name, tps in tuples.items():
        if tuple_name in non_prob_gpt_preds:
            scl_file += ('rel ' + tuple_name + ' = {' + ', '.join([str(t).replace("'", '"') for t in tps]) + '}')
        else:
            # scl_file += ('rel ' + tuple_name + ' = {' + ', '.join([(str(t[0].item()) + '::' + str(t[1])).replace("'", '"') for t in tps]) + '}')
            scl_file += ('rel ' + tuple_name + ' = {' + ', '.join([str(t).replace("'", '"') for t in tps]) + '}')
        scl_file += '\n'

    with open(file_path, 'w') as file:
        file.write(scl_file)

    return scl_file

def obtain_stats(pred_dict):
    new_result = {}
    all_gt = []
    all_pred = []
    for pred_name, pred_info in pred_dict.items():
        if not pred_name in new_result:
            new_result[pred_name] = {}
        all_gt += (pred_info['gt'])
        all_pred += ( pred_info['pred'])
        new_result[pred_name]['accu'] = metrics.accuracy_score(pred_info['gt'], pred_info['pred'])
        new_result[pred_name]['recall'] = metrics.recall_score(pred_info['gt'], pred_info['pred'])
        new_result[pred_name]['precision'] = metrics.precision_score(pred_info['gt'], pred_info['pred'])
        new_result[pred_name]['f1'] = metrics.f1_score(pred_info['gt'], pred_info['pred'])
        new_result[pred_name]['count'] = len(pred_info['gt'])

    all_accu = metrics.accuracy_score(all_pred, all_gt)
    return all_accu, new_result

def get_report(stats):
    total_number = 0
    report_str = []

    for name, stats_info in stats.items():
        total_number += stats_info['count']

    for name, stats_info in stats.items():
        # print(f"{name}, {stats_info['count']/total_number}, {stats_info['precision']}, {stats_info['recall']},{stats_info['f1']}")
        report_str += [f"{name}, {stats_info['count']/total_number}, {stats_info['precision']}, {stats_info['recall']},{stats_info['f1']}"]

    return report_str

def calculate_iou(span1, span2):
    intersection = (span1 * span2).sum()
    union = span1.sum() + span2.sum() - intersection
    return intersection / union if union > 0 else 0