from trainer_combined import LitModuleCombined

import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping
from pytorch_lightning import loggers as pl_loggers

from itertools import product

import os

import torch
from torch.utils.tensorboard import SummaryWriter

tensorboard_root = "tensorboard_logs/" 

accelerator = "gpu" if torch.cuda.is_available() else "cpu"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embed_size = 512
batch_size = 128
data_type = "AMES" # "AMES","HIV","Tox21","ClinTox"
data_split = "scaffold" # "scaffold","fp","random"
checkpoint_dir = "lightning_logs"
ood_factor = 0.01


params = {}
params["batch_size"] = batch_size
params["max_epochs"] = 50
params["embed_size"] = embed_size
params["data_type"] = data_type
params["split_type"] = data_split
params["ood_factor"] = ood_factor
params["ood_head"] = True
params["num_quantiles"] = 2

earlystopping = EarlyStopping(monitor="val_loss", mode="min", patience=5)
pbar = TQDMProgressBar()
file_suffix = params["data_type"]
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    save_top_k=1,
    filename="{file_suffix},Val_loss:{val_loss:.2f}",
)

tb_save_dir = os.path.join(tensorboard_root, params["data_type"])
tb_name = params["split_type"]
tensorboard = pl_loggers.TensorBoardLogger(save_dir=tb_save_dir, name=tb_name)

model = LitModuleCombined(params)

trainer = pl.Trainer(
    max_epochs=params["max_epochs"],
    accelerator=accelerator,
    devices=1,
    callbacks=[pbar, earlystopping, checkpoint_callback],
    logger=tensorboard,
)
trainer.fit(model)
trainer.test(ckpt_path="best")
