import os
from config import cfg
import torch
import numpy as np
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import random
from pytorch_lightning.loggers import TensorBoardLogger

from models import CL_Encoder, SP_Encoder

torch.autograd.set_detect_anomaly(True)

def create_dataloader(dataset, config, shuffle=True):
    loader = DataLoader(
        dataset,
        batch_size=config['optim_params']['batch_size'],
        shuffle=shuffle,
        pin_memory=True,
        drop_last=shuffle,
        num_workers=config['experiment_params']['data_loader_workers'],
    )
    return loader


def seed_everything(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_config(cfg):
    model_dict = {
        'triplet': CL_Encoder,
        'nll': SP_Encoder
    }
    model = model_dict[cfg['loss_params']['name']](cfg)
    
    return model

def run():
    
    seed_everything(cfg['experiment_params']['seed'])
    logger = TensorBoardLogger(cfg['experiment_params']['exp_name'], name=cfg['experiment_params']['exp_name'])
    ckpt_callback = pl.callbacks.ModelCheckpoint(
        monitor='eval_loss',
        dirpath=cfg['experiment_params']['exp_dir'],
        save_top_k=1,
        every_n_epochs=cfg['experiment_params']['checkpoint_epochs'],
    )

    model = get_config(cfg)
    trainer = pl.Trainer(
        default_root_dir=cfg['experiment_params']['exp_dir'],
        accelerator="gpu", 
        strategy='ddp_find_unused_parameters_true',
        devices=[0],
        max_epochs=int(cfg['experiment_params']['num_epochs']),
        logger=logger,
        callbacks=[ckpt_callback],
    )

    # model, train dataloader, test dataloader
    trainer.fit(model)
    ## Save the model
    trainer.save_checkpoint(os.path.join(cfg['experiment_params']['exp_dir'],"{}-{}.ckpt".format(cfg['experiment_params']['exp_name'],cfg["model_params"]["latent_dim"])))
    model.save(directory=cfg['experiment_params']['exp_dir'])
    ## Evaluation:
    trainer.test(model)


if __name__ == "__main__":
    run()
