import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter

from torch_geometric.graphgym.models.layer import LayerConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_layer
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.utils import to_dense_batch


from typing import List

def get_graph_pooling(name):
    if name == "sum":
        pool = global_add_pool
    elif name == "mean":
        pool = global_mean_pool
    elif name == "max":
        pool = global_max_pool
    else:
        raise ValueError(f"graph pooling - {name} is unsupported at this time!")
    return pool

def reset_sequential_parameters(seq: torch.nn.Sequential) -> None:
    lst = torch.nn.ModuleList(seq)
    for l in lst:
        if not isinstance(l, (torch.nn.ReLU, torch.nn.Dropout, torch.nn.GELU)):
            l.reset_parameters()
    # return torch.nn.Sequential(**lst)

class MLP(torch.nn.Module):
    def __init__(self, hidden_dims: List,
                 batch_norm: bool = False,
                 layer_norm: bool = False,
                 dropout: float = 0.5,
                 activate_last: bool = False):
        super(MLP, self).__init__()

        assert not (batch_norm and layer_norm)   # cannot be both true

        num_layers = len(hidden_dims) - 1
        modules = []
        for i in range(num_layers):
            modules.append(torch.nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=i < num_layers - 1))
            if batch_norm and i < num_layers - 1:
                modules.append(torch.nn.BatchNorm1d(hidden_dims[i + 1]))
            if layer_norm and i < num_layers - 1:
                modules.append(torch.nn.LayerNorm(hidden_dims[i + 1]))
            if i < num_layers - 1 or activate_last:
                modules.append(torch.nn.ReLU())
                modules.append(torch.nn.Dropout(p=dropout))

        self.mlp = torch.nn.Sequential(*modules)

    def forward(self, x):
        return self.mlp(x)

    def reset_parameters(self):
        reset_sequential_parameters(self.mlp)

class primphormer_LM(nn.Module):

    def __init__(self, in_dim, out_dim, num_heads, use_bias, num_vn=5, graph_pooling="mean", low_rank=30):
        super().__init__()

        if out_dim % num_heads != 0:
            raise ValueError('hidden dimension is not dividable by the number of heads.')
        if num_vn <= 0:
            raise ValueError('Num of VNs should be larger than 0.')
        self.d_keys = out_dim // num_heads
        self.n_heads = num_heads
        self.use_bias = use_bias
        self.low_rank = low_rank

        self.query_projection = nn.Linear(in_dim, self.d_keys * num_heads, bias=use_bias)
        self.key_projection = nn.Linear(in_dim, self.d_keys * num_heads, bias=use_bias)
        self.pool = get_graph_pooling(graph_pooling)
        self.vn_projection = nn.Linear(in_dim, low_rank, bias=True)

        self.We = nn.Parameter(nn.init.orthogonal_(torch.Tensor(num_vn, self.n_heads, self.d_keys)))
        self.Wr = nn.Parameter(nn.init.orthogonal_(torch.Tensor(num_vn, self.n_heads, self.d_keys)))
        self.Lambda = nn.Parameter(nn.init.uniform_(torch.Tensor(self.n_heads, low_rank)))
        self.concate_weight = nn.Linear(2*low_rank, self.d_keys)

    def norm(self, x):
        return F.normalize(x, p=2, dim=-1)

    def propagate_vn2(self, batch, h):
        h = self.vn_projection(h)
        h_vn = self.pool(h, batch.batch).unsqueeze(1)
        h_vn = h_vn + batch.virt_h
        return h_vn

    def forward(self, batch):
        x = batch.x
        x_dense, mask = to_dense_batch(x, batch.batch)
        B, M = mask.shape

        # update of virtual nodes. (Batch * n_vn * d_in)
        # Not full, loss ksvd explode

        #B * N_vn * low_rank
        virtual_node = self.propagate_vn2(batch, x)
        We_X = torch.einsum('bdv,vhe->bdhe', virtual_node.transpose(2, 1), self.We)
        Wr_X = torch.einsum('bdv,vhe->bdhe', virtual_node.transpose(2, 1), self.Wr)


        queries = self.query_projection(x_dense).view(B, M, self.n_heads, -1)
        keys = self.key_projection(x_dense).view(B, M, self.n_heads, -1)

        # new item about normalization
        # queries = self.norm_q(queries, keys)

        queries = self.norm(queries)
        keys = self.norm(keys)

        escore = torch.einsum('bmhd,bhde->bmhe', queries, We_X.permute(0, 2, 3, 1))[mask]
        rscore = torch.einsum('bmhd,bhde->bmhe', keys, Wr_X.permute(0, 2, 3, 1))[mask]

        score = torch.cat((escore, rscore), dim=-1)
        out = self.concate_weight(score).contiguous()
        out = out.view(-1, self.n_heads * self.d_keys)
        batch.virt_h = virtual_node

        # loss_escore = (torch.einsum('nhd,hd->nhd', escore, self.Lambda).norm(dim=-1, p=2)**2).sum(dim=0).mean()/2/(batch.batch[-1].item() + 1)
        # loss_rscore = (torch.einsum('nhd,hd->nhd', rscore, self.Lambda).norm(dim=-1, p=2)**2).sum(dim=0).mean()/2/(batch.batch[-1].item() + 1)

        loss_escore = (torch.einsum('nhd,hd->nhd', escore, self.Lambda).norm(dim=-1, p=2)**2).mean()/2
        loss_rscore = (torch.einsum('nhd,hd->nhd', rscore, self.Lambda).norm(dim=-1, p=2)**2).mean()/2
        loss_trace = torch.einsum('dhe,ehk->dhk', self.We.permute(2, 1, 0), self.Wr).mean(dim=1).trace()

        loss_ksvd = (loss_escore + loss_rscore - loss_trace) ** 2

        return out, loss_ksvd

register_layer('Primphormer_lm', primphormer_LM)

if __name__ == '__main__':
    x = torch.arange(240).view(5, 6, 8).float()
    len = 4
    B, L, d = x.shape
    H = 2
    indices = torch.linspace(0, L-1, len, dtype=int)
    x = x.transpose(-2, -1).reshape(B, H, d//H, L)
    x = x[:, :, :, indices].transpose(1, 2)
    print(1)
