from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from rdkit import Chem

import dgllife


class EnergyModel(nn.Module):
    def __init__(self, model_config):
        super().__init__()
        node_feature_len = model_config["num_nodes"]
        edge_feature_len = model_config["num_edges"]
        feature_len = model_config["num_feat"]
        self.time_steps = model_config["n_steps"]
        self.num_outputs = model_config["num_outputs"]

        self.gconv = dgllife.model.gnn.attentivefp.AttentiveFPGNN(
            node_feature_len,
            edge_feature_len,
            num_layers=self.time_steps,
            graph_feat_size=feature_len,
        )
        self.readout = dgllife.model.readout.attentivefp_readout.AttentiveFPReadout(
            feature_len
        )

        self.op_1 = nn.Linear(feature_len, int(feature_len / 2))
        self.op = nn.Linear(int(feature_len / 2), self.num_outputs)

    def forward(self, mol_graph_batch):

        node_features = self.gconv(
            mol_graph_batch, mol_graph_batch.ndata["h"], mol_graph_batch.edata["e"]
        )

        gfeat = self.readout(mol_graph_batch, node_features)

        outputs = F.relu(self.op_1(gfeat))
        outputs = self.op(outputs)

        return outputs

class EnergyFeaturesModel(nn.Module):
    def __init__(self, model_config):
        super().__init__()
        node_feature_len = model_config["num_nodes"]
        edge_feature_len = model_config["num_edges"]
        feature_len = model_config["num_feat"]
        self.time_steps = model_config["n_steps"]
        self.num_outputs = model_config["num_outputs"]

        self.gconv = dgllife.model.gnn.attentivefp.AttentiveFPGNN(
            node_feature_len,
            edge_feature_len,
            num_layers=self.time_steps,
            graph_feat_size=feature_len,
        )
        self.readout = dgllife.model.readout.attentivefp_readout.AttentiveFPReadout(
            feature_len
        )

        self.op_1 = nn.Linear(feature_len, int(feature_len / 2))
        self.op_2 = nn.Linear(int(feature_len / 2),8)
        self.op = nn.Linear(8,self.num_outputs)

    def forward(self, mol_graph_batch):

        node_features = self.gconv(
            mol_graph_batch, mol_graph_batch.ndata["h"], mol_graph_batch.edata["e"]
        )

        gfeat = self.readout(mol_graph_batch, node_features)

        outputs = F.relu(self.op_1(gfeat))
        outputs = F.relu(self.op_2(outputs))
        outputs = self.op(outputs)

        return outputs

    def get_features(self, mol_graph_batch):

        node_features = self.gconv(
            mol_graph_batch, mol_graph_batch.ndata["h"], mol_graph_batch.edata["e"]
        )

        gfeat = self.readout(mol_graph_batch, node_features)

        outputs = F.relu(self.op_1(gfeat))
        outputs = F.relu(self.op_2(outputs))

        return outputs

class ClassificationFeatureModel(nn.Module):
    def __init__(self, model_config):
        super().__init__()
        node_feature_len = model_config["num_nodes"]
        edge_feature_len = model_config["num_edges"]
        feature_len = model_config["num_feat"]
        self.time_steps = model_config["n_steps"]
        self.num_outputs = model_config["num_outputs"]

        self.gconv = dgllife.model.gnn.attentivefp.AttentiveFPGNN(
            node_feature_len,
            edge_feature_len,
            num_layers=self.time_steps,
            graph_feat_size=feature_len,
        )
        self.readout = dgllife.model.readout.attentivefp_readout.AttentiveFPReadout(
            feature_len
        )

        self.op_1 = nn.Linear(feature_len, int(feature_len / 2))
        self.op_2 = nn.Linear(int(feature_len/2), 8)
        self.op = nn.Linear(8, self.num_outputs)

        

    def forward(self, mol_graph_batch):

        node_features = self.gconv(
            mol_graph_batch, mol_graph_batch.ndata["h"], mol_graph_batch.edata["e"]
        )

        gfeat = self.readout(mol_graph_batch, node_features)
        op1 = F.relu(self.op_1(gfeat))
        op2 = F.relu(self.op_2(op1))
        outputs = self.op(op2)
        return outputs
    
    def get_features(self, mol_graph_batch):
        node_features = self.gconv(
            mol_graph_batch, mol_graph_batch.ndata["h"], mol_graph_batch.edata["e"]
        )

        gfeat = self.readout(mol_graph_batch, node_features)
        op1 = self.op_1(gfeat)
        op2 = self.op_2(op1)
        return op2

