from pathlib import Path
from omegaconf import DictConfig, OmegaConf
import cv2
from pytorch_lightning.loggers import TensorBoardLogger 
from pytorch_lightning import Trainer, seed_everything
# from pytorch_lightning.plugins import DDPPlugin
from pl_model import DiffusionModel

def multiple_gpu_setup(cfg: DictConfig) -> None:
    try:
        n_devices = cfg.devices
    except:
        n_devices = cfg.get('gpus', None)
    if n_devices is None:
        return

    if OmegaConf.is_list(n_devices):
        n_devices = len(n_devices)
    if n_devices > 1:
        msg = f"Number of devices = {n_devices} > 1. "
        msg += "Setting n_threads = 0 and disabling opencl "
        msg += "to prevent deadlocks."
        print(msg)

        cv2.setNumThreads(0)
        cv2.ocl.setUseOpenCL(False)


def get_exp_name(cfg):
    if cfg.pipeline['loss_thr'] > 100:
        return 'original_image'
    
    pipe_seq = ['eps', 'lr', 'wm_loss_w', 
                'lpips_w', 'mse_w', 'grad_thr', 
                'thr_num_iter', 'max_num_iter',
                # 'early_stop_iter_num',
                'post_swap']
    name = ''
    for k in pipe_seq:
        name += f'{k}_{cfg.pipeline[k]}_'
    # seq = ['blur_sigma', 'jpeg_quality', 'rotation_magnitude']    
    # for k in seq:
    #     name += f'{k}_{cfg[k]}_'   
    
    name += cfg.pipeline['add_tag']  
        
    return name

def validate(cfg: DictConfig):
    multiple_gpu_setup(cfg.trainer)

    log_dir = Path(cfg.experiment_path, "logs")

    if cfg.seed is not None:
        seed_everything(cfg.seed, workers=True)

    net = DiffusionModel(cfg)

    logger = TensorBoardLogger(
                                name='swap-watermarking-1000',
                                # name='swap-watermarking-2',
                                version='different_keys',
                                # version=get_exp_name(cfg), 
                                # version='original_images',
                                save_dir=log_dir
                                )
    logger = None
    
    trainer_dict = dict(cfg.__dict__['_content']['trainer'])
    # if trainer_dict.pop('strategy', None) == 'ddp':
    #     trainer_dict['strategy'] = DDPPlugin(find_unused_parameters=False)
    trainer = Trainer(
            callbacks=None,
            logger=logger,
            **trainer_dict,
            num_sanity_val_steps=0, # to skip Sanity Check
            inference_mode=False,
        )

    trainer.validate(net)
    