import math

import torch
import torch.nn as nn

from torch_geometric.nn import global_add_pool, global_mean_pool
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.models import PNA, GAT
from torch_geometric.utils import remove_self_loops

from pytorch3d.transforms import axis_angle_to_matrix

from torch_scatter import scatter_min, scatter

from .baselines.comenet import angle_emb, torsion_emb, SimpleInteractionBlock, EdgeGraphConv, TwoLayerLinear

from register import MODEL_REGISTRY, ACT_REGISTRY

import ipdb

class SALCWrapper(nn.Module):
    def __init__(self, model_config):
        super(SALCWrapper, self).__init__()

        self.model = SALConv(
                dim_embedding = model_config['dim_embedding'],
                dim_out = model_config['dim_out'],
                act = model_config['act'],
                num_blocks = model_config['num_blocks'],
                num_layers = model_config['num_layers'],
                num_atoms = model_config['num_atoms'],
                num_radial = model_config['num_radial'],
                num_spherical = model_config['num_spherical'],
                deg = model_config['degrees_hist'],
                scalers = model_config['scalers'],
                aggregators = model_config['aggregators'],
                output_type = model_config['output_type']
                )

    def forward(self, data_obj):
        z = data_obj.z
        x = data_obj.x
        asu_edge_index = remove_self_loops(data_obj.asu_edge_index)[0]
        symmetric_edge_index = data_obj.symmetric_edge_index
        projection_edge_index = data_obj.project_edge_index
        pos = data_obj.align_pos # use aligned pos for invariant positional information
        batch = data_obj.batch
        frame_R = data_obj.frame_R
        frame_t = data_obj.frame_t
        return self.model(z, x, asu_edge_index, symmetric_edge_index, projection_edge_index, pos, batch, frame_R, frame_t)

MODEL_REGISTRY.register('salc', SALCWrapper)

class SALConv(nn.Module):
    """ Symmetry Adapted Linear Combination """
    def __init__(self,
                 dim_embedding: int,
                 dim_out: int,
                 act: str='relu',
                 num_blocks: int=4,
                 num_layers:int=4,
                 num_radial:int=3,
                 num_spherical:int=3,
                 num_atoms:int=15,
                 deg=None,
                 scalers=None,
                 aggregators=None,
                 output_type='graph'):

        super(SALConv, self).__init__()
        self.act_fn = ACT_REGISTRY.get(act)
        self.atom_emb = Embedding(num_atoms, dim_embedding-3)
        self.num_blocks = num_blocks

        self.equiv_class_gnns = nn.ModuleList([])
        self.equiv_repr_gnns = nn.ModuleList([])
        self.align_optims = nn.ModuleList([])
        self.attn_gnns = nn.ModuleList([])
        self.layer_norms = nn.ModuleList([])

        for _ in range(self.num_blocks):
            self.layer_norms.append(nn.LayerNorm(dim_embedding))
            self.align_optims.append(OptimAlignment(dim_embedding))
            self.equiv_class_gnns.append(SymmElemGNN(dim_embedding=dim_embedding,
                                                     act=act,
                                                     num_layers=num_layers,
                                                     num_radial=num_radial,
                                                     num_spherical=num_spherical,
                                                     deg=deg,
                                                     scalers=scalers,
                                                     aggregators=aggregators))
            self.equiv_repr_gnns.append(ASUConv(num_radial=num_radial,
                                                    num_spherical=num_spherical,
                                                    dim_embedding=dim_embedding,
                                                    num_layers=num_layers))


        self.output = OutputBlock(dim_embedding, dim_out, output_type=output_type)

    def forward(self, z, x, asu_edge_index, symmetric_edge_index, projection_edge_index, pos, batch, frame_R, frame_t):
        pos.requires_grad = True
        asu_elems = asu_edge_index[0].unique()
        x = self.atom_emb(z)
        x = torch.cat([x, pos], dim=-1) # invariant by optimized alignment

        # Symmetry Adapted Linear Combination
        for i in range(self.num_blocks):
            inv_basis = self.equiv_class_gnns[i](x, symmetric_edge_index, pos) # Step 1: Symmetry Adapted Basis
            pos = self.align_optims[i](inv_basis, asu_elems, pos, batch) # Step 2: Optimize Alignment
            irrep_coeff = self.equiv_repr_gnns[i](pos, asu_edge_index, pos, batch) # Step 3: Symmetry Adapted Representation
            x = torch.tanh(irrep_coeff) * inv_basis  # Step 3: Apply Coefficients
            x = self.layer_norms[i](x)
        
        return self.output(x, symmetric_edge_index, asu_elems, projection_edge_index, pos, batch, frame_R, frame_t)


class OutputBlock(nn.Module):
    def __init__(self, dim_embedding, dim_out, output_type='graph'):
        super(OutputBlock, self).__init__()
        self.output_type = output_type
        if self.output_type in ['graph','graph-all']:
            self.fc = nn.Linear(dim_embedding, dim_out)
        elif self.output_type == 'eqv-node':
            #self.econv = EdgeConv(dim_embedding, dim_out)
            self.fc = nn.Linear(dim_embedding, 3)
        elif self.output_type == 'inv-node':
            self.fc = nn.Linear(dim_embedding, dim_out)
        elif self.output_type == 'forces':
            self.fc = nn.Linear(dim_embedding, 1)
        else:
            raise ValueError(f'Output type {self.output_type} not recognized')

    def forward(self, x, symmetric_edge_index, asu_elems, projection_edge_index, pos, batch,  frame_R, frame_t):
        if self.output_type == 'graph':
            graph_out = global_add_pool(x[asu_elems], batch[asu_elems]) 
            return self.fc(graph_out)
        elif self.output_type == 'graph-all':
            graph_out = global_add_pool(x, batch) 
            return self.fc(graph_out)
        elif self.output_type == 'eqv-node':
            #x = torch.cat([x, pos], dim=-1)
            x = self.fc(x)
            #batch_frame = frame_R.view(-1, 3, 3)[batch]
            #batch_frame_t = frame_t.view(-1, 3)[batch]
            #x =  torch.matmul(torch.linalg.inv(batch_frame), x.unsqueeze(2)).squeeze() + batch_frame_t
            return x
        elif self.output_type == 'inv-node':
            x = self.fc(x)
            x[projection_edge_index[1]] = x[projection_edge_index[0]]
            return x
        elif self.output_type == 'forces':
            energy_contributions = self.fc(x)
            total_energy = global_add_pool(energy_contributions, batch)
            forces = -torch.autograd.grad(total_energy, pos, create_graph=True)[0]
            return forces.view(-1)


class Embedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
    def forward(self, x):
        return self.embedding(x)

class SymmElemGNN(nn.Module):
    def __init__(self, dim_embedding, act, num_layers, num_radial, num_spherical, deg, scalers, aggregators):
        super(SymmElemGNN, self).__init__()
        self.gnn = PNA(in_channels=dim_embedding,
                       deg=deg,
                       num_layers=num_layers,
                       towers=2, # maybe add some channeling
                       hidden_channels=dim_embedding,
                       out_channels=dim_embedding,
                       edge_dim=num_radial*num_spherical**2 + num_radial*num_spherical + 3,
                       scalers=scalers,
                       aggregators=aggregators,
                       act='tanh')
        self.feature1 = torsion_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=8.0)
        self.feature2 = angle_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=8.0)
        self.cutoff = 8.0
        self.act = nn.SiLU()

    def forward(self, x, symm_elem_edge_index, pos):
        row, col = symm_elem_edge_index
        assert not torch.any(torch.isnan(x)), f'NaN in x, {x}'
        dist, theta, phi, tau = extract_geometry(x, symm_elem_edge_index, pos, self.cutoff)
        assert not any(torch.isnan(dist) | torch.isnan(theta) | torch.isnan(phi) | torch.isnan(tau)), 'NaN in geometry'
        
        feature1 = self.feature1(dist, theta, phi)
        feature2 = self.feature2(dist, tau)
        feature3 = pos[col] - pos[row]
        edge_attr = torch.tanh(torch.cat([feature1, feature2, feature3], dim=-1))
        return self.gnn(x, symm_elem_edge_index, edge_attr=edge_attr)

class ASUConv(nn.Module):
    """ Asymmetric Unit Convolution """
    def __init__(self, num_radial, num_spherical, dim_embedding, num_layers):
        super(ASUConv, self).__init__()
        self.cutoff = 8.0

        self.two_layer_linear = TwoLayerLinear(3, dim_embedding, dim_embedding)

        self.feature1 = torsion_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=8.0)
        self.feature2 = angle_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=8.0)
        self.interaction_blocks = nn.ModuleList([SimpleInteractionBlock(hidden_channels=dim_embedding, middle_channels=dim_embedding, num_radial=num_radial, num_spherical=num_spherical, num_layers=num_layers, output_channels=dim_embedding) for _ in range(1)])

    def forward(self, x, asu_edge_index, pos, batch):

        x = self.two_layer_linear(x)

        dist, theta, phi, tau = extract_geometry(x, asu_edge_index, pos, self.cutoff)
        assert not any(torch.isnan(dist) | torch.isnan(theta) | torch.isnan(phi) | torch.isnan(tau)), 'NaN in geometry'
        
        feature1 = self.feature1(dist, theta, phi)
        feature2 = self.feature2(dist, tau)

        # Interaction blocks.
        for interaction_block in self.interaction_blocks:
            x = interaction_block(x, feature1, feature2, asu_edge_index, batch)
        return x


class OptimAlignment(nn.Module):
    def __init__(self, dim_embedding):
        super(OptimAlignment, self).__init__()
        self.fc = nn.Linear(dim_embedding, 3)
    def forward(self, x, asu_elems, pos, batch):
        graph_out = global_mean_pool(x[asu_elems], batch[asu_elems]) # Step 4: IRRep Linear Combination
        graph_angles = self.fc(torch.tanh(graph_out))
        R = axis_angle_to_matrix(graph_angles)
        #R = construct_rotational_frame(graph_angles)
        batch_R = R[batch]
        pos = torch.bmm(batch_R, pos.unsqueeze(-1)).squeeze(-1)
        return pos

def construct_rotational_frame(x):
    """ Construct a rotational frame from theta_z, theta_y, phi_z as features of a node """
    theta_z, theta_y, phi_z = x[:, 0], x[:, 1], x[:, 2]

    theta_z_frame = torch.zeros(x.size(0), 3, 3)
    cos_theta_z, sin_theta_z = torch.cos(theta_z), torch.sin(theta_z)

    theta_z_frame[:, 0, 0] = cos_theta_z
    theta_z_frame[:, 0, 1] = -sin_theta_z
    theta_z_frame[:, 1, 0] = sin_theta_z
    theta_z_frame[:, 1, 1] = cos_theta_z
    theta_z_frame[:, 2, 2] = 1

    theta_y_frame = torch.zeros(x.size(0), 3, 3)
    cos_theta_y, sin_theta_y = torch.cos(theta_y), torch.sin(theta_y)

    theta_y_frame[:, 0, 0] = cos_theta_y
    theta_y_frame[:, 0, 2] = sin_theta_y
    theta_y_frame[:, 1, 1] = 1
    theta_y_frame[:, 2, 0] = -sin_theta_y
    theta_y_frame[:, 2, 2] = cos_theta_y


    phi_z_frame = torch.zeros(x.size(0), 3, 3)
    cos_phi_z, sin_phi_z = torch.cos(phi_z), torch.sin(phi_z)

    phi_z_frame[:, 0, 0] = cos_phi_z
    phi_z_frame[:, 0, 1] = -sin_phi_z
    phi_z_frame[:, 1, 0] = sin_phi_z
    phi_z_frame[:, 1, 1] = cos_phi_z
    phi_z_frame[:, 2, 2] = 1

    return torch.bmm(theta_z_frame, torch.bmm(theta_y_frame, phi_z_frame)).to(x.device)


def extract_geometry(x, edge_index, pos, cutoff):
        j,i = edge_index
        vecs = pos[j] - pos[i]
        dist = vecs.norm(dim=-1)
        num_nodes = x.size(0)

        # Calculate distances.
        _, argmin0 = scatter_min(dist, i, dim_size=num_nodes)
        argmin0[argmin0 >= len(i)] = 0
        n0 = j[argmin0]
        add = torch.zeros_like(dist).to(dist.device)
        add[argmin0] = cutoff
        dist1 = dist + add

        _, argmin1 = scatter_min(dist1, i, dim_size=num_nodes)
        argmin1[argmin1 >= len(i)] = 0
        n1 = j[argmin1]
        # --------------------------------------------------------

        _, argmin0_j = scatter_min(dist, j, dim_size=num_nodes)
        argmin0_j[argmin0_j >= len(j)] = 0
        n0_j = i[argmin0_j]

        add_j = torch.zeros_like(dist).to(dist.device)
        add_j[argmin0_j] = cutoff
        dist1_j = dist + add_j

        # i[argmin] = range(0, num_nodes)
        _, argmin1_j = scatter_min(dist1_j, j, dim_size=num_nodes)
        argmin1_j[argmin1_j >= len(j)] = 0
        n1_j = i[argmin1_j]

        # ----------------------------------------------------------

        # n0, n1 for i
        n0 = n0[i]
        n1 = n1[i]

        # n0, n1 for j
        n0_j = n0_j[j]
        n1_j = n1_j[j]

        # tau: (iref, i, j, jref)
        # when compute tau, do not use n0, n0_j as ref for i and j,
        # because if n0 = j, or n0_j = i, the computed tau is zero
        # so if n0 = j, we choose iref = n1
        # if n0_j = i, we choose jref = n1_j
        mask_iref = n0 == j
        iref = torch.clone(n0)
        iref[mask_iref] = n1[mask_iref]
        idx_iref = argmin0[i]
        idx_iref[mask_iref] = argmin1[i][mask_iref]

        mask_jref = n0_j == i
        jref = torch.clone(n0_j)
        jref[mask_jref] = n1_j[mask_jref]
        idx_jref = argmin0_j[j]
        idx_jref[mask_jref] = argmin1_j[j][mask_jref]

        pos_ji, pos_in0, pos_in1, pos_iref, pos_jref_j = (
            vecs,
            vecs[argmin0][i],
            vecs[argmin1][i],
            vecs[idx_iref],
            vecs[idx_jref]
        )

        epsilon = 1e-6
        pos_ji = pos_ji + epsilon
        # Calculate angles.
        a = ((-pos_ji) * pos_in0).sum(dim=-1)
        b = torch.cross(-pos_ji, pos_in0).norm(dim=-1)
        theta = torch.atan2(b, a+epsilon)
        theta[theta < 0] = theta[theta < 0] + math.pi

        # Calculate torsions.
        dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt()
        plane1 = torch.cross(-pos_ji, pos_in0)
        plane2 = torch.cross(-pos_ji, pos_in1)
        a = (plane1 * plane2).sum(dim=-1)  # cos_angle * |plane1| * |plane2|
        b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / (dist_ji+1e-16)
        phi = torch.atan2(b, a+epsilon)
        phi[phi < 0] = phi[phi < 0] + math.pi

        # Calculate right torsions.
        plane1 = torch.cross(pos_ji, pos_jref_j)
        plane2 = torch.cross(pos_ji, pos_iref)
        a = (plane1 * plane2).sum(dim=-1)  # cos_angle * |plane1| * |plane2|
        b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / (dist_ji+1e-16)
        tau = torch.atan2(b, a+epsilon)
        tau[tau < 0] = tau[tau < 0] + math.pi

        return dist, theta, phi, tau
