import os, sys
import pathlib
from argparse import ArgumentParser
sys.path.insert(0, os.path.dirname(pathlib.Path(__file__).parent.absolute())   )

import pytorch_lightning as pl
from pl_modules.image_data_module import ImageNetDataModule
from  pl_modules.ncsn_module import NCSN_Module

from data.data_transforms import ImageDataTransform

# Imports for logging and other utility
import yaml
import torch.distributed
from utils import load_config_from_yaml

def cli_main(args):
    
    if args.verbose:
        print(args.__dict__)
        print('pytorch-lightning version: {}'.format(pl.__version__))
                
    pl.seed_everything(args.seed)
    
    # Set up schedules
    exp_config = load_config_from_yaml(args.experiment_config_file)
    operator_config = exp_config['operator']
    noise_config = exp_config['noise']
    print('Loaded operator: ', operator_config)
    print('Loaded noise schedule: ', noise_config)
        
    dt = 1.0/args.num_steps
    val_dt = None if args.num_val_steps is None else 1.0/args.num_val_steps

    # ------------
    # model
    # ------------
    model = NCSN_Module(
        dt=dt,        
        val_dt=val_dt,
        operator_config=operator_config,
        noise_config=noise_config,
        loss_type=args.loss_type,
        lr=args.lr,
        lr_step_size=args.lr_step_size,
        lr_gamma=args.lr_gamma,
        residual_prediction=(not args.no_residual_prediction),
        weight_decay=args.weight_decay,
        logger_type=args.logger_type,
        max_epochs=args.max_epochs,
        full_val_only_last_epoch=args.full_val_only_last_epoch,
        num_log_images=args.num_log_images,
    )
    
    # ------------
    # data
    # ------------
    train_transform = ImageDataTransform(is_train=True, operator_config=operator_config, noise_config=noise_config, dt=dt)
    val_transform = ImageDataTransform(is_train=False, operator_config=operator_config, noise_config=noise_config, dt=dt)
    test_transform = ImageDataTransform(is_train=False, operator_config=operator_config, noise_config=noise_config, dt=dt)
    
    # ptl data module - this handles data loaders
    data_module = ImageNetDataModule(
        data_path=args.data_path,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        sample_rate_dict={'train': args.sample_rates[0], 'val': args.sample_rates[1], 'test': args.sample_rates[2]},
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        distributed_sampler=True,
    )

    # ------------
    # trainer
    # ------------
    # set up logger
    if args.logger_type == 'tb':
        logger = True
    elif args.logger_type == 'wandb':
        logger = pl.loggers.WandbLogger(project=args.experiment_name)
    else:
        raise ValueError('Unknown logger type.')
        
    callbacks=[]
    
    if args.save_checkpoints:
        callbacks.append(args.checkpoint_callback)
    trainer = pl.Trainer.from_argparse_args(args, 
                                            enable_checkpointing=args.save_checkpoints,
                                            callbacks=callbacks,
                                            logger=logger,
                                            strategy="ddp", 
)
    
    # Save all hyperparameters to .yaml file in the current log dir
    if torch.distributed.is_available():
        if torch.distributed.is_initialized():
            if torch.distributed.get_rank() == 0:
                save_all_hparams(trainer, args)
    else: 
         save_all_hparams(trainer, args)     
            
    # ------------
    # run
    # ------------
    trainer.fit(model, datamodule=data_module)

def save_all_hparams(trainer, args):
    if not os.path.exists(trainer.logger.log_dir):
        os.makedirs(trainer.logger.log_dir)
    save_dict = args.__dict__
    save_dict.pop('checkpoint_callback')
    with open(trainer.logger.log_dir + '/hparams.yaml', 'w') as f:
        yaml.dump(save_dict, f)
    
def build_args():
    parser = ArgumentParser()

    # basic args
    backend = "ddp"
    batch_size = 1

    # client arguments
    parser.add_argument(
        '--experiment_config_file', 
        type=pathlib.Path,          
        help='Experiment configuration will be loaded from this file.',
    )
    parser.add_argument(
        '--verbose', 
        default=False,   
        action='store_true',          
        help='If set, print all command line arguments at startup.',
    )
    parser.add_argument(
        '--logger_type', 
        default='tb',   
        type=str,          
        help='Set Pytorch Lightning training logger. Options "tb" - Tensorboard (default), "wandb" - Weights and Biases',
    )
    parser.add_argument(
        '--experiment_name', 
        default='test-exp',   
        type=str,          
        help='Used with wandb logger to define the project name.',
    )
    parser.add_argument(
        '--full_val_only_last_epoch', 
        default=False,   
        action='store_true',          
    )
    parser.add_argument(
        '--save_checkpoints', 
        default=False,   
        action='store_true',          
    )

    # data config
    parser = ImageNetDataModule.add_data_specific_args(parser)
    parser.set_defaults(
        test_path=None,  # path for test split, overwrites data_path
    )

    # module config
    parser = NCSN_Module.add_model_specific_args(parser)

    # trainer config
    parser = pl.Trainer.add_argparse_args(parser)
    parser.set_defaults(
        accelerator='gpu',  # what distributed version to use
        seed=42,  # random seed
    )

    args = parser.parse_args()

    args.checkpoint_callback = pl.callbacks.ModelCheckpoint(
        save_top_k=1,
        verbose=True,
        monitor="val/weighted_mse_loss",
        mode="min",
        filename='epoch{epoch}-val-loss{val/weighted_mse_loss:.4f}',
        auto_insert_metric_name=False,
        save_last=True
    )

    return args


def run_cli():
    args = build_args()

    # ---------------------
    # RUN TRAINING
    # ---------------------
    cli_main(args)


if __name__ == "__main__":
    run_cli()