import psutil

print('RAM memory % used:', psutil.virtual_memory()[2])
print('RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)

import torch
from graph_reconstruction.buildingblocks import *
from graph_reconstruction.build_molecule import *
import neptune
import yaml
import sys
from datasets import load_from_disk
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
import torch.optim as optim
from utils_normalizations import *
import torch.nn.utils as utils
from tqdm import tqdm
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument('--neptune', type=str, default=None, required=False)
parser.add_argument('config_path', type=str, required=True)
args = parser.parse_args()

max_grad_norm = 0.001

with open(args.config_path, 'r') as file:
    config = yaml.safe_load(file)
model_args = config['model_args']

criterion = torch.nn.BCEWithLogitsLoss(reduction="none").cuda()
feature_onehot_encoding = True
learning_rate = 0.00001
num_epochs = 100
model = get_model(model_args, feat_dim = 140, num_cats = 2)
    
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
# Extracting the models
gcn_model = model.gcn
relu = torch.nn.ReLU()

use_neptune = args.neptune is not None

mode = 'adam'

if use_neptune:
    run = neptune.init_run(
        project=args.neptune, 
        api_token=os.environ['NEPTUNE_API_KEY'], 
        name="sincere-oxpecker", 
    )
    run["feature_onehot_encoding"] = feature_onehot_encoding
    run["num_epochs"] = num_epochs
    run["optimizer type"] = mode
# Loading the dataset
dataloader_clintox = DataLoader(load_from_disk(config['data_args']['dataset']).with_format("torch")['train'], batch_size=1, shuffle=False)

frac_train = 1
grad_accum_steps = 10

lowest_loss = 100

oh = 0

for epoch in tqdm(range(num_epochs)):
    list_of_grads = []
    losses = []
    
    for ind, batch in enumerate(dataloader_clintox):
        
        if batch['smiles'] == ['[Se]']:
            continue
        
        if ind>=frac_train*len(dataloader_clintox):
            break

        gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
        label = batch['target']

        if label==1:
            gt_ls = torch.tensor([0.,1.]).cuda()
        else:
            gt_ls = torch.tensor([1.,0.]).cuda()

        gt_ams = normalize_adjacency(torch.tensor(get_A(gt_mol))).cuda()

        num_nodes = gt_ams.size(dim=0)

        gt_fms = normalize_features(torch.tensor(get_X(gt_mol, feature_onehot_encoding=feature_onehot_encoding)), dataset = 'clintox')
        gt_fms.requires_grad_(True)
        gcn_output = model.gcn(gt_fms,gt_ams)
        logits = model.readout(gt_fms, gcn_output).cuda()
        gcn_output.requires_grad_(True)
        loss = criterion(logits, gt_ls)
        assert(loss.numel()==2)
        loss = loss.sum() / loss.numel()         
        losses.append(loss)

        run["loss current"].append(loss)
        
        oh+=1
                
        loss.backward()  # Compute gradients for this batch
        
        utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        if (ind + 1) % grad_accum_steps == 0:
            optimizer.step()# Update weights using Adam optimizer
            optimizer.zero_grad() # Clear previous gradients
            print(f'{ind+1} Step taken, weights updated')
            
    for name, param in model.named_parameters():
            if param.grad is not None:
                print(f"Gradient for {name} at step {ind}: {param.grad}")
    
    # breakpoint()
    
    print(f'Epoch: {epoch}, average loss: {sum(losses)/len(losses)}', flush=True)
    
    if lowest_loss>sum(losses)/len(losses):
        
        torch.save(model.state_dict(), './models/model_one_hot_new.pth')
        lowest_loss = sum(losses)/len(losses)
        
    if (ind + 1) % grad_accum_steps != 0:
            optimizer.step()# Update weights using Adam optimizer
            optimizer.zero_grad() # Clear previous gradients
    
    sys.stdout.flush()
    
    if use_neptune:
        run["average loss"].append(sum(losses)/len(losses))

test_loss = []

model = torch.load('./models/model.pth')

for ind, batch in enumerate(dataloader_clintox):
    
    if ind<frac_train*len(dataloader_clintox):
        continue
    
    # print(batch['smiles'])
    gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
    label = batch['target']
    if label==1:
        gt_ls = torch.tensor([0.,1.])
    else:
        gt_ls = torch.tensor([1.,0.])
    gt_ams = torch.tensor(get_A(gt_mol))
    num_nodes = gt_ams.size(dim=0)
    gt_fms = torch.tensor(get_X(gt_mol, feature_onehot_encoding=feature_onehot_encoding))
    
    with torch.no_grad():  # Disable gradient calculation during testing
        gcn_output = model.gcn(gt_fms,gt_ams)
        logits = model.readout(gt_fms, gcn_output)
        loss = criterion(logits, gt_ls)
        assert(loss.numel()==2)
        loss = loss.sum() / loss.numel()         
        test_loss.append(loss)
    

if use_neptune:
    run["average test loss"] = sum(test_loss)/len(test_loss)

sys.stdout.flush()
