import math
from typing import List, Union

import torch
import torch.nn as nn
from torch_geometric.nn.models.schnet import GaussianSmearing
from torch_geometric.utils import get_laplacian
from torch_scatter import scatter, scatter_add

# from .layers import GatedEquivariantBlock, Scalar
from ..model_utils import ModelOutput
from .layers import PaiNNMessage, PaiNNUpdate
from ..gemnet.layers.radial_basis import RadialBasis
from ..p3m_utils import add_cell_to_data, get_nonpbc_mesh_atom_graph
from ..p3m_utils import FNO3d
from ..p3m_utils import InteractionBlock, MultiheadAttention, EquivariantVector, Scalar
from ..p3m_utils import (get_distances, get_distances_pbc,
                            radius_determinstic, radius_graph_determinstic,
                            radius_graph_pbc, radius_pbc)
from ..phi_module_utils import AlphaNet, laplacian_matvec, block_diag_sparse


class PaiNN_P3M(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.use_phi_module = self.config.model.use_phi_module

        assert not self.config.model.use_pbc, 'PBC is not supported for this model yet'

        self.regress_forces = self.config.training.predict_forces
        self.use_pbc = self.config.model.use_pbc
        self.num_layers = self.config.model.num_layers
        self.num_rbf = self.config.model.p3m_num_rbf
        self.num_filters = self.config.model.p3m_num_filters
        self.hidden_channels = self.config.model.hidden_channels
        self.max_z = self.config.model.max_z
        self.atom_cutoff = self.config.model.cutoff
        self.max_a2a_neighbors = self.config.model.max_neighbors
        self.grid_cutoff = self.config.model.p3m_grid_cutoff
        self.max_a2m_neighbors = self.config.model.p3m_max_a2m_neighbors

        num_grids = self.config.model.p3m_num_grids
        long_type = self.config.model.long_type
        
        if isinstance(num_grids, int):
            self.num_grids = [num_grids, num_grids, num_grids]
        else:
            self.num_grids = num_grids
            
        self.total_num_grids = self.num_grids[0] * self.num_grids[1] * self.num_grids[2]
        
        self.direct_forces = self.config.model.direct_forces
        
        self.embedding = nn.Embedding(self.max_z, self.hidden_channels)
        self.a2a_distance_expansion = RadialBasis(self.num_rbf, self.atom_cutoff)
        self.a2m_distance_expansion = GaussianSmearing(0.0, self.grid_cutoff, self.num_rbf)
        self.m2a_distance_expansion = GaussianSmearing(0.0, self.grid_cutoff, self.num_rbf)
        
        self.sl_block = nn.ModuleList()
        for _ in range(self.num_layers):
            a2m_mp = InteractionBlock(self.hidden_channels, self.num_rbf, self.num_filters, self.grid_cutoff)
            m2a_mp = InteractionBlock(self.hidden_channels, self.num_rbf, self.num_filters, self.grid_cutoff)
            short_mp = PaiNNMessage(self.hidden_channels, self.num_rbf)
            short_update = PaiNNUpdate(self.hidden_channels)
            if long_type == 'FNO':
                long_mp = FNO3d(*self.num_grids,
                    hidden_channels=self.hidden_channels // 2, 
                    in_channels=self.hidden_channels, 
                    out_channels=self.hidden_channels, 
                    n_layers=1,
                    lifting_channels=self.hidden_channels // 2,
                    projection_channels=self.hidden_channels // 2,
                    non_linearity=nn.SiLU(),
                    )
            elif long_type == 'MHA': 
                long_mp = MultiheadAttention(self.hidden_channels, self.hidden_channels, 8)
            else:
                raise ValueError(f'Unknown long range interaction type: {long_type}')
            self.sl_block.append(
                ShortLongMixLayer(
                    self.hidden_channels,
                    a2m_mp,
                    m2a_mp,
                    self.num_grids,
                    short_mp,
                    short_update,
                    long_mp,
                )
            )
        self.out_a_norm = nn.LayerNorm(self.hidden_channels)
        self.out_m_norm = nn.LayerNorm(self.hidden_channels)
        self.a_output = Scalar(self.hidden_channels)
        self.m_output = Scalar(self.hidden_channels)
        
        if self.regress_forces and self.direct_forces:
            self.a_forces_output = EquivariantVector(self.hidden_channels)
            self.m_rbf_proj = nn.Linear(self.num_rbf, self.hidden_channels, bias=False)
            self.m_forces_output = EquivariantVector(self.hidden_channels)

        if self.use_phi_module:
            self.alpha_model = AlphaNet(in_channels=self.hidden_channels, k=self.config.training.k_eigenvalues)
            
        self.reset_parameters()
        
    def reset_parameters(self):
        self.embedding.reset_parameters()
        for layer in self.sl_block:
            layer.reset_parameters()
        self.out_a_norm.reset_parameters()
        self.out_m_norm.reset_parameters()
        self.a_output.reset_parameters()
        self.m_output.reset_parameters()
        if self.regress_forces and self.direct_forces:
            self.a_forces_output.reset_parameters()
            nn.init.xavier_uniform_(self.m_rbf_proj.weight)
            self.m_forces_output.reset_parameters()
    
    def forward(self, data):
        data_cell = add_cell_to_data(data)
        mesh_atom_graph = get_nonpbc_mesh_atom_graph(data_cell, expand_size=self.config.model.p3m_expand_size, 
                                                     num_grids=self.config.model.p3m_num_grids)
        
        if self.regress_forces and not self.direct_forces:
            mesh_atom_graph['atom'].pos.requires_grad_(True)     
               
        bs = mesh_atom_graph.num_graphs if hasattr(mesh_atom_graph['atom'], 'batch') else 1
        
        a_pos = mesh_atom_graph['atom'].pos
        m_pos = mesh_atom_graph['mesh'].pos
        
        if not self.use_pbc:
            a2a_edge_index, _ = radius_graph_determinstic(
                mesh_atom_graph['atom'], 
                self.atom_cutoff, 
                self.max_a2a_neighbors
            )
            a2a_edge_weights, a2a_edge_vectors = get_distances(a2a_edge_index, a_pos, return_distance_vec=True)
            
            a2m_edge_index = radius_determinstic(
                mesh_atom_graph['atom'],
                mesh_atom_graph['mesh'],
                self.grid_cutoff,
                self.max_a2m_neighbors,
            )
            a2m_edge_weights = get_distances(a2m_edge_index, a_pos, m_pos, return_distance_vec=False)
            
            m2a_edge_index = a2m_edge_index.flip(0)
            m2a_edge_weights = get_distances(m2a_edge_index, m_pos, a_pos, return_distance_vec=False)
        else:
            cell = mesh_atom_graph['atom'].cell
            a2a_edge_index, a2a_cell_offsets, a2a_neighbors = radius_graph_pbc(
                mesh_atom_graph['atom'], 
                self.atom_cutoff, 
                self.max_a2a_neighbors,
                symmetrize=True,
            )
            a2a_edge_weights, a2a_edge_vectors = get_distances_pbc(
                a2a_edge_index, 
                cell, 
                a2a_cell_offsets, 
                a2a_neighbors, 
                a_pos, 
                return_distance_vec=True
            )
            
            a2m_edge_index, a2m_cell_offset, a2m_neighbors = radius_pbc(
                mesh_atom_graph['atom'],
                mesh_atom_graph['mesh'],
                self.grid_cutoff,
                self.max_a2m_neighbors,
            )
            
            a2m_edge_weights = get_distances_pbc(
                a2m_edge_index, 
                cell, 
                a2m_cell_offset, 
                a2m_neighbors, 
                a_pos, 
                m_pos,
                return_distance_vec=False
            )
            
            m2a_edge_index = a2m_edge_index.flip(0)
            m2a_cell_offset = -1 * a2m_cell_offset
            m2a_neighbors = a2m_neighbors
            
            m2a_edge_weights = get_distances_pbc(
                m2a_edge_index,
                cell,
                m2a_cell_offset,
                m2a_neighbors,
                m_pos,
                a_pos,
                return_distance_vec=False
            )
            
        a2a_edge_attr = self.a2a_distance_expansion(a2a_edge_weights)
        a2m_edge_attr = self.a2m_distance_expansion(a2m_edge_weights)
        m2a_edge_attr = self.m2a_distance_expansion(m2a_edge_weights)
        
        # N_atoms, F
        a_x = self.embedding(mesh_atom_graph['atom'].z)
        a_v = torch.zeros(a_x.shape[0], 3, a_x.shape[1], device=a_x.device)
        # NonTrainable Message Passing For Initial Mesh Embedding
        a_x_j = torch.index_select(a_x, 0, a2m_edge_index[0])
        m_x = scatter(a_x_j, a2m_edge_index[1], dim=0, reduce='mean', dim_size=self.total_num_grids * bs)

        if self.use_phi_module:
            # Eigenbasis projection
            edge_index_L, Ls = get_laplacian(edge_index=a2a_edge_index, edge_weight=a2a_edge_weights, normalization='sym')
            sparse_diag_block_L = block_diag_sparse(edge_index_L, Ls, data.batch)
            evals, U = torch.lobpcg(sparse_diag_block_L, k=self.config.training.k_eigenvalues, largest=False)
        
        for i in range(self.num_layers):
            a_x, a_v, m_x = self.sl_block[i](
                a_x, 
                a_v,
                m_x, 
                a2a_edge_index, 
                a2m_edge_index, 
                m2a_edge_index,
                a2a_edge_vectors,
                a2m_edge_weights,
                m2a_edge_weights,
                a2a_edge_attr,
                a2m_edge_attr,
                m2a_edge_attr,
            )

            if self.use_phi_module:
                # Compute eigenbasis coefficients "alpha"
                alpha, alpha_rho = self.alpha_model(a_x) 
                
                # Perform spectral projection to accumulate potential and charges
                if i == 0:
                    phi = U @ alpha
                    rho = (U * evals) @ alpha_rho
                else:
                    phi_step = U @ alpha
                    rho_step = (U * evals) @ alpha_rho

                    phi += phi_step
                    rho += rho_step

        if self.use_phi_module:
            # Compute PDE residual
            L_phi = laplacian_matvec(Ls, phi, edge_index_L)
            pde_res = (L_phi - rho).pow(2).mean()

            # Apply optional constraint on net zero charge
            net_charge = torch.abs(scatter_add(rho, data.batch, dim=0)).sum()
            pde_res += self.config.training.net_charge_lambda * net_charge
        else:
            pde_res = None
        
        out_a_x = self.out_a_norm(a_x)
        out_m_x = self.out_m_norm(m_x)
        
        output_a_x = self.a_output(out_a_x)
        energy_a = scatter(
            output_a_x,
            data.batch, # mesh_atom_graph['atom'].batch if hasattr(mesh_atom_graph['atom'], 'batch') else torch.zeros_like(mesh_atom_graph['atom'].atomic_numbers), 
            dim=0, 
            reduce='sum'
        )
        output_m_x = self.m_output(out_m_x)

        energy_m = torch.sum(output_m_x.reshape(bs, -1), dim=-1, keepdim=True)
        energy = energy_a + energy_m

        if self.use_phi_module:
            self.electrostatic_term = 0.5 * (phi * rho).sum()
            energy = energy + self.electrostatic_term 
        
        if self.regress_forces:
            if self.direct_forces:
                a_forces = self.a_forces_output(out_a_x, a_v)
                out_m_x_j = torch.index_select(out_m_x, 0, m2a_edge_index[0])
                out_m_x_j = self.m_rbf_proj(m2a_edge_attr) * out_m_x_j
                out_m_x = scatter(out_m_x_j, m2a_edge_index[1], dim=0, reduce='add', dim_size=a_x.shape[0])
                m_forces = self.m_forces_output(out_m_x, a_v)
                forces = a_forces + m_forces
            else:
                forces = -1 * (
                    torch.autograd.grad(
                        energy,
                        mesh_atom_graph['atom'].pos,
                        grad_outputs=torch.ones_like(energy),
                        create_graph=True,
                    )[0]
                )
            return ModelOutput(out=energy, forces=forces, pde_residual=pde_res)
        else:
            return ModelOutput(out=energy, pde_residual=pde_res)

class ShortLongMixLayer(nn.Module):
    def __init__(
        self,
        hidden_channels: int,
        a2m_mp: nn.Module,
        m2a_mp: nn.Module,
        num_grids: List[int],
        short_mp: nn.Module,
        short_update: nn.Module,
        long_mp: nn.Module,
    ):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.a2m_mp = a2m_mp
        self.m2a_mp = m2a_mp
        self.num_grids = num_grids
        self.short_mp = short_mp
        self.short_update = short_update
        self.long_mp = long_mp
        self.a2m_layernorm = nn.LayerNorm(hidden_channels)
        self.m2a_layernorm = nn.LayerNorm(hidden_channels)
        self.short_layernorm = nn.LayerNorm(hidden_channels)
        self.long_layernorm = nn.LayerNorm(hidden_channels)
    
        self.reset_parameters()
    
    def reset_parameters(self):
        self.a2m_mp.reset_parameters()
        self.m2a_mp.reset_parameters()
        self.short_mp.reset_parameters()
        self.short_update.reset_parameters()
        self.long_mp.reset_parameters()
        self.a2m_layernorm.reset_parameters()
        self.m2a_layernorm.reset_parameters()
        self.short_layernorm.reset_parameters()
        self.long_layernorm.reset_parameters()
    
    def forward(
        self, 
        a_x, 
        a_v,
        m_x,
        a2a_edge_index,
        a2m_edge_index,
        m2a_edge_index,
        a2a_edge_vectors,
        a2m_edge_weights,
        m2a_edge_weights,
        a2a_edge_attr,
        a2m_edge_attr,
        m2a_edge_attr,
    ):

        delta_a_x, delta_m_x = a_x, m_x
        delta_a_v = a_v

        a_x = self.short_layernorm(a_x)
        a_x, a_v = self.short_mp(a_x, a_v, a2a_edge_index, a2a_edge_attr, a2a_edge_vectors)
        a_x = delta_a_x + a_x
        a_v = delta_a_v + a_v
        
        da_x, da_v = self.short_update(a_x, a_v)

        # N_meshs, F
        m_x = self.long_layernorm(m_x)
        if isinstance(self.long_mp, MultiheadAttention):
            m_x = m_x.reshape(-1, torch.prod(torch.tensor(self.num_grids)), self.hidden_channels)
            m_x = self.long_mp(m_x)
            m_x = m_x.reshape(-1, self.hidden_channels)
        else:
            m_x = m_x.reshape(-1, self.num_grids[0], self.num_grids[1], self.num_grids[2], self.hidden_channels).permute(0, 4, 1, 2, 3)
            m_x = self.long_mp(m_x).permute(0, 2, 3, 4, 1).reshape(-1, self.hidden_channels)
        
        # N_meshs, F
        a2m_message = self.a2m_mp(da_x, a2m_edge_index, a2m_edge_weights, a2m_edge_attr, dim_size=m_x.shape[0])
        a2m_message = self.a2m_layernorm(a2m_message)

        # N_atoms, F
        m2a_message = self.m2a_mp(m_x, m2a_edge_index, m2a_edge_weights, m2a_edge_attr, dim_size=a_x.shape[0])
        m2a_message = self.m2a_layernorm(m2a_message)
        
        return a_x + da_x + m2a_message, a_v + da_v, m_x + a2m_message + delta_m_x