# Standard Library Imports
import os
import sys

# Third-Party Library Imports
import numpy as np
import torch
from torch.utils.data import DataLoader

# Local Imports
root_dir = "../"
sys.path.append(root_dir)
import utils
import sim_utils
import sim_dataset

"""
PREREQUISTES
"""
# Empty cache
torch.cuda.empty_cache()

# Set seed
np.random.seed(123)
torch.manual_seed(456)

# Set torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Argument Parser
args = sim_utils.scrambler_parse_args()

# Experiment name
print(f"Training a alpha={args.alpha} scrambler for a {args.exp_name} configuration")

# Load trained classifier
clf_dir = os.path.join(args.exp_name, "models")
clf = sim_utils.load_classifier(state_dict_path=clf_dir, device=device)

# Loading explainer
scram = sim_utils.load_scrambler(device, clf, state_dict_path=None)

# Original data
sim_pos_config_path = os.path.join(args.exp_name, "data", "spi1_ctcf_exp_config.json")
sim_neg_config_path = os.path.join(args.exp_name, "data", "spi1_ctcf_exp_neg_config.json")

batch_size = args.batch_size
num_train_batches = 500
num_val_batches = 20

train_loader = sim_dataset.create_data_loader(
    sim_pos_config_path, neg_motif_config_path=sim_neg_config_path, 
    batch_size=batch_size, num_batches=num_train_batches
)
val_loader = sim_dataset.create_data_loader(
    sim_pos_config_path, neg_motif_config_path=sim_neg_config_path,
    batch_size=batch_size, num_batches=num_val_batches
)

# Filtered data
mex_train_data = utils.filter_dataset(train_loader, clf)
mex_val_data = utils.filter_dataset(val_loader, clf)
dataloaders = {
    "train": DataLoader(mex_train_data, batch_size, shuffle=True),
    "val": DataLoader(mex_val_data, batch_size, shuffle=False),
}

print('Original number of training samples = ', num_train_batches*batch_size)
print('Correct number of training samples = ', len(mex_train_data))

# Optimizer
optimizer = torch.optim.Adam(scram.parameters(), lr=args.lr)

# Dict of params for training
params = {
    'dataloaders': dataloaders,
    'optimizer': optimizer,
    'save_path': None,
    'alpha': args.alpha,
    'kl_mult': args.kl_mult, 
    't_bits': args.t_bits, 
    'num_epochs': args.num_epochs, 
    'num_bkgd_samples': args.num_bkgd_samples,
    'device': device,
    'log_frac': 0.5,
    'wandb_log': False,
}

# Make directory to save parameters and results
if params["alpha"] == 1:
    save_path = os.path.join(args.exp_name, "results", "suff")
elif params["alpha"] == 0:
    save_path = os.path.join(args.exp_name, "results", "nec")
if not os.path.exists(save_path):
    os.makedirs(save_path)
    print(f"Folder {save_path} created")

# Update path to save results
params["save_path"] = save_path

"""
MAIN FUNCTION
"""
if __name__ == "__main__":

    # Make dictionary of parameters to save
    param_dict = {
        "lr": optimizer.param_groups[0]['lr'],
        "alpha": params["alpha"],
        "kl_mult": params["kl_mult"],
        "t_bits": params["t_bits"],
        "num_epochs": params["num_epochs"],
        "num_bkgd_samples": params["num_bkgd_samples"]
    }

    # Save params
    utils.save_dict(param_dict, params["save_path"], "scrambler_params.txt")

    scram.train_and_validate(**params)