from typing import Dict, Union, Optional
import math
import torch
import torch.nn as nn
from torch_geometric.data import Data

from matching.sinkhorn import LogSinkhorn, argSinkhorn
from matching.sinkhorn_padded import LogSinkhornPadded, LogSinkhornPaddedRect, argSinkhornPadded
from matching.sinkhorn_approx_diag import (LogNystromSinkhornBPdiag, LogSparseSinkhornBPdiag,
                                           LogSparseNystromSinkhornBPdiag)
from matching.uniform import Uniform
from matching.chamfer import calc_chamfer
from utils.distance_matrix import compute_distmatrix
from utils import aggregation, sinkhorn_normalization, distances


class GTN(nn.Module):
    def __init__(
            self,
            gnn: nn.Module,
            emb_dist_scale: float,
            sparse_batching: bool,
            distance: str,
            device: torch.device,
            p_norm: int = 2,
            sinkhorn_reg: float = 0.1,
            sinkhorn_niter: int = 50,
            sinkhorn_reg_stepval: Optional[float] = None,
            return_matching: bool = False,
            bp_dist_matrix: bool = False,
            nystrom: dict = None,
            sparse: dict = None,
            extensive: bool = True,
            matching_size: int = 0,
            scale_embeddings: bool = False,
            output_sim: bool = False):
        super().__init__()
        self.gnn = gnn
        self.sparse_batching = sparse_batching
        self.distance = distance
        self.device = device
        self.sinkhorn_reg = torch.tensor(sinkhorn_reg * emb_dist_scale, dtype=torch.float32, device=self.device)
        self.sinkhorn_reg_stepval = sinkhorn_reg_stepval
        self.return_matching = return_matching
        self.bp_dist_matrix = bp_dist_matrix
        self.nystrom = nystrom.copy() if nystrom else None
        self.sparse = sparse.copy() if sparse else None
        self.extensive = extensive
        self.matching_size = matching_size
        self.scale_embeddings = scale_embeddings
        self.output_sim = output_sim

        if self.nystrom:
            self.nystrom['centroids'] = None
        if self.sparse:
            self.sparse['centroids'] = None
            self.sparse['nhashes'] = self.sparse['nhashes'] if 'nhashes' in self.sparse else 1

        if type(gnn.aggregate) in [aggregation.Last, aggregation.FullProjection, aggregation.MLP]:
            gnn_output_size = gnn.emb_size
        else:
            gnn_output_size = gnn.emb_size * (len(gnn.layers) + 1)
        self.match_emb_size = gnn_output_size

        self.emb_scale = emb_dist_scale / math.sqrt(self.match_emb_size)

        if p_norm == -1:
            self.dist = distances.MLP(self.match_emb_size).to(self.device)
        else:
            self.dist = distances.PNorm(p_norm)

        # Alpha for determining cost of insertion/deletion
        self.alpha = nn.Parameter(torch.empty(self.match_emb_size))

        if self.matching_size == 0:
            self.output_layer = nn.Identity()
        elif self.matching_size == 1:
            self.output_layer = nn.Linear(1, 1, bias=True)
        else:
            self.dist_transform = nn.Linear(self.match_emb_size, self.matching_size * self.match_emb_size, bias=False)
            self.output_layer = nn.Sequential(
                nn.Linear(self.matching_size, gnn.emb_size, bias=True),
                nn.LeakyReLU(),
                nn.Linear(gnn.emb_size, 1, bias=True),
                )

        self.sinkhorn_niter = sinkhorn_niter

        self.reset_parameters()
        self.to(device)

    def step_sinkhorn_reg(self):
        if self.sinkhorn_reg_stepval:
            self.sinkhorn_reg = max(self.sinkhorn_reg / self.sinkhorn_reg_stepval, 0.04)

    def output_transform(self, x, dist_mat_len):
        if x is None:
            return None
        else:
            if not self.extensive:
                if x.dim() == 1:
                    x = x / dist_mat_len
                else:
                    x = x / dist_mat_len[:, None]
            if self.return_matching:
                # return self.output_layer(x)
                if self.matching_size > 0:
                    dist_matrix = self.output_layer(x.transpose(1, -1)).squeeze(-1)
                    dist_matrix = dist_matrix ** 2  # ensure entries are positive
                    batch_size, max_nodes, _ = dist_matrix.shape
                    mask_vector = (torch.arange(max_nodes, dtype=torch.float32, device=dist_matrix.device).expand(batch_size, max_nodes)
                                   >= dist_mat_len[:, None]).float()
                    u, v = sinkhorn_normalization(dist_matrix, 50, False, mask_vector)
                    return torch.diag_embed(u) @ dist_matrix @ torch.diag_embed(v)
                else:
                    return x
            else:
                if x.dim() == 1:
                    x = x[:, None]
                out = self.output_layer(x).squeeze(-1)
                if self.output_sim:
                    return torch.exp(-out)
                else:
                    return out

    def reset_parameters(self):
        if self.sparse_batching:
            nn.init.zeros_(self.alpha)
        else:
            nn.init.ones_(self.alpha)
        if self.matching_size == 1:
            nn.init.ones_(self.output_layer.weight)
            nn.init.zeros_(self.output_layer.bias)

    def _compute_node_embeddings(self, graph1, graph2):
        node_embeddings_raw = []
        if self.sparse_batching:
            for graph in [graph1, graph2]:
                data = Data(x=graph['attr_matrix'], edge_index=graph['adj_idx'], edge_attr=graph['edge_attr_matrix'])
                node_embeddings_raw.append(self.gnn(data))
        else:
            b_x_nnodes = graph1['attr_matrix'].shape[:2]
            nfeat = graph1['attr_matrix'].shape[2:]
            for graph in [graph1, graph2]:
                x = graph['attr_matrix'].view((-1, *nfeat))
                data = Data(x=x, edge_index=graph['adj_idx'], edge_attr=graph['edge_attr_matrix'])
                node_embeddings_raw.append(self.gnn(data).view(*b_x_nnodes, -1))

        if self.scale_embeddings:
            node_embeddings = [embs * self.emb_scale for embs in node_embeddings_raw]
        else:
            node_embeddings = node_embeddings_raw

        return node_embeddings

    def _compute_matching(self, node_embeddings, dist_mat_obj, return_distance=True, sparse=None, nystrom=None):
        dist_matrix, dist_idx, dist_mat_len, num_nodes, norms1, norms2, sinkhorn_reg = (
                dist_mat_obj.m, dist_mat_obj.dist_idx, dist_mat_obj.dist_mat_len, dist_mat_obj.num_nodes,
                dist_mat_obj.norms1, dist_mat_obj.norms2, dist_mat_obj.sinkhorn_reg)

        if self.sparse_batching:
            if self.distance == 'sinkhorn':
                batch_idx = torch.repeat_interleave(
                        torch.arange(dist_mat_len.shape[0], device=dist_mat_len.device),
                        dist_mat_len**2)
                if return_distance:
                    output = LogSinkhorn.apply(dist_matrix, dist_idx, dist_mat_len, sinkhorn_reg, batch_idx, self.sinkhorn_niter)
                else:
                    output = argSinkhorn(dist_matrix, dist_idx, dist_mat_len, sinkhorn_reg, batch_idx, self.sinkhorn_niter)
            else:
                raise NotImplementedError(f"Invalid distance method '{self.distance}'.")
        else:
            if return_distance:
                if self.distance == 'chamfer':
                    output = calc_chamfer(dist_matrix, dist_mat_len)
                elif self.distance == 'uniform':
                    output = Uniform.apply(dist_matrix, dist_mat_len)
                elif self.distance == 'sinkhorn':
                    if sparse and nystrom:
                        output = LogSparseNystromSinkhornBPdiag.apply(*dist_matrix, norms1, norms2,
                                                                          num_nodes, sinkhorn_reg, self.sinkhorn_niter)
                    elif sparse:
                        output = LogSparseSinkhornBPdiag.apply(*dist_matrix, *norms1, *norms2,
                                                                   num_nodes, sinkhorn_reg, self.sinkhorn_niter)
                    elif nystrom:
                        output = LogNystromSinkhornBPdiag.apply(*dist_matrix, norms1, norms2, num_nodes,
                                                                sinkhorn_reg, self.sinkhorn_niter)
                    else:
                        if self.bp_dist_matrix:
                            output = LogSinkhornPadded.apply(dist_matrix, dist_mat_len, sinkhorn_reg, self.sinkhorn_niter, False)
                        else:
                            # output = LogSinkhornPadded.apply(dist_matrix, dist_mat_len, sinkhorn_reg, self.sinkhorn_niter)
                            output = LogSinkhornPaddedRect.apply(dist_matrix, num_nodes, sinkhorn_reg, self.sinkhorn_niter)
                else:
                    raise NotImplementedError(f"Invalid distance method '{self.distance}'.")
            else:
                if self.distance == 'sinkhorn':
                    output = argSinkhornPadded(dist_matrix, dist_mat_len, sinkhorn_reg, self.sinkhorn_niter)
                else:
                    raise NotImplementedError(f"Invalid distance method '{self.distance}'.")
        return output

    def forward(self,
                graph1: Dict[str, Union[torch.Tensor, torch.sparse.FloatTensor]],
                graph2: Dict[str, Union[torch.Tensor, torch.sparse.FloatTensor]]):

        num_nodes = torch.stack((graph1['num_nodes'], graph2['num_nodes']))
        if self.bp_dist_matrix or self.nystrom or self.sparse:
            dist_mat_len = num_nodes.sum(0)
        else:
            dist_mat_len = num_nodes.max(0).values

        node_embeddings = self._compute_node_embeddings(graph1, graph2)

        _, max_nodes, _ = node_embeddings[0].shape

        if self.matching_size > 1:
            node_match_embs = []
            node_match_embs.append(torch.cat(self.dist_transform(node_embeddings[0]).split(self.match_emb_size, dim=-1),
                                   dim=0))
            node_match_embs.append(torch.cat(self.dist_transform(node_embeddings[1]).split(self.match_emb_size, dim=-1),
                                   dim=0))
            dist_mat_len_rep = dist_mat_len.repeat(self.matching_size)
            num_nodes_rep = num_nodes.repeat(1, self.matching_size)
        else:
            node_match_embs = node_embeddings
            dist_mat_len_rep = dist_mat_len
            num_nodes_rep = num_nodes

        reg_scaled = self.sinkhorn_reg / torch.log(num_nodes_rep.min(0).values.float() + 1)

        dist_matrix, nll_mask = compute_distmatrix(
                node_match_embs, num_nodes_rep, dist_mat_len_rep,
                sparse=self.sparse, nystrom=self.nystrom,
                reg_scaled=reg_scaled, sinkhorn_niter=self.sinkhorn_niter,
                dist=self.dist, alpha=self.alpha,
                bp_dist_matrix=self.bp_dist_matrix,
                sparse_batching=self.sparse_batching)

        # Uncomment for visualizing node embeddings with tensorboard
        # self.node_embeddings = node_embeddings

        output = self._compute_matching(
                node_match_embs, dist_matrix,
                return_distance=(not self.return_matching),
                sparse=self.sparse, nystrom=self.nystrom,)

        if self.matching_size > 1:
            output = output.reshape(self.matching_size, num_nodes.shape[1], *output.shape[1:]).transpose(0, 1)

        return (self.output_transform(None, dist_mat_len),
                self.output_transform(output, dist_mat_len))
