from pathlib import Path

import torch
import pytorch_lightning as pl


def pl_train(cfg, pl_module_cls):
    if cfg.seed is not None:
        pl.seed_everything(cfg.seed)
    model = pl_module_cls(cfg.model, cfg.dataset, cfg.train)
    trainer = pl.Trainer(
        # gpus=1 if config['gpu'] else None,
        gpus=1,
        gradient_clip_val=cfg.train.gradient_clip_val,
        max_epochs=1 if cfg.smoke_test else cfg.train.epochs,
        early_stop_callback=False, progress_bar_refresh_rate=1,
        limit_train_batches=cfg.train.limit_train_batches,
        checkpoint_callback=False,  # Disable checkpointing to save disk space
    )
    trainer.fit(model)
    if 'save_checkpoint_path' in cfg.train:
        path = cfg.train.save_checkpoint_path
        if 'dataset' in cfg  and 'crossfit_index' in cfg.dataset:
            path =  path.replace('.ckpt', f'{cfg.dataset.crossfit_index}.ckpt')
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        trainer.save_checkpoint(str(path))
    return trainer, model
