import torch
import argparse
import os
from dig.threedgraph.dataset import QM93D
from dig.threedgraph.method import SphereNet, SchNet
from dig.threedgraph.evaluation import ThreeDEvaluator
from dig.threedgraph.method import run
from tqdm import tqdm
from models import GNN

from torch_geometric.nn import (MessagePassing, global_add_pool,
                                global_max_pool, global_mean_pool)
import torch.nn as nn


class GNN_graphpred(nn.Module):
    """
    Extension of GIN to incorporate edge information by concatenation.

    Args:
        args.num_layer (int): the number of GNN layers
        arg.emb_dim (int): dimensionality of embeddings
        num_tasks (int): number of tasks in multi-task learning scenario
        args.JK (str): last, concat, max or sum.
        args.graph_pooling (str): sum, mean, max, attention, set2set

    See https://arxiv.org/abs/1810.00826
    JK-net: https://arxiv.org/abs/1806.03536 """

    def __init__(self, molecule_model):
        super(GNN_graphpred, self).__init__()

        
        self.molecule_model = molecule_model
        

        # Different kind of graph pooling
        
        
        
        self.pool = global_mean_pool
        

        # For graph-level binary classification


        
        
        pred_head = [nn.Linear(300, 150), nn.ReLU(inplace=True)]
        pred_head.append(nn.Linear(150, 1))
        self.graph_pred_linear=nn.Sequential(*pred_head)
        
        return

    def from_pretrained(self, model_file):
        self.molecule_model.load_state_dict(torch.load(model_file))
        return

    def get_graph_representation(self, *argv):
        if len(argv) == 4:
            x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr, batch = data.x, data.edge_index, \
                                              data.edge_attr, data.batch
        else:
            raise ValueError("unmatched number of arguments.")

        node_representation = self.molecule_model.gnn(x, edge_index, edge_attr)
        graph_representation = self.pool(node_representation, batch)
        pred = self.graph_pred_linear(graph_representation)

        return graph_representation, pred

    def forward(self, data):
        
        x, edge_index, edge_attr, batch = data.x, data.edge_index, \
                                              data.edge_attr, data.batch
        
        node_representation = self.molecule_model(x, edge_index, edge_attr)
        graph_representation = self.pool(node_representation, batch)
        output = self.graph_pred_linear(graph_representation)

        return output




parser = argparse.ArgumentParser(description='QM9 SphereNet')
parser.add_argument('--device', type=int, default=0)

parser.add_argument('--cutoff', type=float, default=5.0)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--out_channels', type=int, default=1)
parser.add_argument('--int_emb_size', type=int, default=64)
parser.add_argument('--basis_emb_size_dist', type=int, default=8)
parser.add_argument('--basis_emb_size_angle', type=int, default=8)
parser.add_argument('--basis_emb_size_torsion', type=int, default=8)
parser.add_argument('--out_emb_channels', type=int, default=256)
parser.add_argument('--num_spherical', type=int, default=3)
parser.add_argument('--num_radial', type=int, default=6)

parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--vt_batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--lr_decay_factor', type=float, default=0.5)
parser.add_argument('--lr_decay_step_size', type=int, default=100)

parser.add_argument('--task', type=str, default='')

args = parser.parse_args()
print(args)

device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print('device',device)


# Load the dataset and split
dataset = QM93D(root='dataset/')
target = args.task

dataset.data.y = dataset.data[target]
split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=110000, valid_size=10000, seed=42)
train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]


model_2D = GNN(5, 300, "last", 0, "gin")

state_dict = torch.load("/your/checkpoint/path")["model_2D_state_dict"]

from collections import OrderedDict
new_state_dict = OrderedDict()
for key,v in state_dict.items():
    if key.startswith("gnn."):
        name = key[4:]
        new_state_dict[name] = v


model = GNN_graphpred(model_2D)
model.molecule_model.load_state_dict(new_state_dict)

loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

# Train and evaluate
run3d = run()
run3d.run(device, train_dataset, valid_dataset, test_dataset, model, loss_func, evaluation,
          epochs=args.epochs, batch_size=args.batch_size, vt_batch_size=args.vt_batch_size, lr=args.lr, lr_decay_factor=args.lr_decay_factor, lr_decay_step_size=args.lr_decay_step_size, log_dir="/your/log/dir"+args.task+"_log", save_dir="/your/save/dir"+args.task+"_save")

