from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential, Embedding, ModuleList, BatchNorm1d, ReLU, LayerNorm

from torch_geometric.nn.attention import PerformerAttention
from torch_geometric.nn import GPSConv, GINEConv, global_add_pool

class RedrawProjection:
    def __init__(self, model: torch.nn.Module,
                redraw_interval: Optional[int] = None):
        self.model = model
        self.redraw_interval = redraw_interval
        self.num_last_redraw = 0

    def redraw_projections(self):
        if not self.model.training or self.redraw_interval is None:
            return
        if self.num_last_redraw >= self.redraw_interval:
            fast_attentions = [
                module for module in self.model.modules()
                if isinstance(module, PerformerAttention)
            ]
            for fast_attention in fast_attentions:
                fast_attention.redraw_projection_matrix()
            self.num_last_redraw = 0
            return
        self.num_last_redraw += 1

class GPS(torch.nn.Module):
    def __init__(self, channels: int, pe_dim: int, num_layers: int, num_atom_type: int,
                attn_type: str = 'multihead', tau: float = 0.1, attn_kwargs: Dict[str, Any] = None):
        super().__init__()

        self.node_emb = Embedding(num_atom_type, channels - pe_dim)
        self.pe_lin = Linear(20, pe_dim)
        self.pe_norm = BatchNorm1d(20)
        self.edge_mlp = Sequential(
            Linear(channels * 2, channels),
            ReLU(),
            Linear(channels, channels),
        )

        self.convs = ModuleList()
        for _ in range(num_layers):
            nn = Sequential(
                Linear(channels, channels),
                ReLU(),
                LayerNorm(channels),
                Linear(channels, channels),
            )
            conv = GPSConv(channels, GINEConv(nn), heads=4, norm='graph_norm',
                        attn_type=attn_type, attn_kwargs=attn_kwargs)
            self.convs.append(conv)

        # self.mlp1 = Sequential(
        #     Linear(channels, channels),
        #     ReLU(),
        #     LayerNorm(channels),
        #     Linear(channels, channels),
        # )
        # self.mlp2 = Sequential(
        #     Linear(channels, channels),
        #     ReLU(),
        #     LayerNorm(channels),
        #     Linear(channels, channels),
        # )
        self.mlp = Sequential(
            Linear(channels, channels),
            ReLU(),
            LayerNorm(channels),
            Linear(channels, channels),
        )
        self.redraw_projection = RedrawProjection(
            self.convs,
            redraw_interval=1000 if attn_type == 'performer' else None)
        
        self.tau = tau

    def forward(self, x, pe, edge_index, edge_attr, batch):
        x_pe = self.pe_norm(pe)
        x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)
        if edge_attr is not None:
            edge_attr = self.edge_mlp(edge_attr)
        else:
            row, col = edge_index
            edge_attr = self.edge_mlp(torch.cat([x[row], x[col]], dim=-1))

        for conv in self.convs:
            x = conv(x, edge_index, batch, edge_attr=edge_attr)
        x = global_add_pool(x, batch)
        return self.mlp(x)
    
    def cl_loss(self, graph1, graph2):
        x = (self(graph1.x, graph1.random_walk_pe, graph1.edge_index, graph1.edge_attr, graph1.batch))
        y = (self(graph2.x, graph2.random_walk_pe, graph2.edge_index, graph2.edge_attr, graph2.batch))
        
        x, y = F.normalize(x, dim=-1), F.normalize(y, dim=-1)
        print(x.shape, y.shape)
        score = F.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0), dim=-1)
        print(score)
        exit()
        score = score / self.tau
        is_positive = torch.diag(torch.ones(len(x), dtype=torch.bool, device=x.device))
        
        n = len(x)
        negative_score = score.flatten()[1:].view(n-1, n+1)[:,:-1].reshape(n, n-1)
        mutual_info = -(score[is_positive] - negative_score.logsumexp(dim=-1)).mean()
        # label = torch.arange(x.size(0)).to(x.device)
        # loss_fct = torch.nn.CrossEntropyLoss()    
        # mutual_info = loss_fct(score, label)
        
        return mutual_info