"""
This module contains optimization utils used to train the models.

Source: https://github.com/bunnech/jkonet

Functions
---------
- ``get_optimizer``
    Returns an Optax optimizer object based on the provided configuration.

- ``create_train_state``
    Creates an initial `TrainState` for the given model and optimizer.

- ``create_train_state_from_params``
    Creates a `TrainState` from existing model parameters.

- ``global_norm``
    Computes the global norm of gradients across a nested structure of tensors.

- ``clip_weights_icnn``
    Clip the weights of an Input Convex Neural Network (ICNN).
"""

from typing import Any, Callable, Dict

import chex
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.core import FrozenDict, freeze
from flax.training import train_state


def get_schedule(config: Dict[str, Any]) -> Callable[[int], float]:
    """
    Returns a learning rate schedule function based on configuration.
    """
    lr = config["lr"]
    scheduler = config.get("scheduler", None)

    if scheduler is None:
        return lambda step: lr  # constant LR

    elif scheduler == "cosine_decay":
        total_steps = config["total_steps"]
        return optax.cosine_decay_schedule(init_value=lr, decay_steps=total_steps)

    elif scheduler == "linear_warmup":
        warmup_steps = config["warmup_steps"]
        total_steps = config["total_steps"]
        end_value = config.get("end_value", 0.0)
        return optax.join_schedules(
            schedules=[
                optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps),
                optax.linear_schedule(init_value=lr, end_value=end_value, transition_steps=total_steps - warmup_steps),
            ],
            boundaries=[warmup_steps],
        )

    elif scheduler == "exponential_decay":
        decay_rate = config["decay_rate"]
        transition_steps = config["transition_steps"]
        return optax.exponential_decay(init_value=lr, transition_steps=transition_steps, decay_rate=decay_rate)

    elif scheduler == "cosine_warmup":
        warmup_steps = config["warmup_steps"]
        total_steps = config["total_steps"]
        cosine_schedule = optax.cosine_decay_schedule(init_value=lr, decay_steps=total_steps - warmup_steps)
        warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
        return optax.join_schedules(schedules=[warmup_schedule, cosine_schedule], boundaries=[warmup_steps])

    elif scheduler == "polynomial":
        return optax.polynomial_schedule(
            init_value=lr,
            end_value=config.get("end_value", 0.0),
            power=config.get("power", 1.0),
            transition_steps=config["total_steps"],
        )
    elif scheduler == "composite":
        decay_rate = config.get("decay_rate", 0.9)
        transition_steps = config.get("transition_steps", 1000)
        boundaries = config.get("boundaries", [1000])
        return optax.join_schedules(
            schedules=[
                optax.constant_schedule(lr),
                optax.exponential_decay(lr, transition_steps=transition_steps, decay_rate=decay_rate),
            ],
            boundaries=boundaries,
        )
    elif scheduler == "piecewise_constant_schedule":
        boundaries_and_scales = config.get("boundaries_and_scales", {1000: 0.5, 2000: 0.1})
        return optax.piecewise_constant_schedule(init_value=lr, boundaries_and_scales=boundaries_and_scales)
    else:
        raise NotImplementedError(f"Scheduler '{scheduler}' not supported.")


def get_optimizer(config: Dict[str, Any]) -> optax.GradientTransformation:
    """
    Returns an Optax optimizer object based on the provided configuration.

    Parameters
    ----------
    config : Dict[str, Any]
        Dictionary containing optimizer configuration. Expected keys are:

        - 'optimizer': The name of the optimizer ('Adam' or 'SGD').
        - 'lr': Learning rate for the optimizer.
        - 'beta1': Beta1 parameter for the Adam optimizer.
        - 'beta2': Beta2 parameter for the Adam optimizer.
        - 'eps': Epsilon parameter for the Adam optimizer.
        - 'grad_clip': Optional maximum global norm for gradient clipping.

    Returns
    -------
    optax.GradientTransformation
        The configured Optax optimizer object.

    Raises
    ------
    NotImplementedError
        If the optimizer name is not supported.
    """

    chex.assert_type([config["lr"], config["beta1"], config["beta2"], config["eps"]], [float] * 4)
    chex.assert_scalar_positive(config["lr"])
    chex.assert_scalar_positive(config["beta1"])
    chex.assert_scalar_positive(config["beta2"])
    chex.assert_scalar_positive(config["eps"])

    if "grad_clip" in config and config["grad_clip"] is not None:
        chex.assert_type(config["grad_clip"], float)
        chex.assert_scalar_positive(config["grad_clip"])

    optimizer_name = config["optimizer"].lower()
    lr_schedule = get_schedule(config)
    if optimizer_name == "adam":
        optimizer = optax.adam(
            learning_rate=lr_schedule,
            b1=config.get("beta1", 0.9),
            b2=config.get("beta2", 0.999),
            eps=config.get("eps", 1e-8),
        )

    elif optimizer_name == "adamw":
        optimizer = optax.adamw(
            learning_rate=lr_schedule,
            b1=config.get("beta1", 0.9),
            b2=config.get("beta2", 0.999),
            eps=config.get("eps", 1e-8),
            weight_decay=config.get("weight_decay", 1e-2),
        )

    elif optimizer_name == "sgd":
        optimizer = optax.sgd(
            learning_rate=lr_schedule, momentum=config.get("momentum", 0.0), nesterov=config.get("nesterov", False)
        )

    elif optimizer_name == "lamb":
        optimizer = optax.lamb(
            learning_rate=lr_schedule,
            b1=config.get("beta1", 0.9),
            b2=config.get("beta2", 0.999),
            eps=config.get("eps", 1e-6),
            weight_decay=config.get("weight_decay", 0.0),
        )

    else:
        raise NotImplementedError(f"Optimizer '{optimizer_name}' is not supported.")

    if config["grad_clip"]:
        optimizer = optax.chain(optax.clip_by_global_norm(config["grad_clip"]), optimizer)
    return optimizer


def create_train_state(
    rng: jax.random.PRNGKey, model: nn.Module, optimizer: optax.GradientTransformation, input_shape: int
) -> train_state.TrainState:
    """
    Creates an initial `TrainState` for the given model and optimizer.

    Parameters
    ----------
    rng : jax.random.PRNGKey
        Random key used for initializing the model parameters.
    model : nn.Module
        Flax model used for creating the initial state.
    optimizer : optax.GradientTransformation
        Optimizer object used for updating the model parameters.
    input_shape : int
        Shape of the input data used to initialize the model.

    Returns
    -------
    train_state.TrainState
        The initialized train state containing model parameters and optimizer.
    """

    params = model.init(rng, jnp.ones(input_shape))["params"]
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)


def create_train_state_from_params(
    model: nn.Module, params: Dict[str, Any], optimizer: optax.GradientTransformation
) -> train_state.TrainState:
    """
    Creates a `TrainState` from existing model parameters.

    Parameters
    ----------
    model : nn.Module
        Flax model used for creating the initial state.
    params : Dict[str, Any]
        Dictionary of model parameters.
    optimizer : optax.GradientTransformation
        Optimizer object used for updating the model parameters.

    Returns
    -------
    train_state.TrainState
        The train state containing the provided model parameters and optimizer.
    """
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)


def global_norm(updates: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """
    Computes the global norm of gradients across a nested structure of tensors.

    Parameters
    ----------
    updates : Dict[str, jnp.ndarray]
        Dictionary where values are tensors (e.g., gradients).

    Returns
    -------
    jnp.ndarray
        The global norm of the gradients.
    """
    return jnp.sqrt(sum([jnp.sum(jnp.square(x)) for x in jax.tree_util.tree_leaves(updates)]))


def clip_weights_icnn(params: FrozenDict) -> FrozenDict:
    """
    Clip the weights of an Input Convex Neural Network (ICNN).

    This function modifies the weights of the ICNN by clipping the values in kernels that start with 'Wz'
    to ensure they are non-negative. This is necessary to maintain the convexity property of the ICNN.

    Parameters
    ----------
    params : FrozenDict
        A frozen dictionary containing the parameters of the ICNN.

    Returns
    -------
    Any
        A frozen dictionary with the same structure as `params`, but with the relevant weights clipped to be non-negative.
    """
    params = params.unfreeze()
    for k in params.keys():
        if k.startswith("Wz"):
            params[k]["kernel"] = jnp.clip(params[k]["kernel"], a_min=0)

    return freeze(params)


def penalize_weights_icnn(params: FrozenDict) -> jnp.ndarray:
    """
    Compute a penalty for negative weights in an ICNN.

    This function calculates a penalty term based on the L2 norm of any negative values in the weights
    that start with 'Wz'. This penalty can be added to the loss function during training to encourage
    the network to maintain non-negative weights in those layers, which is important for the ICNN's convexity.

    Parameters
    ----------
    params : FrozenDict
        A frozen dictionary containing the parameters of the ICNN.

    Returns
    -------
    jnp.ndarray
        A scalar penalty value representing the sum of the L2 norms of the negative weights.
    """
    penalty = 0
    for k in params.keys():
        if k.startswith("Wz"):
            penalty += jnp.linalg.norm(jax.nn.relu(-params[k]["kernel"]))
    return penalty
