import random
from re import S

import lark

# from src import visual_genome_utils
import numpy as np
import torch
from deisam_utils import load_neumann, load_neumann_for_sgg_training, load_sam_model
from learning_utils import (
    get_target_selection_rules,
    merge_atoms_list,
    merge_langs,
    translate_atoms_to_sgg_format,
    translate_rules_to_sgg_format,
)
from llm_logic_generator import LLMLogicGenerator
from sam_utils import to_boxes, to_boxes_with_sgg, to_transformed_boxes
from semantic_unifier import SemanticUnifier
from visual_genome_utils import (
    get_init_language_with_sgg,
    scene_graph_to_language,
    scene_graph_to_language_with_sgg,
)

from neumann.torch_utils import softor


class DeiSAM(torch.nn.Module):
    """A class of DeiSAM. DeiSAM segments objects given pair of a visual scene and a complex textual prompt."""

    def __init__(self, api_key, device, vg_utils):
        super(DeiSAM, self).__init__()
        self.device = device
        self.sam_predictor = load_sam_model(device)
        self.llm_logic_generator = LLMLogicGenerator(api_key)
        self.visual_genome_utils = vg_utils

    def generate_rules_by_llm(self, deictic_text, lang):
        """Generate FOL rules using LLMs given textual prompts.

        Args:
            deictic_text (str): A text of a prompt, e.g. An object on the table and in front of the cup.
            lang (neumann.fol.language.Language): A FOL language.

        Returns:
            list[neumann.fol.logic.Clause]: A set of generated FOL rules.
        """
        try:
            rules = self.llm_logic_generator.generate_rules(deictic_text, lang)
            # currently we assume we get 1 rule to be generated
            # if len(rules) > 1:
            #     rules = [rules[0]]
        except (
            lark.exceptions.VisitError,
            lark.exceptions.UnexpectedCharacters,
            lark.exceptions.UnexpectedEOF,
            IndexError,
            RuntimeError,
        ):
            print("Failed to parse the LLM response.")
            rules = []
        return rules

    def unify_semantics(self, text_lang, scene_graph_lang, llm_rules):
        """Unify semantics of the deictic prompt and the scene graph.

        Args:
            text_lang (str): A deictic prompt, i.e textual description of an object.
            scene_graph_lang (neumann.fol.language.Language): A FOL language for describing a scene graph.
            llm_rules (list[neumann.fol.logic.Clause]): A list of rules generated by LLMs.

        Returns:
            list[neumann.fol.logic.Clause]: A list of FOL rules that are rewritten accoding to the scene graph.
        """
        try:
            self.semantic_unifier = SemanticUnifier(scene_graph_lang, self.device)
            rewritten_rules = self.semantic_unifier.rewrite_rules(
                rules=llm_rules, lang=text_lang, graph_lang=scene_graph_lang
            )
            return rewritten_rules
        except (RuntimeError, IndexError):
            return llm_rules

    def forward_reasoning(self, lang, atoms, rules):
        """Perform forward reasoning by building NEUMANN, a GNN-based differentiable forward reasoner.

        Args:
            lang (neumann.fol.language.Language): A FOL language.
            atoms (list[neumann.fol.logic.Atom]): A list of ground atoms (facts) that represents a scene graph.
            rules (list[neumann.fol.logic.Clause]): A list of FOL rules that describe a deictic prompt.

        Returns:
            _type_: _description_
        """
        try:
            print("Building NEUMANN reasoner...")
            fc, neumann = load_neumann(lang, rules, atoms, self.device)
            # save reasoner for being accessed later
            self.neumann = neumann
            v_0 = fc(atoms)
            v_T = neumann(v_0)
            target_atoms = [
                atom
                for atom in neumann.get_top_atoms(v_T[0])
                if atom.pred.name == "target"
            ]
            self.target_atoms = target_atoms
            return target_atoms, v_T, neumann
        except (RuntimeError, IndexError):
            self.target_atoms = []
            return [], None, None

    def segment_objects_by_sam(self, image_source, target_atoms, data_index):
        """Segment objects by running SAM on extracted crops.

        Args:
            image_source (_type_): Image source used by SAM.
            target_atoms (list[neumann.fol.logic.Atom]): A list of atoms that represents target objects to be segmented.
            data_index (int): A data index.

        Returns:
            list[mask] : A list of masks that represenets object segmentations
        """
        boxes = to_boxes(target_atoms, data_index, self.visual_genome_utils)
        # set image source
        self.sam_predictor.set_image(image_source)
        # transform boxes
        transformed_boxes = to_transformed_boxes(
            boxes, image_source, self.sam_predictor, self.device
        )
        try:
            masks, _, _ = self.sam_predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes,
                multimask_output=False,
            )
            return masks
        except RuntimeError:
            return []

    def forward(self, data_index, image_id, scene_graph, text, image_source):
        """Forwarding function of DeiSAM. It generates FOL rules using a LLM, build a differentiable reasoner, performs reasoning, and segment by SAM.
        Args:
            data_index (int): A data index.
            scene_graph (graph): A scene graph.
            text (str): A deictic prompt.
            image_source (image): An image source used by SAM.

        Returns:
            _type_: _description_
        """
        # extract a logic language from a scene graph
        lang = scene_graph_to_language(
            scene_graph=scene_graph,
            text=text,
            logic_generator=self.llm_logic_generator,
            num_objects=2,  # len(cropped_objects),
        )
        # generate rules by a LLM
        llm_rules = self.generate_rules_by_llm(text, lang)
        rewritten_rules = None

        print("LLM-generated rules:")
        for rule in llm_rules:
            print("   ", rule)

        # load scene graph in the form of logic by transforming it to the format
        (
            scene_graph_atoms,
            scene_graph_lang,
        ) = self.visual_genome_utils.data_index_to_atoms(
            data_index=data_index,
            lang=lang,
        )

        # perform forward reasoning (without semantic unification)
        target_atoms, v_T, neumann = self.forward_reasoning(
            scene_graph_lang, scene_graph_atoms, llm_rules
        )

        # if no targets found, perform semantic unification by rewriting rules according to the scene graph
        if len(target_atoms) == 0:
            rewritten_rules = self.unify_semantics(lang, scene_graph_lang, llm_rules)
            print("Semantically unified rules:")
            for rule in rewritten_rules:
                print("   ", rule)
            target_atoms, v_T, neumann = self.forward_reasoning(
                scene_graph_lang, scene_graph_atoms, rewritten_rules
            )

        # perform segmentation by SAM
        if len(target_atoms) > 0:
            masks = self.segment_objects_by_sam(image_source, target_atoms, data_index)
        else:
            masks = []
        return masks, llm_rules, rewritten_rules


class DeiSAMSGG(torch.nn.Module):
    """A class of DeiSAM. DeiSAM segments objects given pair of a visual scene and a complex textual prompt."""

    def __init__(self, api_key, device, vg_utils, sgg_model):
        super(DeiSAMSGG, self).__init__()
        self.device = device
        self.sam_predictor = load_sam_model(device)
        self.llm_logic_generator = LLMLogicGenerator(api_key)
        # self.semantic_unifier = SemanticUnifier(device)
        self.visual_genome_utils = vg_utils
        self.sgg_model = sgg_model
        # self.visual_genome_utils = VisualGenomeUtils()

    def generate_rules_by_llm(self, deictic_text, lang):
        """Generate FOL rules using LLMs given textual prompts.

        Args:
            deictic_text (str): A text of a prompt, e.g. An object on the table and in front of the cup.
            lang (neumann.fol.language.Language): A FOL language.

        Returns:
            list[neumann.fol.logic.Clause]: A set of generated FOL rules.
        """
        try:
            rules = self.llm_logic_generator.generate_rules(deictic_text, lang)
        except (
            lark.exceptions.VisitError,
            lark.exceptions.UnexpectedCharacters,
            lark.exceptions.UnexpectedEOF,
            IndexError,
            RuntimeError,
        ):
            print("Failed to parse the LLM response.")
            rules = []
        return rules

    def unify_semantics(self, text_lang, scene_graph_lang, llm_rules):
        """Unify semantics of the deictic prompt and the scene graph.

        Args:
            text_lang (str): A deictic prompt, i.e textual description of an object.
            scene_graph_lang (neumann.fol.language.Language): A FOL language for describing a scene graph.
            llm_rules (list[neumann.fol.logic.Clause]): A list of rules generated by LLMs.

        Returns:
            list[neumann.fol.logic.Clause]: A list of FOL rules that are rewritten accoding to the scene graph.
        """
        try:
            self.semantic_unifier = SemanticUnifier(scene_graph_lang, self.device)
            rewritten_rules = self.semantic_unifier.rewrite_rules(
                rules=llm_rules, lang=text_lang, graph_lang=scene_graph_lang
            )
            return rewritten_rules
        except (RuntimeError, IndexError):
            return llm_rules

    def forward_reasoning(self, lang, atoms, rules):
        """Perform forward reasoning by building NEUMANN, a GNN-based differentiable forward reasoner.

        Args:
            lang (neumann.fol.language.Language): A FOL language.
            atoms (list[neumann.fol.logic.Atom]): A list of ground atoms (facts) that represents a scene graph.
            rules (list[neumann.fol.logic.Clause]): A list of FOL rules that describe a deictic prompt.

        Returns:
            _type_: _description_
        """
        try:
            print("Building NEUMANN reasoner...")
            fc, neumann = load_neumann(lang, rules, atoms, self.device)
            # save reasoner for being accessed later
            self.neumann = neumann
            v_0 = fc(atoms)
            neumann.print_valuation_batch(v_0)
            v_T = neumann(v_0)
            target_atoms = [
                atom
                for atom in neumann.get_top_atoms(v_T[0])
                if atom.pred.name == "target"
            ]
            self.target_atoms = target_atoms
            return target_atoms, v_T, neumann
        except (RuntimeError, IndexError):
            self.target_atoms = []
            return [], None, None

    def segment_objects_by_sam(self, image_source, target_atoms, image_id):
        """Segment objects by running SAM on extracted crops.

        Args:
            image_source (_type_): Image source used by SAM.
            target_atoms (list[neumann.fol.logic.Atom]): A list of atoms that represents target objects to be segmented.
            data_index (int): A data index.

        Returns:
            list[mask] : A list of masks that represenets object segmentations
        """
        boxes = to_boxes_with_sgg(target_atoms, image_id, self.visual_genome_utils)
        # set image source
        self.sam_predictor.set_image(image_source)
        # transform boxes
        transformed_boxes = to_transformed_boxes(
            boxes, image_source, self.sam_predictor, self.device
        )
        try:
            masks, _, _ = self.sam_predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes,
                multimask_output=False,
            )
            return masks
        except RuntimeError:
            return []

    def forward(self, data_index, image_id, scene_graph, text, image_source):
        """Forwarding function of DeiSAM. It generates FOL rules using a LLM, build a differentiable reasoner, performs reasoning, and segment by SAM.
        Args:
            data_index (int): A data index.
            scene_graph (graph): A scene graph.
            text (str): A deictic prompt.
            image_source (image): An image source used by SAM.

        Returns:
            _type_: _description_
        """
        # extract a logic language from a scene graph to parse rules with in next step
        # lang = predicted_scene_graph_to_language(
        lang = get_init_language_with_sgg(
            scene_graph=scene_graph,
            text=text,
            logic_generator=self.llm_logic_generator,  # len(cropped_objects),
        )
        # generate rules by a LLM
        llm_rules = self.generate_rules_by_llm(text, lang)
        rewritten_rules = None

        print("LLM-generated rules:")
        for rule in llm_rules:
            print("   ", rule)

        # load scene graph in the form of logic by transforming it to the format
        (
            scene_graph_atoms,
            scene_graph_lang,
        ) = self.visual_genome_utils.image_id_to_atoms(
            image_id=image_id,
            lang=lang,  # , rules=llm_rules
        )

        # base language for semantic unifier
        scene_graph_base_lang = scene_graph_to_language_with_sgg(scene_graph)

        # perform forward reasoning (without semantic unification)
        target_atoms, v_T, neumann = self.forward_reasoning(
            scene_graph_lang, scene_graph_atoms, llm_rules
        )

        # if no targets found, perform semantic unification by rewriting rules according to the scene graph
        if len(target_atoms) == 0:
            rewritten_rules = self.unify_semantics(
                lang, scene_graph_base_lang, llm_rules
            )
            print("Semantically unified rules:")
            for rule in rewritten_rules:
                print("   ", rule)
            target_atoms, v_T, neumann = self.forward_reasoning(
                scene_graph_lang, scene_graph_atoms, rewritten_rules
            )

        # perform segmentation by SAM
        if len(target_atoms) > 0:
            masks = self.segment_objects_by_sam(image_source, target_atoms, data_index)
        else:
            masks = []

        # visualize and save the result
        # self.save_prediction()
        # self.save_visualized_masks()
        return masks, llm_rules, rewritten_rules


class TrainableDeiSAM(torch.nn.Module):
    """A class of trainable DeiSAM that parameterizes scene graph generators as weighted mixtures."""

    def __init__(self, api_key, device, vg_utils_list, sem_uni=False):
        super(TrainableDeiSAM, self).__init__()
        self.device = device
        self.sam_predictor = load_sam_model(device)
        self.llm_logic_generator = LLMLogicGenerator(api_key)
        # self.semantic_unifier = SemanticUnifier(device)
        self.visual_genome_utils_list = vg_utils_list
        self.rule_weights = self.init_random_weights(
            program_size=1, num_rules=len(vg_utils_list), device=device
        )
        # self.rule_weights = self.init_uniform_weights(
        #     program_size=1, num_rules=len(vg_utils_list), device=device
        # )
        self.device = device
        self.sem_uni = sem_uni
        # self.rule_weight_params.requires_grad = True
        # self.sgg_model = sgg_model
        # self.visual_genome_utils = VisualGenomeUtils()

    def init_random_weights(self, program_size, num_rules, device):
        """Initialize the clause weights with a random initialization."""
        weights = torch.nn.Parameter(
            torch.Tensor(np.random.rand(program_size, num_rules)).to(device)
        )
        return weights

    def init_uniform_weights(self, program_size, num_rules, device):
        """Initialize the clause weights with a random initialization."""
        weights = torch.nn.Parameter(torch.tensor([[0.5, 0.5]]).to(device))
        return weights

    def _softmax_clause_weights(self, clause_weights, temp=1e-1):
        """Take softmax of clause weights to choose M clauses."""
        clause_weights_sm = torch.softmax(clause_weights / temp, dim=1)
        return softor(clause_weights_sm, dim=0)

    def generate_rules_by_llm(self, deictic_text, lang):
        """Generate FOL rules using LLMs given textual prompts.

        Args:
            deictic_text (str): A text of a prompt, e.g. An object on the table and in front of the cup.
            lang (neumann.fol.language.Language): A FOL language.

        Returns:
            list[neumann.fol.logic.Clause]: A set of generated FOL rules.
        """
        try:
            rules = self.llm_logic_generator.generate_rules(deictic_text, lang)
            # currently we assume we get 1 rule to be generated
            # if len(rules) > 1:
            #     rules = [rules[0]]
        except (
            lark.exceptions.VisitError,
            lark.exceptions.UnexpectedCharacters,
            lark.exceptions.UnexpectedEOF,
            IndexError,
            RuntimeError,
        ):
            print("Failed to parse the LLM response.")
            rules = []
        return rules

    def unify_semantics(self, text_lang, scene_graph_lang, llm_rules):
        """Unify semantics of the deictic prompt and the scene graph.

        Args:
            text_lang (str): A deictic prompt, i.e textual description of an object.
            scene_graph_lang (neumann.fol.language.Language): A FOL language for describing a scene graph.
            llm_rules (list[neumann.fol.logic.Clause]): A list of rules generated by LLMs.

        Returns:
            list[neumann.fol.logic.Clause]: A list of FOL rules that are rewritten accoding to the scene graph.
        """
        try:
            self.semantic_unifier = SemanticUnifier(scene_graph_lang, self.device)
            # predicate re-writing is off for SGG training
            rewritten_rules = self.semantic_unifier.rewrite_rules(
                rules=llm_rules,
                lang=text_lang,
                graph_lang=scene_graph_lang,
                rewrite_pred=True,
            )
            return rewritten_rules
        except (RuntimeError, IndexError):
            return llm_rules

    def forward_reasoning(self, lang, atoms, rules_to_learn, rules_bk):
        """Perform forward reasoning by building NEUMANN, a GNN-based differentiable forward reasoner.

        Args:
            lang (neumann.fol.language.Language): A FOL language.
            atoms (list[neumann.fol.logic.Atom]): A list of ground atoms (facts) that represents a scene graph.
            rules (list[neumann.fol.logic.Clause]): A list of FOL rules that describe a deictic prompt.

        Returns:
            _type_: _description_
        """
        # try:
        print("Building NEUMANN reasoner...")
        fc, neumann = load_neumann_for_sgg_training(
            lang, rules_to_learn, rules_bk, atoms, self.device, infer_step=4
        )
        # overlap the rule weights in neumann by deisam parameters
        # ones = torch.ones_like(neumann.clause_weights, device=self.device)
        # neumann.clause_weights -= ones
        softmaxed_deisam_rule_weights = self._softmax_clause_weights(self.rule_weights)
        neumann.clause_weights = softmaxed_deisam_rule_weights
        # neumann.clause_weights = self.rule_weights
        neumann.print_program()
        # neumann.print_program()
        # save reasoner for being accessed later
        self.neumann = neumann
        V_0 = fc(atoms)
        V_T = neumann(V_0)
        target_atoms, target_scores = self.get_target_atoms_with_scores_to_segment(
            V_T[0], neumann.atoms, th=0.2
        )
        self.target_atoms = target_atoms
        # assert len(target_atoms) > 0, "No targets found!"
        return target_atoms, target_scores, V_T, neumann
        # except (RuntimeError, IndexError):
        #     self.target_atoms = []
        #     return [], None, None

    def get_target_atoms_with_scores_to_segment(self, v_T, atoms, th=0.2):
        target_atom_valuations = []
        target_atoms = []
        for i, atom in enumerate(atoms):
            if atom.pred.name == "target" and v_T[i] > th:
                target_atoms.append(atom)
                target_atom_valuations.append(v_T[i])
        if len(target_atoms) > 0:
            return target_atoms, target_atom_valuations
        else:
            ###
            ### get the highest-valuation target atom
            ###
            best_target_atom = atoms[0]
            best_target_valuation = -9999
            for i, atom in enumerate(atoms):
                if atom.pred.name == "target" and v_T[i] > best_target_valuation:
                    best_target_atom = atom
                    best_target_valuation = v_T[i]
            ###
            ### get random mask
            # ###
            # print("Get random object to segment....")
            # target_atoms = []
            # for i, atom in enumerate(atoms):
            #     if atom.pred.name == "target":
            #         target_atoms.append(atom)
            #         # target_atom_valuations.append(torch.tensor(0.5).to(self.device))
            # best_target_atom = random.choice(target_atoms)
            # best_target_valuation = torch.tensor(random.uniform(0.1, 0.4)).to(
            #     self.device
            # )  # torch.tensor(0.2).to(self.device)

            return [best_target_atom], [best_target_valuation]

    def segment_objects_by_sam(self, image_source, target_atoms, data_index, image_id):
        """Segment objects by running SAM on extracted crops.

        Args:
            image_source (_type_): Image source used by SAM.
            target_atoms (list[neumann.fol.logic.Atom]): A list of atoms that represents target objects to be segmented.
            data_index (int): A data index.

        Returns:
            list[mask] : A list of masks that represenets object segmentations
        """
        boxes_vg = to_boxes(target_atoms, data_index, self.visual_genome_utils_list[0])
        boxes_sgg = to_boxes_with_sgg(
            target_atoms, image_id, self.visual_genome_utils_list[1]
        )
        boxes = boxes_vg + boxes_sgg
        # set image source
        self.sam_predictor.set_image(image_source)
        # transform boxes
        transformed_boxes = to_transformed_boxes(
            boxes, image_source, self.sam_predictor, self.device
        )
        try:
            masks, _, _ = self.sam_predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes,
                multimask_output=False,
            )
            return masks
        except (RuntimeError, AttributeError) as e:
            print("Found RuntimeError while segmenting.. ")
            return None

    def forward(self, data_index, image_id, scene_graphs, text, image_source):
        """Forwarding function of DeiSAM. It generates FOL rules using a LLM, build a differentiable reasoner, performs reasoning, and segment by SAM.
        Args:
            data_index (int): A data index.
            scene_graph (graph): A scene graph.
            text (str): A deictic prompt.
            image_source (image): An image source used by SAM.

        Returns:
            _type_: _description_
        """
        assert (
            len(scene_graphs) == 2
        ), "Currently only 2 SGGs ara accepted for learning."
        # extract a logic language from a scene graph
        langs = []
        langs.append(
            scene_graph_to_language(
                scene_graph=scene_graphs[0],
                text=text,
                logic_generator=self.llm_logic_generator,
                num_objects=2,  # len(cropped_objects),
            )
        )
        # use the same language
        langs = [langs[0], langs[0]]
        # langs.append(scene_graph_to_language_with_sgg(scene_graph=scene_graphs[1]))

        # generate rules by a LLM
        llm_rules_list = [self.generate_rules_by_llm(text, lang) for lang in langs]
        # use re-loaded llm rules
        # llm_rules_list = [llm_rules, llm_rules]

        scene_graph_atoms_list = []
        scene_graph_langs = []
        for i, lang in enumerate(langs):
            if i == 0:
                # for ground truth SGG, call data_index_to_atom
                (
                    scene_graph_atoms,
                    scene_graph_lang,
                ) = self.visual_genome_utils_list[i].data_index_to_atoms(
                    data_index=data_index,
                    lang=lang,  # , rules=llm_rules
                )
            else:
                # for other SGGs, call image_id_to_atom
                (
                    scene_graph_atoms,
                    scene_graph_lang,
                ) = self.visual_genome_utils_list[i].image_id_to_atoms(
                    image_id=image_id,
                    lang=lang,  # , rules=llm_rules
                )

            scene_graph_atoms_list.append(scene_graph_atoms)
            scene_graph_langs.append(scene_graph_lang)

        # translate to sgg learning format and marge all graphs and languages
        sgg_graph_atoms, sgg_lang = translate_atoms_to_sgg_format(
            scene_graph_atoms_list, scene_graph_langs
        )
        sgg_rules = translate_rules_to_sgg_format(llm_rules_list)
        # update rules and language by adding target selection rules
        sgg_target_rules, sgg_lang = get_target_selection_rules(sgg_lang)

        rules_bk = sgg_rules
        rules_to_learn = sgg_target_rules

        # perform forward reasoning (without semantic unification)
        target_atoms, target_scores, v_T, neumann = self.forward_reasoning(
            sgg_lang, sgg_graph_atoms, rules_to_learn, rules_bk
        )
        print("Target Atoms: ", target_atoms)
        print("Target Scores: ", target_scores)
        # return target_atoms, v_T, neumann, sgg_rules

        ###
        ### Semantic Unification
        ###
        # if no targets found, perform semantic unification by rewriting rules according to the scene graph
        print("semantic unifier flag: ", self.sem_uni)
        if (
            self.sem_uni
            and torch.cat([s.unsqueeze(-1) for s in target_scores]).max() < 0.1
        ):
            rewritten_rules = self.unify_semantics(
                langs[0], scene_graph_langs[0], llm_rules_list[0]
            )
            print("Semantically unified rules:")
            for rule in rewritten_rules:
                print("   ", rule)
            # update by new rewritten rules
            llm_rules_list[1] = rewritten_rules
            # translate to sgg learning format and marge all graphs and languages
            sgg_graph_atoms, sgg_lang = translate_atoms_to_sgg_format(
                scene_graph_atoms_list, scene_graph_langs
            )
            sgg_rules = translate_rules_to_sgg_format(llm_rules_list)
            # update rules and language by adding target selection rules
            sgg_target_rules, sgg_lang = get_target_selection_rules(sgg_lang)

            rules_bk = sgg_rules
            rules_to_learn = sgg_target_rules
            target_atoms, target_scores, v_T, neumann = self.forward_reasoning(
                sgg_lang, sgg_graph_atoms, rules_to_learn, rules_bk
            )
            print("Target Atoms: ", target_atoms)
            print("Target Scores: ", target_scores)

        # sgement bys seeing taget_sgg1, target_sgg2..
        # perform segmentation by SAM
        if len(target_atoms) > 0:
            masks = self.segment_objects_by_sam(
                image_source, target_atoms, data_index, image_id
            )
        else:
            masks = None
        return masks, target_scores, sgg_rules
