from typing import List, Tuple
import torch
from torch.distributions import Categorical

import os
import sys
import multiprocessing as mp
import gc
import time

# from cmd_args import cmd_args
import numpy as np
from tqdm import tqdm
from .preprocess import Query
from models.sg_model import SceneGraphModel
from preprocessing.utilities import get_body_atoms, getArguments, getPredicate
from models.idx2word import Idx2Word
import torch

def get_domains(query: str, idx2word: Idx2Word):
    name_choices = list()
    attr_choices = list()
    rela_choices = list()
    atoms = get_body_atoms(query, ":-")
    for atom in atoms:
        predicate = getPredicate(atom)
        arguments = getArguments(atom)
        if predicate == "name":
            name_choices = idx2word.get_names()
        elif predicate == "attr":
            attr_choices.append(arguments[0])
        elif predicate == "rela":
            #rela_choices.append(arguments[0])
            rela_choices = idx2word.get_relas()
    return name_choices, attr_choices, rela_choices

def get_quoted_query(query: str):
    query_quoted = ""
    processing_literal = False
    is_arg = False
    for i in range(len(query)):
        char = query[i]
        if char == "(":
            is_arg = True
        if char == ")":
            is_arg = False
        query_quoted += char
        next_char = query[i+1] if i+1 < len(query) else None
        if processing_literal:
            if next_char == "," or next_char == ")" or next_char is None and is_arg:
                query_quoted += '"'
                processing_literal = False
        else:
            if char == "(" or char == ",":
                # if next_char is a capital letter, it is a variable
                if next_char is not None and not next_char.isupper() and is_arg:
                    query_quoted += '"'
                    processing_literal = True
    return query_quoted

#TODO control training in CUDA/GPU/MP4
#Inputs: 
# sg_models: MLP models for predicting properties of te bounding boxes
# scene_graphs_and_features: entries from train_features_file
# queries: queries for training 
# idx2word: maps each bounding box type, relation type, or attribute type to a unique output of the sg_models model. 
#Outputs: 
# Structure of the form: 
# [ {
# "name" -> {o_i -> unnormalized predictions ordered by class insertion order in idx2word}, 
# "relation" -> {(o_i,o_j) -> unnormalized predictions ordered by class insertion order in idx2word},
# "attr" -> {o_i -> unnormalized predictions for specific attribute name occurring in the query}, 
# } 
# ]    
def get_batched_scene_graph_predictions(
    sg_models: SceneGraphModel, samples_in_batch, scene_graphs_and_features: List, queries: List[str], idx2word: Idx2Word, gpu=-1, softmax=True, return_dicts=False
) -> Tuple[List, List, List, List]:
    device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() and gpu != -1 else "cpu")
    batched_obj_features = []
    batched_relation_features = []

    batched_attr_idxes = []
    batched_rela_idxes = []
    batched_name_tps = []
    batched_attr_tps = []
    batched_rela_tps = []
    batched_obj_ids = []
    batched_obj_id_pairs = []    

    split_obj_ids = []
    split_rela_ids = []
    query_objects = []

    name_cls = False
    attr_cls = False
    for (_, _, _, object_ids_in_proofs, object_id_pairs_in_proofs), scene_graph_and_features, query in zip(samples_in_batch, scene_graphs_and_features, queries):
        # name_choices: list of candidate types for which to get predictions. As types are mutually exclusive, we get predictions on all possible types. Softmax is used at the end layer. 
        # attr_choises: list of attributes for which we want to take predictions. These are the ones occuring in the query. Sigmoid is used at the end layer. 
        # rela_choises: list of candidate relations for which to get predictions. As relations are mutually exclusive, we get predictions on all possible relations. Softmax is used at the end layer.  
        # print("query: ", query)
        
        # q_obj = query
        # query_content = [i for i in transformer.transform(q_obj).split('\n') if not i == '']
        # query_objects.append(query_content)
        # print("query_content: ", query_content)
        # print("query_quoted: ", get_quoted_query(query))
        query_objects.append(get_quoted_query(query))
        name_choices, attr_choices, rela_choices = get_domains(query, idx2word)

        if name_choices:
            name_cls = True

        if attr_choices:
            attr_cls = True

        batched_obj_ids.append(object_ids_in_proofs)
        batched_obj_id_pairs.append(object_id_pairs_in_proofs)

        all_object_ids = scene_graph_and_features["object_ids"]
        object_features = scene_graph_and_features["object_feature"]
        bboxes = scene_graph_and_features["scene_graph"]["bboxes"]
        # Maps each object to its feature vector
        object_feature_dict = {
            str(object_id): feature
            for object_id, feature in zip(all_object_ids, object_features)
        }
        batched_obj_features += [object_feature_dict[str(o)] for o in object_ids_in_proofs]

        # generate object attributes
        candidate_attrs = attr_choices
        candidate_attr_idxes = [
            idx2word.attr_to_idx(candidate_attr) for candidate_attr in candidate_attrs
        ]
        candidate_attr_idxes = [x for x in candidate_attr_idxes if x is not None]
        batched_attr_idxes.append(candidate_attr_idxes)
        current_attr_tps = []
        current_name_tps = []

        for object_id in object_ids_in_proofs: #all_object_ids:
            for candidate_attr in candidate_attrs:
                current_attr_tps.append((candidate_attr, object_id, ))
                # current_attr_tps.append((object_id, candidate_attr))
            for name in name_choices:
                current_name_tps.append((name, object_id, ))
                # current_name_tps.append((object_id, name))

        batched_attr_tps.append(current_attr_tps)
        batched_name_tps.append(current_name_tps)

        candidate_rela_idxes = [
            idx2word.rela_to_idx(candidate_rela) for candidate_rela in rela_choices
        ]
        candidate_rela_idxes = [x for x in candidate_rela_idxes if x is not None]
        batched_rela_idxes.append(candidate_rela_idxes)
        current_rela_tps = []

        current_rela_inputs = []
        #for sub in all_object_ids:
        #    for obj in all_object_ids:
        #        if sub == obj:
        #            continue

        for sub,obj in object_id_pairs_in_proofs:
            sub_feat_np_array = object_feature_dict[str(sub)]
            obj_feat_np_array = object_feature_dict[str(obj)]
            sub_bbox_np_array = np.array(bboxes[int(sub)])
            obj_bbox_np_array = np.array(bboxes[int(obj)])
            batched_relation_features.append(
                np.concatenate(
                    [
                        sub_feat_np_array,
                        obj_feat_np_array,
                        sub_bbox_np_array,
                        obj_bbox_np_array,
                    ]
                )
            )
            for rela in rela_choices:
                current_rela_tps.append((rela, sub, obj))
                if (sub, obj) not in current_rela_inputs:
                    current_rela_inputs.append((sub, obj))

        batched_rela_tps.append(current_rela_tps)
        split_obj_ids.append(len(batched_obj_features))
        split_rela_ids.append(len(batched_relation_features))

    batched_obj_features = torch.cat(
        [torch.from_numpy(x).float() for x in batched_obj_features]
    ).reshape(len(batched_obj_features), -1)
    X = batched_obj_features.to(device)
    if not len(batched_relation_features) == 0:
        batched_relation_features = torch.cat(
            [torch.from_numpy(x).float() for x in batched_relation_features]
        ).reshape(len(batched_relation_features), -1)
        Y = batched_relation_features.to(device)
    else:
        Y = None

    name_probs, name_logits, attr_probs, attr_logits, rel_probs, rel_logits = sg_models.forward(
        obj_features=X,
        rela_features=Y,
        batch_obj_split=split_obj_ids,
        batch_rela_split=split_rela_ids,
        name_cls=name_cls,
        attr_cls=attr_cls,
    )

    name_probs, attr_probs, rel_probs = name_probs.cpu() if name_probs is not None else None, attr_probs.cpu() if attr_probs is not None else None, rel_probs.cpu() if rel_probs is not None else None


    current_obj_id = 0
    current_rela_id = 0

    

    batched_tps = []
    scallop_rel_tps = []
    for (
        name_tps,
        attr_tps,
        rela_tps,
        rela_idxes,
        attr_idxes,
        next_obj_id,
        next_rela_id,
        obj_ids,
        pairs_ids
    ) in zip(
        batched_name_tps,
        batched_attr_tps,
        batched_rela_tps,
        batched_rela_idxes,
        batched_attr_idxes,
        split_obj_ids,
        split_rela_ids,
        batched_obj_ids,
        batched_obj_id_pairs
    ):
        if return_dicts:
            # print(name_choices, attr_choices, rela_choices)
            tps = dict()
            tps["name"] = dict()
            tps["rela"] = dict()
            tps["attr"] = dict()
            if len(name_tps) > 0:
                start_idx = 0
                end_idx = len(name_choices)
                scores = name_probs[current_obj_id:next_obj_id].reshape(-1)
                for o in obj_ids:
                    tps["name"][o] = scores[start_idx:end_idx]
                    start_idx = end_idx
                    end_idx = end_idx + len(name_choices)
                tps["name"]["choices"] = name_choices
            else:
                tps["name"] = [[]]
            if len(rela_tps) > 0:
                start_idx = 0
                end_idx = len(rela_choices)
                scores = rel_probs[current_rela_id:next_rela_id].reshape(-1)
                for o in pairs_ids:
                    tps["rela"][o] = scores[start_idx:end_idx]
                    start_idx = end_idx
                    end_idx = end_idx + len(rela_choices)
                tps["rela"]["choices"] = rela_choices
            else:
                tps["rela"] = [[]]
            if len(attr_tps) > 0:
                scores = attr_probs[current_obj_id:next_obj_id][:, attr_idxes].reshape(-1)
                tps["attr"] = {o:scores[i] for i,o in enumerate(obj_ids)}
                tps["attr"]["choices"] = attr_choices
            else:
                tps["attr"] = [[]]

            batched_tps.append(tps)
            
            current_rela_id = next_rela_id
            current_obj_id = next_obj_id
        else:
            tps = {}
            current_name = []
            current_attr = []
            current_rela = []

            if name_probs is not None:
                for prob, tp in zip(name_probs[current_obj_id: next_obj_id].reshape(-1), name_tps):
                    current_name.append((prob, tp))
                tps['name'] = [current_name]
            else:
                tps['name'] = [[]]
            
            if attr_probs is not None:
                for prob, tp in zip(attr_probs[current_obj_id: next_obj_id][:, attr_idxes].reshape(-1), attr_tps):
                    current_attr.append((prob, tp))
                tps['attr'] = [current_attr]
            else:
                tps['attr'] = [[]]

            if rel_probs is not None:
                for prob, tp in zip(rel_probs[current_rela_id: next_rela_id][:, rela_idxes].reshape(-1), rela_tps):
                    current_rela.append((prob, tp))
                tps['rela'] = [current_rela]
            else:
                tps['rela'] = [[]]

            current_rela_id = next_rela_id
            current_obj_id = next_obj_id
            batched_tps.append(tps)

    return batched_tps, query_objects
    
