import argparse
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np

# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
from envs.rollout import rollout
from models.agents import ActorCriticCont as ActorCritic
from rl_training.ppo import make_train
from envs.env_utils import get_env
from utils.utils import save_to_file


def train_ref_agents(rng_seed, env_name, save_dir, indeces):
    # Create environment
    rng = jr.PRNGKey(rng_seed)
    env, env_params, config = get_env(env_name, backend="positional", indeces=indeces, normalize=True)

    # Init agent
    config["NUM_UPDATES"] = 1
    train_fn = make_train(
        config,
        env,
        env_params,
        start_from_prev=False,
    )
    expert_train_out = jax.jit(train_fn)(rng)

    save_to_file(
        expert_train_out[0].params,
        save_dir,
        "agent_init.npy",
    )

    # Train agents
    config["NUM_UPDATES"] = config["NUM_UPDATES"] // 50
    train_fn_from_prev = make_train(
        config,
        env,
        env_params,
        start_from_prev=True,
    )
    reward_classes = jnp.linspace(0, 200, 21)
    reward_classes_to_save = set(np.arange(len(reward_classes)))

    # Evaluate net
    network = ActorCritic(
        env.action_space(env_params).shape[0],
        activation=config["ACTIVATION"],
        normalize=config["NORMALIZE_OBS"],
    )

    train_fn_from_prev_jit = jax.jit(train_fn_from_prev)
    for i in range(50):
        expert_train_out = train_fn_from_prev_jit(rng, expert_train_out)

        rng_rollout = jr.PRNGKey(rng_seed + i)
        num_eval_agents = 1000
        _, _, all_rewards, _ = rollout(
            num_eval_agents,
            1000,
            env,
            env_params,
            expert_train_out[0].params,
            rng_rollout,
            network,
            return_reward=True,
            without_restart=True,
        )
        cur_class = jnp.sum(all_rewards.sum(axis=1).mean() >= reward_classes) - 1
        if cur_class.astype(int).item() in reward_classes_to_save:
            reward_classes_to_save.remove(cur_class.astype(int).item())
            save_to_file(
                expert_train_out[0].params,
                save_dir,
                f"agent_{reward_classes[cur_class]}.npy",
            )


if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("--env_name", type=str, default="hopper")
    arg_parser.add_argument(
        "--save_dir",
        type=str,
    )
    arg_parser.add_argument("--seed", type=int, default=0)
    arg_parser.add_argument("--indeces", type=int, nargs="+", default=None)
    args = arg_parser.parse_args()
    print(args.indeces)
    temp = "constrained" if args.indeces else "unconstrained"
    print(temp)
    train_ref_agents(
        args.seed,
        args.env_name,
        args.save_dir + f"/{temp}/{args.env_name}/",
        args.indeces,
    )
