import os
import time
import torch

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from utils.ema_callback import EMA


if __name__ == '__main__':
    torch.set_float32_matmul_precision('high')
    data = 'cifar10'
    model = 'cnn'

    save_path = 'your_save_root' 

    os.makedirs(save_path, exist_ok=True)

    if data in ['celeba', 'ffhq']:
        from utils.dataloader.celeba import CelebA_Loader as Data
    elif data == 'cifar10':
        from utils.dataloader.cifar10 import Cifar10_Loader as Data

    if model == 'cnn':
        from module_cnn import VIAE
    elif model == 'vit':
        from module_vit import VIAE


    start_time = time.time()

    dataset = Data(batch_size=64, num_workers=18, data_dir='your_data_path')
    model = VIAE(lr=1e-4, 
                 idem_alpha=1.,
                 idem_beta=0.05,
                 max_sigma=2,
                 use_idempotent=True)
    
    best = ModelCheckpoint(
        monitor='val/loss',
        dirpath=save_path,
        filename='loss_{val/loss:.5f}',
        save_top_k=2,
        mode='min',
        save_last=True,
        save_weights_only=False)

    ema = EMA(decay=0.9995, 
              dirpath=save_path,
              save_last=True)
    callback = [best, ema]

    tb_l = TensorBoardLogger(save_path)

    trainer = Trainer(strategy='ddp_find_unused_parameters_true', 
                      accelerator='gpu', precision='bf16-mixed', accumulate_grad_batches=4,
                      gradient_clip_algorithm='norm', gradient_clip_val=1.0,
                      max_steps=800000, callbacks=callback, logger=tb_l, log_every_n_steps=20)

    trainer.fit(model, dataset)

    end_time = time.time()
    print(f'took: {(end_time - start_time) / 60:.2f} min')
    