from collections import defaultdict
import enum
import logging
from re import L, S
from select import select
from time import time
import torch
import json
import torch.nn as nn
import torch.nn.functional as F
import scallopy
import os

# type video_mugen_action(usize, String)
# type video_mugen_horizontal_dir(usize, String)
# type video_mugen_vertical_dir(usize, String)
# type video_mugen_is_dead(usize, bool)
# type video_mugen_kill_monster(usize, String)
# type video_mugen_on_ladder(usize, bool)
# type video_mugen_has_shield(usize, bool)

# // video
# type video_mugen_collect_item(usize, String)
# type video_mugen_jump_toward(usize, String)

# // text
# type text_mugen_action (usize, String)
# type text_mugen_direction (usize, String)
# type text_mugen_jump_direction (usize, String)
# type text_mugen_killed_by(usize, String)
# type text_mugen_kill_monster(usize, String)
# type text_mugen_collected_items (usize, String)

import torch
from torch import nn
import torch.nn.functional as F


logger = logging.getLogger("dolphin.alignmodule")

logger.stats = defaultdict(float)
logger.reset_stats = lambda : logger.stats.clear()

from modules import VideoEncoder, ProjectionHead, Projection, MLPClassifier

action_list = ["crouch",  "stand", "walk", "climb", "jump", "collect", "die", "kill"]
horizontal_directions = ['left', 'right', 'none']
vertical_directions = ['up', 'down', 'none']
mugen_is_dead = [True, False]

monsters = ['gear', 'barnacle', 'face', 'slime', 'mouse', 'snail', 'ladybug', 'worm', 'frog', "bee", "none"]
collectables = ['coin', 'gem', 'none']

actions = {'text_mugen_action': action_list,
           'text_mugen_horizontal_dir': horizontal_directions,
           'text_mugen_vertical_dir': vertical_directions,
           'text_mugen_kill_monster': monsters,
           'text_mugen_kill_by_monster': monsters,
           'text_mugen_collect_item': collectables,
           'video_mugen_action': action_list,
           'video_mugen_horizontal_dir': horizontal_directions,
           'video_mugen_vertical_dir': vertical_directions,
           'video_mugen_kill_monster': monsters,
           'video_mugen_kill_by_monster': monsters,
           'video_mugen_collect_item': collectables,
           }


def split_n_per_list(l, n):
    for idx in range(0, len(l), n):
        yield l[idx: idx+n]

class CLIPModel(nn.Module):
    def __init__(self, video_enc=False, pretrained=False, trainable=False,):
        super().__init__()
        self.video_enc = video_enc

        if self.video_enc:
            self.visual_encoder = VideoEncoder(pretrained=pretrained, trainable=trainable)
            self.image_projection = Projection(self.visual_encoder.embedding_dim)

    def get_video_embedding(self, batch):
        image_features = self.visual_encoder(batch["video"])
        image_embed = self.image_projection(image_features)
        image_embed = F.normalize(image_embed, dim=-1)
        return image_embed

    def get_audio_embedding(self, batch):
        audio_features = self.audial_encoder(batch["audio"])
        audio_embed = self.audio_projection(audio_features)
        audio_embed = F.normalize(audio_embed, dim=-1)
        return audio_embed

    def get_text_embedding(self, batch):
        text_features = self.text_encoder(batch['text'])
        # Getting Image and Text Embeddings (with same dimension)
        caption_embed = self.text_projection(text_features)
        caption_embed = F.normalize(caption_embed, dim=-1)
        return caption_embed

def get_video_scl_tuples(prediction, action_list, idxes):
    batched_scl_actions = []
    for data_id, (start_idx, end_idx) in enumerate(idxes):
        scl_actions = []
        for frame_id, seq_pred in enumerate(prediction[start_idx:end_idx]):
            for action_idx, prob in enumerate(seq_pred):
                scl_actions.append((prob, (frame_id, action_list[action_idx])))
        batched_scl_actions.append(scl_actions)
    return batched_scl_actions

def get_text_scl_tuples(prediction, action_list, idxes, multi_text):
    all_scl_actions = []
    batch_size = len(idxes)
    for data_id, (start_idx, end_idx) in enumerate(idxes):
        scl_actions = []
        for frame_id, seq_pred in enumerate(prediction[start_idx:end_idx]):
            for action_idx, prob in enumerate(seq_pred):
                if multi_text:
                    scl_actions.append((prob, (data_id, frame_id, action_list[action_idx])))
                else:
                    scl_actions.append((prob, (frame_id, action_list[action_idx])))
        all_scl_actions.append(scl_actions)
    if multi_text:
        all_scl_actions = [[j for i in all_scl_actions for j in i ]] * batch_size
    return all_scl_actions

def get_gt_text_scl_tuples(text_gt, idxes, multi_text):
    all_scl_actions = {}
    batch_size = len(idxes)
    assert not multi_text
    for data_id, (start_idx, end_idx) in enumerate(idxes):
        scl_actions = {}
        gt_result_ls = []

        for i in range(start_idx, end_idx):
            gt_result_ls += text_gt[i]

        for frame_id, gt_result in enumerate(gt_result_ls):
            for k, gt_action in gt_result.items():
                if not k in scl_actions:
                    scl_actions[k] = []
                if multi_text:
                    scl_actions[k].append((data_id, frame_id, gt_action))
                else:
                    scl_actions[k].append((frame_id, gt_action))

        for k, v in scl_actions.items():
            if not k in all_scl_actions:
                all_scl_actions[k] = []
            all_scl_actions[k].append(v)

    return all_scl_actions

def to_scl_string(result):
    scl_strings = []
    for rel_name, batched_tuples in result.items():
        tuples = batched_tuples[3]
        if isinstance(tuples[0][0], torch.Tensor):
            current_rel_string = 'rel ' + rel_name + '={' + ', '.join([str(prob.item()) + '::' + str(tp).replace("'", '"') for prob, tp in tuples]) + '}'
        else:
            current_rel_string = 'rel ' + rel_name + '={' + ', '.join([str(tp).replace("'", '"') for tp in tuples]) + '}'

        scl_strings.append(current_rel_string)
    return '\n'.join(scl_strings)

def obtain_prediction(result, text_idxes):
    predictions = {}
    batch_size = len(text_idxes)
    video_counts = int(result["video_mugen_action"].shape[0] / batch_size)

    for rel_name, preds in result.items():
        selected_values, selected_indices = torch.topk(preds, k=3, dim = 1)
        k_selected = [[(i, prob.item(), actions[rel_name][index])  for prob, index in zip(k_probs, k_indexes)] for i, (k_probs, k_indexes) in enumerate(zip(selected_values, selected_indices))]

        if not rel_name in predictions:
            predictions[rel_name] = []

        if "text" in rel_name:
            for start_idx, end_idx in text_idxes:
                predictions[rel_name].append(k_selected[start_idx: end_idx])
        else:
            predictions[rel_name] = [k_selected[video_counts * i: video_counts * (i+1) ] for i in range(batch_size)]

    return predictions

def combine_text_and_video(text_results, video_results):
    batch_size = len(list(text_results.values())[0])
    all_pos = []
    combined_scl_queries = {}

    for k in text_results.keys():
        combined_scl_queries[k] = []
    for k in video_results.keys():
        combined_scl_queries[k] = []

    for vid in range(batch_size):
        for tid in range(batch_size):
            pos = (vid, tid)
            all_pos.append(pos)

    for vid, tid in all_pos:
        for text_rel_name, text_batched_rels in text_results.items():
            combined_scl_queries[text_rel_name].append(text_batched_rels[tid])
        for video_rel_name, video_batched_rels in video_results.items():
            combined_scl_queries[video_rel_name].append(video_batched_rels[vid])

    return all_pos, combined_scl_queries

class AlignModule(nn.Module):
    def __init__(self, batch_size, video_enc=False, audio_enc=False, text_enc=False, pretrained=False, trainable=False,
                 text_embedding=768, video_decoder_layers=2, text_decoder_layers=2, dropout_rate=0.3, constraint_weight=0.1,
                 provenance="diffaddmultprob", scl_filename=None, train_top_k=5, test_top_k=5, debug=True, multi_text=True,
                 alternative_train_freq=10, load_path=None, gt_text=False, pred_save_dir=None, constraint_violation=False):

        super().__init__()

        if not load_path is None:
            self.load(load_path)
        else:
            self.clip_model = CLIPModel(video_enc=video_enc, pretrained=pretrained,
                      trainable=trainable)

            self.video_action_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(action_list), n_layers=video_decoder_layers, dropout_rate=dropout_rate)
            self.video_direction_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(horizontal_directions), n_layers=video_decoder_layers, dropout_rate=dropout_rate)
            self.video_jump_direction_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(vertical_directions), n_layers=video_decoder_layers, dropout_rate=dropout_rate)
            self.video_is_dead_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(mugen_is_dead), n_layers=video_decoder_layers, dropout_rate=dropout_rate)
            self.video_killed_monster_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(monsters), n_layers=video_decoder_layers, dropout_rate=dropout_rate)
            self.video_killed_by_monster_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(monsters), n_layers=video_decoder_layers, dropout_rate=dropout_rate)
            self.video_collects_item_decoder =  MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(collectables), n_layers=video_decoder_layers, dropout_rate=dropout_rate)

            self.text_action_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(action_list), n_layers=text_decoder_layers, dropout_rate=dropout_rate)
            self.text_direction_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(horizontal_directions), n_layers=text_decoder_layers, dropout_rate=dropout_rate)
            self.text_jump_direction_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(vertical_directions), n_layers=text_decoder_layers, dropout_rate=dropout_rate)
            self.text_is_dead_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(mugen_is_dead), n_layers=text_decoder_layers, dropout_rate=dropout_rate)
            self.text_killed_monster_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(monsters), n_layers=text_decoder_layers, dropout_rate=dropout_rate)
            self.text_killed_by_monster_decoder = MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(monsters), n_layers=text_decoder_layers, dropout_rate=dropout_rate)
            self.text_collects_item_decoder =  MLPClassifier(input_dim=256, latent_dim=64, output_dim=len(collectables), n_layers=text_decoder_layers, dropout_rate=dropout_rate)

        self.debug = debug
        self.pred_save_dir = pred_save_dir
        self.multi_text = multi_text
        self.alternative_train_freq=alternative_train_freq
        self.constraint_weight = constraint_weight
        self.constraint_violation = constraint_violation

        self.scallop_ctx = scallopy.ScallopContext(provenance=provenance, train_k=train_top_k, test_k=test_top_k)

        if self.multi_text:
            self.scl_file_name = os.path.abspath(os.path.join(os.path.abspath(__file__), "../scl/" + scl_filename))
            self.scallop_ctx.import_file(self.scl_file_name)
            self.scallop_ctx.set_non_probabilistic(["text_start", "text_end", "video_start", "video_end"])
            if self.debug:
                self.reason = self.scallop_ctx.forward_function(output_mappings={"text_video_match": list(range(batch_size)), "too_many_consecutive_text": list(range(batch_size))}, dispatch="single", debug_provenance=True,  retain_graph=True)
            else:
                self.reason = self.scallop_ctx.forward_function(output_mappings={"text_video_match": list(range(batch_size)), "too_many_consecutive_text": list(range(batch_size))}, retain_graph=True)

        else:
            self.scl_file_name = os.path.abspath(os.path.join(os.path.abspath(__file__), "../scl/" + scl_filename))
            self.scallop_ctx.import_file(self.scl_file_name)
            self.scallop_ctx.set_non_probabilistic(["text_start", "text_end", "video_start", "video_end"])
            if self.debug:
                self.reason = self.scallop_ctx.forward_function(output_mappings={"text_video_match": [()], "too_many_consecutive_text": [()]}, dispatch="single", debug_provenance=True,  retain_graph=True)
            else:
                self.reason = self.scallop_ctx.forward_function(output_mappings={"text_video_match": [()], "too_many_consecutive_text": [()]}, retain_graph=True)


        if gt_text:
            self.scallop_ctx.set_non_probabilistic(["text_mugen_action",
            "text_mugen_horizontal_dir", "text_mugen_vertical_dir", "text_mugen_kill_monster",
            "text_mugen_kill_by_monster", "text_mugen_collect_item"])
        self.processed_batch = 0
        self.current_training_id = 0

        self.modal2models = {
            "video":
            [self.clip_model.visual_encoder,
             self.video_action_decoder,
             self.video_direction_decoder,
             self.video_jump_direction_decoder,
             self.video_is_dead_decoder,
             self.video_killed_monster_decoder,
             self.video_killed_by_monster_decoder,
             self.video_collects_item_decoder],
            }

    def save(self, save_path):
        nn_info = {"clip_model": self.clip_model,
        "video_action_decoder": self.video_action_decoder,
        "video_direction_decoder": self.video_direction_decoder,
        "video_jump_direction_decoder": self.video_jump_direction_decoder,
        "video_is_dead_decoder": self.video_is_dead_decoder,
        "video_killed_monster_decoder": self.video_killed_monster_decoder,
        "video_killed_by_monster_decoder": self.video_killed_by_monster_decoder,
        "video_collects_item_decoder": self.video_collects_item_decoder,
        }

        torch.save(nn_info, save_path)

    def load(self, save_path):
        nn_info = torch.load(save_path)
        self.clip_model = nn_info["clip_model"]
        self.video_action_decoder = nn_info["video_action_decoder"]
        self.video_direction_decoder = nn_info["video_direction_decoder"]
        self.video_jump_direction_decoder = nn_info["video_jump_direction_decoder"]
        self.video_is_dead_decoder = nn_info["video_is_dead_decoder"]
        self.video_killed_monster_decoder = nn_info["video_killed_monster_decoder"]
        self.video_killed_by_monster_decoder = nn_info["video_killed_by_monster_decoder"]
        self.video_collects_item_decoder = nn_info["video_collects_item_decoder"]

    def set_train(self, model_ls, is_train):
        for model in model_ls:
            if is_train:
                model.train()
            else:
                model.eval()

    def toggle_training_model(self):
        training_models = list(self.modal2models.keys())
        current_train_model = training_models[self.current_training_id]

        for modal in self.modal2models.keys():
            if not modal == current_train_model:
                self.set_train(self.modal2models[modal], False)
            else:
                self.set_train(self.modal2models[modal], True)


    def forward_both(self, batch):
        # TODO: maybe use a window rather than hard split?
        text_embedding = self.clip_model.get_text_embedding(batch)
        video_embedding = self.clip_model.get_video_embedding(batch)

        pred_video_actions = self.video_action_decoder(video_embedding)
        pred_video_horizontal_directions = self.video_direction_decoder(video_embedding)
        pred_video_vertical_directions = self.video_jump_direction_decoder(video_embedding)
        pred_video_killed_monster = self.video_killed_monster_decoder(video_embedding)
        pred_video_killed_by_monster = self.video_killed_by_monster_decoder(video_embedding)
        pred_video_collects_item = self.video_collects_item_decoder(video_embedding)

        pred_text_actions = self.text_action_decoder(text_embedding)
        pred_text_horizontal_directions = self.text_direction_decoder(text_embedding)
        pred_text_vertical_directions = self.text_jump_direction_decoder(text_embedding)
        pred_text_killed_monster = self.text_killed_monster_decoder(text_embedding)
        pred_text_killed_by_monster = self.text_killed_by_monster_decoder(text_embedding)
        pred_text_collects_item = self.text_collects_item_decoder(text_embedding)

        batch_size = len(batch['text_idx'])
        video_split = int(batch['video'].shape[0] / batch_size)
        video_idxes = [(video_split * i,  video_split * (i + 1)) for i in range(batch_size)]

        text_starts = []
        text_ends = []
        video_starts = []
        video_ends = []

        if self.multi_text:
            for data_id, ((text_start, text_end)) in enumerate(batch['text_idx']):
                text_starts.append((data_id, 0))
                text_ends.append((data_id, text_end - text_start - 1))
                video_starts.append([(0,)])
                video_ends.append([(video_split-1,)])
            text_starts = [text_starts] * batch_size
            text_ends = [text_ends] * batch_size

        else:
            for data_id, ((text_start, text_end)) in enumerate(batch['text_idx']):
                text_starts.append([(0,)])
                text_ends.append([(text_end - text_start - 1,)])
                video_starts.append([(0,)])
                video_ends.append([(video_split-1,)])

        video_mugen_action = get_video_scl_tuples(pred_video_actions, action_list, video_idxes)
        video_mugen_horizontal_dir = get_video_scl_tuples(pred_video_horizontal_directions, horizontal_directions, video_idxes)
        video_mugen_vertical_dir = get_video_scl_tuples(pred_video_vertical_directions, vertical_directions, video_idxes)
        video_mugen_killed_monster = get_video_scl_tuples(pred_video_killed_monster, monsters, video_idxes)
        video_mugen_killed_by_monster = get_video_scl_tuples(pred_video_killed_by_monster, monsters, video_idxes)
        video_mugen_collects_item = get_video_scl_tuples(pred_video_collects_item, collectables, video_idxes)

        text_mugen_action = get_text_scl_tuples(pred_text_actions, action_list, batch['text_idx'], self.multi_text)
        text_mugen_horizontal_dir = get_text_scl_tuples(pred_text_horizontal_directions, horizontal_directions, batch['text_idx'], self.multi_text)
        text_mugen_vertical_dir = get_text_scl_tuples(pred_text_vertical_directions, vertical_directions, batch['text_idx'], self.multi_text)
        text_mugen_killed_monster = get_text_scl_tuples(pred_text_killed_monster, monsters, batch['text_idx'], self.multi_text)
        text_mugen_killed_by_monster = get_text_scl_tuples(pred_text_killed_by_monster, monsters, batch['text_idx'], self.multi_text)
        text_mugen_collects_item = get_text_scl_tuples(pred_text_collects_item, collectables, batch['text_idx'], self.multi_text)

        text_results = {
            'text_mugen_action': text_mugen_action,
            'text_mugen_horizontal_dir': text_mugen_horizontal_dir,
            'text_mugen_vertical_dir': text_mugen_vertical_dir,
            'text_mugen_kill_monster': text_mugen_killed_monster,
            'text_mugen_kill_by_monster': text_mugen_killed_by_monster,
            'text_mugen_collect_item': text_mugen_collects_item,
            'text_start': text_starts,
            'text_end': text_ends,
        }

        video_results = {
            'video_mugen_action': video_mugen_action,
            'video_mugen_horizontal_dir': video_mugen_horizontal_dir,
            'video_mugen_vertical_dir': video_mugen_vertical_dir,
            'video_mugen_kill_monster': video_mugen_killed_monster,
            'video_mugen_kill_by_monster': video_mugen_killed_by_monster,
            'video_mugen_collect_item': video_mugen_collects_item,
            'video_start': video_starts,
            'video_end': video_ends
        }

        results = {
            'video_mugen_action': pred_video_actions,
            'video_mugen_horizontal_dir': pred_video_horizontal_directions,
            'video_mugen_vertical_dir': pred_video_vertical_directions,
            'video_mugen_kill_monster': pred_video_killed_monster,
            'video_mugen_kill_by_monster': pred_video_killed_by_monster,
            'video_mugen_collect_item': pred_video_collects_item,
            'text_mugen_action': pred_text_actions,
            'text_mugen_horizontal_dir': pred_text_horizontal_directions,
            'text_mugen_vertical_dir': pred_text_vertical_directions,
            'text_mugen_kill_monster': pred_text_killed_monster,
            'text_mugen_kill_by_monster': pred_text_killed_by_monster,
            'text_mugen_collect_item': pred_text_collects_item,
        }

        text_pred = obtain_prediction(results, batch['text_idx'])

        if self.multi_text:
            queries = text_results
            queries.update(video_results)
        else:
            pos, queries = combine_text_and_video(text_results, video_results)

        pred = self.reason(**queries)

        # pred dim: video x text
        pred_match = pred['text_video_match'].reshape(batch_size, batch_size)
        if self.constraint_violation:
            pred_constraint_violation = pred['too_many_consecutive_text'].reshape(batch_size, batch_size)
        else:
            pred_constraint_violation = torch.zeros_like(pred_match)

        self.processed_batch = (self.processed_batch + 1) % self.alternative_train_freq
        if self.processed_batch == 0:
            self.current_training_id = (self.current_training_id + 1) % len(self.modal2models)

        return pred_match, pred_constraint_violation

    def devide_batch(self, batch):
        batch_size = len(batch['text_idx'])
        single_dps = [{} for _ in range(batch_size)]
        video_split = int(batch['video'].shape[0] / batch_size)

        for dp_ct, single_text_idx in enumerate(batch['text_idx']):
            single_dps[dp_ct]['text_idx'] = []
            single_dps[dp_ct]['text'] = []
            start_idx = single_text_idx[0]
            end_idx = single_text_idx[1]
            single_dps[dp_ct]['text'].append(batch['text'][start_idx: end_idx])
            from_id, to_id = single_text_idx
            single_dps[dp_ct]['text_idx'].append((from_id - start_idx, to_id - start_idx))

        for k, v in batch.items():
            if k == 'text' or k == 'text_idx':
                continue
            if k == 'video':
                for dp_ct, single_video in enumerate(list(split_n_per_list(v, video_split))):
                    single_dps[dp_ct][k] = single_video
                continue
            for dp_ct, small_batch_values in enumerate(v):
                single_dps[dp_ct][k] = small_batch_values

        return single_dps


    # Just calculate the probability, no need to backprob
    def predict(self, batch, n=1):

        text_gt = self.get_gt_text(batch)
        text_results = get_gt_text_scl_tuples(text_gt, batch['text_idx'], self.multi_text)
        single_dps = self.devide_batch(batch)
        single_dps_pred_scls = {'video_mugen_action': [],
                                'video_mugen_horizontal_dir': [],
                                'video_mugen_vertical_dir': [],
                                'video_mugen_kill_monster': [],
                                'video_mugen_kill_by_monster': [],
                                'video_mugen_collect_item': []}

        batch_size = len(batch['text_idx'])
        video_split = int(batch['video'].shape[0] / batch_size)
        video_idxes = [(0,  video_split)]



        video_embedding = self.clip_model.get_video_embedding(batch)

        pred_video_actions = self.video_action_decoder(video_embedding)
        pred_video_horizontal_directions = self.video_direction_decoder(video_embedding)
        pred_video_vertical_directions = self.video_jump_direction_decoder(video_embedding)
        pred_video_killed_monster = self.video_killed_monster_decoder(video_embedding)
        pred_video_killed_by_monster = self.video_killed_by_monster_decoder(video_embedding)
        pred_video_collects_item = self.video_collects_item_decoder(video_embedding)

        batch_size = len(batch['text_idx'])
        video_split = int(batch['video'].shape[0] / batch_size)
        video_idxes = [(video_split * i,  video_split * (i + 1)) for i in range(batch_size)]

        text_starts = []
        text_ends = []
        video_starts = []
        video_ends = []

        if self.multi_text:
            for data_id, ((text_start, text_end)) in enumerate(batch['text_idx']):
                text_starts.append((data_id, 0))
                text_ends.append((data_id, text_end - text_start - 1))
                video_starts.append([(0,)])
                video_ends.append([(video_split-1,)])
            text_starts = [text_starts] * batch_size
            text_ends = [text_ends] * batch_size

        else:
            for data_id, ((text_start, text_end)) in enumerate(batch['text_idx']):
                text_starts.append([(0,)])
                text_ends.append([(text_end - text_start - 1,)])
                video_starts.append([(0,)])
                video_ends.append([(video_split-1,)])

        text_results['text_start'] = text_starts
        text_results['text_end'] = text_ends
        video_mugen_action = get_video_scl_tuples(pred_video_actions, action_list, video_idxes)
        video_mugen_horizontal_dir = get_video_scl_tuples(pred_video_horizontal_directions, horizontal_directions, video_idxes)
        video_mugen_vertical_dir = get_video_scl_tuples(pred_video_vertical_directions, vertical_directions, video_idxes)
        video_mugen_killed_monster = get_video_scl_tuples(pred_video_killed_monster, monsters, video_idxes)
        video_mugen_killed_by_monster = get_video_scl_tuples(pred_video_killed_by_monster, monsters, video_idxes)
        video_mugen_collects_item = get_video_scl_tuples(pred_video_collects_item, collectables, video_idxes)

        video_results = {
            'video_mugen_action': video_mugen_action,
            'video_mugen_horizontal_dir': video_mugen_horizontal_dir,
            'video_mugen_vertical_dir': video_mugen_vertical_dir,
            'video_mugen_kill_monster': video_mugen_killed_monster,
            'video_mugen_kill_by_monster': video_mugen_killed_by_monster,
            'video_mugen_collect_item': video_mugen_collects_item,
            'video_start': video_starts,
            'video_end': video_ends
        }

        results = {
            'video_mugen_action': pred_video_actions,
            'video_mugen_horizontal_dir': pred_video_horizontal_directions,
            'video_mugen_vertical_dir': pred_video_vertical_directions,
            'video_mugen_kill_monster': pred_video_killed_monster,
            'video_mugen_kill_by_monster': pred_video_killed_by_monster,
            'video_mugen_collect_item': pred_video_collects_item,
        }

        text_pred = obtain_prediction(results, batch['text_idx'])
        pos, queries = combine_text_and_video(text_results, video_results)

        pred = self.reason(**queries)

        # pred dim: video x text
        pred_match = pred['text_video_match'].reshape(batch_size, batch_size)
        if self.constraint_violation:
            pred_constraint_violation = pred['too_many_consecutive_text'].reshape(batch_size, batch_size)
        else:
            pred_constraint_violation = torch.zeros_like(pred_match)
        
        self.processed_batch = (self.processed_batch + 1) % self.alternative_train_freq
        if self.processed_batch == 0:
            self.current_training_id = (self.current_training_id + 1) % len(self.modal2models)

        return pred_match, pred_constraint_violation

    def get_one_aspect(self, text, action_ls, default="none"):

        processed_text = text.replace("killed by", "die")
        mentioned_elements = []

        for word in text.split(' '):
            for action in action_ls:
                if action in word:
                    mentioned_elements.append(action)

        # for action in action_ls:
        #     if action in text:
        #        mentioned_elements.append(action)

        if len(mentioned_elements) == 0:
            mentioned_elements = [default]
        return mentioned_elements

    def get_gt_text(self, batch):
        batched_texts = {}
        for text_id, text in enumerate(batch['text']):

            action = self.get_one_aspect(text, action_list)
            assert(not action == ["none"] and len(action) == 1)
            hori_dir = self.get_one_aspect(text, horizontal_directions)
            verti_dir = self.get_one_aspect(text, vertical_directions)
            collectable_ls = self.get_one_aspect(text, collectables)
            monster = self.get_one_aspect(text, monsters)

            # Only collectable can be more than 1
            assert(len(hori_dir) == 1 and len(verti_dir) == 1)
            # if action[0] == "die" or action[0] == "kill":
            #     assert (len(monster) == 1) // There maybe more than two monsters being killed

            monster = monster[0]
            if "killed_by" in text:
                kill_by = monster
            elif "killed" in text:
                kill_by = "none"
                kill =  monster
            else:
                # assert monster == "none"
                kill_by = "none"
                kill =  "none"

            # if len(collectable_ls) > 1:
            #     print("here")

            for collectable in collectable_ls:

                text_description = {}
                text_description['text_mugen_action'] = action[0]
                text_description['text_mugen_horizontal_dir'] = hori_dir[0]
                text_description['text_mugen_vertical_dir'] = verti_dir[0]
                text_description['text_mugen_collect_item'] = collectable
                text_description['text_mugen_kill_by_monster'] = kill_by
                text_description['text_mugen_kill_monster'] = kill

                if not text_id in batched_texts:
                    batched_texts[text_id] = []
                batched_texts[text_id].append(text_description)
        return batched_texts

    # Forward train video only
    def forward(self, batch):

        text_gt = self.get_gt_text(batch)
        text_results = get_gt_text_scl_tuples(text_gt, batch['text_idx'], self.multi_text)

        # TODO: maybe use a window rather than hard split?
        video_embedding = self.clip_model.get_video_embedding(batch)

        pred_video_actions = self.video_action_decoder(video_embedding)
        pred_video_horizontal_directions = self.video_direction_decoder(video_embedding)
        pred_video_vertical_directions = self.video_jump_direction_decoder(video_embedding)
        pred_video_killed_monster = self.video_killed_monster_decoder(video_embedding)
        pred_video_killed_by_monster = self.video_killed_by_monster_decoder(video_embedding)
        pred_video_collects_item = self.video_collects_item_decoder(video_embedding)

        batch_size = len(batch['text_idx'])
        video_split = int(batch['video'].shape[0] / batch_size)
        video_idxes = [(video_split * i,  video_split * (i + 1)) for i in range(batch_size)]

        text_starts = []
        text_ends = []
        video_starts = []
        video_ends = []

        if self.multi_text:
            for data_id, ((text_start, text_end)) in enumerate(batch['text_idx']):
                text_starts.append((data_id, 0))
                text_ends.append((data_id, text_end - text_start - 1))
                video_starts.append([(0,)])
                video_ends.append([(video_split-1,)])
            text_starts = [text_starts] * batch_size
            text_ends = [text_ends] * batch_size

        else:
            for data_id, ((text_start, text_end)) in enumerate(batch['text_idx']):
                text_starts.append([(0,)])
                text_ends.append([(text_end - text_start - 1,)])
                video_starts.append([(0,)])
                video_ends.append([(video_split-1,)])

        text_results['text_start'] = text_starts
        text_results['text_end'] = text_ends
        video_mugen_action = get_video_scl_tuples(pred_video_actions, action_list, video_idxes)
        video_mugen_horizontal_dir = get_video_scl_tuples(pred_video_horizontal_directions, horizontal_directions, video_idxes)
        video_mugen_vertical_dir = get_video_scl_tuples(pred_video_vertical_directions, vertical_directions, video_idxes)
        video_mugen_killed_monster = get_video_scl_tuples(pred_video_killed_monster, monsters, video_idxes)
        video_mugen_killed_by_monster = get_video_scl_tuples(pred_video_killed_by_monster, monsters, video_idxes)
        video_mugen_collects_item = get_video_scl_tuples(pred_video_collects_item, collectables, video_idxes)

        video_results = {
            'video_mugen_action': video_mugen_action,
            'video_mugen_horizontal_dir': video_mugen_horizontal_dir,
            'video_mugen_vertical_dir': video_mugen_vertical_dir,
            'video_mugen_kill_monster': video_mugen_killed_monster,
            'video_mugen_kill_by_monster': video_mugen_killed_by_monster,
            'video_mugen_collect_item': video_mugen_collects_item,
            'video_start': video_starts,
            'video_end': video_ends
        }

        results = {
            'video_mugen_action': pred_video_actions,
            'video_mugen_horizontal_dir': pred_video_horizontal_directions,
            'video_mugen_vertical_dir': pred_video_vertical_directions,
            'video_mugen_kill_monster': pred_video_killed_monster,
            'video_mugen_kill_by_monster': pred_video_killed_by_monster,
            'video_mugen_collect_item': pred_video_collects_item,
        }

        # text_pred = obtain_prediction(results, batch['text_idx'])
        t = time()
        pos, queries = combine_text_and_video(text_results, video_results)

        pred = self.reason(**queries)

        logger.stats['T_Reason'] += time() - t

        # pred dim: video x text
        pred_match = pred['text_video_match'].reshape(batch_size, batch_size)
        if self.constraint_violation:
            pred_constraint_violation = pred['too_many_consecutive_text'].reshape(batch_size, batch_size)
        else:
            pred_constraint_violation = torch.zeros_like(pred_match)

        self.processed_batch = (self.processed_batch + 1) % self.alternative_train_freq
        if self.processed_batch == 0:
            self.current_training_id = (self.current_training_id + 1) % len(self.modal2models)

        return pred_match, pred_constraint_violation
