import enum
from functools import partial

import jax
import jax.random
import jax.numpy as jnp


class KernelPreconditionedCrankNicolson:
    def __init__(self, std=1):
        self.std = std

    @partial(jax.jit, static_argnames=['self'])
    def propose(self, key, param, step_size, drift=None):
        if drift is None:
            drift = jnp.sqrt(1 - jnp.square(step_size))
        return (
            drift * param
            + step_size * self.sample_noise(key, param.shape, dtype=param.dtype)
        )

    def sample_initial_param(self, key, shape, dtype=None):
        return self.sample_noise(key, shape, dtype=dtype)

    def sample_noise(self, key, shape, dtype=None):
        return jax.random.normal(key, shape, dtype) * self.std

    @staticmethod
    @jax.jit
    def get_acceptance(key, energy, new_energy):
        (
            key_accept,
        ) = jax.random.split(key, 1)
        delta_energy = new_energy - energy
        acceptance_probability = jnp.clip(jnp.exp(-delta_energy), max=1)
        accept = jax.random.bernoulli(key_accept, acceptance_probability).astype(jnp.bool)
        return accept

    @staticmethod
    @jax.jit
    def parallel_tempering_swap_probability(
        energy_hot, inverse_temperature_hot,
        energy_cold, inverse_temperature_cold,
    ):
        delta_energy = energy_cold - energy_hot
        delta_inverse_temperature = inverse_temperature_cold - inverse_temperature_hot
        swap_probability = jnp.clip(jnp.exp(delta_inverse_temperature * delta_energy), max=1)
        return swap_probability


class PTMCMC:
    def __init__(
        self,
        kernel, potential_fn,
        num_iters,
        *,
        shape=None,
        temperature=None, inverse_temperature=None,
        step_size=None, step_size_angle=None,
    ):
        kwargs = _normalize_common_kwargs(
            temperature=temperature, inverse_temperature=inverse_temperature,
            step_size=step_size, step_size_angle=step_size_angle,
        )
        self.kernel = kernel
        self.potential_fn = potential_fn
        self.num_iters = num_iters
        self.inverse_temperature = kwargs['inverse_temperature']
        self.num_chains = jnp.size(inverse_temperature)
        self.step_size = kwargs['step_size']
        self.sample_initial_param = jax.jit(jax.vmap(partial(kernel.sample_initial_param, shape=shape)))
        self.potential_fn_vmap = jax.jit(jax.vmap(potential_fn))
        self.update_chain_jit = jax.jit(jax.vmap(
            self.update_chain,
            in_axes=(0, None, None, 0, 0, 0, 0),
        ), static_argnames=['kernel', 'potential_fn'])

    def run(
        self,
        key,
    ):
        r"""
        Parallel Tempering Markov Chain Monte Carlo.

        Args
        ----
        key : jax.random.key
            Pseudo-random number generator key.
        kernel :
            Kernel class with expected static methods implemented.
        potential_fn : Callable, default None
            Computes the potential energy of a data sample.
        num_iters : int
            Number of sample updates and chain swaps (replica exchanges).

        Keyword Args
        ------------
        shape : tuple[int]
            Shape of a batch of data samples.
        temperature : jnp.Array, default None
            One-dimensnional array of temperatures :math:`T_k`. If passed, ``inverse_temperature``
            must be ``None``.
        inverse_temperature : jnp.Array, default None
            One-dimensional array of inverse temperatures :math:`\beta_k = 1/T_k`.
            If passed, ``temperature`` must be ``None``.
        step_size : jnp.Array, default None
            One-dimensional array of kernel step sizes :math:`s_k`. If passed, ``step_size_angle``
            must be ``None``.
        step_size_angle : jnp.Array, default None
            One-dimensional array of kernel step size angles :math:`\theta_k` such
            that :math:`s_k = \sin(\theta_k)`. If passed, ``step_size`` must be ``None``.
        """
        (
            key_initial_sample,
            key_loop,
        ) = jax.random.split(key, 2)

        param = self.sample_initial_param(
            jax.random.split(key_initial_sample, self.num_chains),
        )
        energy_with_unit_temperature = self.potential_fn_vmap(param)
        for _ in range(self.num_iters):
            (
                key_update_chains,
                key_swap_chains,
                key_loop,
            ) = jax.random.split(key_loop, 3)
            param, energy_with_unit_temperature = self.update_chain_jit(
                jax.random.split(key_update_chains, self.num_chains),
                self.kernel, self.potential_fn,
                self.inverse_temperature, self.step_size, param, energy_with_unit_temperature
            )
            for k in range(self.num_chains - 1):
                param, energy_with_unit_temperature = self.swap_chain(key, self.kernel, param, energy_with_unit_temperature, self.inverse_temperature, k)
        return param

    @staticmethod
    @partial(jax.jit, static_argnames=['kernel', 'potential_fn'])
    def update_chain(key, kernel, potential_fn, inverse_temperature, step_size, param, energy_with_unit_temperature):
        (
            key_proposer,
            key_accept,
        ) = jax.random.split(key, 2)
        new_param = kernel.propose(key_proposer, param, step_size)
        new_energy_with_unit_temperature = potential_fn(new_param)
        accept = kernel.get_acceptance(
            key_accept,
            inverse_temperature * energy_with_unit_temperature,
            inverse_temperature * new_energy_with_unit_temperature,
        )
        updated_energy_with_unit_temperature = jnp.where(accept, new_energy_with_unit_temperature, energy_with_unit_temperature)
        accept = accept.reshape(*accept.shape, *[1 for _ in range(param.ndim - accept.ndim)])
        updated_param = jnp.where(accept, new_param, param)
        return updated_param, updated_energy_with_unit_temperature

    @staticmethod
    @partial(jax.jit, static_argnames=['kernel'])
    def swap_chain(key, kernel, param, energy_with_unit_temperature, inverse_temperature, k):
        swap_probability = kernel.parallel_tempering_swap_probability(
            energy_hot=energy_with_unit_temperature[k], inverse_temperature_hot=inverse_temperature[k],
            energy_cold=energy_with_unit_temperature[k + 1], inverse_temperature_cold=inverse_temperature[k + 1],
        )
        key_chain_swap, key = jax.random.split(key)
        swap = jax.random.bernoulli(key_chain_swap, swap_probability).astype(jnp.bool)
        energy_temp_1_kp1 = jnp.where(swap, energy_with_unit_temperature[k], energy_with_unit_temperature[k + 1])
        energy_temp_1_k = jnp.where(swap, energy_with_unit_temperature[k + 1], energy_with_unit_temperature[k])
        energy_with_unit_temperature = energy_with_unit_temperature.at[k].set(energy_temp_1_k)
        energy_with_unit_temperature = energy_with_unit_temperature.at[k + 1].set(energy_temp_1_kp1)
        swap = swap.reshape(*swap.shape, *[1 for _ in range(param.ndim - 1 - swap.ndim)])
        param_kp1 = jnp.where(swap, param[k], param[k + 1])
        param_k = jnp.where(swap, param[k + 1], param[k])
        param = param.at[k].set(param_k)
        param = param.at[k + 1].set(param_kp1)
        return param, energy_with_unit_temperature


def _check_xor_args(keys, kwargs):
    has_one_key = any(k in kwargs for k in keys)
    if has_one_key:
        if sum(kwargs[k] is not None for k in keys) != 1:
            raise ValueError('Exactly one of the following must be specified: ' + ', '.join(keys))
        return has_one_key, next(k for k in keys if kwargs[k] is not None)
    return has_one_key, ''


def _normalize_common_kwargs(**kwargs):
    has_one_key, k = _check_xor_args(['temperature', 'inverse_temperature'], kwargs)
    if has_one_key:
        kwargs['inverse_temperature'] = kwargs[k]

    has_one_key, k = _check_xor_args(['step_size', 'step_size_angle'], kwargs)
    if has_one_key:
        kwargs['step_size'] = kwargs[k] if k == 'step_size' else jnp.sin(kwargs[k])
        if not jnp.all((0 < kwargs['step_size']) & (kwargs['step_size'] < 1)):
            if k == 'step_size':
                raise ValueError('step_size must be in (0, 1)')
            elif k == 'step_size_angle':
                raise ValueError('step_size_angle must be in (0, pi/2)')

    _check_xor_args(['sample_noise_fn'], kwargs)

    return kwargs


class FeynmannKacPotential(enum.StrEnum):
    DIFFERENCE = 'difference'
    MAX = 'max'
    SUM = 'sum'


class FeynmannKac:
    @classmethod
    def run(
        cls,
        key,
        initialize_samples_fn,
        reward_fn,
        potential,
        times,
        batch_size,
        ensemble_size,
        *,
        shape=None,
    ):
        try:
            potential = FeynmannKacPotential(potential)
        except ValueError:
            raise ValueError(f'potential must be one of {[e.value for e in FeynmannKacPotential]}, not {potential!r}.')

        (
            key_initial_sample,
            key_loop,
        ) = jax.random.split(key, 2)

        batch = jax.vmap(partial(initialize_samples_fn, shape=shape))(
            jax.random.split(key_initial_sample, ensemble_size),
        )
        aggregated_reward = jnp.zeros((batch_size, ensemble_size))
        key_reward, key_reward_fn = jax.random.split(key_reward_fn)
        if self.cfg.steering.potential is (conf.model.FeynmannKacPotential.DIFFERENCE, conf.model.FeynmannKacPotential.MAX):
            reward = aggregated_reward
        else:
            t0 = times[0]
            t1 = times[-1]
            reward = reward_fn(jax.random.split(key_reward, self.cfg.steering.ensemble_size), t0, t1, dt, batch)
        potential, aggregated_reward = self.feynmann_kac_potential_fn(reward, aggregated_reward)
        aggregated_potential = 1
        for time_step, (t0, t1) in enumerate(zip(times, times[1:])):
            aggregated_potential = aggregated_potential / potential
            key_resample, key_particle_resampling = jax.random.split(key_particle_resampling)
            batch = self.resampler(
                jax.random.split(key_resample, batch_size),
                batch,
                p=potential / potential.sum(axis=1, keepdims=True),
            )
            key_propose, key_propose_next_intermediate = jax.random.split(key_propose_next_intermediate)
            batch = self.propose_fn(jax.random.split(key_propose, self.cfg.steering.ensemble_size), t0, t1, dt, batch)
            key_reward, key_reward_fn = jax.random.split(key_reward_fn)
            reward = reward_fn(jax.random.split(key_reward, self.cfg.steering.ensemble_size), t0, t1, dt, batch)
            if t1 == times[-1] and self.cfg.steering.potential in (conf.model.FeynmannKacPotential.MAX, conf.model.FeynmannKacPotential.SUM):
                potential = jnp.exp(self.cfg.steering.tilt * reward)
            else:
                potential, aggregated_reward = self.feynmann_kac_potential_fn(reward, aggregated_reward)
        if self.cfg.steering.potential in (conf.model.FeynmannKacPotential.MAX, conf.model.FeynmannKacPotential.SUM):
            final_potential = potential / aggregated_potential
        else:
            final_potential = potential
        return jax.vmap(lambda b, i: b[i])(batch, jnp.argmax(final_potential, axis=1))
