import collections

import torch
import torch.nn as nn

from .relation_graph import RelationGraph
from .match import Matcher
from .convert_graph import to_networkx

from dark_kg.utils import IngredientModelWrapper


class RelationGraphPredictor(nn.Module):
    """
    Procedure:
        1. use ingredient model to predict sequence of ingredient
        2. use dark KG to predict
    Prediction:
        "pred": dark kg prediction, shape: [bs, num_classes];
        "origin_pred": origin model prediction, shape: [bs, num_classes];
        "codes": codes predicted by origin model, shape: [bs, H, W];
        "attribution": attribution to codes w.r.t. each class, shape: [bs, num_classes, H, W]
    """
    def __init__(
        self,
        ingredient_wrapper: IngredientModelWrapper,
        relation_graph: RelationGraph,
        matcher: Matcher
    ):
        super().__init__()
        self.ingredient_wrapper = ingredient_wrapper
        self.relation_graph = relation_graph
        self.matcher = matcher
        self.num_classes = relation_graph.num_classes

    def forward(self, x: torch.Tensor, requires_graph: bool = False, task: int = None):
        ret = collections.OrderedDict()
        with torch.no_grad():
            output = self.ingredient_wrapper(x)
        vertices, edges = self.relation_graph(
            ingredients=output["ingredients"],
            attn=output["attn"],
            attn_cls=output["attn_cls"]
        )
        vertex_weights = self.relation_graph.get_vertex_weights()
        edge_weights = self.relation_graph.get_edge_weights()
        pred, info_instance, info_category = self.matcher(
            instance_vertices=vertices,
            instance_edges=edges,
            kg_vertices=vertex_weights,
            kg_edges=edge_weights,
            task=task
        )
        ret["pred"] = pred
        ret["vertex_weights"] = vertex_weights
        ret["edge_weights"] = edge_weights
        ret["info_instance"] = info_instance
        ret["info_category"] = info_category
        if requires_graph:
            ret["instance_vertices"] = vertices
            ret["instance_edges"] = edges
            ret["ingredients"] = output["ingredients"]
            ret["attn_cls"] = output["attn_cls"]
        return ret
