from enum import Enum
from typing import Any, Generic, NamedTuple, TypeVar

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import optax
from jax import Array

from .utils.tree_utils import tree_add, tree_mul, tree_dot, tree_sub


T = TypeVar('T')


class ProdigyState(NamedTuple, Generic[T]):
    init_params: T
    m: T
    v: T
    r: Array
    s: T
    d: Array
    step: Array


def prodigy(
    lr: optax.ScalarOrSchedule = 1.0,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    d0: float = 1e-6,
):
    T = TypeVar('T')

    def init_fn(params: T) -> ProdigyState[T]:
        return ProdigyState(
            init_params=jtu.tree_map(jnp.copy, params),
            m=jtu.tree_map(jnp.zeros_like, params),
            v=jtu.tree_map(jnp.zeros_like, params),
            r=jnp.zeros((), dtype=jnp.float32),
            s=jtu.tree_map(jnp.zeros_like, params),
            d=jnp.array(d0, dtype=jnp.float32),
            step=jnp.array(0, dtype=jnp.int32),
        )

    def update(
        updates: T, state: ProdigyState[T], params: T | None = None
    ) -> tuple[T, ProdigyState[T]]:
        gamma = lr(state.step) if callable(lr) else lr
        m = tree_add(
            tree_mul(state.m, beta1), tree_mul(updates, (1.0 - beta1) * state.d)
        )
        v = tree_add(
            tree_mul(state.v, beta2),
            tree_mul(jtu.tree_map(lambda x: x**2, updates), (1.0 - beta2) * state.d**2),
        )
        r = jnp.sqrt(beta2) * state.r + (
            1 - jnp.sqrt(beta2)
        ) * gamma * state.d**2 * tree_dot(updates, tree_sub(state.init_params, params))
        s = tree_add(
            tree_mul(state.s, beta2),
            tree_mul(updates, (1 - jnp.sqrt(beta2)) * gamma * state.d**2),
        )
        d_hat = r / jtu.tree_reduce(
            jnp.add, jtu.tree_map(lambda x: jnp.abs(x).sum(), s)
        )
        d = jnp.maximum(state.d, d_hat)
        return jtu.tree_map(
            lambda m_i, v_i: gamma * d * m_i / (jnp.sqrt(v_i) + d * eps), m, v
        ), ProdigyState(
            init_params=state.init_params,
            m=m,
            v=v,
            r=r,
            s=s,
            d=d,
            step=state.step + 1,
        )

    return optax.GradientTransformation(init_fn, update)  # type: ignore


def scale_by_trust_ratio_embeddings(
    min_norm: float = 0.0,
    trust_coefficient: float = 1.0,
    eps: float = 0.0,
) -> optax.GradientTransformation:
    """Scale by trust ratio but for embeddings were we don't want the norm
    over all parameters but just the last dimension.
    """

    def init_fn(params):
        del params
        return optax.ScaleByTrustRatioState()

    def update_fn(updates, state, params=None):
        if params is None:
            raise ValueError()

        def _scale_update(update, param):
            # Clip norms to minimum value, by default no clipping.
            param_norm = optax.safe_norm(param, min_norm, axis=-1, keepdims=True)
            update_norm = optax.safe_norm(update, min_norm, axis=-1, keepdims=True)
            trust_ratio = trust_coefficient * param_norm / (update_norm + eps)

            # If no minimum norm clipping is used
            # Set trust_ratio to 1 in case where parameters would never be updated.
            zero_norm = jnp.logical_or(param_norm == 0.0, update_norm == 0.0)
            safe_trust_ratio = jnp.where(
                zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio
            )

            return update * safe_trust_ratio

        updates = jax.tree_util.tree_map(_scale_update, updates, params)
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)


def decay_to_init(reg: float):
    def init(params):
        return params

    def update_fn(updates, state, params=None):
        updates = jtu.tree_map(
            lambda g, p, init: g + reg * (p - init), updates, params, state
        )
        return updates, state

    return optax.GradientTransformation(init, update_fn)


class Schedule(Enum):
    CONSTANT = 'constant'
    LINEAR = 'linear'
    EXPONENTIAL = 'exponential'
    COSINE = 'cosine'
    HYPERBOLIC = 'hyperbolic'


def hyperbolic_schedule(init_value: float, delay: float, decay: float):
    def schedule(step):
        return init_value / (1 + step / delay) ** decay

    return schedule


def make_schedule(
    schedule: Schedule | str, kwargs: dict[str, dict[str, Any]]
) -> optax.Schedule:
    schedule = Schedule(schedule)
    if schedule == Schedule.CONSTANT:
        return optax.constant_schedule(**kwargs['constant'])
    elif schedule == Schedule.LINEAR:
        return optax.linear_schedule(**kwargs['linear'])
    elif schedule == Schedule.EXPONENTIAL:
        return optax.exponential_decay(**kwargs['exponential'])
    elif schedule == Schedule.COSINE:
        return optax.cosine_decay_schedule(**kwargs['cosine'])
    elif schedule == Schedule.HYPERBOLIC:
        return hyperbolic_schedule(**kwargs['hyperbolic'])
    else:
        raise ValueError(f'Unknown schedule: {schedule}')


def prodigy_with_schedule(
    schedule: Schedule | str,
    schedule_args: dict[str, dict[str, Any]],
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    d0: float = 1e-6,
):
    return prodigy(make_schedule(schedule, schedule_args), beta1, beta2, eps, d0)


def prodigy_with_cosine(
    init_lr: float,
    num_steps: int,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    d0: float = 1e-6,
):
    return prodigy(
        optax.cosine_decay_schedule(init_lr, num_steps), beta1, beta2, eps, d0
    )


def make_optimizer(
    schedule: Schedule | str,
    schedule_args: dict[str, dict[str, Any]],
    transformations: list[tuple[str, tuple, dict]],
) -> optax.GradientTransformation:
    def get_cls(name):
        if name in globals():
            return globals()[name]
        else:
            return getattr(optax, name)

    return optax.chain(
        *[get_cls(name)(*args, **kwargs) for name, args, kwargs in transformations],
        optax.scale_by_schedule(make_schedule(schedule, schedule_args)),
        optax.scale(-1.0),
    )
