import pdb
import sys
sys.path.append(".")
from pytorch_lightning import Trainer
from Data.GM12878_DataModule import GM12878Module
from Models.VAE_Module import VAE_Model
import os
def train():
    dm      = GM12878Module(piece_size=244)
    dm.prepare_data()
    dm.setup(stage='fit')

    pargs = {'batch_size': 512,
            'condensed_latent': 2,
            'gamma': 1.0, 
            'kld_weight': .000001,
            'kld_weight_inc': 0.000,
            'latent_dim': 200,
            'lr': 0.00001,
            'pre_latent': 2048}
    # 4608 for 257 x 257
    # 2048 for 244 x 244

    model    = VAE_Model(batch_size=pargs['batch_size'],
                        condensed_latent=pargs['condensed_latent'],
                        gamma=pargs['gamma'],
                        kld_weight=pargs['kld_weight'],
                        kld_weight_inc=pargs['kld_weight_inc'],
                        latent_dim=pargs['latent_dim'],
                        lr=pargs['lr'],
                        pre_latent=pargs['pre_latent'])
    trainer = Trainer(accelerator = 'gpu', devices=1, max_epochs=50)
    trainer.fit(model, dm)

if __name__ == '__main__':
    train()
'''
pargs = {'batch_size': 512,
        'condensed_latent': 3,
        'gamma': 1.0, 
        'kld_weight': .0001,
        'kld_weight_inc': 0.000,
        'latent_dim': 110,
        'lr': 0.00001,
        'pre_latent': 4608}
'''
