import os 
import hydra
from models.ctrl_sim_finetuning import CtRLSim
from datamodules.waymo_rl_datamodule_finetuning import RLWaymoDataModuleFineTuning

import torch
torch.set_float32_matmul_precision('medium')
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import WandbLogger

@hydra.main(version_base=None, config_path="/home/ctrl-sim-dev/cfgs/", config_name="config")
def main(cfg):
    pl.seed_everything(cfg.train.seed, workers=True)

    datamodule = RLWaymoDataModuleFineTuning(cfg)
    save_dir = f'/home/wandb/{cfg.train.run_name}'
    model = CtRLSim.load_from_checkpoint(os.path.join(save_dir, 'model.ckpt'), cfg=cfg, data_module=datamodule)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    # save all models. Make sure disk space can handle this!
    model_checkpoint = ModelCheckpoint(monitor=None, save_top_k=-1, every_n_epochs=1, dirpath=save_dir, filename='model_finetuning')
    lr_monitor = LearningRateMonitor(logging_interval='step')
    model_summary = ModelSummary(max_depth=-1)
    wandb_logger = WandbLogger(
        project='decepticons',
        name=cfg.train.run_name,
        entity='swish',
        log_model=False,
        save_dir=save_dir
    )
    if cfg.train.track:
        logger = wandb_logger 
    else:
        logger = None
    
    trainer = pl.Trainer(accelerator=cfg.train.accelerator, # set to cpu when debugging layer shapes
                         devices=cfg.train.devices,
                         strategy=DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True),
                         callbacks=[model_summary, model_checkpoint, lr_monitor],
                         max_steps=cfg.train_finetuning.max_steps,
                         check_val_every_n_epoch=cfg.train.check_val_every_n_epoch,
                         precision=cfg.train.precision,
                         limit_train_batches=cfg.train.limit_train_batches, # train on smaller dataset
                         limit_val_batches=0.01, # we really don't care much about val
                         gradient_clip_val=cfg.train.gradient_clip_val,
                         logger=logger
                        )
    trainer.fit(model, datamodule)

if __name__ == '__main__':
    main()


