import logging
import time
from typing import Any

import jax
import jax.numpy as jnp
import optax
import tqdm.auto as tqdm

from neural_pfaffian.logger import Logger, make_trace_controller
from neural_pfaffian.pretraining import Pretraining
from neural_pfaffian.systems import Systems, SystemsWithPretrainTarget
from neural_pfaffian.utils import batch
from neural_pfaffian.vmc import VMC, VMCState

TRAIN_POSTFIX_KEYS = ['E', 'E_std', 'excited/overlap_loss', 'grad', 'overlap_grad']


def thermalize(
    key: jax.Array,
    vmc: VMC,
    state: VMCState,
    systems: Systems,
    n_epochs: int,
    batch_size: int,
    logger: Logger,
):
    batches = list(map(Systems.merge, batch(systems, batch_size)))
    # Initialize batches
    key, subkey = jax.random.split(key)
    batch_keys = jax.random.split(subkey, len(batches))
    batches = list(map(vmc.init_systems, batch_keys, batches))

    for _ in tqdm.trange(n_epochs):
        for i in range(len(batches)):
            key, subkey = jax.random.split(key)
            batches[i], mcmc_aux = vmc.mcmc_step(
                subkey,
                state.sharded,
                batches[i].sharded,
            )
            log_data = jax.tree.map(lambda x: x.item(), mcmc_aux)
            logger.log(log_data, prefix='thermalize/mcmc')
    return Systems.merge(batches)


def pretrain(
    key: jax.Array,
    vmc: VMC,
    state: VMCState,
    systems: Systems,
    optimizer: optax.GradientTransformation,
    reparam_loss_scale: float,
    epochs: int,
    batch_size: int,
    basis: str,
    hf_config: dict,
    logger: Logger,
):
    pretrainer = Pretraining(vmc, optimizer, reparam_loss_scale)
    pre_state = pretrainer.init(state)

    # Initialize batches
    batches: list[SystemsWithPretrainTarget] = []
    for b in map(Systems.merge, batch(systems, batch_size)):
        key, subkey = jax.random.split(key)
        batches.append(pretrainer.init_systems(subkey, b.with_hf(basis, **hf_config)))

    last_time = time.perf_counter()
    step = 0
    for _epoch in tqdm.trange(epochs):
        for i in range(len(batches)):
            key, subkey = jax.random.split(key)
            # Update step
            pre_state, batches[i], log_data = pretrainer.step(
                subkey,
                pre_state.sharded,
                batches[i].sharded,
            )
            # Logging
            log_data = jax.tree.map(lambda x: x.item(), log_data)
            log_data['time_step'] = time.perf_counter() - last_time
            log_data['step'] = step
            step += 1
            logger.log(log_data, prefix='pretrain')
            last_time = time.perf_counter()
    return pre_state.vmc_state, SystemsWithPretrainTarget.merge(batches).to_systems


def train(
    key: jax.Array,
    vmc: VMC,
    state: VMCState,
    systems: Systems,
    epochs: int,
    batch_size: int,
    logger: Logger,
    max_consecutive_fails: int,
    max_total_rollbacks: int | None,
    *,
    continue_training: bool,
    profiling_config: dict[str, Any] | None = None,
):
    trace_controller = make_trace_controller(profiling_config, logger.log_directories())
    try:
        # Init systems
        key, subkey = jax.random.split(key)
        systems = vmc.init_systems(subkey, systems)
        num_walker_per_mol = systems.electrons.shape[0]
        if continue_training:
            logging.info('Loading checkpoint and resuming training')
            state, systems = logger.load_checkpoint(state, systems)

        #  In case of restarting from a pretrained state after an OOM
        # we may need to update the systems batch size
        key, subkey = jax.random.split(key)
        systems = systems.update_batch_size(subkey, num_walker_per_mol)

        epoch = int(state.epoch)
        if epoch >= epochs:
            logging.info('Training already done, skipping')
            return state, systems

        def _make_batches(_systems):
            return [
                Systems.merge(batch) for batch in Systems.safe_batch(_systems, batch_size)
            ]

        batches = _make_batches(systems)

        last_time = time.perf_counter()
        epoch_bar = tqdm.tqdm(total=epochs, initial=epoch)
        step_bar = tqdm.tqdm(total=len(batches)) if len(batches) > 1 else None
        consecutive_fails = 0
        total_rollbacks = 0
        while epoch < epochs:
            postfix_str = ''
            epoch_had_failure = False

            for i in range(len(batches)):
                key, subkey = jax.random.split(key)
                host_step = int(state.step.item())
                with trace_controller.trace(host_step):
                    state, batches[i], raw_log_data, kept_update = vmc.step(
                        subkey,
                        state.sharded,
                        batches[i].sharded,
                    )
                    log_data = jax.tree.map(lambda x: x.item(), raw_log_data)
                    kept_update_value = bool(kept_update.item())
                    current_step_value = int(state.step.item())

                log_data['time_step'] = time.perf_counter() - last_time
                log_data['epoch'] = epoch
                log_data['step'] = current_step_value + 1
                logger.log(log_data, prefix='train')

                postfix_str = _to_postfix_str(log_data)
                if step_bar:
                    step_bar.set_postfix_str(postfix_str)
                    step_bar.update(1)

                state = state.replace(
                    step=jnp.array(state.step + 1, dtype=state.step.dtype),
                )

                if not kept_update_value:
                    epoch_had_failure = True

                last_time = time.perf_counter()

            if step_bar:
                step_bar.reset()
            else:
                epoch_bar.set_postfix_str(postfix_str)

            consecutive_fails = consecutive_fails + 1 if epoch_had_failure else 0
            if consecutive_fails >= max_consecutive_fails:
                logging.warning(
                    f'Maximum number of consecutive fails {max_consecutive_fails} reached, '
                    'rolling back to last checkpoint and restarting epoch',
                )
                state, systems = logger.rollback(state, systems)
                total_rollbacks += 1
                if (
                    max_total_rollbacks is not None
                    and total_rollbacks > max_total_rollbacks
                ):
                    msg = (
                        'Maximum number of rollbacks '
                        f'({max_total_rollbacks}) exceeded after {total_rollbacks} rollbacks; '
                        'aborting training.'
                    )
                    logging.error(msg)
                    raise RuntimeError(msg)
                batches = _make_batches(systems)
                logging.warning(
                    f'Rolled back to epoch {int(state.epoch)}',
                )
                epoch = int(state.epoch)
                epoch_bar.n = epoch
                epoch_bar.refresh()

            if logger.should_save_checkpoint(epoch) and consecutive_fails == 0:
                logger.checkpoint(state, Systems.merge(batches))

            logger.reschedule_hook(state, batches)
            epoch += 1
            epoch_bar.update(1)
            state = state.replace(epoch=jnp.array(epoch, dtype=state.epoch.dtype))

        return state, Systems.merge(batches)
    finally:
        trace_controller.close()


def _to_postfix_str(log_data) -> str:
    postfix_data = {k: v for k, v in log_data.items() if k in TRAIN_POSTFIX_KEYS}
    return ', '.join(f'{(f"{k}={v:.4g}"):>{len(k) + 8}}' for k, v in postfix_data.items())
