from typing import Any, Dict

import torch
import torch.nn as nn
from torch import Tensor

from .gnn import GNN


class Matcher(nn.Module):
    def __init__(
        self,
        similarity: str,
        num_codes: int,
        gnn_cfg: Dict[str, Any],
        inc_embedding: int = None,
        vertex_weight: Tensor = None,
        debug: bool = False
    ):
        super().__init__()
        self.gnn = GNN(
            num_codes=num_codes,
            **gnn_cfg,
            inc_embedding=inc_embedding,
            vertex_weight=vertex_weight,
            debug=debug
        )
        SUPPORTED_SIM = {
            "cosine": self._cosine_sim,
            "euclidean": self._euclidean_sim,
            "inner_product": self._inner_product
        }
        self.similarity = SUPPORTED_SIM[similarity]
        self.debug = debug

    def _cosine_sim(self, feat_1: torch.Tensor, feat_2: torch.Tensor):
        sim = torch.cosine_similarity(feat_1, feat_2, dim=-1)
        return (sim + 1) / 2

    def _euclidean_sim(self, feat_1: torch.Tensor, feat_2: torch.Tensor):
        dist = torch.linalg.vector_norm(feat_1 - feat_2, dim=-1)
        return 1 / (1 + dist)

    def _inner_product(self, feat_1: torch.Tensor, feat_2: torch.Tensor):
        dist = (feat_1 * feat_2).sum(-1)
        return dist

    def forward(
        self,
        instance_vertices: torch.Tensor, instance_edges: torch.Tensor,
        kg_vertices: torch.Tensor, kg_edges: torch.Tensor,
        task: int = None
    ):
        # [bs, dim]
        feat_instance, info_instance = self.gnn(instance_vertices, instance_edges, task=task)
        # [num_classes, dim]
        feat_kg, info_category = self.gnn(kg_vertices, kg_edges, task=task)

        bs = feat_instance.shape[0]

        feat_kg = feat_kg.expand(bs, -1, -1)
        feat_instance = feat_instance.unsqueeze(1).expand_as(feat_kg)

        sim = self.similarity(feat_instance, feat_kg)
        return sim, info_instance, info_category
