# Copyright 2024 the Llamole team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from torch_geometric.nn import MessagePassing
import json

class GraphCLIP(nn.Module):
    def __init__(
        self,
        graph_num_layer,
        graph_hidden_size,
        dropout,
        model_config,
    ):
        super().__init__()
        self.model_config = model_config
        self.hidden_size = graph_hidden_size
        self.molecule_encoder = GNNEncoder(num_layer=graph_num_layer, hidden_size=graph_hidden_size, drop_ratio=dropout)
        self.molecule_projection = ProjectionHead(embedding_dim=graph_hidden_size, projection_dim=graph_hidden_size, dropout=dropout)

    def forward(self, x, edge_index, edge_attr, batch):
        molecule_features = self.molecule_encoder(x, edge_index, edge_attr, batch)
        molecule_embeddings = self.molecule_projection(molecule_features)
        molecule_embeddings = molecule_embeddings / molecule_embeddings.norm(dim=-1, keepdim=True)
        return molecule_embeddings
    
    def save_pretrained(self, output_dir):
        """
        Save the molecule encoder, projection models, and model_config to the output directory.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        molecule_path = os.path.join(output_dir, 'model.pt')
        proj_path = molecule_path.replace('model', 'model_proj')
        config_path = os.path.join(output_dir, 'model_config.json')

        torch.save(self.molecule_encoder.state_dict(), molecule_path)
        torch.save(self.molecule_projection.state_dict(), proj_path)
        
        # Save model_config to JSON file
        with open(config_path, 'w') as f:
            json.dump(self.model_config, f, indent=2)

    def disable_grads(self):
        """
        Disable gradients for all parameters in the model.
        """
        for param in self.parameters():
            param.requires_grad = False

    def init_model(self, model_path, verbose=True):
        molecule_path = os.path.join(model_path, 'model.pt')
        proj_path = molecule_path.replace('model', 'model_proj')
        if os.path.exists(molecule_path):
            self.molecule_encoder.load_state_dict(torch.load(molecule_path, map_location='cpu'))
        else:
            raise FileNotFoundError(f"Molecule encoder file not found: {molecule_path}")
        
        if os.path.exists(proj_path):
            self.molecule_projection.load_state_dict(torch.load(proj_path, map_location='cpu'))
        else:
            raise FileNotFoundError(f"Molecule projection file not found: {proj_path}")
        
        if verbose:
            print('GraphCLIP Models initialized.')
            print('Molecule model:\n', self.molecule_encoder)
            print('Molecule projection:\n', self.molecule_projection)

class GNNEncoder(nn.Module):
    def __init__(self, num_layer, hidden_size, drop_ratio):

        super(GNNEncoder, self).__init__()
        
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = nn.Embedding(118, hidden_size)

        ### set the initial virtual node embedding to 0.
        self.virtualnode_embedding = nn.Embedding(1, hidden_size)
        nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
        
        ### List of GNNs
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.mlp_virtualnode_list = nn.ModuleList()

        for layer in range(num_layer):
            self.convs.append(GINConv(hidden_size, drop_ratio))
            self.norms.append(nn.LayerNorm(hidden_size, elementwise_affine=True))
            if layer < num_layer - 1:
                self.mlp_virtualnode_list.append(nn.Sequential(nn.Linear(hidden_size, 4*hidden_size), nn.LayerNorm(4*hidden_size), nn.GELU(), nn.Dropout(drop_ratio), \
                                                               nn.Linear(4*hidden_size, hidden_size)))
            
    def initialize_weights(self):
        ### Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

    def forward(self, x, edge_index, edge_attr, batch):

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

        h_list = [self.atom_encoder(x)]
        
        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]   
        
            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.norms[layer](h)
            
            if layer < self.num_layer - 1:
                h = F.gelu(h)
                h = F.dropout(h, self.drop_ratio, training = self.training)
            
            h = h + h_list[layer]
            h_list.append(h)

            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtual_pool = global_max_pool(h_list[layer], batch)
                virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtual_pool), self.drop_ratio, training = self.training)

        h_node = h_list[-1]
        h_graph = global_add_pool(h_node, batch)
        
        return h_graph
    
class GINConv(MessagePassing):
    def __init__(self, hidden_size, drop_ratio):
        '''
            hidden_size (int)
        '''
        super(GINConv, self).__init__(aggr = "add")

        self.mlp = nn.Sequential(nn.Linear(hidden_size, 4*hidden_size), nn.LayerNorm(4*hidden_size), nn.GELU(), nn.Dropout(drop_ratio), nn.Linear(4*hidden_size, hidden_size))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = nn.Embedding(5, hidden_size)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return F.gelu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim,
        dropout,
        act_layer=nn.GELU,
        hidden_features=None,
        bias=True
    ):
        super().__init__()
        projection_dim = projection_dim or embedding_dim
        hidden_features = hidden_features or embedding_dim
        linear_layer = nn.Linear

        self.fc1 = linear_layer(embedding_dim, hidden_features, bias=bias)
        self.norm1 = nn.LayerNorm(hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(dropout)
        self.fc2 = linear_layer(hidden_features, projection_dim, bias=bias)

    def forward(self, x):
        x = self.fc1(x)
        x = self.norm1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        return x
