import chex
import fiddle as fdl
import jax
import jax.numpy as jnp
import optax

from tabular_mvdrl import kernels
from tabular_mvdrl.agents.cat_td import CatTDTrainer
from tabular_mvdrl.envs import mrp as mrp_envs
from tabular_mvdrl.utils import support_init
from tabular_mvdrl.utils.discrete_distributions import (
    SquaredMMDMetric,
    SupremalMetric,
    Wasserstein2Metric,
)


def dirichlet_prior(key: chex.PRNGKey, n: int, alpha: float) -> chex.Array:
    return jax.random.dirichlet(key, alpha * jnp.ones(n))


def base(
    seed: int = 0,
    num_states: int = 10,
    reward_dim: int = 2,
    dirichlet_alpha: float = 1.0,
    env_seed: int | None = None,
) -> fdl.Config[CatTDTrainer]:
    env_seed = env_seed or seed + 1
    return fdl.Config(
        CatTDTrainer,
        env=fdl.Config(
            mrp_envs.MarkovRewardProcess.from_independent_priors,
            env_seed,
            num_states=num_states,
            transition_prior=fdl.Partial(
                dirichlet_prior, n=num_states, alpha=dirichlet_alpha
            ),
            cumulant_prior=fdl.Partial(
                jax.random.uniform, minval=-1.0, maxval=1.0, shape=(reward_dim,)
            ),
        ),
        signed=False,
        support_map_initializer=fdl.Config(
            support_init.repeated_map,
            fdl.Config(
                support_init.uniform_lattice,
                reward_dim,
                bins_per_dim=8,
                maxval=-15.0,
                minval=15.0,
            ),
            num_states,
        ),
        num_steps=10_000,
        seed=seed,
        write_metrics_interval_steps=100,
        eval_interval_steps=1000,
        # optim=fdl.Config(
        #     optax.sgd,
        #     learning_rate=fdl.Config(
        #         optax.schedules.polynomial_schedule,
        #         init_value=1,
        #         end_value=0.0,
        #         power=0.3,
        #         transition_steps=30_000,
        #         transition_begin=1_000,
        #     ),
        # ),
        optim=fdl.Config(optax.sgd, learning_rate=0.01),
        kernel=fdl.Partial(kernels.energy_distance, alpha=1.0),
        discount=0.9,
        return_metric=fdl.Config(
            SupremalMetric, base_metric=fdl.Config(Wasserstein2Metric, epsilon=1e-4)
        ),
        dsf_metric=fdl.Config(SquaredMMDMetric, fdl.Partial(kernels.energy_distance)),
    )


def rowland(
    seed: int = 0,
) -> fdl.Config[CatTDTrainer]:
    cfg = base(seed=seed, num_states=2, reward_dim=1)
    cfg.env = fdl.Config(mrp_envs.RowlandTwoStateMRP, seed)
    return cfg


def rowland_multivariate(
    seed: int = 0, reward_dim: int = 2
) -> fdl.Config[CatTDTrainer]:
    cumulants = jnp.zeros((2, reward_dim))
    cumulants = (
        cumulants.at[0, :]
        .set(2.0 * jnp.ones(reward_dim))
        .at[1, :]
        .set(-1.0 * jnp.ones(reward_dim))
    )
    cfg = base(seed=seed, num_states=2, reward_dim=reward_dim)
    cfg.env = fdl.Config(mrp_envs.RowlandTwoStateMRPMultivariateNaive, reward_dim)
    return cfg


# def extrapolated_rowland() -> fdl.Config[CatTDTrainer]:
#     # cfg = base(seed=0, num_states=4, reward_dim=2)
#     cfg = base()
#     cfg.env = fdl.Config(mrp_envs.ExtrapolatedRowlandMRP)
#     cfg.optim.learning_rate = 0.3
#     cfg.num_steps = 20_000
#     cfg.support_map_initializer = (
#         fdl.Config(
#             support_init.repeated_map,
#             support_init=fdl.Config(
#                 support_init.uniform_lattice,
#                 d=2,
#                 bins_per_dim=8,
#                 maxval=40.0,
#                 minval=-5.0,
#             ),
#             n=4,
#         ),
#     )
#     return cfg


### FIDDLERS


def extrapolated_rowland(cfg: fdl.Config[CatTDTrainer]):
    # cfg = base(seed=0, num_states=4, reward_dim=2)
    cfg.env = fdl.Config(mrp_envs.ExtrapolatedRowlandMRP)
    cfg.optim.learning_rate = 0.3
    cfg.num_steps = 20_000
    cfg.support_map_initializer = (
        fdl.Config(
            support_init.repeated_map,
            support_init=fdl.Config(
                support_init.uniform_lattice,
                d=2,
                bins_per_dim=8,
                maxval=40.0,
                minval=-5.0,
            ),
            n=4,
        ),
    )


def finite_horizon(cfg: fdl.Config[CatTDTrainer], horizon=4):
    env = fdl.build(cfg.env)
    cfg.env = fdl.Config(mrp_envs.FiniteHorizonTerminalRewardMRP, cfg.env, horizon)
    support = env.cumulants * cfg.discount ** (horizon - 1) / (1 - cfg.discount)
    cfg.support_map_initializer = fdl.Config(
        support_init.repeated_map,
        fdl.Config(support_init.explicit_support, support),
        horizon * env.num_states,
    )
    cfg.optim.learning_rate /= 3


def terminal_reward(num_states=30, num_terminal=10):
    cfg = base(num_states=num_states)
    env = fdl.build(cfg.env)
    cfg.env = fdl.Config(mrp_envs.TerminalRewardMRP, env, num_terminal)
    term_rew_env = fdl.build(cfg.env)
    rmin = jnp.min(term_rew_env.cumulants, axis=0)
    rmax = jnp.max(term_rew_env.cumulants, axis=0)
    cfg.support_map_initializer = fdl.Config(
        support_init.repeated_map,
        support_init=fdl.Config(
            support_init.uniform_lattice,
            d=2,
            bins_per_dim=8,
            maxval=rmin,
            minval=rmax,
        ),
        n=term_rew_env.num_states,
    )
    return cfg


def l1_kernel(cfg: fdl.Config[CatTDTrainer]):
    cfg.kernel = kernels.l1
