import os 
import hydra
from models.ctrl_sim_diffusion import CtRLSimDiffusion
from datamodules.waymo_rl_datamodule import RLWaymoDataModule 

import torch
torch.set_float32_matmul_precision('medium')
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import WandbLogger

@hydra.main(version_base=None, config_path="/home/ctrl-sim-dev/cfgs/", config_name="config")
def main(cfg):
    pl.seed_everything(cfg.train_diffusion.seed, workers=True)

    model = CtRLSimDiffusion(cfg)
    datamodule = RLWaymoDataModule(cfg, use_diffusion=True)
    save_dir = f'/home/wandb/{cfg.train_diffusion.run_name}'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    model_checkpoint = ModelCheckpoint(monitor='state_mse', save_top_k=15, save_last=True, mode='min', dirpath=save_dir)
    lr_monitor = LearningRateMonitor(logging_interval='step')
    model_summary = ModelSummary(max_depth=-1)
    wandb_logger = WandbLogger(
        project='decepticons_diffusion',
        name=cfg.train_diffusion.run_name,
        entity='swish',
        log_model=False,
        save_dir=save_dir
    )
    if cfg.train_diffusion.track:
        logger = wandb_logger 
    else:
        logger = None
    
    files_in_save_dir = os.listdir(save_dir)
    ckpt_path = None
    for file in files_in_save_dir:
        if file.endswith('.ckpt') and 'last' in file:
            ckpt_path = os.path.join(save_dir, file)
            print("Resuming from checkpoint: ", ckpt_path)
    
    trainer = pl.Trainer(accelerator=cfg.train_diffusion.accelerator, # set to cpu when debugging layer shapes
                         devices=cfg.train_diffusion.devices,
                         strategy=DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True),
                         callbacks=[model_summary, model_checkpoint, lr_monitor],
                         max_steps=cfg.train_diffusion.max_steps,
                         check_val_every_n_epoch=cfg.train_diffusion.check_val_every_n_epoch,
                         precision=cfg.train_diffusion.precision,
                         limit_train_batches=cfg.train_diffusion.limit_train_batches, # train on smaller dataset
                         limit_val_batches=cfg.train_diffusion.limit_val_batches,
                         gradient_clip_val=cfg.train_diffusion.gradient_clip_val,
                         logger=logger
                        )
    trainer.fit(model, datamodule, ckpt_path=ckpt_path)

if __name__ == '__main__':
    main()


