import contextlib
import os

os.environ['JAX_DEFAULT_DTYPE_BITS'] = '32'
# os.environ['JAX_COMPILATION_CACHE_DIR'] = '/tmp/jax_cache'
# FIXME: Don't understand why we get empty CUDA graphs, but this fixes things.
# os.environ['XLA_FLAGS'] = (
#     '--xla_gpu_enable_command_buffer="" '
# '--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_latency_hiding_scheduler=True '
# '--xla_gpu_memory_limit_slop_factor=50'
# )
import logging
from copy import deepcopy
from pathlib import Path

import jax
import numpy as np
import rich.syntax
import seml
import yaml
from seml.utils.yaml import YamlDumper

import wandb
from neural_pfaffian import config
from neural_pfaffian.clipping import CLIPPINGS, MASKINGS
from neural_pfaffian.dataset import create_systems
from neural_pfaffian.evaluate import evaluate, parse_epochs_and_walkers_from_config
from neural_pfaffian.logger import Logger
from neural_pfaffian.mcmc import MetropolisHastings
from neural_pfaffian.nn import (
    ANTISYMMETRIZERS,
    EMBEDDINGS,
    ENVELOPES,
    JASTROWS,
    META_NETWORKS,
    GeneralizedWaveFunction,
    WaveFunction,
)
from neural_pfaffian.nn.wave_function import MixtureLogAmplitude
from neural_pfaffian.overlap import OverlapPenalty
from neural_pfaffian.overlap_scaler import OVERLAP_SCALER
from neural_pfaffian.preconditioner import PRECONDITIONER
from neural_pfaffian.sample_reweighting import ReweightingMode
from neural_pfaffian.spin_operator import SpinPenalty
from neural_pfaffian.train import pretrain, thermalize, train
from neural_pfaffian.utils.optim import make_optimizer
from neural_pfaffian.vmc import VMC

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_matmul_precision', 'float32')
# Memmory flags
with contextlib.suppress(AttributeError):
    jax.config.update('jax_memory_fitting_effort', 1)
    jax.config.update('jax_memory_fitting_level', 'O3')
    jax.config.update('jax_vjp3', True)
jax.config.update('jax_persistent_cache_min_entry_size_bytes', -1)
jax.config.update('jax_persistent_cache_min_compile_time_secs', 1)
# jax.config.update('jax_explain_cache_misses', True)
ex = seml.Experiment()

root_dir = Path(config.__file__).parent
ex.add_config(str(root_dir.joinpath('default.yaml')))


def main(
    seed,
    vmc_config,
    wave_function_config,
    pretraining_config,
    systems_config,
    logging_config,
    evaluation_config,
):
    # Proper main file
    mutable_config = get_config()  # type: ignore

    key = jax.random.key(seed)
    np.random.seed(seed)

    logging.info('Running with config:')
    cfg_str = yaml.dump(
        mutable_config,
        indent=2,
        default_flow_style=None,
        Dumper=YamlDumper,
    )
    rich.print(rich.syntax.Syntax(cfg_str.strip(), 'yaml', background_color='default'))

    _validate_config(mutable_config)

    logging.info('Creating systems')
    key, subkey = jax.random.split(key)
    systems = create_systems(subkey, **systems_config)

    # Initialize the wave function
    logging.info('Initializing wave function')
    wave_function = GeneralizedWaveFunction.create(
        WaveFunction(
            EMBEDDINGS.init(**wave_function_config['embedding']),
            ANTISYMMETRIZERS.init(
                **wave_function_config['orbitals'],
                envelope=ENVELOPES.init(**wave_function_config['envelope']),
            ),
            JASTROWS.init_many(wave_function_config['jastrows']),
        ),
        META_NETWORKS.init_or_none(**wave_function_config['meta_network']),
        systems,
    )

    # Initialize VMC object
    logging.info('Initializing VMC')
    preconditioner = PRECONDITIONER.init(
        **vmc_config['preconditioner'],
        wave_function=wave_function,
    )
    optimizer = make_optimizer(vmc_config['optimizer'])
    mcmc = MetropolisHastings(wave_function, **vmc_config['mcmc'])
    clipping = CLIPPINGS.init(**vmc_config['clipping'])
    masking = MASKINGS.init(**vmc_config['masking'])
    overlap_penalty = None
    if systems.max_num_states > 1:
        logging.info('Found excitations, initializing overlap penalty')
        overlap_config = vmc_config['state_overlap']
        overlap_clipping = CLIPPINGS.init(**overlap_config['clipping'])
        overlap_masking = MASKINGS.init(**overlap_config['masking'])
        overlap_penalty = OverlapPenalty(
            wave_function,
            overlap_clipping,
            OVERLAP_SCALER.init(**overlap_config['scaler']),
            penalty_scale=overlap_config['penalty_scale'],
            dtype=overlap_config['dtype'],
            masking=overlap_masking,
        )
    spin_penalty_config = vmc_config.get('spin_penalty')
    spin_penalty = None
    if spin_penalty_config['enabled']:
        logging.info('Found spin penalty, initializing spin penalty')
        spin_penalty = SpinPenalty.create(
            wave_function=wave_function,
            sample_masking=MASKINGS.init(**spin_penalty_config['masking']),
            ratio_clipping=CLIPPINGS.init(**spin_penalty_config['clipping']),
            penalty_scale=spin_penalty_config['penalty_scale'],
            max_grad_norm=spin_penalty_config['max_grad_norm'],
            penalty_type=spin_penalty_config['penalty_type'],
            spin_ema_decay=spin_penalty_config['decay'],
        )
    vmc = VMC(
        wave_function,
        preconditioner,
        optimizer,
        mcmc,
        clipping,
        masking,
        overlap_penalty,
        spin_penalty,
        reweighting_mode=vmc_config['reweighting_mode'],
        determinant_regularization=vmc_config['determinant_regularization'],
        normalizer_regularization=vmc_config['normalizer_regularization'],
    )

    # init state
    logging.info('Initializing VMC state')
    key, subkey = jax.random.split(key)
    state = vmc.init(subkey, systems)

    # Init wandb
    logging.info('Initializing logger')
    logger = Logger(str(systems), logging_config, experiment=ex)
    logger.update_and_log_config(mutable_config)
    continue_training = logger.has_checkpoint()
    # Pretraining
    if continue_training:
        logging.info('Found checkpoint, skipping pretraining')
    elif pretraining_config.get('epochs', 0) == 0:
        logging.info('Pretraining epochs set to 0, skipping pretraining')
    else:
        logging.info('Pretraining')
        mcmc_config = pretraining_config['mcmc'].copy()
        mix_log_amp = MixtureLogAmplitude(
            wave_function,
            mcmc_config.pop('hf_fraction'),
        )
        pre_mcmc = MetropolisHastings(mix_log_amp, **mcmc_config)

        key, subkey = jax.random.split(key)
        state, systems = pretrain(
            subkey,
            vmc.replace(sampler=pre_mcmc),
            state,
            systems,
            make_optimizer(pretraining_config['optimizer']),
            reparam_loss_scale=pretraining_config['reparam_loss_scale'],
            epochs=pretraining_config['epochs'],
            batch_size=pretraining_config['batch_size'],
            basis=pretraining_config['basis'],
            hf_config=pretraining_config['hf_config'],
            logger=logger,
        )

    # Thermalizing
    if not continue_training:
        logging.info('Thermalizing')
        key, subkey = jax.random.split(key)
        systems = thermalize(
            subkey,
            vmc,
            state,
            systems,
            n_epochs=vmc_config['thermalizing_epochs'],
            batch_size=vmc_config['batch_size'],
            logger=logger,
        )
        logger.checkpoint(state, systems)

    # VMC Training
    logging.info('VMC')
    key, subkey = jax.random.split(key)
    state, systems = train(
        subkey,
        vmc,
        state,
        systems,
        epochs=vmc_config['epochs'],
        batch_size=vmc_config['batch_size'],
        logger=logger,
        continue_training=continue_training,
        max_consecutive_fails=vmc_config['max_consecutive_fails'],
        max_total_rollbacks=vmc_config.get('max_total_rollbacks'),
        profiling_config=vmc_config.get('profiling'),
    )
    logger.checkpoint(state, systems)

    # Evaluation
    eval_epochs, eval_walkers_per_mol = parse_epochs_and_walkers_from_config(
        evaluation_config,
        systems,
    )
    if eval_epochs > 0:
        logging.info(
            f'Preparing systems for evaluation with {eval_walkers_per_mol} walkers per molecule',
        )
        key, subkey = jax.random.split(key)
        systems = systems.update_batch_size(
            subkey,
            eval_walkers_per_mol,
        )
        key, subkey = jax.random.split(key)
        logging.info('Thermalizing walkers for evaluation')
        systems = thermalize(
            subkey,
            vmc,
            state,
            systems,
            n_epochs=evaluation_config['thermalizing_epochs'],
            batch_size=vmc_config['batch_size'],
            logger=logger,
        )

        logging.info(f'Evaluating for {eval_epochs} epochs')
        key, subkey = jax.random.split(key)
        summary = evaluate(
            subkey,
            vmc.replace(reweighting_mode=ReweightingMode.NONE),
            state,
            systems,
            epochs=eval_epochs,
            mcmc_steps=evaluation_config['mcmc_steps'],
            logger=logger,
        )
        wandb.summary.update(summary)
    else:
        logging.info('Skipping evaluation')
    wandb.finish()

    logging.info('Done')
    return


def _validate_config(config):
    """Random checks against the config to ensure some common oopsies are not made."""

    # Check if KFAC is used with a dummy optimizer
    preconditioner_module = config['vmc']['preconditioner']['module']
    optimizer = config['vmc']['optimizer']
    if preconditioner_module == 'kfac':
        assert all(
            [
                len(optimizer) == 1,
                optimizer[0]['transform'] == 'sgd',
                optimizer[0]['learning_rate'] == 1.0,
            ],
        ), (
            'Encountered KFAC preconditioner with invalid optimizer. Remember to use SGD with lr=1.0 '
            'with KFAC preconditioner. '
        )

    evaluation_config = config['evaluation']
    assert (
        evaluation_config.get('epochs', -1) > 0
        or evaluation_config.get('total_samples_per_energy', -1) > 0
    ), (
        'Evaluation config invalid: either epochs or total_samples_per_energy must be positive.'
    )


@ex.capture
def get_config(seed, vmc, wave_function, pretraining, systems, logging, evaluation):
    return deepcopy(locals())


@ex.automain
def _main(
    seed,
    vmc,
    wave_function,
    pretraining,
    systems,
    logging,
    evaluation,
    xla_flags=None,
):
    if xla_flags is not None:
        if isinstance(xla_flags, str):
            os.environ['XLA_FLAGS'] = xla_flags
        elif isinstance(xla_flags, list):
            os.environ['XLA_FLAGS'] = ' '.join(xla_flags)
        else:
            raise ValueError('xla_flags must be a str or list of str')
    # A wrapper to have simpler yaml keys
    return main(seed, vmc, wave_function, pretraining, systems, logging, evaluation)


def cli_main():
    ex.run_commandline()
