import numpy as np
from wandb.wandb_torch import torch
from ip.diffusion import *
from ip.configs.base_config import config
from torch_geometric.data import DataLoader
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
import os
import pickle


if __name__ == '__main__':
    record = False
    run_name = 'test'
    save_dir = f'./runs/{run_name}' if record else None

    if record and not os.path.exists(save_dir):
        os.makedirs(save_dir)

    config['save_dir'] = save_dir
    config['record'] = record

    data_root_val = f'PATH/TO/VAL/DATA'
    data_root = f'PATH/TO/TRAIN/DATA'
    data_samples = [torch.load(f'{data_root_val}/data_{samp}.pt') for samp in range(len(os.listdir(data_root_val)))]
    dataloader_val = DataLoader(data_samples, batch_size=config['batch_size_val'], shuffle=False)

    dset = RunningDataset(data_root, 100000, rand_g_prob=config['randomize_g_prob'])  # 100000 is the buffer size.
    dataloader = DataLoader(dset, batch_size=config['batch_size'], drop_last=True, shuffle=True,
                            num_workers=8, pin_memory=True)
    model = GraphDiffusion(config).to(config['device'])

    if record:
        logger = WandbLogger(project='InstantPolicy',
                             name=f'{run_name}',
                             save_dir=save_dir,
                             log_model=False)
        # Dump config to save_dir
        pickle.dump(config, open(f'{save_dir}/config.pkl', 'wb'))
    else:
        logger = None
    lr_monitor = LearningRateMonitor(logging_interval='step')
    trainer = L.Trainer(
        enable_checkpointing=False,  # We save the models manually.
        accelerator=config['device'],
        devices=1,
        max_steps=config['num_iters'],
        enable_progress_bar=True,
        precision='16-mixed',
        val_check_interval=20000,
        num_sanity_val_steps=10,
        check_val_every_n_epoch=None,
        logger=logger,
        log_every_n_steps=500,
        gradient_clip_val=1,
        gradient_clip_algorithm='norm',
        callbacks=[lr_monitor],
    )

    trainer.fit(
        model=model,
        train_dataloaders=dataloader,
        val_dataloaders=dataloader_val,
    )

    # Save last:
    if record:
        model.save_model(f'{save_dir}/last.pt')
