import jax
import jax.numpy as jnp
import jax.random as jr
from functools import partial

from dataset_gen.judge import judge
from envs.rollout import rollout
from models.agents import ActorCriticCont as ActorCritic
from envs.env_utils import get_env


def gen_dataset(rng, args, agent_params=None, indeces=None):
    # Set rng
    rng_rollout, rng_judge, rng_noise, rng_shuffle = jr.split(rng, 4)

    # Init env
    env, env_params, config = get_env(
        args.env_name, backend="positional", indeces=indeces
    )

    # --- Generate dataset ---
    # Init network
    network = ActorCritic(
        env.num_actions,
        activation=config["ACTIVATION"],
    )

    # Load agents
    if agent_params is None:
        agent_params1, norm_stats1 = jnp.load(args.agent1, allow_pickle=True)
        agent_params2, norm_stats2 = jnp.load(args.agent2, allow_pickle=True)
        agents_params = jax.tree_util.tree_map(
            lambda x, y: jnp.stack([x, y]), agent_params1, agent_params2
        )
        norm_stats = jax.tree_util.tree_map(
            lambda x, y: jnp.stack([x, y]), norm_stats1, norm_stats2
        )
    else:
        agent_params1, norm_stats1 = agent_params
        agent_params2, norm_stats2 = agent_params
        agents_params = jax.tree_util.tree_map(
            lambda x, y: jnp.stack([x, y]), agent_params1, agent_params2
        )
        norm_stats = jax.tree_util.tree_map(
            lambda x, y: jnp.stack([x, y]), norm_stats1, norm_stats2
        )
    # Rollouts
    rollout_init = partial(
        rollout,
        num_envs=args.num_data_points,
        num_steps=1000,
        env=env,
        env_params=env_params,
        network=network,
        return_reward=True,
        without_restart=True,
    )

    all_obs, all_actions, all_rewards, _ = jax.vmap(rollout_init)(
        agent_params=agents_params, norm_stats=norm_stats, rng=jr.split(rng_rollout)
    )

    print("tot reward", all_rewards.mean() * 1000)

    all_obs1, all_obs2 = all_obs[0], all_obs[1]
    all_actions1, all_actions2 = all_actions[0], all_actions[1]
    all_rewards1, all_rewards2 = all_rewards[0], all_rewards[1]

    # Compute total rewards
    all_tot_rewards1 = jax.vmap(jnp.sum)(all_rewards1)
    all_tot_rewards2 = jax.vmap(jnp.sum)(all_rewards2)
    
    if args.noise > 0:
        flips = jr.bernoulli(rng_noise, args.noise, all_tot_rewards1.shape)
        temp1 = jnp.where(
            flips,
            all_tot_rewards2,
            all_tot_rewards1,
        )
        temp2 = jnp.where(
            flips,
            all_tot_rewards1,
            all_tot_rewards2,
        )
        all_tot_rewards1, all_tot_rewards2 = temp1, temp2

    if args.shuffle_agents == 1:
        print("Shuffling agents")
        all_obs = jnp.concatenate([all_obs1, all_obs2], axis=0)
        all_actions = jnp.concatenate([all_actions1, all_actions2], axis=0)
        all_tot_rewards = jnp.concatenate([all_tot_rewards1, all_tot_rewards2], axis=0)
        perm = jr.permutation(rng_shuffle, all_obs.shape[0])
        all_obs, all_actions, all_tot_rewards = (
            all_obs[perm],
            all_actions[perm],
            all_tot_rewards[perm],
        )
        all_obs1, all_obs2 = jnp.split(all_obs, 2)
        all_actions1, all_actions2 = jnp.split(all_actions, 2)
        all_tot_rewards1, all_tot_rewards2 = jnp.split(all_tot_rewards, 2)

    # Judge preferences
    rngs = jr.split(rng_judge, all_obs1.shape[0])
    preferences = jax.vmap(judge, in_axes=(0, 0, 0, 0, 0, 0, 0, None))(
        rngs,
        all_obs1,
        all_actions1,
        all_tot_rewards1,
        all_obs2,
        all_actions2,
        all_tot_rewards2,
        args,
    )
    return preferences
