import os
import traceback

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import MLFlowLogger

from puupl.experiments import Experiment
from puupl.lib.utils import ConfigurationException, flatten_dict, get_config


def main() -> None:
    torch.set_default_dtype(torch.float64)
    torch.multiprocessing.set_start_method("spawn")

    config = get_config()

    # for reproducibility
    torch.manual_seed(config['logging_params']['manual_seed'])
    np.random.seed(config['logging_params']['manual_seed'])
    cudnn.deterministic = True
    cudnn.benchmark = True

    try:
        experiment = Experiment(
            config,
            run_name=config['logging_params']['run_name'],
            experiment_name=config['logging_params']['experiment_name']
        )

    # do not stop hyperband in case we catch errors from bad random values
    except ConfigurationException:
        print('invalid configuration detected')
        traceback.print_exc()

        hb_outdir = config['exp_params'].get('hyperband_output_dir')
        if hb_outdir is not None:
            with open(os.path.join(hb_outdir, 'result'), 'w') as f:
                f.write('nan')
        return

    if config['trainer_params'].get('gpus', 0) != 0 and not torch.cuda.is_available():
        config['trainer_params']['gpus'] = 0

    # compute max epochs from number of requested PL iterations
    pl_iterations = config['trainer_params'].pop('pl_iterations', None)
    if pl_iterations is not None:
        pl_every = config['exp_params'].get('pseudolabel_every')
        assert pl_every is not None
        max_epochs = (pl_iterations + 1) * pl_every - 1
        config['trainer_params']['max_epochs'] = max_epochs
        print(f'Set max_epochs to {max_epochs}')

    # fix problems with parallel dataloaders
    # https://github.com/pytorch/pytorch/issues/37377
    if config['exp_params'].get('dataloader_workers', 0) != 0:
        os.environ['MKL_THREADING_LAYER'] = 'GNU'

    mlflow_logger = MLFlowLogger(
        experiment_name=config['logging_params']['experiment_name']
    )
    flat_cfg = flatten_dict(config)

    # log actual configuration
    mlflow_logger.log_hyperparams(params=flat_cfg)
    width = max(map(len, flat_cfg.keys())) + 2
    print('===  Actual Configuration')
    for k, v in flat_cfg.items():
        print(f'{k:>{width}s} : {v}')
    print('===')

    try:
        trainer = Trainer(
            min_epochs=1,
            checkpoint_callback=True,
            logger=mlflow_logger,
            check_val_every_n_epoch=1,
            num_sanity_val_steps=5,
            fast_dev_run=False,
            **config['trainer_params']
        )

        trainer.fit(experiment)
        trainer.test()

    # handle certain errors stemming from bad random configurations during hyperband
    except RuntimeError as exc:
        hb_outdir = config['exp_params'].get('hyperband_output_dir')
        if hb_outdir is not None and len(exc.args) > 0 and any(e in exc.args[0] for e in [
            'CUDA out of memory',
            'cuDNN error: CUDNN_STATUS_NOT_SUPPORTED',
            'CUDA error: an illegal memory access was encountered',
            'loss is not finite',
        ]):
            print('swallowing allowed exception during hyperband')
            with open(os.path.join(hb_outdir, 'result'), 'w') as f:
                f.write('nan')
            traceback.print_exc()
        else:
            raise


if __name__ == '__main__':
    main()
