from pathlib import Path
project_root = Path(__file__).parent.absolute()

import os
import random
import math
from collections.abc import Sequence
from functools import partial

import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

from munch import Munch
from omegaconf.listconfig import ListConfig

import ray
from ray import tune
from ray.tune import Trainable, Experiment, sample_from, grid_search
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.integration.wandb import WandbLogger

from utils import dictconfig_to_munch, munch_to_dictconfig


TUNE_KEYS = ['_grid', '_sample', '_sample_uniform', '_sample_log_uniform']


def munchconfig_to_tune_munchconfig(cfg):
    """Convert config to one compatible with Ray Tune.
    Entry as list whose first element is "_grid" is converted to ray.tune.grid_search.
    "_sample" is converted to ray.tune.sample_from.
    "_sample_uniform" is converted to ray.tune.sample_from with uniform distribution [min, max).
    "_sample_log_uniform" is converted to ray.tune.sample_from with uniform distribution
        exp(uniform(log(min), log(max)))
    Examples:
        lr=1e-3 for a specific learning rate
        lr=[_grid, 1e-3, 1e-4, 1e-5] means grid search over those values
        lr=[_sample, 1e-3, 1e-4, 1e-5] means randomly sample from those values
        lr=[_sample_uniform, 1e-4, 3e-4]  means randomly sample from those min/max
        lr=[_sample_log_uniform, 1e-4, 1e-3]  means randomly sample from those min/max but
            distribution is log uniform: exp(uniform(log 1e-4, log 1e-3))
    """

    def convert_value(v):
        # The type is omegaconf.listconfig.ListConfig and not list, so we test if it's a Sequence
        # In hydra 0.11, more precisely omegaconf 1.4.1, ListConfig isn't an instance of Sequence.
        # So we have to test it directly.
        if not (isinstance(v, (Sequence, ListConfig)) and len(v) > 0 and v[0] in TUNE_KEYS):
            return v
        else:
            if v[0] == '_grid':
                # grid_search requires list for some reason
                return grid_search(list(v[1:]))
            elif v[0] == '_sample':
                # Python's lambda doesn't capture the object, it only captures the variable name
                # So we need extra argument to capture the object
                # https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
                # https://stackoverflow.com/questions/2295290/what-do-lambda-function-closures-capture
                # Switching back to not capturing variable since (i) ray 1.0 doesn't like that
                # (ii) v isn't changing in this scope
                return sample_from(lambda _: random.choice(v[1:]))
            elif v[0] == '_sample_uniform':
                min_, max_ = v[1:]
                if isinstance(min_, int) and isinstance(max_, int):
                    return sample_from(lambda _: random.randint(min_, max_))
                else:
                    return sample_from(lambda _: random.uniform(min_, max_))
            elif v[0] == '_sample_log_uniform':
                min_, max_ = v[1:]
                return sample_from(lambda _: math.exp(random.uniform(math.log(min_), math.log(max_))))
            else:
                assert False

    def convert(cfg):
        return Munch({k: convert(v) if isinstance(v, Munch) else
                      convert_value(v) for k, v in cfg.items()})

    return convert(cfg)


class TuneReportCallback(Callback):
    # We group train and val reporting into one, otherwise tune thinks there're 2 different epochs.
    def on_train_epoch_end(self, trainer, pl_module):
        results = pl_module._train_results
        results.update(getattr(pl_module, '_val_results', {}))
        results['mean_loss'] = results.get('val_loss', results['train_loss'])
        if 'val_accuracy' in results:
            results['mean_accuracy'] = results['val_accuracy']
        tune.report(**results)

    def on_test_epoch_end(self, trainer, pl_module):
        results = pl_module._test_results
        tune.report(**results)


class CheckpointCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
            trainer.save_checkpoint(os.path.join(checkpoint_dir,
                                                 f"{type(pl_module).__name__}.ckpt"))


def pl_train_with_tune(cfg, pl_module_cls, checkpoint_dir=None):
# def pl_train_with_tune(cfg, pl_module_cls):
    cfg = munch_to_dictconfig(Munch(cfg))
    if cfg.seed is not None:
        pl.seed_everything(cfg.seed)
    model = pl_module_cls(cfg.model, cfg.dataset, cfg.train)
    checkpoint_path = (None if not checkpoint_dir
                       else os.path.join(checkpoint_dir, f"{pl_module_cls.__name__}.ckpt"))
    trainer = pl.Trainer(
        gpus=1 if cfg.gpu else None,
        gradient_clip_val=cfg.train.gradient_clip_val,
        max_epochs=cfg.train.epochs,
        early_stop_callback=False,
        progress_bar_refresh_rate=0,
        limit_train_batches=cfg.train.limit_train_batches,
        checkpoint_callback=False,  # Disable pl's checkpointing to save disk space
        resume_from_checkpoint=checkpoint_path,
        callbacks=[TuneReportCallback(), CheckpointCallback()]
        # callbacks=[TuneReportCallback()]
    )
    trainer.fit(model)
    if 'save_checkpoint_path' in cfg.train:
        path = cfg.train.save_checkpoint_path
        if 'dataset' in cfg  and 'crossfit_index' in cfg.dataset:
            path = path.replace('.ckpt', f'{cfg.dataset.crossfit_index}.ckpt')
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        trainer.save_checkpoint(str(path))
    # trainer.test(model)


def ray_train(cfg, pl_module_cls):
    # Munch is much easier to deal with than DictConfig
    cfg = munchconfig_to_tune_munchconfig(dictconfig_to_munch(cfg))
    ray_config={
        'model': cfg.model,
        'dataset': cfg.dataset,
        'train': cfg.train,
        'seed': cfg.seed,
        'wandb': cfg.wandb,
        'gpu': cfg.runner.gpu_per_trial != 0.0,
    }
    dataset_str = cfg.dataset._target_.split('.')[-1]
    model_str = cfg.model._target_.split('.')[-1]
    args_str = '_'
    experiment = Experiment(
        name=f'{dataset_str}_{model_str}',
        run=partial(pl_train_with_tune, pl_module_cls=pl_module_cls),
        local_dir=cfg.runner.result_dir,
        num_samples=cfg.runner.ntrials if not cfg.smoke_test else 1,
        resources_per_trial={'cpu': 1 + cfg.dataset.num_workers, 'gpu': cfg.runner.gpu_per_trial},
        # epochs + 1 because calling trainer.test(model) counts as one epoch
        # pl seems to run an extra epoch when resumed from checkpoint so + 2 here.
        stop={"training_iteration": 1 if cfg.smoke_test else cfg.train.epochs + 2},
        config=ray_config,
        loggers=[WandbLogger],
        keep_checkpoints_num=1,  # Save disk space, just need 1 for recovery
        # checkpoint_at_end=True,
        # checkpoint_freq=1000,  # Just to enable recovery with @max_failures
        max_failures=-1,
        sync_to_driver=lambda source, target: None,  # We're writing to dfs or efs already, no need to sync explicitly
        # This needs to be a noop function, not just False. If False, ray won't restore failed spot instances
    )

    if cfg.smoke_test or cfg.runner.local:
        ray.init(num_gpus=torch.cuda.device_count())
    else:
        try:
            ray.init(address='auto')
        except:
            try:
                with open(project_root / 'ray_config/redis_address', 'r') as f:
                    address = f.read().strip()
                with open(project_root / 'ray_config/redis_password', 'r') as f:
                    password = f.read().strip()
                    ray.init(address=address, redis_password=password)
            except:
                ray.init(num_gpus=torch.cuda.device_count())
                import warnings
                warnings.warn("Running Ray with just one node")

    if cfg.runner.hyperband:
        scheduler = AsyncHyperBandScheduler(metric='mean_accuracy', mode='max',
                                            max_t=cfg.train.epochs,
                                            grace_period=cfg.runner.grace_period)
    else:
        scheduler = None
    trials = ray.tune.run(experiment, scheduler=scheduler,
                          raise_on_failed_trial=False, queue_trials=True)
    return trials
