import torch
import argparse
import os
from dig.threedgraph.dataset import QM93D
from dig.threedgraph.method import SphereNet, SchNet
from dig.threedgraph.evaluation import ThreeDEvaluator
from dig.threedgraph.method import run_pretrain_holimol
from tqdm import tqdm
from molecule_graphfrag_randomaug_3D_2 import Molecule3DDatasetFragRandomaug3d_2
import numpy as np
from models import GNN
from pretrain_JOAO import graphcl, graphcl3d


parser = argparse.ArgumentParser(description='QM9 SphereNet')
parser.add_argument('--device', type=int, default=0)

parser.add_argument('--cutoff', type=float, default=5.0)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--out_channels', type=int, default=1)
parser.add_argument('--int_emb_size', type=int, default=64)
parser.add_argument('--basis_emb_size_dist', type=int, default=8)
parser.add_argument('--basis_emb_size_angle', type=int, default=8)
parser.add_argument('--basis_emb_size_torsion', type=int, default=8)
parser.add_argument('--out_emb_channels', type=int, default=256)
parser.add_argument('--num_spherical', type=int, default=3)
parser.add_argument('--num_radial', type=int, default=6)

parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--vt_batch_size', type=int, default=256)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--choose', type=int, default=0)
parser.add_argument('--aug_mode', type=str, default='choosetwo')
parser.add_argument('--aug_strength', type=float, default=0.1)
parser.add_argument('--root_2d', type=str, default='')

#parser.add_argument('--lr_decay_factor', type=float, default=0.5)
#parser.add_argument('--lr_decay_step_size', type=int, default=100)

args = parser.parse_args()
print(args)

device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print('device',device)






root_2d = args.root_2d
dataset = Molecule3DDatasetFragRandomaug3d_2(root=root_2d, n_mol=304466, choose=args.choose,
                          smiles_copy_from_3D_file='%s/processed/smiles.csv' % root_2d)

dataset.set_augMode(args.aug_mode)
dataset.set_augStrength(args.aug_strength)
aug_prob = np.ones(25) / 25
dataset.set_augProb(aug_prob)


gnn = GNN(5, 300, "last", 0, "gin")
model_2D = graphcl(gnn).to(device)


gnn_3D = SchNet(hidden_channels=300)
model_3D = graphcl3d(gnn_3D).to(device)



loss_func = torch.nn.L1Loss()


# Train and evaluate
run3d = run_pretrain_holimol()
run3d.run(device, dataset,  model_2D, model_3D, loss_func,
          epochs=args.epochs, batch_size=args.batch_size, vt_batch_size=args.vt_batch_size, lr=args.lr)
