import torch
import torch.nn as nn
from e3nn import o3
from torch.nn import Linear
from torch.nn import functional as F
from torch_cluster import radius, radius_graph
from torch_geometric.nn import BatchNorm
from torch_geometric.nn.aggr import  Set2Set
from torch_geometric.nn.conv import TransformerConv
from torch_scatter import scatter, scatter_mean
from tqdm import tqdm
from functools import partial
from models.layers import TensorProductConvLayer


class GNNE(nn.Module):
    """GNN Encoder working on internal graph representation.
       Inspired by https://jcheminf.biomedcentral.com/counter/pdf/10.1186/s13321-019-0396-x.pdf
       Code adapted from https://github.com/deepfindr/gvae """
    
    def __init__(self, feature_size,encoder_embedding_size, edge_embedding_size, latent_embedding_size):
        super(GNNE, self).__init__()
        self.encoder_embedding_size = encoder_embedding_size
        self.edge_embedding_size = edge_embedding_size
        self.latent_embedding_size = latent_embedding_size
        self.node_embedding = NodeEmbedding(emb_dim=feature_size, feature_dims=([116,116],3))

        # Encoder layers
        self.conv1 = TransformerConv(feature_size, 
                                    self.encoder_embedding_size, 
                                    heads=4, 
                                    concat=False,
                                    beta=True,
                                    edge_dim=self.edge_embedding_size)
        self.bn1 = BatchNorm(self.encoder_embedding_size)
        self.conv2 = TransformerConv(self.encoder_embedding_size, 
                                    self.encoder_embedding_size, 
                                    heads=4, 
                                    concat=False,
                                    beta=True,
                                    edge_dim=self.edge_embedding_size)
        self.bn2 = BatchNorm(self.encoder_embedding_size)
        self.conv3 = TransformerConv(self.encoder_embedding_size, 
                                    self.encoder_embedding_size, 
                                    heads=4, 
                                    concat=False,
                                    beta=True,
                                    edge_dim=self.edge_embedding_size)
        self.bn3 = BatchNorm(self.encoder_embedding_size)
        self.conv4 = TransformerConv(self.encoder_embedding_size, 
                                    self.encoder_embedding_size, 
                                    heads=4, 
                                    concat=False,
                                    beta=True,
                                    edge_dim=self.edge_embedding_size)

        # Pooling layers
        self.pooling = Set2Set(self.encoder_embedding_size, processing_steps=4)

        # Latent transform layers
        self.mu_transform = Linear(self.encoder_embedding_size*2, 
                                            self.latent_embedding_size)
        self.logvar_transform = Linear(self.encoder_embedding_size*2, 
                                            self.latent_embedding_size)

      

    def forward(self,data):
        x, edge_attr, edge_index, batch_index = self.node_embedding(data.x), data.edge_attr, data.edge_index, data.batch
        
        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.bn1(x)
        x = self.conv2(x, edge_index, edge_attr).relu()
        x = self.bn2(x)
        x = self.conv3(x, edge_index, edge_attr).relu()
        x = self.bn3(x)
        x = self.conv4(x, edge_index, edge_attr).relu()

        # Pool to global representation
        x = self.pooling(x, batch_index)

        # Latent transform layers
        mu = self.mu_transform(x)
        logvar = self.logvar_transform(x)
        return mu, logvar
    

class TensorProductEncoder(nn.Module):

    def __init__(self,ns,nv,sh_lmax, num_conv_layers , batch_norm, dropout , latent_embedding_size , in_edge_features , use_set2set_pooling, distance_embed_dim=32) -> None:
        super().__init__()

        self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
        self.ns = ns # number of scalar features
        self.nv = nv # number of vector features
        self.sh_lmax = sh_lmax # max spherical harmonics degree
        self.in_edge_features = 1 # number of edge features of input graph
        self.max_radius = 5  # max radius of the radius graph constructed for the convolution graph
        self.distance_embed_dim = distance_embed_dim  # number of gaussians used to expand the distance embedding (used for the radius graph distances)
        self.use_set2set_pooling = use_set2set_pooling  # use set2set aggregation instead of global pooling

        # Node embeddings for convolution graph, feature_dims = ([Number_Categories_Cat_1,Number_Categories_Cat_2,...],Number_Scalar_Features)
        self.node_embedding = NodeEmbedding(emb_dim=ns, feature_dims=([116],1))
        # Edge embeddings for convolution graph
        self.edge_embedding = nn.Sequential(nn.Linear(in_edge_features  + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
        # Edge distance expansion for radius graph edges
        self.edge_dist_expansion = GaussianSmearing(0.0, self.max_radius, self.distance_embed_dim)

        # Convolution layers irreps defining input and output irreps of each layer, based on the number of scalar and vector features
        irrep_seq = [
                f'{ns}x0e',
                f'{ns}x0e + {nv}x1o',
                f'{ns}x0e + {nv}x1o + {nv}x1e',
                f'{ns}x0e + {nv}x1o + {nv}x1e + {ns}x0o'
            ]
        conv_layers = []


        # Convolution layers
        for i in range(num_conv_layers):
                    in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
                    out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
                    parameters = {
                        'in_irreps': in_irreps,
                        'sh_irreps': self.sh_irreps,
                        'out_irreps': out_irreps,
                        'n_edge_features': 3 * ns ,
                        'residual': False,
                        'batch_norm': batch_norm,
                        'dropout': dropout,
                        'faster': sh_lmax == 1 
                    }

                    
                    conv_layers.append(TensorProductConvLayer(**parameters))

        self.conv_layers = nn.ModuleList(conv_layers)


        self.final_edge_embedding = nn.Sequential(
                    nn.Linear(distance_embed_dim, ns),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(ns, ns)
                )
        
        self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
        self.tor_bond_conv = TensorProductConvLayer(
            in_irreps=self.conv_layers[-1].out_irreps,
            sh_irreps=self.final_tp_tor.irreps_out,
            out_irreps=f'{ns}x0o + {ns}x0e',
            n_edge_features=3 * ns,
            residual=False,
            dropout=dropout,
            batch_norm=batch_norm
        )
       
        self.final_tor_layer = nn.Sequential(nn.Linear(o3.Irreps(self.tor_bond_conv.out_irreps).dim,10, bias=False),
                                             nn.Tanh(),
                                                nn.Dropout(dropout),
                                                nn.Linear(10,1))

        conv_out_size = o3.Irreps(self.conv_layers[-1].out_irreps).dim * (1+ self.use_set2set_pooling) + 1 # 1 for the final tor layer output
        
        # Pooling layers
        if self.use_set2set_pooling:
            self.pooling = Set2Set(o3.Irreps(self.conv_layers[-1].out_irreps).dim, processing_steps=4)
        else:
            self.pooling = partial(scatter_mean, dim=0)

        # Latent transform layers
        self.mu_transform = nn.Sequential(Linear(conv_out_size , latent_embedding_size),
                                          nn.Dropout(dropout),
                                            nn.ReLU(),
                                            Linear(latent_embedding_size, latent_embedding_size))

        self.logvar_transform = nn.Sequential(Linear(conv_out_size, latent_embedding_size),
                                            nn.Dropout(dropout),
                                            nn.ReLU(),
                                            Linear(latent_embedding_size, latent_embedding_size))
        

    def forward(self, data):
     
        # get conv graph and compute node and edge embeddings
        node_attr, edge_index, edge_attr, edge_sh, edge_weight = self.build_conv_graph(data)
        node_attr = self.node_embedding(node_attr)
        edge_attr = self.edge_embedding(edge_attr)


        src, dst = edge_index
        for l in range(len(self.conv_layers)):
            # message passing
            edge_attr_ = torch.cat([edge_attr, node_attr[src, :self.ns], node_attr[dst, :self.ns]], -1)
            intra_update = self.conv_layers[l](node_attr, edge_index, edge_attr_, edge_sh, edge_weight=edge_weight)


            # padding original features
            node_attr = F.pad(node_attr, (0, intra_update.shape[-1] - node_attr.shape[-1]))

            # update features 
            node_attr = node_attr + intra_update 


        # pool nodes 
        if self.use_set2set_pooling:
            # set2set aggregation
            x = self.pooling(node_attr, data.batch)
        else:
            x = scatter_mean(node_attr,data.batch,dim=0)


        # torsional components
        tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_torsion_conv_graph(data)
        tor_bond_vec = data.pos[tor_bonds[1]] - data.pos[tor_bonds[0]]
        tor_bond_attr = node_attr[tor_bonds[0]] + node_attr[tor_bonds[1]]

        tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
        tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])

        tor_edge_attr = torch.cat([tor_edge_attr, node_attr[tor_edge_index[1], :self.ns],
                                   tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
        
        tor_pred = self.tor_bond_conv(node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,out_nodes=tor_edge_index[0][-1]+1, edge_weight=tor_edge_weight)
        tor_pred = self.final_tor_layer(tor_pred)

        # print(tor_pred[0])
        x = torch.cat([x,scatter_mean(tor_pred,torch.hstack([torch.ones(19, device=data.x.device)*i for i in range(data.batch[-1]+1) ]).long(),dim=0)],dim=1)

        # Latent transform layers
        mu = self.mu_transform(x)
        logvar = self.logvar_transform(x)
        return mu, logvar

    def build_conv_graph(self, data):
     
        # compute edges
        radius_edges = radius_graph(data.pos, self.max_radius, data.batch)

        # concat radius graph edges to original edges
        edge_index = torch.cat([data.edge_index, radius_edges], 1).long()


        # node attributes
        node_attr = data.x



        src, dst = edge_index
        # edge vectors between all nodes (original graph + radius graph)
        edge_vec = data.pos[dst.long()] - data.pos[src.long()]
        # edge distance embedding
        edge_length_emb = self.edge_dist_expansion(edge_vec.norm(dim=-1))

        # edge attributes
        edge_attr = torch.cat([
                    data.edge_attr.unsqueeze(1),
                    torch.zeros(radius_edges.shape[-1], self.in_edge_features, device=data.x.device)
                ], 0)
        edge_attr = torch.cat([edge_attr, edge_length_emb], 1)

        # spherical harmonics of edge vectors, up to degree sh_lmax
        edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')

        # use 1.0 as edge weights for now
        edge_weight = 1.0 

        return node_attr, edge_index, edge_attr, edge_sh, edge_weight


    def build_torsion_conv_graph(self, data):
        # builds the graph for the convolution between the center of the rotatable bonds and the neighbouring nodes
        bonds = data.torsion_indices[:,1:3].T.long()
        bond_pos = (data.pos[bonds[0]] + data.pos[bonds[1]]) / 2
        bond_batch = data.batch[bonds[0]]
        # determine for each bond the atoms that lie within a certain distance
        # edge_index = radius(data.pos, bond_pos, self.max_radius, batch_x=data.batch, batch_y=bond_batch)

        # get indices of the atoms that define the torsions
        n_torsions = int(data.torsion_indices.shape[0] / (data.batch[-1]+1))
        edge_index = torch.hstack([torch.hstack([torch.vstack([(torch.ones(4, device=data.x.device) * i) + b*n_torsions, data.torsion_indices[i]]) for i in range(n_torsions)]) for b in data.batch.unique()]).long().to(data.x.device)
        
        edge_vec = data.pos[edge_index[1]] - bond_pos[edge_index[0]]
        edge_attr = self.edge_dist_expansion(edge_vec.norm(dim=-1))
        
        edge_attr = self.final_edge_embedding(edge_attr)
        edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
        edge_weight = 1.0

        return bonds, edge_index, edge_attr, edge_sh, edge_weight

class GaussianSmearing(torch.nn.Module):
    # used to embed the edge distances
    def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
        super().__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
        self.register_buffer('offset', offset)

    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))
    

class NodeEmbedding(torch.nn.Module):
    def __init__(self, emb_dim, feature_dims=([116],1)):
        # first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features,
        # for now we only have 1 categorical feature (atomic number : 116 possible atoms) and 1 scalar feature (atom mass)
        super(NodeEmbedding, self).__init__()
        self.atom_embedding_list = torch.nn.ModuleList()
        self.num_categorical_features = len(feature_dims[0])
       
        self.additional_features_dim = feature_dims[1] 
        for i, dim in enumerate(feature_dims[0]):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

        if self.additional_features_dim > 0:
            self.additional_features_embedder = torch.nn.Linear(self.additional_features_dim + emb_dim, emb_dim)

    def forward(self, x):
        x_embedding = 0
        assert x.shape[1] == self.num_categorical_features + self.additional_features_dim
        for i in range(self.num_categorical_features):
            x_embedding += self.atom_embedding_list[i](x[:, i].long())

        if self.additional_features_dim > 0:
            x_embedding = self.additional_features_embedder(torch.cat([x_embedding, x[:, self.num_categorical_features:]], axis=1))
        return x_embedding


class SimpleInternalEncoder(torch.nn.Module):
    """Simple internal encoder that uses a MLP to map the internal coordinates to a latent space."""
    
    def __init__(self,coordinate_transform,internal_dim, latent_embedding_size):
        super(SimpleInternalEncoder, self).__init__()
        self.coordinate_transform = coordinate_transform
        self.latent_embedding_size = latent_embedding_size
        self.internal_dim = internal_dim

        self.mu_transform = nn.Sequential(
            nn.Linear(self.internal_dim, self.internal_dim),
            # nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(self.internal_dim, 25),
            # nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(25, self.latent_embedding_size),
        )

        self.logvar_transform = nn.Sequential(
            nn.Linear(self.internal_dim, self.internal_dim),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(self.internal_dim, 25),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(25, self.latent_embedding_size),
        )


    def forward(self, data):
        x = data.pos
        _,(bonds,angles,dihedrals) = self.coordinate_transform.forward(x.view(-1,22 * 3))

        x = torch.cat([bonds,angles,dihedrals],dim=1)
        mu = self.mu_transform(x)
        logvar = self.logvar_transform(x)
        return mu,logvar
    

