from typing import Callable, Optional, Tuple, Union
import copy, math
import torch
import torch.nn as nn
from torch.nn import BatchNorm1d
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import DenseSAGEConv, DenseGCNConv, JumpingKnowledge, GraphConv, global_mean_pool, GCNConv
# from OTCoarsening.src.Sinkhorn import sinkhorn_loss_default
from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectTopK
from topk import SelectStaticTopK
from torch_geometric.typing import OptTensor

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class OTPooling(torch.nn.Module):
    def __init__(self,
        in_channels: int,
        ratio: Union[float, int] = 0.5,
        GNN: torch.nn.Module = GraphConv,
        nonlinearity: Union[str, Callable] = 'tanh',
        **kwargs,
    ):

        self.in_channels = in_channels
        self.ratio = ratio

        self.gnn = GNN(in_channels, 1, **kwargs)

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        self.gnn.reset_parameters()
    
    def normalize_batch_adj(self, adj):  # adj shape: batch_size * num_node * num_node, D^{-1/2} (A+I) D^{-1/2}
        dim = adj.size()[1]
        A = adj + torch.eye(dim, device=adj.device)
        deg_inv_sqrt = A.sum(dim=-1).clamp(min=1).pow(-0.5)

        newA = deg_inv_sqrt.unsqueeze(-1) * A * deg_inv_sqrt.unsqueeze(-2)
        newA = (adj.sum(-1)>0).float().unsqueeze(-1).to(adj.device) * newA
        return newA
        
    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        adj: Tensor,
        batch_size: int,
        edge_attr: OptTensor = None,
        batch: OptTensor = None,
        attn: OptTensor = None,
    ) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]:
        device = x.device

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        alpha_vec = F.sigmoid(torch.pow(self.gcn_att(x, edge_index), 2)) # b*n*1 --> b*n

        norm_adj = self.normalize_batch_adj(adj)
        cut_batch_num_nodes = batch_num_nodes
        cut_value = torch.zeros_like(alpha_vec[:, 0])
        for j in range(batch_size):
            if cut_batch_num_nodes[j] > 1:
                cut_batch_num_nodes[j] = torch.ceil(cut_batch_num_nodes[j].float() * self.assign_ratio)+1
                # cut_value[j], _ = (-alpha_vec[j]).kthvalue(cut_batch_num_nodes[j], dim=-1)
                temptopk, topk_ind = alpha_vec[j].topk(cut_batch_num_nodes[j], dim=-1)
                cut_value[j] = temptopk[-1]

            else:
                cut_value[j] = 0
        # cut_alpha_vec = torch.mul( ((alpha_vec - torch.unsqueeze(cut_value, -1))>=0).float(), alpha_vec)  # b * n
        cut_alpha_vec = F.relu(alpha_vec+0.0000001 - torch.unsqueeze(cut_value, -1))

        S = torch.mul(norm_adj, cut_alpha_vec.unsqueeze(1))  # repeat rows of cut_alpha_vec, #b * n * n
        # temp_rowsum = torch.sum(S, -1).unsqueeze(-1).pow(-1)
        # # temp_rowsum[temp_rowsum > 0] = 1.0 / temp_rowsum[temp_rowsum > 0]
        # S = torch.mul(S, temp_rowsum)  # row-wise normalization
        S = F.normalize(S, p=1, dim=-1)

        embedding_tensor = torch.matmul(torch.transpose(S, 1, 2),
                                        x)  # equals to torch.einsum('bij,bjk->bik',...)
        new_adj = torch.matmul(torch.matmul(torch.transpose(S, 1, 2), adj), S)  # batched matrix multiply

        return embedding_tensor, new_adj, S

class FisherPooling(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        ratio: Union[float, int] = 0.5,
        GNN: torch.nn.Module = GraphConv,
        min_score: Optional[float] = None,
        multiplier: float = 1.0,
        nonlinearity: Union[str, Callable] = 'tanh',
        num_proj: int = 3,
        **kwargs,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.min_score = min_score
        self.multiplier = multiplier
        self.num_proj = num_proj

        self.gnn = GNN(in_channels, in_channels, **kwargs)
        self.select = SelectStaticTopK(1, ratio, min_score, nonlinearity)
        self.connect = FilterEdges()

        self.projector = nn.Sequential(
            Linear(in_channels, in_channels),
            BatchNorm1d(in_channels),
            nn.LeakyReLU(0.1),
            Linear(in_channels, in_channels),
            BatchNorm1d(in_channels),
        )

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        self.gnn.reset_parameters()
        self.select.reset_parameters()

    def compute_fisher_information(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        batch: OptTensor = None,
    ) -> Tensor:
        batch_size = torch.unique(batch).size(0)
        
        device = x.device

        image_features = x.view(batch_size, -1, self.in_channels)[:, 0, :]

        graph_features = self.gnn(x, edge_index, edge_attr).view(batch_size, -1, self.in_channels)[:, 0, :]

        image_features = self.projector(image_features)
        graph_features = self.projector(graph_features)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        graph_features = graph_features / graph_features.norm(dim=-1, keepdim=True)

        score = F.cosine_similarity(image_features.unsqueeze(1), graph_features.unsqueeze(0), dim=-1)
        score = score / 0.7

        positive = torch.diag(torch.ones(len(graph_features), dtype=torch.bool, device=device))
        mutual_info = -(0.5 * (score[positive] - score.logsumexp(dim=-1)).mean() + 0.5 * (score[positive] - score.logsumexp(dim=0)).mean())

        # mutual_info = F.mse_loss(graph_features, image_features, reduction='mean')
        
        grad = torch.autograd.grad(mutual_info, x, create_graph=True)[0]
        fisher_info = (grad * grad).sum(dim=-1)

        # print(fisher_info.view(batch_size, -1).shape)
        # print(fisher_info.view(batch_size, -1)[0].max())
        # print(torch.sort(fisher_info.view(batch_size, -1)[0], descending=True))
        # exit()
        return fisher_info, mutual_info

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        batch: OptTensor = None,
        attn: OptTensor = None,
    ) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]:
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        attn, mutual_info = self.compute_fisher_information(x, edge_index, edge_attr, batch)

        select_out = self.select(attn, batch)

        perm = select_out.node_index

        batch_size = torch.unique(batch).size(0)
        # print(perm.view(batch_size, -1)[0])
        # exit()
        score = select_out.weight
        assert score is not None

        x = x[perm] * score.view(-1, 1)
        # x = x[perm]
        x = self.multiplier * x if self.multiplier != 1 else x

        connect_out = self.connect(select_out, edge_index, edge_attr, batch)

        return (x, connect_out.edge_index, connect_out.edge_attr,
                connect_out.batch, perm, score, mutual_info)

class SimPooling(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        ratio: Union[float, int] = 0.5,
        min_score: Optional[float] = None,
        multiplier: float = 1.0,
        nonlinearity: Union[str, Callable] = 'tanh',
        **kwargs,
    ):
        super().__init__()
        self.ratio = ratio
        self.min_score = min_score
        self.multiplier = multiplier
        self.in_channels = in_channels
        self.linears = clones(nn.Linear(self.in_channels, self.in_channels), 4)
        self.dropout = nn.Dropout(p=0.1)

        self.select = SelectStaticTopK(1, ratio, min_score, nonlinearity)
        self.connect = FilterEdges()

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        self.select.reset_parameters()

    def attention(self, query, key, value, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        d_k = query.size(-1)
        query = query.unsqueeze(1)
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                / math.sqrt(d_k)
        scores = scores.squeeze()
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        # return torch.matmul(p_attn.unsqueeze(1), value).squeeze(), p_attn
        return torch.mul(p_attn.unsqueeze(-1), value), p_attn

    def get_attention_score(
        self,
        x: Tensor,
        batch: OptTensor = None,
    ) -> Tensor:
        batch_size = torch.unique(batch).size(0)

        x = x.view(batch_size, -1, x.shape[-1])
        image_feature = x[:, 0, :]
        text_feature = x[:, 1:, :]

        query, key, value = \
            [l(x) for l, x in zip(self.linears, (image_feature, x, x))]

        # print(query.shape, key.shape)

        x, attn_score = self.attention(query, key, value, dropout=self.dropout)

        # print(attn_score.shape)
        return self.linears[-1](x), -attn_score

    def get_similarity_score(
        self,
        x: Tensor,
        batch: OptTensor = None,
    ) -> Tensor:
        batch_size = torch.unique(batch).size(0)

        x = x.view(batch_size, -1, x.shape[-1])
        x = x / x.norm(dim=-1, keepdim=True)

        x_t = x.permute(0, 2, 1)

        similarity = x @ x_t

        score = similarity[:, 0, :]
        return score.flatten()

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        batch: OptTensor = None,
        attn: OptTensor = None,
    ) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]:
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        if attn is None:
            x, attn = self.get_attention_score(x, batch)
            x = x.view(-1, x.size(-1))
            # attn_min = attn.min(dim=-1, keepdim=True)[0]
            # attn[:, 0, None] = attn_min - 1e-8
            attn = attn.flatten()


        select_out = self.select(attn, batch)
        print(attn)

        perm = select_out.node_index
        score = select_out.weight
        assert score is not None
        print(score.shape)
        exit()
        x = x[perm] #* score.view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        connect_out = self.connect(select_out, edge_index, edge_attr, batch)

        return (x, connect_out.edge_index, connect_out.edge_attr,
                connect_out.batch, perm, score)

class SAGPooling(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        ratio: Union[float, int] = 0.5,
        GNN: torch.nn.Module = GraphConv,
        min_score: Optional[float] = None,
        multiplier: float = 1.0,
        nonlinearity: Union[str, Callable] = 'tanh',
        **kwargs,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.min_score = min_score
        self.multiplier = multiplier

        self.gnn = GNN(in_channels, 1, **kwargs)
        self.select = SelectStaticTopK(1, ratio, min_score, nonlinearity)
        self.connect = FilterEdges()

        self.reset_parameters()

    def reset_parameters(self):
        self.gnn.reset_parameters()
        self.select.reset_parameters()

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        batch: OptTensor = None,
        attn: OptTensor = None,
    ) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]:

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        batch_size = torch.unique(batch).size(0)

        attn = x if attn is None else attn
        attn = attn.view(-1, 1) if attn.dim() == 1 else attn
        attn = self.gnn(attn, edge_index)
        attn = attn.view(batch_size, -1)
        attn_max = attn.max(dim=-1, keepdim=True)[0]
        attn[:, 0, None] = attn_max + 1e-8
        attn = attn.flatten()

        select_out = self.select(attn, batch)

        perm = select_out.node_index
        score = select_out.weight
        assert score is not None

        x = x[perm] #* score.view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        connect_out = self.connect(select_out, edge_index, edge_attr, batch)

        return (x, connect_out.edge_index, connect_out.edge_attr,
                connect_out.batch, perm, score)

    def __repr__(self) -> str:
        if self.min_score is None:
            ratio = f'ratio={self.ratio}'
        else:
            ratio = f'min_score={self.min_score}'

        return (f'{self.__class__.__name__}({self.gnn.__class__.__name__}, '
                f'{self.in_channels}, {ratio}, multiplier={self.multiplier})')

