from braivest.train.Trainer import Trainer
import wandb
from braivest.utils import load_data
from braivest.train.custom_wandb_callbacks import CustomWandbCallback
import os
import json

def train_config(config=None):
    with wandb.init(config=config) as run:
        config = wandb.config

        artifact = run.use_artifact(config.data_artifact, type='dataset')
        artifact_dir = artifact.download()

        train_X = load_data(artifact_dir, 'train.npy')
        test = load_data(artifact_dir, 'test.npy')
        test_hypno = load_data(artifact_dir, 'test_hypno.npy')
        input_dim = train_X.shape[1]
        trainer= Trainer(config, input_dim)

        trainer.load_dataset(artifact_dir)
        custom_callbacks = [CustomWandbCallback(test, test_hypno, plot=False)]
        history = trainer.train(wandb=True, save_best_only=config.save_best, save_dir = wandb.run.dir, custom_callbacks=custom_callbacks)
        #save last model just in case
        trainer.model.save_weights(os.path.join(wandb.run.dir, "model.h5"))

def main():
    wandb.login()
    config = json.load(open('config.json'))
    train_config(config)

if __name__ == '__main__':
	main()