import jax
import optax
import wandb
import yaml
import jax.numpy as jnp
import flax.linen as nn
import orbax.checkpoint as obcheckpoint
from flax import struct
from functools import partial
from evosax import ParameterReshaper
from flax.training.train_state import TrainState as TrainState_flax


from envs.rollout import rollout
from models.agents import ActorCriticCont as ActorCritic
from models.mirror_map import param_tuner as param_tuner_init, MPO_map
from rl_training.ppo import TrainState
from rlhf_agents.losses import loss_dict
from envs.env_utils import get_env
from utils.logging import init_logger


NUM_STEPS = 1000


@struct.dataclass
class Logits:
    loss: jnp.ndarray
    ref_logit_chosen: jnp.ndarray
    ref_logit_rejected: jnp.ndarray
    logit_chosen: jnp.ndarray
    logit_rejected: jnp.ndarray


def make_train(config, env, env_params):

    def train(
        rng,
        dataset,
        mmap_state=None,
        mmap_params=None,
        start_agent=None,
        norm_stats=None,
    ):
        # Init agent
        network = ActorCritic(
            env.action_space(env_params).shape[0],
            activation=config["ACTIVATION"],
            normalize=config["NORMALIZE_OBS"],
        )
        config["TOT_NUM_UPDATES"] = (
            config["DATASET_SIZE"] * config["UPDATE_EPOCHS"] // config["MINIBATCH_SIZE"]
        )
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(
                    learning_rate=optax.linear_schedule(
                        init_value=config["LR"],
                        end_value=config["LR_END"],
                        transition_steps=config["TOT_NUM_UPDATES"],
                        transition_begin=config["TOT_NUM_UPDATES"] // 10,
                    ),
                    eps=1e-5,
                ),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )

        if start_agent:
            train_state = TrainState.create(
                apply_fn=network.apply,
                params=start_agent[0],
                norm_stats=start_agent[1],
                tx=tx,
            )
        else:
            rng, _rng = jax.random.split(rng)
            init_x = jnp.zeros(env.observation_space(env_params).shape)
            train_state = TrainState.create(
                apply_fn=network.apply,
                params=network.init(_rng, init_x)["params"],
                norm_stats=norm_stats,  # norm_stats,
                tx=tx,
            )

        aggregate_fun = jax.vmap(
            lambda logits, obs: jnp.sum(
                jnp.where(obs == -100, jnp.zeros_like(logits), logits)
            )
        )

        # Init loss fun
        if config["STAGES"] == 1:
            loss_fun = loss_dict[config["LOSS_TYPE"]]
            if config["LOSS_TYPE"] == "mpo":
                mmap = mmap_state.train_state.replace(params=mmap_params)
                loss_fun = partial(loss_fun, mirror_map=mmap)
            loss_fun = partial(loss_fun, config=config)
        else:
            loss_fun = loss_dict["sft"]
            loss_fun_stage_2 = loss_dict[config["LOSS_TYPE"]]
            if config["LOSS_TYPE"] == "mpo":
                mmap = mmap_state.train_state.replace(params=mmap_params)
                loss_fun_stage_2 = partial(loss_fun_stage_2, mirror_map=mmap)
            loss_fun_stage_2 = partial(loss_fun_stage_2, config=config)

        def update_epoch_no_ref(carry_state, n_epoch):
            train_state, rng = carry_state
            rng, _rng = jax.random.split(rng)

            def _update_minibatch(carry, minibatch):
                train_state, n_minibatch = carry

                def loss_fn(params, minibatch):
                    pi, _ = train_state.apply_fn(
                        {"params": params, "norm_stats": train_state.norm_stats}, minibatch.obs[:, 0]
                    )
                    logit_chosen = pi.log_prob(minibatch.actions[:, 0])
                    logit_chosen = aggregate_fun(
                        logit_chosen, minibatch.obs[:, 0, :, 0]
                    )
                    pi, _ = train_state.apply_fn(
                        {"params": params, "norm_stats": train_state.norm_stats}, minibatch.obs[:, 1]
                    )
                    logit_rejected = pi.log_prob(minibatch.actions[:, 1])
                    logit_rejected = aggregate_fun(
                        logit_rejected, minibatch.obs[:, 1, :, 1]
                    )
                    ref_logit_chosen = jnp.ones_like(logit_chosen)
                    ref_logit_rejected = jnp.ones_like(logit_rejected)
                    time = (
                        n_minibatch
                        + n_epoch * config["DATASET_SIZE"] // config["MINIBATCH_SIZE"]
                    ) / config["TOT_NUM_UPDATES"]
                    length_chosen = (
                        jnp.sum(minibatch.obs[:, 0, :, 0] != -100, axis=-1) + 1e-6
                    )
                    length_rejected = (
                        jnp.sum(minibatch.obs[:, 1, :, 0] != -100, axis=-1) + 1e-6
                    )
                    return jax.vmap(loss_fun, in_axes=(0, 0, 0, 0, 0, 0, None))(
                        logit_chosen,
                        logit_rejected,
                        ref_logit_chosen,
                        ref_logit_rejected,
                        length_chosen,
                        length_rejected,
                        time,
                    ).mean(), (
                        logit_chosen / length_chosen,
                        logit_rejected / length_rejected,
                        ref_logit_chosen / length_chosen,
                        ref_logit_rejected / length_rejected,
                    )

                grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
                (loss, info), grad = grad_fn(train_state.params, minibatch)
                train_state = train_state.apply_gradients(grads=grad)

                if config["TRACKING"]:
                    metrics = Logits(
                        loss,
                        info[2],
                        info[3],
                        info[0],
                        info[1],
                    )
                else:
                    metrics = loss

                return (train_state, n_minibatch + 1), metrics

            batch_size = dataset.obs.shape[0]
            permutation = jax.random.permutation(_rng, batch_size)
            shuffled_dataset = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), dataset
            )

            num_minibatches = batch_size // config["MINIBATCH_SIZE"]
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])),
                shuffled_dataset,
            )
            n_minibatch = 0
            (train_state, _), metrics = jax.lax.scan(
                _update_minibatch, (train_state, n_minibatch), minibatches
            )
            return (train_state, rng), metrics

        def update_epoch_with_ref(carry_state, n_epoch):
            train_state, rng, ref_agent = carry_state
            rng, _rng = jax.random.split(rng)

            def _update_minibatch(carry, minibatch):
                train_state, n_minibatch = carry

                def loss_fn(params, minibatch):
                    pi, _ = train_state.apply_fn(
                        {"params": params, "norm_stats": train_state.norm_stats}, minibatch.obs[:, 0]
                    )
                    logit_chosen = pi.log_prob(minibatch.actions[:, 0])
                    logit_chosen = aggregate_fun(
                        logit_chosen, minibatch.obs[:, 0, :, 0]
                    )
                    pi, _ = train_state.apply_fn(
                        {"params": params, "norm_stats": train_state.norm_stats}, minibatch.obs[:, 1]
                    )
                    logit_rejected = pi.log_prob(minibatch.actions[:, 1])
                    logit_rejected = aggregate_fun(
                        logit_rejected, minibatch.obs[:, 1, :, 1]
                    )
                    pi, _ = ref_agent.apply_fn(
                        {"params": ref_agent.params}, minibatch.obs[:, 0]
                    )
                    ref_logit_chosen = pi.log_prob(minibatch.actions[:, 0])
                    ref_logit_chosen = aggregate_fun(
                        ref_logit_chosen, minibatch.obs[:, 0, :, 0]
                    )
                    pi, _ = ref_agent.apply_fn(
                        {"params": ref_agent.params}, minibatch.obs[:, 1]
                    )
                    ref_logit_rejected = pi.log_prob(minibatch.actions[:, 1])
                    ref_logit_rejected = aggregate_fun(
                        ref_logit_rejected, minibatch.obs[:, 1, :, 1]
                    )
                    time = (
                        n_minibatch
                        + n_epoch * config["DATASET_SIZE"] // config["MINIBATCH_SIZE"]
                    ) / config["TOT_NUM_UPDATES"]
                    length_chosen = (
                        jnp.sum(minibatch.obs[:, 0, :, 0] != -100, axis=-1) + 1e-6
                    )
                    length_rejected = (
                        jnp.sum(minibatch.obs[:, 1, :, 0] != -100, axis=-1) + 1e-6
                    )
                    return jax.vmap(loss_fun_stage_2, in_axes=(0, 0, 0, 0, 0, 0, None))(
                        logit_chosen,
                        logit_rejected,
                        ref_logit_chosen,
                        ref_logit_rejected,
                        length_chosen,
                        length_rejected,
                        time,
                    ).mean(), (
                        logit_chosen / length_chosen,
                        logit_rejected / length_rejected,
                        ref_logit_chosen / length_chosen,
                        ref_logit_rejected / length_rejected,
                    )

                grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
                (loss, info), grad = grad_fn(train_state.params, minibatch)
                train_state = train_state.apply_gradients(grads=grad)

                if config["TRACKING"]:
                    metrics = Logits(
                        loss,
                        info[2],
                        info[3],
                        info[0],
                        info[1],
                    )
                else:
                    metrics = loss

                return (train_state, n_minibatch + 1), metrics

            batch_size = dataset.obs.shape[0]
            permutation = jax.random.permutation(_rng, batch_size)
            shuffled_dataset = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), dataset
            )

            num_minibatches = batch_size // config["MINIBATCH_SIZE"]
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])),
                shuffled_dataset,
            )
            n_minibatch = 0
            (train_state, _), metrics = jax.lax.scan(
                _update_minibatch, (train_state, n_minibatch), minibatches
            )
            return (train_state, rng, ref_agent), metrics

        # Two-stage
        if config["STAGES"] == 1:
            carry_state = (train_state, rng)
            (train_state, _), metrics = jax.lax.scan(
                update_epoch_no_ref, carry_state, jnp.arange(config["UPDATE_EPOCHS"])
            )
        else:
            carry_state = (train_state, rng)
            (train_state, _), metrics1 = jax.lax.scan(
                update_epoch_no_ref,
                carry_state,
                jnp.arange(config["UPDATE_EPOCHS"] - 2),
            )
            ref_agent = jax.tree_util.tree_map(lambda x: x.copy(), train_state)
            carry_state = (train_state, rng, ref_agent)
            (train_state, _, _), metrics2 = jax.lax.scan(
                update_epoch_with_ref, carry_state, jnp.arange(2)
            )
            metrics = jax.tree_util.tree_map(
                lambda x, y: jnp.concatenate([x, y], axis=0), metrics1, metrics2
            )

        return train_state, metrics

    return train


class MapState(struct.PyTreeNode):
    train_state: TrainState_flax = struct.field(pytree_node=True)


def train_rlhf_agent(args):
    rng = jax.random.PRNGKey(args.seed)

    # --- Initialize env ---
    env, env_params, config_env = get_env(args.env_name)

    # --- Load config ---
    if args.not_wandb_tuning:
        path = {args.main_folder_path} + "atari_rlhf/config.yaml"
        with open(path, "r") as file:
            rlhf_config = yaml.safe_load(file)
    else:
        rlhf_config = {
            "LR": args.LR,
            "LR_END": args.LR_END,
            "UPDATE_EPOCHS": args.UPDATE_EPOCHS,
            "ANNEAL_LR": args.ANNEAL_LR,
            "MAX_GRAD_NORM": args.MAX_GRAD_NORM,
            "BETA": args.BETA,
            "ALPHA": args.ALPHA,
            "MINIBATCH_SIZE": args.MINIBATCH_SIZE,
        }
    rlhf_config.update(
        {
            "ACTIVATION": config_env["ACTIVATION"],
            "DATASET_SIZE": args.num_data_points,
            "LOSS_TYPE": args.loss_type,
            "ANNNEAL_LR": True,
            "TRACKING": args.tracking,
            "STAGES": args.n_stages,
        },
        allow_val_change=True,
    )
    rlhf_config["UPDATE_EPOCHS"] = int(
        rlhf_config["UPDATE_EPOCHS"] * args.update_epochs_multiplier
    )

    # Init logger
    if not args.nolog:
        init_logger(args, rlhf_config)

    # Load agent
    if args.reference_agent is None:
        ref_agent = None
    else:
        print("Loading reference agent", args.reference_agent)
        ref_agent = jnp.load(args.reference_agent, allow_pickle=True)

    # Generate dataset
    dataset_path = f"{args.dataset_path}/{args.data_type}/{args.env_name}/dataset_{args.reward_agent1}vs{args.reward_agent2}.npz"
    dataset = jnp.load(dataset_path, allow_pickle=True).item()

    # Load mirror map
    if args.loss_type == "mpo":
        map = MPO_map(
            num_hidden_units=args.mmap_net_width,
            temporally_aware=args.temporally_aware,
            parametrised_reward_model=args.parametrised_reward_model,
            add_logsimoid_bias=args.add_logsimoid_bias,
            add_sft_bias=args.add_sft_bias,
            add_dpo_bias=args.add_dpo_bias,
            sft_term=args.sft_term,
        )
        param_tuner_single = lambda u: param_tuner_init(u, single=True)

        params_init = map.init(jax.random.PRNGKey(0), 1)["params"]
        tx = optax.adam(1e-3)
        train_state = TrainState_flax.create(
            apply_fn=map.apply, params=params_init, tx=tx
        )
        if args.map_location:
            param_reshaper = ParameterReshaper(train_state.params, n_devices=1)
            orbax_checkpointer = obcheckpoint.PyTreeCheckpointer()
            ckpt_path = args.map_location
            lpmd_train_state = orbax_checkpointer.restore(ckpt_path)["es_state"]["mean"]
            mmap_params = param_reshaper.reshape_single(lpmd_train_state)
            mmap_params = param_tuner_single(mmap_params)
            train_state = train_state.replace(params=mmap_params)
        mirror_map = MapState(train_state=train_state)
    else:
        mirror_map = None

    # --- Initialize candidate fitness fun ---
    # Train function
    agent_train_fn = make_train(config=rlhf_config, env=env, env_params=env_params)
    # Make eval function
    network = ActorCritic(
        env.num_actions,
        activation=config_env["ACTIVATION"],
    )
    rollout_eval = partial(
        rollout,
        num_envs=args.num_eval_agents,
        num_steps=NUM_STEPS,
        env=env,
        env_params=env_params,
        network=network,
        return_reward=True,
        without_restart=True,
    )

    def sample_dataset(rng, dataset, num_samples=1):
        indices = jax.random.randint(rng, (num_samples,), 0, dataset.obs.shape[0])
        return jax.tree_map(lambda x: x[indices], dataset)

    def _compute_candidate_fitness(rng, preferences):
        """Train and evaluate an agent with an LPMD parameter candidate."""

        rng_data, rng_train, rng_eval = jax.random.split(rng, 3)

        preferences = sample_dataset(rng_data, preferences, args.num_data_points)

        n_valid_point = (
            jnp.sum(jnp.where(preferences.obs != -100, 1, 0)) + 1e-6
        )  # Avoid division by zero
        dataset_mean = (
            jnp.sum(jnp.where(preferences.obs == -100, 0, preferences.obs))
            / n_valid_point
        )
        dataset_var = (
            jnp.sum(
                jnp.where(
                    preferences.obs == -100, 0, (preferences.obs - dataset_mean) ** 2
                )
            )
            / n_valid_point
        )
        norm_stats = {"mean": dataset_mean, "var": dataset_var, "count": n_valid_point}

        # --- Evaluate candidate parameters ---
        if mirror_map is not None:
            map_params = mirror_map.train_state.params
        else:
            map_params = None
        # Train
        train_state, metrics = agent_train_fn(
            rng_train,
            preferences,
            mirror_map,
            map_params,
            start_agent=ref_agent,
            norm_stats=norm_stats,
        )

        # --- Compute return of trained agent ---
        _, _, all_rewards, _ = rollout_eval(
            agent_params=train_state.params,
            rng=rng_eval,
        )

        all_rewards_ref = jnp.zeros_like(all_rewards)

        return (
            all_rewards.mean() * NUM_STEPS,
            metrics,
            all_rewards_ref.mean() * NUM_STEPS,
        )

    compute_candidate_fitness = partial(_compute_candidate_fitness, preferences=dataset)

    n_runs = 25
    rngs = jax.random.split(jax.random.PRNGKey(args.seed), n_runs)
    returns, metrics, returns_ref = jax.jit(jax.vmap(compute_candidate_fitness))(
        rngs
    )
    print(
        "Avg_tot_reward: ",
        returns.mean(),
        "$\\pm$",
        returns.std() / jnp.sqrt(n_runs),
    )
    wandb.log(
        {
            "avg_tot_reward": returns.mean(),
            "std_tot_reward": returns.std() / jnp.sqrt(n_runs),
        }
    )
