import sys
sys.path.append('..')

import math
import logging
log = logging.getLogger(__name__)

import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from src.utils import math_utils, graph

from src.neuralnet.layers import XEyTransformerLayer

from torch_geometric.utils.sparse import dense_to_sparse
from torch_geometric.utils import (
    get_laplacian,
    to_scipy_sparse_matrix,
)
from torch.cuda.amp import autocast

from scipy.sparse.linalg import eigsh
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PositionalEmbedding(nn.Module):

    def __init__(self, dim, pos_emb_permutations):
        super().__init__()
        self.dim = dim
        self.pos_emb_permutations = pos_emb_permutations

    def matched_positional_encodings_laplacian_scipy(self, E, atom_map_numbers, molecule_assignments, max_eigenvectors):
        assert molecule_assignments != None, 'Need to provide molecule assigments'
        assert len(E.shape) == 3 and E.shape[1] == E.shape[2], 'E should not be in one-hot format'
        bs, n = atom_map_numbers.shape[0], atom_map_numbers.shape[1]
        E = E.clone() 
        molecule_assignments = molecule_assignments.cpu()
        atom_map_numbers = atom_map_numbers.cpu()
        
        product_indices = molecule_assignments.max(-1).values
        pos_embeddings = torch.zeros((bs, n, self.dim), dtype=torch.float32, device=E.device)
        
        for i in range(bs):
            E[i, product_indices[i] != molecule_assignments[i], :] = 0
            E[i, :, product_indices[i] != molecule_assignments[i]] = 0

            k = min(max_eigenvectors, (product_indices[i] == molecule_assignments[i]).sum().item())
            pe = laplacian_eigenvectors_scipy(E[i], k) # Shape (bs, n, k)
            pe = torch.cat([pe, torch.zeros((n, self.dim - k), dtype=torch.float32, device=pe.device)], dim=-1)
            
            # Create a mapping from atom map number to the interesting positional encodings (zero pos enc for non-atom mapped atoms)
            pos_embs_prod = torch.cat([torch.zeros(1, self.dim), pe[molecule_assignments[i] == product_indices[i]]], 0)
            
            indices_of_product = (molecule_assignments[i] == product_indices[i]).nonzero().squeeze(1)
            max_atom_map = atom_map_numbers[i, :].max() # Need this because sometimes some of the atom maps are missing in the data
            atom_map_nums_product = torch.cat([torch.tensor([0]), atom_map_numbers[i, indices_of_product]]) # The zero is for the non-atom mapped atoms
            atom_map_to_idx = torch.zeros(max_atom_map + 1, dtype=torch.long)
            atom_map_to_idx.scatter_(dim=0, index=atom_map_nums_product, src=torch.arange(len(atom_map_nums_product)))

            pos_embeddings[i] = pos_embs_prod[atom_map_to_idx[atom_map_numbers[i]]]
        return pos_embeddings

    def matched_positional_encodings_laplacian(self, E, atom_map_numbers, molecule_assignments, max_eigenvectors):
        assert molecule_assignments != None, 'Need to provide molecule assigments'
        assert len(E.shape) == 3 and E.shape[1] == E.shape[2], 'E should not be in one-hot format'
        bs, n = atom_map_numbers.shape[0], atom_map_numbers.shape[1]
        E = E.clone() 
        molecule_assignments = molecule_assignments.cpu()
        atom_map_numbers = atom_map_numbers.cpu()
        
        perm = torch.arange(atom_map_numbers.max().item()+1)[1:]
        perm = perm[torch.randperm(len(perm))]
        perm = torch.cat([torch.zeros(1, dtype=torch.long), perm])
        atom_map_numbers = perm[atom_map_numbers]

        product_indices = molecule_assignments.max(-1).values
        
        pos_embeddings = torch.zeros((bs, n, self.dim), dtype=torch.float32, device=E.device)
        
        for i in range(bs):
            E[i, product_indices[i] != molecule_assignments[i], :] = 0
            E[i, :, product_indices[i] != molecule_assignments[i]] = 0

            k = min(max_eigenvectors, (product_indices[i] == molecule_assignments[i]).sum().item())
            pe = laplacian_eigenvectors(E[i], k) # Shape (bs, n, k)
            pe = torch.cat([pe, torch.zeros((n, self.dim - k), dtype=torch.float32, device=pe.device)], dim=-1)
            
            pos_embs_prod = torch.cat([torch.zeros(1, self.dim, device=device), pe[molecule_assignments[i] == product_indices[i]]], 0)
            

            indices_of_product = (molecule_assignments[i] == product_indices[i]).nonzero().squeeze(1)
            max_atom_map = atom_map_numbers[i, :].max() # Need this because sometimes some of the atom maps are missing in the data
            # The atom map numbers of the product
            atom_map_nums_product = torch.cat([torch.tensor([0]), atom_map_numbers[i, indices_of_product]]) # The zero is for the non-atom mapped atoms
            atom_map_to_idx = torch.zeros(max_atom_map + 1, dtype=torch.long)
            atom_map_to_idx.scatter_(dim=0, index=atom_map_nums_product, src=torch.arange(len(atom_map_nums_product)))
            
            pos_embeddings[i] = pos_embs_prod[atom_map_to_idx[atom_map_numbers[i]]]

        return pos_embeddings

def laplacian_eigenvectors_scipy(E, k):
    """
    Computes the eigenvectors of the Laplacian matrix of a graph.

    Parameters:
    E (np.ndarray): A dense adjacency matrix of shape (n, n) where n is the number of nodes in the graph.
    k (int): The number of eigenvectors to compute.

    Returns:
    torch.Tensor: A tensor of shape (n, k) containing the real parts of the k eigenvectors of the Laplacian matrix, excluding the first eigenvector.
    """
    edge_index, edge_attr = dense_to_sparse(E)
    num_nodes = E.shape[-1]

    L_edge_index, L_edge_weight = get_laplacian(
        edge_index,
        normalization='sym',
        num_nodes=num_nodes
    )
    
    L = to_scipy_sparse_matrix(L_edge_index, L_edge_weight, num_nodes)

    eig_vals, eig_vecs = eigsh(
        L,
        k=k+1,
        which='SA',
        return_eigenvectors=True,
        ncv=min(E.shape[0], max(20*k + 1, 40))
    )

    eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()])
    pe = torch.from_numpy(eig_vecs[:, 1:k + 1])
    # pe = torch.from_numpy(eig_vecs)
    sign = -1 + 2 * torch.randint(0, 2, (k, ))
    pe *= sign
    return pe

def laplacian_eigenvectors(E, k):
    """
    Computes the eigenvectors of the Laplacian matrix of a graph using PyTorch and CUDA.

    Parameters:
    E (torch.Tensor): A dense adjacency matrix of shape (n, n) where n is the number of nodes in the graph.
    k (int): The number of eigenvectors to compute.

    Returns:
    torch.Tensor: A tensor of shape (n, k) containing the real parts of the k eigenvectors of the Laplacian matrix, excluding the first eigenvector.
    """

    num_nodes = E.shape[-1]

    edge_indices = torch.nonzero(E, as_tuple=False).t()
    L_edge_index, L_edge_weight = get_laplacian(edge_indices, normalization='sym', num_nodes=num_nodes)

    L_dense = torch.zeros((num_nodes, num_nodes), device=E.device)
    L_dense[L_edge_index[0], L_edge_index[1]] = L_edge_weight

    eig_vals, eig_vecs = torch.linalg.eigh(L_dense)

    # Sort eigenvectors based on eigenvalues and exclude the first eigenvector
    sorted_indices = torch.argsort(eig_vals)
    pe = eig_vecs[:, sorted_indices[1:k + 1]]

    # Apply random sign flipping
    sign = -1 + 2 * torch.randint(0, 2, (k, ), device=E.device)
    pe *= sign

    return pe

class GraphTransformerWithYStacked(nn.Module):
    """
    n_layers : int -- number of layers
    dims : dict -- contains dimensions for each feature type
    """
    def __init__(self, n_layers: int, input_dims: dict, hidden_mlp_dims: dict, hidden_dims: dict,
                 output_dims: dict, act_fn_in: nn.ReLU(), act_fn_out: nn.ReLU(), pos_emb_permutations: int = 0, 
                 dropout=0.1, p_to_r_skip_connection=False, p_to_r_init=10.):
        super().__init__()
        self.n_layers = n_layers
        self.out_dim_X = output_dims['X']
        self.out_dim_E = output_dims['E']
        self.out_dim_y = output_dims['y']
        self.input_dim_X = input_dims['X']
        self.pos_emb_permutations = pos_emb_permutations
        self.p_to_r_skip_connection = p_to_r_skip_connection

        self.pos_emb_module = PositionalEmbedding(dim=input_dims['X'], pos_emb_permutations=-1)

        self.mlp_in_X = nn.Sequential(nn.Linear(input_dims['X']*3, hidden_mlp_dims['X']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['X'], hidden_dims['dx']), act_fn_in) # Reactants, products, and positional encodings

        self.mlp_in_E = nn.Sequential(nn.Linear(input_dims['E']*2, hidden_mlp_dims['E']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['E'], hidden_dims['de']), act_fn_in)

        self.mlp_in_y = nn.Sequential(nn.Linear(input_dims['y'], hidden_mlp_dims['y']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['y'], hidden_dims['dy']), act_fn_in)

        self.tf_layers = nn.ModuleList([XEyTransformerLayer(dx=hidden_dims['dx'],
                                                            de=hidden_dims['de'],
                                                            dy=hidden_dims['dy'],
                                                            n_head=hidden_dims['n_head'],
                                                            dim_ffX=hidden_dims['dim_ffX'],
                                                            dim_ffE=hidden_dims['dim_ffE'], dropout=dropout)
                                        for i in range(n_layers)])

        self.mlp_out_X = nn.Sequential(nn.Linear(hidden_dims['dx'], hidden_mlp_dims['X']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['X'], output_dims['X']))

        self.mlp_out_E = nn.Sequential(nn.Linear(hidden_dims['de'], hidden_mlp_dims['E']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['E'], output_dims['E']))
        
        self.mlp_out_y = nn.Sequential(nn.Linear(hidden_dims['dy'], hidden_mlp_dims['y']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['y'], output_dims['y']))
        
        if self.p_to_r_skip_connection:
            self.skip_scaling = nn.Parameter(torch.tensor([p_to_r_init], dtype=torch.float))
            self.skip_scaling_2 = nn.Parameter(torch.tensor([p_to_r_init], dtype=torch.float))
            self.skip_scaling_3 = nn.Parameter(torch.tensor([1.], dtype=torch.float))

    def cut_reaction_reactant_part_X_only(self, X, reaction_side_separation_index):
        device = X.device
        rct_side = torch.arange(X.shape[1], device=device)[None,:].repeat(X.shape[0], 1) < reaction_side_separation_index[:,None]
        prod_side = torch.arange(X.shape[1], device=device)[None,:].repeat(X.shape[0], 1) > reaction_side_separation_index[:,None]

        bs, n, dx = X.shape[0], X.shape[1], X.shape[2]
        biggest_reactant_set_size = reaction_side_separation_index.max()
        node_mask_cut = torch.zeros(bs, biggest_reactant_set_size, device=device, dtype=torch.bool)
        X_cut = torch.zeros(bs, biggest_reactant_set_size, dx, device=device)
        for i in range(bs):
            X_cut[i, rct_side[i][:biggest_reactant_set_size]] = X[i, rct_side[i]]
            node_mask_cut[i, rct_side[i][:biggest_reactant_set_size]] = True
        return X_cut, node_mask_cut

    def cut_reaction_reactant_part(self, X, E, reaction_side_separation_index):
        # Takes the graph specified by X, E and returns the subset that only has the reactants.
        device = X.device
        rct_side = torch.arange(X.shape[1], device=device)[None,:].repeat(X.shape[0], 1) < reaction_side_separation_index[:,None]
        prod_side = torch.arange(X.shape[1], device=device)[None,:].repeat(X.shape[0], 1) > reaction_side_separation_index[:,None]
        bs, n, dx = X.shape[0], X.shape[1], X.shape[2]
        de = E.shape[3]

        biggest_reactant_set_size = reaction_side_separation_index.max()
        node_mask_cut = torch.zeros(bs, biggest_reactant_set_size, device=device, dtype=torch.bool)
        X_cut = torch.zeros(bs, biggest_reactant_set_size, dx, device=device)
        for i in range(bs):
            X_cut[i, rct_side[i][:biggest_reactant_set_size]] = X[i, rct_side[i]]
            node_mask_cut[i, rct_side[i][:biggest_reactant_set_size]] = True

        E_cut = torch.zeros(bs, biggest_reactant_set_size, biggest_reactant_set_size, de, device=device)
        for i in range(bs):
            rct_mask_E_cut = rct_side[i][:biggest_reactant_set_size][:,None] * rct_side[i][:biggest_reactant_set_size][None,:]
            E_cut[i][rct_mask_E_cut] = E[i][rct_side[i][:,None] * rct_side[i][None,:]]

        return X_cut, E_cut, node_mask_cut

    def get_X_E_product_aligned_with_reactants(self, X, E, atom_map_numbers, reaction_side_separation_index):
        orig_E = E.clone()
        orig_X = X.clone()
        bs, n, dx = X.shape[0], X.shape[1], X.shape[2]
        device = X.device
        # First we need to split X and E up, find the product indices, then align the product with the reactants.
        rct_side = torch.arange(X.shape[1], device=device)[None,:].repeat(X.shape[0], 1) < reaction_side_separation_index[:,None]
        prod_side = torch.arange(X.shape[1], device=device)[None,:].repeat(X.shape[0], 1) > reaction_side_separation_index[:,None]

        atom_map_numbers_prod, atom_map_numbers_rct = atom_map_numbers.clone(), atom_map_numbers.clone()
        atom_map_numbers_prod[rct_side] = 0
        atom_map_numbers_rct[prod_side] = 0

        X_cut, E_cut, node_mask_cut = self.cut_reaction_reactant_part(X, E, reaction_side_separation_index)
        
        X_prod, E_prod = torch.zeros_like(X), torch.zeros_like(E)

        atom_map_numbers_prod_idxs = [torch.arange(atom_map_numbers.shape[-1], device=device)[atom_map_numbers_prod[i]>0] for i in range(bs)]
        atom_map_numbers_rct_idxs = [torch.arange(atom_map_numbers.shape[-1], device=device)[atom_map_numbers_rct[i]>0] for i in range(bs)]
        # The selection chooses the correct atom map numbers
        E_prods_atom_mapped = [
            orig_E[i,atom_map_numbers_prod_idxs[i]][:, atom_map_numbers_prod_idxs[i]].unsqueeze(0)
            for i in range(bs)]
        assert all(len(atom_map_numbers_prod_idxs[i]) == len(atom_map_numbers_rct_idxs[i]) for i in range(bs))
        Ps = [math_utils.create_permutation_matrix_torch(atom_map_numbers_prod[i][atom_map_numbers_prod_idxs[i]],
                                                    atom_map_numbers_rct[i][atom_map_numbers_rct_idxs[i]]).float().to(device)
                                                    for i in range(bs)]
        P_expanded = [P.unsqueeze(0) for P in Ps] # The unsqueeze will be unnecessary with proper batching here
        # Permute the edges obtained from the product: P @ E @ P^T
        E_prods_am_permuted = [torch.movedim(P_expanded[i].transpose(dim0=1,dim1=2) @ torch.movedim(E_prods_atom_mapped[i].float(), -1, 1) @ P_expanded[i], 1, -1) for i in range(bs)]
        for i in range(bs):
            am_rct_selection = (atom_map_numbers_rct[i] > 0)
            E_prod[i, am_rct_selection[:,None] * am_rct_selection[None,:]] += E_prods_am_permuted[i].reshape(
                                                                                                E_prods_am_permuted[i].shape[1]*E_prods_am_permuted[i].shape[2],
                                                                                                E_prods_am_permuted[i].shape[3]).float()
            E_prod[i, ~(am_rct_selection[:,None]*am_rct_selection[None,:])] += F.one_hot(torch.tensor([0], dtype=torch.long, device=device), E.shape[-1]).float()
        
        X_prods_atom_mapped = [orig_X[i,:,:][atom_map_numbers_prod_idxs[i]].unsqueeze(0) for i in range(bs)]
        X_prods_am_permuted = [P_expanded[i].transpose(dim0=1,dim1=2) @ X_prods_atom_mapped[i] for i in range(bs)]
        for i in range(bs):
            am_rct_selection = (atom_map_numbers_rct[i] > 0)
            X_prod[i, am_rct_selection] += X_prods_am_permuted[i].squeeze(0).float()

        X_prod_cut, E_prod_cut, node_mask_cut_2 = self.cut_reaction_reactant_part(X_prod, E_prod, reaction_side_separation_index)
        assert torch.equal(node_mask_cut, node_mask_cut_2)
        return X_cut, X_prod_cut, E_cut, E_prod_cut, node_mask_cut
        
    def expand_to_full_size(self, X, E, n_nodes):
        # Fills X and E with zeros up to n_nodes dim
        bs, n, dx = X.shape[0], X.shape[1], X.shape[2]
        de = E.shape[3]
        X_ = torch.zeros(bs, n_nodes, dx, device=X.device)
        E_ = torch.zeros(bs, n_nodes, n_nodes, de, device=X.device)
        X_[:, :n] = X
        E_[:, :n, :n] = E
        return X_, E_

    def choose_pos_enc(self, X, E, reaction_side_separation_index, mol_assignments, atom_map_numbers, pos_encoding_type, num_lap_eig_vectors):
        if pos_encoding_type == 'laplacian_pos_enc_gpu':
            pos_encodings = self.pos_emb_module.matched_positional_encodings_laplacian(E.argmax(-1), atom_map_numbers, mol_assignments, num_lap_eig_vectors)
            pos_encodings, _ = self.cut_reaction_reactant_part_X_only(pos_encodings, reaction_side_separation_index)
        elif pos_encoding_type == 'laplacian_pos_enc':
            pos_encodings = self.pos_emb_module.matched_positional_encodings_laplacian_scipy(E.argmax(-1), atom_map_numbers, mol_assignments, num_lap_eig_vectors)
            pos_encodings, _ = self.cut_reaction_reactant_part_X_only(pos_encodings, reaction_side_separation_index)
        else:
            pos_encodings = torch.zeros(X.shape[0], X.shape[1], self.input_dim_X, device=X.device)
            pos_encodings, _ = self.cut_reaction_reactant_part_X_only(pos_encodings, reaction_side_separation_index)
        return pos_encodings

    def forward(self, X, E, y, node_mask, atom_map_numbers, pos_encodings, mol_assignments, use_pos_encoding_if_applicable, pos_encoding_type, num_lap_eig_vectors, atom_types):
        prod_assignment = mol_assignments.max(-1).values
        reaction_side_separation_index = (mol_assignments == prod_assignment[:,None]).to(torch.int).argmax(-1) - 1 # -1 because the supernode doesn't belong to product according to mol_assignments

        if pos_encodings == None: 
            with autocast(enabled=False):
                pos_encodings = self.choose_pos_enc(X, E, reaction_side_separation_index, mol_assignments, atom_map_numbers, pos_encoding_type, num_lap_eig_vectors)
        pos_encodings *= use_pos_encoding_if_applicable[:,None,None].to(pos_encodings.device).float()

        n_nodes_original = X.shape[1]
        orig_node_mask = node_mask        

        X_cut, X_prod_aligned, E_cut, E_prod_aligned, node_mask_cut = self.get_X_E_product_aligned_with_reactants(X, E, atom_map_numbers, reaction_side_separation_index)

        X = torch.cat([X_cut, X_prod_aligned, pos_encodings], dim=-1)
        E = torch.cat([E_cut, E_prod_aligned], dim=-1)
        node_mask = node_mask_cut

        assert atom_map_numbers is not None
        bs, n = X.shape[0], X.shape[1]
        device = X.device

        diag_mask = torch.eye(n)
        diag_mask = ~diag_mask.type_as(E).bool()
        diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)

        X_to_out = X[..., :self.out_dim_X]
        E_to_out = E[..., :self.out_dim_E]
        y_to_out = y[..., :self.out_dim_y]

        new_E = self.mlp_in_E(E)
        new_E = (new_E + new_E.transpose(1, 2)) / 2

        # all mask padding nodes (with node_mask)
        after_in = graph.PlaceHolder(X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)).mask(node_mask)
        X, E, y = after_in.X, after_in.E, after_in.y

        for layer in self.tf_layers:
            X, E, y = layer(X, E, y, node_mask)

        X = self.mlp_out_X(X)
        E = self.mlp_out_E(E)
        y = y[..., :self.out_dim_y]

        X = (X + X_to_out)
        E = (E + E_to_out) * diag_mask
        y = y + y_to_out


        E = 1/2 * (E + torch.transpose(E, 1, 2))

        # Potential edge-skip connection from product side to reactant side
        if self.p_to_r_skip_connection:
            E += E_prod_aligned[...,:self.out_dim_E] * self.skip_scaling_2
            X += X_prod_aligned[...,:self.out_dim_X] * self.skip_scaling

        X, E = self.expand_to_full_size(X, E, n_nodes_original) # for calculating the loss etc. 

        return X, E, y, orig_node_mask
