import sys
sys.path.append('..')

import math
import logging
log = logging.getLogger(__name__)

import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from src.utils import math_utils
import copy
from src.neuralnet.layers import XEyTransformerLayer

from src.utils import graph

from torch_geometric.utils.sparse import dense_to_sparse
from torch_geometric.utils import (
    get_laplacian,
    to_scipy_sparse_matrix,
)
from scipy.sparse.linalg import eigsh
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SinusoidalPosEmb(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        # Input: (bs, )
        # Output: (bs, dim)
        # Can use the batch size dimension to store other stuff as well
        x = x.squeeze() * 1000
        assert len(x.shape) == 1
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim) * -emb)
        emb = emb.type_as(x)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

def laplacian_eigenvectors_scipy(E, k):
    """
    Computes the eigenvectors of the Laplacian matrix of a graph.

    Parameters:
    E (np.ndarray): A dense adjacency matrix of shape (n, n) where n is the number of nodes in the graph.
    k (int): The number of eigenvectors to compute.

    Returns:
    torch.Tensor: A tensor of shape (n, k) containing the real parts of the k eigenvectors of the Laplacian matrix, excluding the first eigenvector.
    """
    edge_index, edge_attr = dense_to_sparse(E)
    num_nodes = E.shape[-1]

    L_edge_index, L_edge_weight = get_laplacian(
        edge_index,
        normalization='sym',
        num_nodes=num_nodes
    )
    
    L = to_scipy_sparse_matrix(L_edge_index, L_edge_weight, num_nodes)

    eig_vals, eig_vecs = eigsh(
        L,
        k=k+1,
        which='SA',
        return_eigenvectors=True,
        ncv=min(E.shape[0], max(20*k + 1, 40))
    )

    eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()])
    pe = torch.from_numpy(eig_vecs[:, 1:k + 1])
    sign = -1 + 2 * torch.randint(0, 2, (k, ))
    pe *= sign
    return pe

def laplacian_eigenvectors_gpu(E, k):
    """
    Computes the eigenvectors of the Laplacian matrix of a graph using PyTorch and CUDA.

    Parameters:
    E (torch.Tensor): A dense adjacency matrix of shape (n, n) where n is the number of nodes in the graph.
    k (int): The number of eigenvectors to compute.

    Returns:
    torch.Tensor: A tensor of shape (n, k) containing the real parts of the k eigenvectors of the Laplacian matrix, excluding the first eigenvector.
    """

    num_nodes = E.shape[-1]

    edge_indices = torch.nonzero(E, as_tuple=False).t()
    L_edge_index, L_edge_weight = get_laplacian(edge_indices, normalization='sym', num_nodes=num_nodes)

    L_dense = torch.zeros((num_nodes, num_nodes), device=E.device)
    L_dense[L_edge_index[0], L_edge_index[1]] = L_edge_weight

    eig_vals, eig_vecs = torch.linalg.eigh(L_dense)

    # Sort eigenvectors based on eigenvalues and exclude the first eigenvector
    sorted_indices = torch.argsort(eig_vals)
    pe = eig_vecs[:, sorted_indices[1:k + 1]]

    # Apply random sign flipping
    sign = -1 + 2 * torch.randint(0, 2, (k, ), device=E.device)
    pe *= sign

    return pe

class PositionalEmbedding(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def matched_positional_encodings_laplacian(self, E, atom_map_numbers, molecule_assignments, max_eigenvectors):
        assert molecule_assignments != None, 'Need to provide molecule assigments'
        assert len(E.shape) == 3 and E.shape[1] == E.shape[2], 'E should not be in one-hot format'
        bs, n = atom_map_numbers.shape[0], atom_map_numbers.shape[1]
        E = E.clone() # TODO: Could also just change the different bond types all to 1
        molecule_assignments = molecule_assignments.cpu()
        atom_map_numbers = atom_map_numbers.cpu()

        product_indices = molecule_assignments.max(-1).values
        
        pos_embeddings = torch.zeros((bs, n, self.dim), dtype=torch.float32, device=E.device)
        
        for i in range(bs):
            E[i, product_indices[i] != molecule_assignments[i], :] = 0
            E[i, :, product_indices[i] != molecule_assignments[i]] = 0

            k = min(max_eigenvectors, (product_indices[i] == molecule_assignments[i]).sum().item())
            pe = laplacian_eigenvectors_scipy(E[i], k) # Shape (bs, n, k)
            pe = torch.cat([pe, torch.zeros((n, self.dim - k), dtype=torch.float32, device=pe.device)], dim=-1)
            
            # Create a mapping from atom map number to the interesting positional encodings (zero pos enc for non-atom mapped atoms)
            pos_embs_prod = torch.cat([torch.zeros(1, self.dim), pe[molecule_assignments[i] == product_indices[i]]], 0)
            
            indices_of_product = (molecule_assignments[i] == product_indices[i]).nonzero().squeeze(1)
            max_atom_map = atom_map_numbers[i, :].max() # Need this because sometimes some of the atom maps are missing in the data
            # The atom map numbers of the product
            atom_map_nums_product = torch.cat([torch.tensor([0]), atom_map_numbers[i, indices_of_product]]) # The zero is for the non-atom mapped atoms
            atom_map_to_idx = torch.zeros(max_atom_map + 1, dtype=torch.long)
            atom_map_to_idx.scatter_(dim=0, index=atom_map_nums_product, src=torch.arange(len(atom_map_nums_product)))

            pos_embeddings[i] = pos_embs_prod[atom_map_to_idx[atom_map_numbers[i]]]

        return pos_embeddings

    def matched_positional_encodings_laplacian_gpu(self, E, atom_map_numbers, molecule_assignments, max_eigenvectors):
        assert molecule_assignments != None, 'Need to provide molecule assigments'
        assert len(E.shape) == 3 and E.shape[1] == E.shape[2], 'E should not be in one-hot format'
        bs, n = atom_map_numbers.shape[0], atom_map_numbers.shape[1]
        E = E.clone() # TODO: Could also just change the different bond types all to 1
        molecule_assignments = molecule_assignments.cpu()
        atom_map_numbers = atom_map_numbers.cpu()
        
        perm = torch.arange(atom_map_numbers.max().item()+1)[1:]
        perm = perm[torch.randperm(len(perm))]
        perm = torch.cat([torch.zeros(1, dtype=torch.long), perm])
        atom_map_numbers = perm[atom_map_numbers]

        product_indices = molecule_assignments.max(-1).values            
        pos_embeddings = torch.zeros((bs, n, self.dim), dtype=torch.float32, device=E.device)
        
        for i in range(bs):
            E[i, product_indices[i] != molecule_assignments[i], :] = 0
            E[i, :, product_indices[i] != molecule_assignments[i]] = 0
            k = min(max_eigenvectors, (product_indices[i] == molecule_assignments[i]).sum().item())
            pe = laplacian_eigenvectors_gpu(E[i], k) # Shape (bs, n, k) <- can this be made faster by only calculating for the product?
            pe = torch.cat([pe, torch.zeros((n, self.dim - k), dtype=torch.float32, device=pe.device)], dim=-1)
            pos_embs_prod = torch.cat([torch.zeros(1, self.dim, device=device), pe[molecule_assignments[i] == product_indices[i]]], 0)
            indices_of_product = (molecule_assignments[i] == product_indices[i]).nonzero().squeeze(1)
            max_atom_map = atom_map_numbers[i, :].max() # Need this because sometimes some of the atom maps are missing in the data
            # The atom map numbers of the product
            atom_map_nums_product = torch.cat([torch.tensor([0]), atom_map_numbers[i, indices_of_product]]) # The zero is for the non-atom mapped atoms
            atom_map_to_idx = torch.zeros(max_atom_map + 1, dtype=torch.long)
            atom_map_to_idx.scatter_(dim=0, index=atom_map_nums_product, src=torch.arange(len(atom_map_nums_product)))
            pos_embeddings[i] = pos_embs_prod[atom_map_to_idx[atom_map_numbers[i]]]

        return pos_embeddings

def get_X_prods_and_E_prods_aligned(mol_assignments, atom_map_numbers, orig_X, orig_E, out_dim_E, out_dim_X, device):
    bs = orig_X.shape[0]
    prod_assignment = mol_assignments.max(-1).values
    atom_map_numbers_prod, atom_map_numbers_rct = atom_map_numbers.clone(), atom_map_numbers.clone()
    atom_map_numbers_prod[mol_assignments < prod_assignment[:,None]] = 0
    atom_map_numbers_rct[mol_assignments == prod_assignment[:,None]] = 0
    # The next picks out the relevant indices, they are of different lengths for different elements in the batch
    atom_map_numbers_prod_idxs = [torch.arange(atom_map_numbers.shape[-1], device=device)[atom_map_numbers_prod[i]>0] for i in range(bs)]
    atom_map_numbers_rct_idxs = [torch.arange(atom_map_numbers.shape[-1], device=device)[atom_map_numbers_rct[i]>0] for i in range(bs)]
    E_prods_atom_mapped = [
        orig_E[:,:,:,:out_dim_E][i,atom_map_numbers_prod_idxs[i]][:, atom_map_numbers_prod_idxs[i]].unsqueeze(0)
        for i in range(bs)]
            
    assert all(len(atom_map_numbers_prod_idxs[i]) == len(atom_map_numbers_rct_idxs[i]) for i in range(bs))
    
    Ps = [math_utils.create_permutation_matrix_torch(atom_map_numbers_prod[i][atom_map_numbers_prod_idxs[i]],
                                        atom_map_numbers_rct[i][atom_map_numbers_rct_idxs[i]]).float().to(device)
                                        for i in range(bs)]
    P_expanded = [P.unsqueeze(0) for P in Ps] # The unsqueeze will be unnecessary with proper batching here
    # Permute the edges obtained from the product: P @ E @ P^T
    E_prods_am_permuted = [torch.movedim(P_expanded[i].transpose(dim0=1,dim1=2) @ torch.movedim(E_prods_atom_mapped[i].float(), -1, 1) @ P_expanded[i], 1, -1) for i in range(bs)]

    # Do the same for X
    # The selection chooses the correct atom map numbers
    X_prods_atom_mapped = [orig_X[i,:,:out_dim_X][atom_map_numbers_prod_idxs[i]].unsqueeze(0) for i in range(bs)]
    # need to unsqueeze to do batched matrix multiplication correctly: (bs,N,N) @ (bs,N,1) -> (bs,N,1). (N is the count of atom mapped nodes)
    X_prods_am_permuted = [P_expanded[i].transpose(dim0=1,dim1=2) @ X_prods_atom_mapped[i] for i in range(bs)]
    return X_prods_am_permuted, E_prods_am_permuted, atom_map_numbers_rct

class GraphTransformerWithY(nn.Module):
    """
    Permutation equivariant model
    n_layers : int -- number of layers
    dims : dict -- contains dimensions for each feature type
    """
    def __init__(self, n_layers: int, input_dims: dict, hidden_mlp_dims: dict, hidden_dims: dict,
                 output_dims: dict, act_fn_in: nn.ReLU(), act_fn_out: nn.ReLU(),
                 dropout=0.1):
        super().__init__()
        self.n_layers = n_layers
        self.out_dim_X = output_dims['X']
        self.out_dim_E = output_dims['E']
        self.out_dim_y = output_dims['y']

        self.mlp_in_X = nn.Sequential(nn.Linear(input_dims['X'], hidden_mlp_dims['X']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['X'], hidden_dims['dx']), act_fn_in)

        self.mlp_in_E = nn.Sequential(nn.Linear(input_dims['E'], hidden_mlp_dims['E']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['E'], hidden_dims['de']), act_fn_in)

        self.mlp_in_y = nn.Sequential(nn.Linear(input_dims['y'], hidden_mlp_dims['y']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['y'], hidden_dims['dy']), act_fn_in)

        self.tf_layers = nn.ModuleList([XEyTransformerLayer(dx=hidden_dims['dx'],
                                                            de=hidden_dims['de'],
                                                            dy=hidden_dims['dy'],
                                                            n_head=hidden_dims['n_head'],
                                                            dim_ffX=hidden_dims['dim_ffX'],
                                                            dim_ffE=hidden_dims['dim_ffE'], dropout=dropout)
                                        for i in range(n_layers)])

        self.mlp_out_X = nn.Sequential(nn.Linear(hidden_dims['dx'], hidden_mlp_dims['X']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['X'], output_dims['X']))

        self.mlp_out_E = nn.Sequential(nn.Linear(hidden_dims['de'], hidden_mlp_dims['E']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['E'], output_dims['E']))
        
        self.mlp_out_y = nn.Sequential(nn.Linear(hidden_dims['dy'], hidden_mlp_dims['y']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['y'], output_dims['y']))

    def forward(self, X, E, y, node_mask):

        bs, n = X.shape[0], X.shape[1]

        diag_mask = torch.eye(n)
        diag_mask = ~diag_mask.type_as(E).bool()
        diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)

        X_to_out = X[..., :self.out_dim_X]
        E_to_out = E[..., :self.out_dim_E]
        y_to_out = y[..., :self.out_dim_y]

        new_E = self.mlp_in_E(E)
        new_E = (new_E + new_E.transpose(1, 2)) / 2

        # all mask padding nodes (with node_mask)
        after_in = graph.PlaceHolder(X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)).mask(node_mask)
        X, E, y = after_in.X, after_in.E, after_in.y

        for layer in self.tf_layers:
            X, E, y = layer(X, E, y, node_mask)

        X = self.mlp_out_X(X)
        E = self.mlp_out_E(E)
        y = y[..., :self.out_dim_y]
    
        X = (X + X_to_out)
        E = (E + E_to_out) * diag_mask
        y = y + y_to_out

        E = 1/2 * (E + torch.transpose(E, 1, 2))

        return X, E, y, node_mask

class GraphTransformerWithYAtomMapPosEmb(nn.Module):
    """
    Neural net for all the aligned models where the reactant and product graphs are joined together into a larger graph with adjacency matrix of size (N_x+N_y, N_x+N_y)
    n_layers : int -- number of layers
    dims : dict -- contains dimensions for each feature type
    """
    def __init__(self, n_layers: int, input_dims: dict, hidden_mlp_dims: dict, hidden_dims: dict,
                 output_dims: dict, act_fn_in: nn.ReLU(), act_fn_out: nn.ReLU(), pos_emb_permutations: int = 0,
                 dropout=0.1, p_to_r_skip_connection=False, p_to_r_init=10., input_alignment=False):
        super().__init__()
        self.n_layers = n_layers
        self.out_dim_X = output_dims['X']
        self.out_dim_E = output_dims['E']
        self.out_dim_y = output_dims['y']
        self.pos_emb_permutations = pos_emb_permutations
        self.p_to_r_skip_connection = p_to_r_skip_connection
        self.input_alignment = input_alignment
        if input_alignment:
            input_dims = copy.deepcopy(input_dims)
            original_data_feature_dim_X = output_dims['X']
            original_data_feature_dim_E = output_dims['E']
            input_dims['X'] += original_data_feature_dim_X # make the input feature dimensionality larger to include the aligned & concatenated product conditioning
            input_dims['E'] += original_data_feature_dim_E 

        self.mlp_in_X = nn.Sequential(nn.Linear(input_dims['X'], hidden_mlp_dims['X']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['X'], hidden_dims['dx']), act_fn_in)

        self.mlp_in_E = nn.Sequential(nn.Linear(input_dims['E'], hidden_mlp_dims['E']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['E'], hidden_dims['de']), act_fn_in)

        self.mlp_in_y = nn.Sequential(nn.Linear(input_dims['y'], hidden_mlp_dims['y']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['y'], hidden_dims['dy']), act_fn_in)

        self.tf_layers = nn.ModuleList([XEyTransformerLayer(dx=hidden_dims['dx'],
                                                            de=hidden_dims['de'],
                                                            dy=hidden_dims['dy'],
                                                            n_head=hidden_dims['n_head'],
                                                            dim_ffX=hidden_dims['dim_ffX'],
                                                            dim_ffE=hidden_dims['dim_ffE'], dropout=dropout)
                                        for i in range(n_layers)])

        self.mlp_out_X = nn.Sequential(nn.Linear(hidden_dims['dx'], hidden_mlp_dims['X']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['X'], output_dims['X']))

        self.mlp_out_E = nn.Sequential(nn.Linear(hidden_dims['de'], hidden_mlp_dims['E']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['E'], output_dims['E']))
        
        self.mlp_out_y = nn.Sequential(nn.Linear(hidden_dims['dy'], hidden_mlp_dims['y']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['y'], output_dims['y']))
        
        if self.p_to_r_skip_connection:
            self.skip_scaling = nn.Parameter(torch.tensor([p_to_r_init], dtype=torch.float))
            self.skip_scaling_2 = nn.Parameter(torch.tensor([p_to_r_init], dtype=torch.float))
            self.skip_scaling_3 = nn.Parameter(torch.tensor([1.], dtype=torch.float))

    def forward(self, X, E, y, node_mask, atom_map_numbers, pos_encodings, mol_assignments):

        assert atom_map_numbers is not None
        bs, n = X.shape[0], X.shape[1]
        device = X.device

        orig_E = E.clone()
        orig_X = X.clone()

        # Potential edge input alignment from the product side to reactant side
        if self.input_alignment:
            X_prods_am_permuted, E_prods_am_permuted, atom_map_numbers_rct = get_X_prods_and_E_prods_aligned(mol_assignments, atom_map_numbers, orig_X, orig_E, self.out_dim_E, self.out_dim_X, device)
            X_prods_am_permuted, E_prods_am_permuted
            X_to_concatenate = torch.zeros(X.shape[0], X.shape[1], self.out_dim_X, device=device)
            E_to_concatenate = torch.zeros(E.shape[0], E.shape[1], E.shape[2], self.out_dim_E, device=device)
            for i in range(bs):
                # The following is used for choosing which parts to change in the output
                am_rct_selection = (atom_map_numbers_rct[i] > 0)
                E_to_concatenate[i, am_rct_selection[:,None] * am_rct_selection[None,:]] += E_prods_am_permuted[i].reshape(
                                                                                                    E_prods_am_permuted[i].shape[1]*E_prods_am_permuted[i].shape[2],
                                                                                                    E_prods_am_permuted[i].shape[3]).float()
                E_to_concatenate[i, ~(am_rct_selection[:,None]*am_rct_selection[None,:])] += F.one_hot(torch.tensor([0], dtype=torch.long, device=device), E.shape[-1]).float()
            for i in range(bs):
                am_rct_selection = (atom_map_numbers_rct[i] > 0)
                X_to_concatenate[i, am_rct_selection] += X_prods_am_permuted[i].squeeze(0).float()
            X = torch.cat([X, X_to_concatenate], dim=-1)
            E = torch.cat([E, E_to_concatenate], dim=-1)

        diag_mask = torch.eye(n)
        diag_mask = ~diag_mask.type_as(E).bool()
        diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)

        X_to_out = X[..., :self.out_dim_X]
        E_to_out = E[..., :self.out_dim_E]
        y_to_out = y[..., :self.out_dim_y]

        new_E = self.mlp_in_E(E)
        new_E = (new_E + new_E.transpose(1, 2)) / 2

        # all mask padding nodes (with node_mask)
        after_in = graph.PlaceHolder(X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)).mask(node_mask)
        X, E, y = after_in.X, after_in.E, after_in.y

        # Add the positional encoding to X. X shape is now (bs, n, dx)
        X = X + pos_encodings

        for layer in self.tf_layers:
            X, E, y = layer(X, E, y, node_mask)

        X = self.mlp_out_X(X)
        E = self.mlp_out_E(E)
        y = y[..., :self.out_dim_y] # TODO: Changed

        X = (X + X_to_out)
        E = (E + E_to_out) * diag_mask
        y = y + y_to_out

        E = 1/2 * (E + torch.transpose(E, 1, 2))

        # Potential edge-skip connection from product side to reactant side
        if self.p_to_r_skip_connection:
            X_prods_am_permuted, E_prods_am_permuted, atom_map_numbers_rct = get_X_prods_and_E_prods_aligned(mol_assignments, atom_map_numbers, orig_X, orig_E, self.out_dim_E, self.out_dim_X, device)
            
            if not self.input_alignment: # if this stuff wasn't done already
                for i in range(bs):
                    # The following is used for choosing which parts to change in the output
                    am_rct_selection = (atom_map_numbers_rct[i] > 0)
                    E[i, am_rct_selection[:,None] * am_rct_selection[None,:]] += E_prods_am_permuted[i].reshape(
                                                                                                        E_prods_am_permuted[i].shape[1]*E_prods_am_permuted[i].shape[2],
                                                                                                        E_prods_am_permuted[i].shape[3]).float() * self.skip_scaling
                    E[i, ~(am_rct_selection[:,None]*am_rct_selection[None,:])] += F.one_hot(torch.tensor([0], dtype=torch.long, device=device), E.shape[-1]).float() * self.skip_scaling_2
                # X_prods_am_permuted = [F.one_hot(X_prods_am_permuted[i], self.out_dim_X) for i in range(bs)] # Shape (bs, N, dx)
                for i in range(bs):
                    am_rct_selection = (atom_map_numbers_rct[i] > 0)
                    X[i, am_rct_selection] += X_prods_am_permuted[i].squeeze(0).float() * self.skip_scaling
            else: # reuse the previous calculations
                X += X_to_concatenate * self.skip_scaling
                E += E_to_concatenate * self.skip_scaling_2

        return X, E, y, node_mask
    