from math import pi as PI, sqrt

import torch
import torch.nn as nn

from egnn.egnn_new import EGNN, GNN
from equivariant_diffusion.utils import remove_mean, remove_mean_with_mask


class EGNN_decoder(nn.Module):
    def __init__(self, in_node_nf, context_node_nf, out_node_nf,
                 n_dims, hidden_nf=64, device='cpu',
                 act_fn=torch.nn.SiLU(), n_layers=4, attention=False,
                 tanh=False, mode='egnn_dynamics', norm_constant=0,
                 inv_sublayers=2, sin_embedding=False, normalization_factor=100, aggregation_method='sum',
                 include_atomic_numbers=False, num_edge_types=5, use_rbf=False, transformer_edge_head=False):
        '''
        :param in_node_nf: Number of latent invariant features.
        :param out_node_nf: Number of invariant features.
        '''
        super().__init__()

        include_atomic_numbers = int(include_atomic_numbers)
        # TODO: remove hard-coded number of formal charges=3
        num_classes = out_node_nf - include_atomic_numbers - 3 # one hot formal charges are always included

        # to differentiate between fingerprint conditioning and other
        self.context_node_nf = context_node_nf
        if context_node_nf == 1024:
            context_node_nf = 0

        self.mode = mode
        if mode == 'egnn_dynamics' and n_layers > 0:
            self.egnn = EGNN(
                in_node_nf=in_node_nf + context_node_nf, out_node_nf=hidden_nf, 
                in_edge_nf=1, hidden_nf=hidden_nf, device=device, act_fn=act_fn,
                n_layers=n_layers, attention=attention, tanh=tanh, norm_constant=norm_constant,
                inv_sublayers=inv_sublayers, sin_embedding=sin_embedding,
                normalization_factor=normalization_factor,
                aggregation_method=aggregation_method,
                fingerprint_conditioning=(self.context_node_nf == 1024))
            self.in_node_nf = in_node_nf
            mlp_hidden_nf = hidden_nf
        elif mode == 'gnn_dynamics':
            self.gnn = GNN(
                in_node_nf=in_node_nf + context_node_nf + 3, out_node_nf=hidden_nf + 3, 
                in_edge_nf=0, hidden_nf=hidden_nf, device=device,
                act_fn=act_fn, n_layers=n_layers, attention=attention,
                normalization_factor=normalization_factor, aggregation_method=aggregation_method)
        elif mode == 'egnn_dynamics' and n_layers == 0:
            # no EGNN layers, directly the edge and node heads
            print('Using Decoder with no EGNN layers, just the node and edge heads!')
            self.egnn = None
            mlp_hidden_nf = hidden_nf
            if mlp_hidden_nf == 128:
                mlp_hidden_nf = 256
            hidden_nf = in_node_nf + context_node_nf

        self.use_rbf = use_rbf
        if use_rbf:
            print('Using RBF!')
            num_radial = 16
            self.rbf = BesselBasisLayer(num_radial, cutoff=5.)
            self.lin_rbf = nn.Linear(num_radial, 3+2*hidden_nf, bias=False)
            glorot_orthogonal(self.lin_rbf.weight, scale=2.0)

        self.transformer_edge_head = transformer_edge_head
        if not self.transformer_edge_head:
            self.edge_head = nn.Sequential(
                nn.Linear(3+2*hidden_nf, mlp_hidden_nf),
                act_fn,
                nn.Linear(mlp_hidden_nf, num_edge_types)
            ).to(device)
        else:
            print('Using Transformer Edge head')
            self.transformer_head = nn.TransformerEncoderLayer(
                                d_model=3+2*hidden_nf, nhead=1, 
                                dim_feedforward=128, dropout=0.0, 
                                activation=act_fn, batch_first=True, device=device)
            self.final_lin = nn.Linear(3+2*hidden_nf, num_edge_types)
            

        self.h_head = nn.Sequential(
            nn.Linear(hidden_nf, mlp_hidden_nf),
            act_fn,
            nn.Linear(mlp_hidden_nf, out_node_nf)
        ).to(device)

        self.num_classes = num_classes
        self.include_atomic_numbers = include_atomic_numbers
        self.device = device
        self.n_dims = n_dims
        self._edges_dict = {}
        # self.condition_time = condition_time
        self.num_edge_types = num_edge_types

    def forward(self, t, xh, node_mask, edge_mask, context=None):
        raise NotImplementedError

    def wrap_forward(self, node_mask, edge_mask, context):
        def fwd(time, state):
            return self._forward(time, state, node_mask, edge_mask, context)
        return fwd

    def unwrap_forward(self):
        return self._forward

    def _forward(self, xh, node_mask, edge_mask, context):
        bs, n_nodes, dims = xh.shape
        h_dims = dims - self.n_dims
        edges = self.get_adj_matrix(n_nodes, bs, self.device)
        edges = [x.to(self.device) for x in edges]
        node_mask = node_mask.view(bs*n_nodes, 1)
        edge_mask = edge_mask.view(bs*n_nodes*n_nodes, 1)
        xh = xh.view(bs*n_nodes, -1).clone() * node_mask
        x = xh[:, 0:self.n_dims].clone()
        if h_dims == 0:
            h = torch.ones(bs*n_nodes, 1).to(self.device)
        else:
            h = xh[:, self.n_dims:].clone()

        if self.context_node_nf != 0 and context is not None:
            # We're conditioning, awesome!
            context = context.view(bs*n_nodes, self.context_node_nf)
            # if we're conditioning on fingerprints, we will embed them inside the EGNN
            if self.context_node_nf != 1024:
                h = torch.cat([h, context], dim=1)
                context = None
        else:
            # this means the diffusion model is conditioned on context but not encoder / decoder
            context = None

        if self.mode == 'egnn_dynamics' and self.egnn is not None:
            h_final, x_final = self.egnn(h, x, edges, node_mask=node_mask, edge_mask=edge_mask, context=context)
            vel = x_final * node_mask  # This masking operation is redundant but just in case
        elif self.mode == 'gnn_dynamics':
            xh = torch.cat([x, h], dim=1)
            output = self.gnn(xh, edges, node_mask=node_mask)
            vel = output[:, 0:3] * node_mask
            h_final = output[:, 3:]
        elif self.mode == 'egnn_dynamics' and self.egnn is None:
            h_final, x_final = h, x
            vel = x_final

        else:
            raise Exception("Wrong mode %s" % self.mode)

        vel = vel.view(bs, n_nodes, -1)

        if torch.any(torch.isnan(vel)):
            print('Warning: detected nan in decoder fwd, resetting EGNN output to zero.')
            vel = torch.zeros_like(vel)

        if node_mask is None:
            vel = remove_mean(vel)
        else:
            vel = remove_mean_with_mask(vel, node_mask.view(bs, n_nodes, 1))

        if torch.any(torch.isnan(h_final)):
            print('Warning: detected nan in decoder fwd, resetting EGNN output to zero.')
            h_final = torch.zeros_like(h_final)

        if node_mask is not None:
            h_final = h_final * node_mask
        h_final = h_final.view(bs, n_nodes, -1)

        if not self.transformer_edge_head:
            adj_pred = self.decode_edges_from_xh(vel, h_final, edge_mask)
        else:
            adj_pred = self.decode_edges_from_xh_with_transformer(vel, h_final, node_mask)
        h_pred = self.decode_features_from_h(h_final, node_mask)

        # returns the edge and node predictions as logits (output of linear layer)
        return adj_pred, h_pred

    def decode_edges_from_xh(self, x, h, edge_mask, remove_diagonal=True, enforce_symmetry=True):
        bs, n_nodes, _ = x.shape

        x_a = x.unsqueeze(1)
        x_b = torch.transpose(x_a, 1, 2)
        X = (x_a - x_b) ** 2 # (bs, n_nodes, n_nodes, 3)

        h_repeated_i = h.unsqueeze(2).repeat(1, 1, n_nodes, 1)
        h_repeated_j = h.unsqueeze(1).repeat(1, n_nodes, 1, 1)
        H = torch.cat([h_repeated_i, h_repeated_j], dim=-1) # (bs, n_nodes, n_nodes, 2*nf)

        XH = torch.cat([X, H], dim=-1) # (bs, n_nodes, n_nodes, 3+2*nf)
        XH = XH.view(bs * n_nodes * n_nodes, -1) # (total_n_edges, 3+2*nf)

        if self.use_rbf:
            dist = X.view(bs * n_nodes * n_nodes, -1).sum(dim=-1).sqrt()
            rbf = self.rbf(dist) # (total_n_edges, num_radial)
            rbf[rbf.isnan()] = 1.
            print(f'rbf.min: {rbf.min()}')
            print(f'rbf.max: {rbf.max()}')
            print(f'rbf.mean: {rbf.mean()}')

            rbf = self.lin_rbf(rbf) # (total_n_edges, 3+2*nf)
            print(f'rbf2.min: {rbf.min()}')
            print(f'rbf2.max: {rbf.max()}')
            print(f'rbf2.mean: {rbf.mean()}')

            XH = rbf * XH
            print(f'XH.min: {XH.min()}')
            print(f'XH.max: {XH.max()}')
            print(f'XH.mean: {XH.mean()}')

        # The following code ensures that we only predict the nodes we care about
        # i.e. we ignore the nodes that were padded to all atoms in the batch so that they have the same number of nodes
        XH_indices_non_zero = edge_mask.bool().squeeze()
        # here, mlp maps from 3+2*nf to 5
        XH_processed = self.edge_head(XH[XH_indices_non_zero]) # (total_n_edges, n_edge_types=5) # logits
        _, n_edge_types = XH_processed.shape
        XH_output = torch.zeros((bs * n_nodes * n_nodes, n_edge_types)).to(self.device)
        XH_output[XH_indices_non_zero] = XH_processed
        XH = XH_output

        adj_pred = XH.view(bs, n_nodes, n_nodes, -1) # (bs, n_nodes, n_nodes, n_edge_types), where n_edge_types=5
        if remove_diagonal:
            adj_pred = adj_pred * (1 - torch.eye(n_nodes).unsqueeze(-1).unsqueeze(0).to(self.device))
        if enforce_symmetry:
            # when using h features, we concatenate the nodes features, which makes the feature of edge (i,j) different from (j,i)
            adj_pred = (adj_pred + torch.transpose(adj_pred, 1, 2)) / 2
        return adj_pred
    
    def decode_features_from_h(self, h, node_mask):
        bs, n_nodes, _ = h.shape

        # The following code ensures that we only predict the nodes we care about
        # i.e. we ignore the nodes that were padded to all atoms in the batch so that they have the same number of nodes
        h = h.view(bs*n_nodes, -1)
        h_indices_non_zero = node_mask.bool().squeeze()
        h_processed = self.h_head(h[h_indices_non_zero])
        _, output_dim = h_processed.shape
        h_final_output = torch.zeros((bs*n_nodes, output_dim)).to(self.device)
        h_final_output[h_indices_non_zero] = h_processed
        h = h_final_output
        h = h.view(bs, n_nodes, -1)
        return h

    def get_adj_matrix(self, n_nodes, batch_size, device):
        if n_nodes in self._edges_dict:
            edges_dic_b = self._edges_dict[n_nodes]
            if batch_size in edges_dic_b:
                return edges_dic_b[batch_size]
            else:
                # get edges for a single sample
                rows, cols = [], []
                for batch_idx in range(batch_size):
                    for i in range(n_nodes):
                        for j in range(n_nodes):
                            rows.append(i + batch_idx * n_nodes)
                            cols.append(j + batch_idx * n_nodes)
                edges = [torch.LongTensor(rows).to(device),
                         torch.LongTensor(cols).to(device)]
                edges_dic_b[batch_size] = edges
                return edges
        else:
            self._edges_dict[n_nodes] = {}
            return self.get_adj_matrix(n_nodes, batch_size, device)

    def decode_edges_from_xh_with_transformer(self, x, h, node_mask, remove_diagonal=True, enforce_symmetry=True):
        bs, n_nodes, _ = x.shape
        nf = h.size(2)
        node_mask = node_mask.view(bs, n_nodes)

        x_a = x.unsqueeze(1)
        x_b = torch.transpose(x_a, 1, 2)
        X = (x_a - x_b) ** 2 # (bs, n_nodes, n_nodes, 3)

        h_repeated_i = h.unsqueeze(2).repeat(1, 1, n_nodes, 1)
        h_repeated_j = h.unsqueeze(1).repeat(1, n_nodes, 1, 1)
        H = torch.cat([h_repeated_i, h_repeated_j], dim=-1) # (bs, n_nodes, n_nodes, 2*nf)

        XH = torch.cat([X, H], dim=-1) # (bs, n_nodes, n_nodes, 3+2*nf)

        XH_output = torch.zeros((bs, n_nodes, n_nodes, self.num_edge_types)).to(self.device)
        for i in range(bs):
            current_xh = XH[i] #(n_nodes, n_nodes, 3+2*nf)
            n_actual_nodes = node_mask[i].sum().int().item()
            current_xh = current_xh[:n_actual_nodes, :n_actual_nodes] #(n_actual_nodes, n_actual_nodes, 3+2*nf)

            # remove self-edges
            current_xh = current_xh[~torch.eye(n_actual_nodes, n_actual_nodes, dtype=bool)].view(n_actual_nodes, n_actual_nodes-1, 3+2*nf) #(n_actual_nodes, n_actual_nodes-1, 3+2*nf)
            current_xh = self.transformer_head(current_xh).view(n_actual_nodes*(n_actual_nodes-1), 3+2*nf) #(n_actual_nodes*(n_actual_nodes-1), 3+2*nf)
            current_xh = self.final_lin(current_xh) #(n_actual_nodes*(n_actual_nodes-1), total_n_edges)

            XH_output[i, :n_actual_nodes, :n_actual_nodes][~torch.eye(n_actual_nodes, n_actual_nodes, dtype=bool)] = current_xh

        adj_pred = XH_output # (bs, n_nodes, n_nodes, n_edge_types), where n_edge_types=5
        if remove_diagonal:
            adj_pred = adj_pred * (1 - torch.eye(n_nodes).unsqueeze(-1).unsqueeze(0).to(self.device))
        if enforce_symmetry:
            # when using h features, we concatenate the nodes features, which makes the feature of edge (i,j) different from (j,i)
            adj_pred = (adj_pred + torch.transpose(adj_pred, 1, 2)) / 2
        return adj_pred


class Envelope(torch.nn.Module):
    def __init__(self, exponent: int):
        super().__init__()
        self.p = exponent + 1
        self.a = -(self.p + 1) * (self.p + 2) / 2
        self.b = self.p * (self.p + 2)
        self.c = -self.p * (self.p + 1) / 2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        p, a, b, c = self.p, self.a, self.b, self.c
        x_pow_p0 = x.pow(p - 1)
        x_pow_p1 = x_pow_p0 * x
        x_pow_p2 = x_pow_p1 * x
        env = (1.0 / x + a * x_pow_p0 + b * x_pow_p1 +
                c * x_pow_p2) * (x < 1.0).to(x.dtype)
        env[env.isinf()] = 0.
        return env


class BesselBasisLayer(torch.nn.Module):
    def __init__(self, num_radial: int, cutoff: float = 5.0,
                 envelope_exponent: int = 5):
        super().__init__()
        self.cutoff = cutoff
        #self.envelope = Envelope(envelope_exponent)

        self.freq = torch.nn.Parameter(torch.empty(num_radial))

        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
        self.freq.requires_grad_()

    def forward(self, dist: torch.Tensor) -> torch.Tensor:
        #dist = dist.unsqueeze(-1) / self.cutoff
        #return self.envelope(dist) * (self.freq * dist).sin()
        dist = dist.unsqueeze(-1)
        return sqrt(2. / self.cutoff) / dist * (self.freq * dist / self.cutoff).sin()

def glorot_orthogonal(tensor, scale):
    if tensor is not None:
        torch.nn.init.orthogonal_(tensor.data)
        scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var())
        tensor.data *= scale.sqrt()
