from collections import OrderedDict
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from typing import Any, Callable, Optional, Mapping
from jax import Array, config
from flax.training import train_state
from flax.core.frozen_dict import FrozenDict



def batch_mul(a: Array, b: Array) -> Array:
    return jax.vmap(lambda a, b: a * b)(a, b)



def loss_fn(
    params: FrozenDict[str, Any],
    apply_fn: Callable,
    x: jnp.ndarray,
    x_init: jnp.ndarray,
    features: Optional[jnp.ndarray],
    t: jnp.ndarray,
) -> jnp.ndarray:
    
    x0 = x_init
    x1 = x
    xt = (1 - t) * x0 + t * x1
    ut = x1 - x0

    if features is not None and getattr(features, "size", 0) != 0:
        pred = apply_fn(params, xt, features, t, training=True)
    else:
        pred = apply_fn(params, xt, t)
    loss = jnp.sum(0.5 * pred ** 2 - batch_mul(ut, pred))
    # loss = ((pred - ut) ** 2).sum(axis=-1).mean()
    return loss



@jax.jit
def train_step(
    state: train_state.TrainState,
    x: jnp.ndarray,
    x_init: jnp.ndarray,
    t: jnp.ndarray,
    features: Optional[jnp.ndarray] = None,
) -> tuple[jnp.ndarray, train_state.TrainState]:

    loss, grads = jax.value_and_grad(loss_fn)(state.params, state.apply_fn, x, x_init, features, t)
    new_state = state.apply_gradients(grads=grads)
    return loss, new_state



def get_lr_optimizer(
    num_steps: int, 
    learning_rate: float, 
    min_learning_rate: float = None, 
    clip: float = None, 
    schedule: str = "constant", 
    weight_decay: float = 0.0
) -> optax.GradientTransformation:
    
    if schedule == "constant":
        if clip is not None:
            return optax.chain(
                optax.clip(max_delta=clip), 
                optax.adamw(learning_rate, weight_decay=weight_decay)  # Assuming no weight decay for constant schedule
                )
        else:
            return optax.adamw(learning_rate, weight_decay=weight_decay)  # Assuming no weight decay for constant schedule
        
    elif schedule == "cosine":
        if min_learning_rate is None:
            raise ValueError("min_learning_rate must be provided for cosine schedule")
        if clip is not None:
            return optax.chain(
                optax.clip(max_delta=clip),
                optax.adamw(
                    learning_rate = optax.cosine_decay_schedule(learning_rate, num_steps, min_learning_rate / learning_rate),
                    weight_decay = weight_decay  # Assuming no weight decay for cosine schedule
                ),
            )
        else:
            return optax.adamw(
                learning_rate = optax.cosine_decay_schedule(learning_rate, num_steps, min_learning_rate / learning_rate),
                weight_decay = weight_decay  # Assuming no weight decay for cosine schedule
            )
        
    else:
        raise ValueError(f"Unknown schedule: {schedule}")



def _make_tx_from_config(config: OrderedDict) -> optax.GradientTransformation:
    """Create optimizer from config."""

    num_samples = config["dataset"]["num_samples"]
    batch_size = config["trainer"]["batch"]
    learning_rate = config["trainer"]["learning_rate"]
    min_learning_rate = config["trainer"]["min_learning_rate"]
    clip = config["trainer"]["clip"]
    schedule = config["trainer"]["schedule"]
    weight_decay = config["trainer"]["weight_decay"]
    epochs = config["trainer"]["epochs"]
    
    num_steps = epochs * (num_samples // batch_size)
    return get_lr_optimizer(
        num_steps=num_steps,
        learning_rate=learning_rate,
        min_learning_rate=min_learning_rate,
        clip=clip,
        schedule=schedule,
        weight_decay=weight_decay
    )



def create_mb_train_state(
    rng: Any, 
    model: nn.Module, 
    config: OrderedDict,
) -> train_state.TrainState:
    """Create TrainState for task 'mb'."""

    cg_level = config["general"]["cg_level"]
    batch_size = config["trainer"]["batch"]

    x_dim = 1 if cg_level == "high" else 2
    x_like = jnp.ones((batch_size, x_dim))
    t_like = jnp.ones((batch_size, 1))
    params = model.init({"params": rng}, x_like, t_like)
    tx = _make_tx_from_config(config)

    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )



def create_ala2_train_state(
    rng: Any, 
    model: nn.Module, 
    example: dict,
    config: OrderedDict,
) -> train_state.TrainState:
    """Create TrainState for task 'ala2'."""
    params = model.init({"params": rng}, **example)
    tx = _make_tx_from_config(config)
        
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )