import chex
import fiddle as fdl
import jax
import jax.numpy as jnp
import optax

from tabular_mvdrl import kernels
from tabular_mvdrl.agents.cat_projected_td import CatProjectedTDTrainer
from tabular_mvdrl.envs import mrp as mrp_envs
from tabular_mvdrl.utils import support_init
from tabular_mvdrl.utils.discrete_distributions import (
    SquaredMMDMetric,
    SupremalMetric,
    Wasserstein2Metric,
)


def dirichlet_prior(key: chex.PRNGKey, n: int, alpha: float) -> chex.Array:
    return jax.random.dirichlet(key, alpha * jnp.ones(n))


def base(
    seed: int = 0,
    num_states: int = 100,
    reward_dim: int = 2,
    dirichlet_alpha: float = 1.0,
    env_seed: int | None = None,
) -> fdl.Config[CatProjectedTDTrainer]:
    env_seed = env_seed or seed + 1
    return fdl.Config(
        CatProjectedTDTrainer,
        env=fdl.Config(
            mrp_envs.MarkovRewardProcess.from_independent_priors,
            env_seed,
            num_states=num_states,
            transition_prior=fdl.Partial(
                dirichlet_prior, n=num_states, alpha=dirichlet_alpha
            ),
            cumulant_prior=fdl.Partial(
                jax.random.uniform, minval=-1.0, maxval=1.0, shape=(reward_dim,)
            ),
        ),
        support_map_initializer=fdl.Config(
            support_init.repeated_map,
            fdl.Config(
                support_init.uniform_lattice,
                reward_dim,
                bins_per_dim=16,
                maxval=10.0,
                minval=0.0,
                # minval=-10.0,
            ),
            num_states,
        ),
        # num_steps=3000,
        num_steps=5000,
        seed=seed,
        write_metrics_interval_steps=100,
        eval_interval_steps=1000,
        optim=fdl.Config(
            optax.sgd,
            learning_rate=fdl.Config(
                optax.schedules.polynomial_schedule,
                # init_value=1e-2,
                init_value=1e-1,
                end_value=0.0,
                power=0.6,
                transition_steps=30_000,
                transition_begin=1_000,
            ),
        ),
        kernel=fdl.Partial(
            kernels.energy_distance, alpha=1.0, zero=jnp.zeros(reward_dim)
        ),
        discount=0.9,
        return_metric=fdl.Config(
            SupremalMetric, base_metric=fdl.Config(Wasserstein2Metric, epsilon=1e-4)
        ),
        dsf_metric=fdl.Config(SquaredMMDMetric, fdl.Partial(kernels.energy_distance)),
        ewp_steps=-1,
    )


def rowland(
    seed: int = 0,
) -> fdl.Config[CatProjectedTDTrainer]:
    cfg = base(seed=seed, num_states=2, reward_dim=1)
    cfg.env = fdl.Config(mrp_envs.RowlandTwoStateMRP, seed)
    return cfg


def rowland_multivariate(
    seed: int = 0, reward_dim: int = 2
) -> fdl.Config[CatProjectedTDTrainer]:
    cumulants = jnp.zeros((2, reward_dim))
    cumulants = (
        cumulants.at[0, :]
        .set(2.0 * jnp.ones(reward_dim))
        .at[1, :]
        .set(-1.0 * jnp.ones(reward_dim))
    )
    cfg = base(seed=seed, num_states=2, reward_dim=reward_dim)
    cfg.env = fdl.Config(mrp_envs.RowlandTwoStateMRPMultivariateNaive, reward_dim)
    return cfg


def extrapolated_rowland() -> fdl.Config[CatProjectedTDTrainer]:
    cfg = base(seed=0, num_states=4, reward_dim=2)
    cfg.env = fdl.Config(mrp_envs.ExtrapolatedRowlandMRP)
    # cfg.optim.learning_rate = 0.3
    return cfg


### FIDDLERS


def finite_horizon(cfg: fdl.Config[CatProjectedTDTrainer], horizon=4):
    env = fdl.build(cfg.env)
    cfg.env = fdl.Config(mrp_envs.FiniteHorizonTerminalRewardMRP, cfg.env, horizon)
    support = env.cumulants * cfg.discount ** (horizon - 1) / (1 - cfg.discount)
    cfg.support_map_initializer = fdl.Config(
        support_init.repeated_map,
        fdl.Config(support_init.explicit_support, support),
        horizon * env.num_states,
    )
    # cfg.optim.learning_rate /= 3


def terminal_reward(reward_dim=2, num_states=30, num_terminal=10):
    cfg = base(reward_dim=reward_dim, num_states=num_states)
    env = fdl.build(cfg.env)
    cfg.env = fdl.Config(mrp_envs.TerminalRewardMRP, env, num_terminal)
    term_rew_env = fdl.build(cfg.env)
    rmin = jnp.min(term_rew_env.cumulants, axis=0)
    rmax = jnp.max(term_rew_env.cumulants, axis=0)
    cfg.support_map_initializer = fdl.Config(
        support_init.repeated_map,
        support_init=fdl.Config(
            support_init.uniform_lattice,
            d=2,
            bins_per_dim=8,
            maxval=rmax,
            minval=rmin,
        ),
        n=term_rew_env.num_states,
    )
    # cfg.optim.learning_rate /= 3
    return cfg


def l1_kernel(cfg: fdl.Config[CatProjectedTDTrainer]):
    cfg.kernel = kernels.l1
