import chex
import fiddle as fdl
import jax
import jax.numpy as jnp
import optax

from tabular_mvdrl import kernels
from tabular_mvdrl.agents.ewp_td import EWPTDTrainer
from tabular_mvdrl.envs import mrp as mrp_envs
from tabular_mvdrl.utils.discrete_distributions import (
    SquaredMMDMetric,
    SupremalMetric,
    Wasserstein2Metric,
)


def dirichlet_prior(key: chex.PRNGKey, n: int, alpha: float) -> chex.Array:
    return jax.random.dirichlet(key, alpha * jnp.ones(n))


def base(
    seed: int = 0,
    num_states: int = 10,
    reward_dim: int = 2,
    dirichlet_alpha: float = 1.0,
    env_seed: int | None = None,
) -> fdl.Config[EWPTDTrainer]:
    env_seed = env_seed or seed + 1
    return fdl.Config(
        EWPTDTrainer,
        env=fdl.Config(
            mrp_envs.MarkovRewardProcess.from_independent_priors,
            env_seed,
            num_states=num_states,
            transition_prior=fdl.Partial(
                dirichlet_prior, n=num_states, alpha=dirichlet_alpha
            ),
            cumulant_prior=fdl.Partial(
                jax.random.uniform, minval=-1.0, maxval=1.0, shape=(reward_dim,)
            ),
        ),
        num_steps=5_000,
        seed=seed,
        write_metrics_interval_steps=100,
        eval_interval_steps=1000,
        num_atoms=256,
        optim=fdl.Config(
            optax.sgd,
            learning_rate=fdl.Config(
                optax.schedules.polynomial_schedule,
                # init_value=1e-2,
                init_value=1e-1,
                end_value=0.0,
                power=0.6,
                transition_steps=30_000,
                transition_begin=1_000,
            ),
        ),
        kernel=fdl.Partial(kernels.energy_distance, alpha=1.0),
        discount=0.9,
        return_metric=fdl.Config(
            SupremalMetric, base_metric=fdl.Config(Wasserstein2Metric, epsilon=1e-4)
        ),
        dsf_metric=fdl.Config(SquaredMMDMetric, fdl.Partial(kernels.energy_distance)),
    )


def rowland(
    seed: int = 0,
) -> fdl.Config[EWPTDTrainer]:
    cfg = base(seed=seed, num_states=2, reward_dim=1)
    cfg.env = fdl.Config(mrp_envs.RowlandTwoStateMRP, seed)
    return cfg


def rowland_multivariate(
    seed: int = 0, reward_dim: int = 2
) -> fdl.Config[EWPTDTrainer]:
    cumulants = jnp.zeros((2, reward_dim))
    cumulants = (
        cumulants.at[0, :]
        .set(2.0 * jnp.ones(reward_dim))
        .at[1, :]
        .set(-1.0 * jnp.ones(reward_dim))
    )
    cfg = base(seed=seed, num_states=2, reward_dim=reward_dim)
    cfg.env = fdl.Config(mrp_envs.RowlandTwoStateMRPMultivariateNaive, reward_dim)
    return cfg


# def extrapolated_rowland() -> fdl.Config[EWPTDTrainer]:
#     cfg = base(seed=0, num_states=4, reward_dim=2)
#     cfg.env = fdl.Config(mrp_envs.ExtrapolatedRowlandMRP)
#     cfg.num_atoms = 64
#     cfg.optim.learning_rate = 30.0
#     return cfg


### FIDDLERS


def extrapolated_rowland(cfg: fdl.Config[EWPTDTrainer]):
    # cfg = base(seed=0, num_states=4, reward_dim=2)
    cfg.env = fdl.Config(mrp_envs.ExtrapolatedRowlandMRP)
    cfg.num_atoms = 64
    # cfg.optim.learning_rate = 30.0


def finite_horizon(cfg: fdl.Config[EWPTDTrainer], horizon=4):
    cfg.env = fdl.Config(mrp_envs.FiniteHorizonTerminalRewardMRP, cfg.env, horizon)


def terminal_reward(reward_dim=2, num_states=30, num_terminal=10):
    cfg = base(reward_dim=reward_dim, num_states=num_states)
    env = fdl.build(cfg.env)
    cfg.env = fdl.Config(mrp_envs.TerminalRewardMRP, env, num_terminal)
    return cfg


def l1_kernel(cfg: fdl.Config[EWPTDTrainer]):
    cfg.kernel = kernels.l1
