import pytorch_lightning as pl
from data import LitDataModule, SSTDataset
from disentangle import MechanisticLitModule, TIMechanisticLitModule
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CSVLogger

is_training = True

param_dim = 20
n_views = 2
base_dir="logs"
if __name__ == "__main__":

    pl.seed_everything(412)

    batch_size = int(360 * 18)
    datamodule = LitDataModule(SSTDataset, batch_size=batch_size, chunk_size=52 * 4)

    train_sample = next(iter(datamodule.train_dataloader()))

    method = TIMechanisticLitModule(
        learning_rate=1e-5,
        batch_size=batch_size,
        n_views=n_views if is_training else 1,
        order=2,
        state_dim=1,
        param_dim=param_dim,
        n_step=datamodule.train_set.chunk_size,
        n_iv_steps=10,
        mlp_enc=True,
        dct_layer=True,
        freq_frac_to_keep=0.25,
        eval_metrics=[],
        hidden_dim=1024,
        code_sharing=None,
        alignment_reg=0.
    )

    trainer = Trainer(
        # max_steps=15000,
        max_epochs=500,
        accelerator="auto",
        devices=[0],
        log_every_n_steps=500,
        check_val_every_n_epoch=20,
        # if you want to run without validation
        # limit_val_batches=0,
        # num_sanity_val_steps=0,
        logger=CSVLogger(base_dir, name="TI_mnn_sst"),
    )

    trainer.fit(method, datamodule)
