import numpy as np
import sys
sys.path.append(".")
from pytorch_lightning import Trainer
from Data.GM12878_DataModule import GM12878Module
from Models.VEHiCLE_Module import GAN_Model

def train():
    dm  = GM12878Module(batch_size=1, piece_size=256)
    dm.prepare_data()
    dm.setup(stage='fit')

    print("Build model")
    model = GAN_Model()

    print("Train mordel")
    trainer = Trainer(accelerator = 'gpu', devices=1, max_epochs=50)
    trainer.fit(model, dm)
if __name__ == '__main__':
    train()