from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Dict, Literal, Protocol, Tuple

import jax.numpy as jnp
import numpy as onp
from jax import random, tree_util
from optax import OptState, join_schedules, schedules

from egxc.training.utils import get_adam_momentum, scale_adam_momentum, set_adam_momentum
from egxc.training.utils.ema import EMA
from egxc.utils.typing import NnParams, PRNGKey, PyTree


def reflect_momentum(
    current_momentum: PyTree,
    rejected_momentum: PyTree,
    scalar: float,
) -> PyTree:
    """
    Partially reflects/absorbs current momentum based on the delta with rejected momentum.

    The delta vector defines a "plane" (its orthogonal complement). This function
    decomposes current momentum into components parallel and perpendicular to delta,
    then scales the parallel component by the scalar.

    Args:
        current_momentum: First momentum (PyTree of parameter arrays)
        rejected_momentum: Second momentum (PyTree of parameter arrays)
        scalar: Controls reflection/absorption (-1 to 1)
                -1: full reflection (parallel component flipped)
                 0: projection onto plane (parallel component removed)
                 1: no change (identity)

    Returns:
        New momentum PyTree with the transformed values
    """
    # Flatten PyTrees to 1D vectors
    flat1, treedef = tree_util.tree_flatten(current_momentum)
    flat2, _ = tree_util.tree_flatten(rejected_momentum)

    v1 = jnp.concatenate([x.ravel() for x in flat1])
    v2 = jnp.concatenate([x.ravel() for x in flat2])

    # Compute delta vector
    delta = v2 - v1

    # Compute projection of v1 onto delta
    delta_norm_sq = jnp.dot(delta, delta)

    # Project v1 onto delta: proj = (v1 · delta / |delta|²) * delta
    v1_parallel_coeff = jnp.dot(v1, delta) / jnp.maximum(delta_norm_sq, 1e-12)
    v1_parallel = v1_parallel_coeff * delta
    v1_perp = v1 - v1_parallel

    # New momentum: keep perpendicular, scale parallel by scalar
    v_new = v1_perp + scalar * v1_parallel

    # Unflatten back to PyTree structure
    shapes = [x.shape for x in flat1]
    sizes = [x.size for x in flat1]
    split_indices = jnp.cumsum(jnp.array(sizes[:-1]))
    flat_new = jnp.split(v_new, split_indices)
    flat_new = [x.reshape(s) for x, s in zip(flat_new, shapes)]

    return tree_util.tree_unflatten(treedef, flat_new)


class CoolingFn(Protocol):
    """
    Should return the probability of accepting the new loss,
    if delta < 0 should return 1.0 i.e. always except.
    As t increases, the probability of accepting a loss increase should vanish,
    while at small t, it should be greater than zero.
    """

    def __call__(self, step: int, delta: float, std_dev: float) -> float: ...


def get_boltzmann_cooling_fn(T_0: float, time_scale: int) -> CoolingFn:
    def cooling_fn(step: int, delta: float, _) -> float:
        T = T_0 / (1 + step / time_scale)
        return onp.exp(-delta / T) if delta >= 0 else 1.0  # type: ignore

    return cooling_fn  # type: ignore


def adaptive_outlier_reject(
    warmup_steps: int,
    sigma_tol: float,
    min_sigma: float,
    max_sigma: float,
) -> CoolingFn:
    sigma_tol_schedule = join_schedules(
        (
            schedules.constant_schedule(1),
            schedules.linear_schedule(1, sigma_tol, warmup_steps // 2, 0),
            schedules.constant_schedule(sigma_tol),
        ),
        boundaries=(warmup_steps // 2, warmup_steps),
    )

    def cosine_decay_fn(x: float, scale: float) -> float:
        x = onp.clip(x, 0.0, scale)
        return 0.5 * (1 + onp.cos(onp.pi * x / scale))  # type: ignore

    def cooling_fn(step: int, delta: float, std_dev: float) -> float:
        std_dev = onp.clip(std_dev, min_sigma, max_sigma)
        scaled_delta = delta / std_dev
        tolerance = sigma_tol_schedule(step)
        half_tolerance = tolerance / 2.0
        if tolerance == 0.0 or not onp.isfinite(scaled_delta):
            return float(delta < 0.0)
        elif scaled_delta >= half_tolerance:
            shifted_delta = scaled_delta - half_tolerance
            return cosine_decay_fn(shifted_delta, half_tolerance)  # type: ignore
        else:
            return 1.0

    return cooling_fn


class RelativeLossChangeStatistics:
    """
    Dynamic / online estimation of the mean and standard deviation of the relative loss change.
    Where the relative loss change is defined as:
        delta_rel = (L_new - L_old) / L_old
    where L_new and L_old are the loss values at successive updates.
    """

    epsilon = 1e-12  # avoid division by zero

    def __init__(self, beta: float):
        self.__beta = beta
        self.__mean: float = 0.0  # mean relative loss change
        self.__mean_sq: float = 0.0  # E[delta_rel^2]
        self.__n_samples: int = 0

    def update(self, delta_rel_loss: float):
        assert onp.isfinite(delta_rel_loss), (
            'delta_rel_loss of an accepted update should be finite and not NaN or Inf'
        )
        self.__n_samples += 1
        self.__mean = self.__beta * self.__mean + (1 - self.__beta) * delta_rel_loss
        self.__mean_sq = (
            self.__beta * self.__mean_sq + (1 - self.__beta) * delta_rel_loss**2
        )

    @property
    def mean(self) -> float:
        return self.__mean

    @property
    def uncentered_standard_deviation(self) -> float:
        """
        Uncentered standard deviation is the standard deviation of the loss,
        but without subtracting the mean. This is useful as it captures the
        typical scale of the relative loss change from epoch to epoch.
        """
        correction = 1 - self.__beta**self.__n_samples
        corrected_mean_sq = self.__mean_sq / (correction + self.epsilon)
        return onp.sqrt(corrected_mean_sq)  # uncentered: sqrt(E[x^2])


@dataclass
class MetropolisTrainingStabilizerConfig:
    method: Literal['off', 'outlier', 'boltzmann']
    method_kwargs: Dict[str, Any]
    consecutive_rejections_threshold: int
    initial_tries: int
    momentum_scaling_on_consecutive_reject: float
    loss_statistics_beta: float
    reinit_during_tryouts: bool

    @classmethod
    def turn_off(cls) -> 'MetropolisTrainingStabilizerConfig':
        return cls(
            method='off',
            method_kwargs={},
            consecutive_rejections_threshold=1,
            initial_tries=1,
            momentum_scaling_on_consecutive_reject=1,
            loss_statistics_beta=1,
            reinit_during_tryouts=False,
        )

    def __post_init__(self):
        assert self.initial_tries > 0, 'initial_tries must be greater than 0'


class MetropolisTrainingStabilizer:
    """
    Metropolis-Hastings algorithm for training based on the epoch mean training loss.
    Note that despite returning the same parameters and optimizer state on reject,
    the PRNG key is split, so the next update acceptance will be different as well as the
    order in which the molecules are processed.
    """

    def __init__(
        self,
        params: NnParams,
        opt_state: Tuple[OptState, OptState, EMA],
        params_init_fn: Callable[[int], NnParams],
        cooling_fn: CoolingFn,
        config: MetropolisTrainingStabilizerConfig,
    ):
        self.__total_counter = 0
        self.__consecutive_reject_count = 0

        self.__initial_tries = config.initial_tries
        self.__initial_tries_counter = 0
        self.__params_init_fn = params_init_fn
        self.__initial_opt_state = deepcopy(opt_state)

        self.__cooling_fn = cooling_fn
        self.__previous_loss = float('inf')
        self.__previous_params = deepcopy(params)
        self.__previous_opt_state = deepcopy(opt_state)
        self.__rel_loss_change_stats = RelativeLossChangeStatistics(
            config.loss_statistics_beta
        )
        self.__consecutive_reject_thresh = config.consecutive_rejections_threshold
        self.__momentum_scaling_on_consecutive_reject = (
            config.momentum_scaling_on_consecutive_reject
        )

    @classmethod
    def create_from_dict(
        cls,
        params: NnParams,
        params_init_fn: Callable[[int], NnParams],
        opt_state: Tuple[OptState, OptState, EMA],
        config: MetropolisTrainingStabilizerConfig,
    ) -> 'MetropolisTrainingStabilizer':
        """Instantiate a stabilizer from config."""
        match config.method:
            case 'off':
                cooling_fn = lambda step, delta, std_dev: 1.0  # accept always
                return cls(params, opt_state, params_init_fn, cooling_fn, config)
            case 'outlier':
                cooling_fn = adaptive_outlier_reject(**config.method_kwargs)
            case 'boltzmann':
                cooling_fn = get_boltzmann_cooling_fn(**config.method_kwargs)
            case _:
                raise ValueError(f'Unknown Metropolis method: {config.method}')

        if not config.reinit_during_tryouts:
            init_params = deepcopy(params)
            params_init_fn = lambda _: init_params

        return MetropolisTrainingStabilizer(
            params,
            opt_state,
            cooling_fn=cooling_fn,
            params_init_fn=params_init_fn,
            config=config,
        )

    def __accept(
        self,
        epoch_train_loss: float,
        params: NnParams,
        opt_state: Tuple[OptState, OptState, EMA],
    ) -> Tuple[bool, NnParams, Tuple[OptState, OptState, EMA]]:
        self.__consecutive_reject_count = 0
        self.__previous_loss = epoch_train_loss
        self.__backup_params = deepcopy(self.__previous_params)
        self.__backup_opt_state = deepcopy(self.__previous_opt_state)
        self.__previous_params = deepcopy(params)
        self.__previous_opt_state = deepcopy(opt_state)
        return True, params, opt_state

    def __reject(self) -> Tuple[bool, NnParams, Tuple[OptState, OptState, EMA]]:
        self.__consecutive_reject_count += 1
        if self.__consecutive_reject_count >= self.__consecutive_reject_thresh:
            print('### Scaling momentum')
            opt1, opt2, ema = self.__previous_opt_state
            opt1 = scale_adam_momentum(
                opt1, self.__momentum_scaling_on_consecutive_reject
            )
            self.__previous_opt_state = (opt1, opt2, ema)
        if self.__consecutive_reject_count >= self.__consecutive_reject_thresh * 3:
            print('### Resetting to backup params and opt state')
            return False, self.__backup_params, self.__backup_opt_state
        return False, self.__previous_params, self.__previous_opt_state

    def __reflect_momentum(
        self, rejected_opt_state: Tuple[OptState, OptState, EMA]
    ) -> Tuple[OptState, OptState, EMA]:
        """Reflect the momentum of the previous optimizer state based on the rejected optimizer state."""
        opt1, opt2, ema = self.__previous_opt_state
        current_mom = get_adam_momentum(opt1)
        rejected_mom = get_adam_momentum(rejected_opt_state[0])
        new_momentum = reflect_momentum(
            current_mom, rejected_mom, self.__momentum_scaling_on_consecutive_reject
        )
        opt1 = set_adam_momentum(opt1, new_momentum)
        return (opt1, opt2, ema)

    def __initial_tryouts(
        self,
        epoch_train_loss: float,
        params: NnParams,
        opt_state: Tuple[OptState, OptState, EMA],
    ) -> Tuple[bool, NnParams, Tuple[OptState, OptState, EMA]]:
        self.__initial_tries_counter += 1
        print('############### Initial tryout ###############')
        print(f'Initial tryout {self.__initial_tries_counter} of {self.__initial_tries}')
        print(f'Loss current try: {epoch_train_loss:.6f}')
        print(f'Best so far: {self.__previous_loss:.6f}')
        if epoch_train_loss < self.__previous_loss:
            self.__previous_loss = epoch_train_loss
            self.__previous_params = deepcopy(params)
            self.__previous_opt_state = deepcopy(opt_state)
        if self.__initial_tries_counter == self.__initial_tries:
            return self.__accept(
                self.__previous_loss,
                self.__previous_params,
                self.__previous_opt_state,
            )
        else:
            return (
                False,
                self.__params_init_fn(self.__initial_tries_counter),
                self.__initial_opt_state,
            )

    def print_state(
        self,
        std_dev: float,
        epoch_train_loss: float,
        delta: float,
        relative_delta: float,
        prob: float,
    ) -> None:
        scaled_delta = relative_delta / std_dev
        print(f'Mean relative loss change: {self.__rel_loss_change_stats.mean:.3f}')
        print(f'Standard deviation: {std_dev:.3f}')
        print(
            f'Loss of proposed update {self.__previous_loss:.6f} -> {epoch_train_loss:.6f}'
        )
        print(f'Absolute loss change: {delta:.6f}')
        print(f'Relative loss change: {relative_delta:.6f}')
        print(f'Relative loss change in units of standard deviation: {scaled_delta:.3f}')
        print(f'Accept probability: {prob:.3f}', flush=True)

    def propose_update(
        self,
        epoch_train_loss: float,
        params: NnParams,
        prng_key: PRNGKey,
        opt_state: Tuple[OptState, OptState, EMA],
    ) -> Tuple[bool, NnParams, Tuple[OptState, OptState, EMA], PRNGKey]:
        self.__total_counter += 1

        if self.__initial_tries_counter < self.__initial_tries:
            return *self.__initial_tryouts(epoch_train_loss, params, opt_state), prng_key
        else:
            delta = epoch_train_loss - self.__previous_loss
            relative_delta = delta / self.__previous_loss
            std_dev = self.__rel_loss_change_stats.uncentered_standard_deviation
            prob = self.__cooling_fn(self.__total_counter, relative_delta, std_dev)
            prng_key, split_key = random.split(prng_key)
            if random.uniform(split_key) < prob:
                self.__rel_loss_change_stats.update(relative_delta)
                return *self.__accept(epoch_train_loss, params, opt_state), prng_key
            else:
                print('##### Metropolis training stabilizer state on rejection #####')
                self.print_state(std_dev, epoch_train_loss, delta, relative_delta, prob)
                return *self.__reject(), prng_key
