# Standard Library Imports
import os
import sys

# Third-Party Library Imports
import torch

# Local Imports
root_dir = "../"
sys.path.append(root_dir)
from expressive_cnn import classifier
import utils
import sim_dataset

"""
PREREQUISITES 
"""
# Set torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Experiment name
exp_name = "repression"
print(f"Training a classifier for a {exp_name} configuration")

# Load untrained classifier
clf = classifier().to(device)
print("Classifier loaded successfully")

# Training parameters
num_epochs = 10
optimizer = torch.optim.Adam(clf.parameters(), lr=1e-3)

# Path to data configs
sim_pos_config_path = os.path.join(exp_name, "data", "spi1_ctcf_exp_config.json")
sim_neg_config_path = os.path.join(exp_name, "data", "spi1_ctcf_exp_neg_config.json")

# Daataloaders
batch_size = 64
num_train_batches = 500
num_val_batches = 20
dataloaders = {
    "train": sim_dataset.create_data_loader(
    sim_pos_config_path, neg_motif_config_path=sim_neg_config_path, 
    batch_size=batch_size, num_batches=num_train_batches
),
    "val": sim_dataset.create_data_loader(
    sim_pos_config_path, neg_motif_config_path=sim_neg_config_path,
    batch_size=batch_size, num_batches=num_val_batches
),
}

# Dict of params
params = {
    'dataloaders': dataloaders,
    'optimizer': optimizer,
    'save_path': os.path.join(exp_name, "models"), 
    'num_epochs': num_epochs, 
    'device': device,
}

"""
Main function
"""
if __name__ == "__main__":

    # Make dictionary of parameters to save
    param_dict = {
        "lr": optimizer.param_groups[0]['lr'],
        "num_epochs": params["num_epochs"],
        "batch size": batch_size,
    }

    # Save params
    utils.save_dict(param_dict, params["save_path"], "params.txt")
    clf.train_and_validate(dataloaders, optimizer, params["save_path"], num_epochs=num_epochs, device=device, real_data=False)

