import argparse
import os
from collections import defaultdict
import numpy as np
from algos import utils
from logger import logger, setup_logger
import matplotlib.pyplot as plt
from matplotlib import cm
import algos.algos_vae as algos
import gym, multitask_env, d4rl


def eval_policy(
    policy, env, eval_episodes=10, plot=False, num_actions_to_sample: int = 1
):
    avg_reward = 0.0
    avg_success = 0
    plt.clf()
    start_states = []
    color_list = cm.rainbow(np.linspace(0, 1, eval_episodes + 2))

    if num_actions_to_sample > 1:
        deterministic = False
    else:
        deterministic = True

    actions_diagnostics = defaultdict(float)
    count = 0
    for i in range(eval_episodes):
        state, done = env.reset(), False
        states_list = []
        start_states.append(state)
        success = 0
        while not done:
            stuff = policy.select_action(
                np.array(state),
                deterministic=deterministic,
                num_actions_to_sample=num_actions_to_sample,
            )
            if len(stuff) != 2:
                print("STUFF: ", stuff)
                raise RuntimeError
            else:
                action, action_diagnostics = stuff
            for key, value in action_diagnostics.items():
                actions_diagnostics[key] += float(value)
            count += 1
            state, reward, done, _ = env.step(action)
            if reward == 1:
                success = 1
            avg_reward += reward
            states_list.append(state)

        avg_success += success
        states_list = np.array(states_list)

        if plot:
            plt.scatter(
                states_list[:, 0], states_list[:, 1], color=color_list[i], alpha=0.1
            )
            plt.scatter(8, 10, color="white", alpha=0.1)
            plt.scatter(2, 0, color="white", alpha=0.1)
    for key, value in actions_diagnostics.items():
        actions_diagnostics[key] = value / count
    logger.record_dict(actions_diagnostics)
    if plot:
        start_states = np.array(start_states)
        plt.scatter(start_states[:, 0], start_states[:, 1], color="red")
        plt.savefig("./eval_fig")

    avg_reward /= eval_episodes
    avg_success /= eval_episodes

    normalized_score = env.get_normalized_score(avg_reward)

    deterministic_name = (
        "deterministic" if deterministic else f"stoch_{num_actions_to_sample}_actions"
    )
    info = {
        f"{deterministic_name}_AverageReturn": avg_reward,
        f"{deterministic_name}_NormReturn": normalized_score,
        f"{deterministic_name}_AverageSuccess": avg_success,
    }
    print("---------------------------------------")
    print(
        f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}, {normalized_score:.3f}, {avg_success:.3f}"
    )
    print("---------------------------------------")
    return info


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Additional parameters
    parser.add_argument("--ExpID", default=9999, type=int)  # Experiment ID
    parser.add_argument("--exp_name", type=str)
    parser.add_argument(
        "--log_dir", default="./results/", type=str
    )  # Logging directory
    parser.add_argument(
        "--load_model", default=0, type=float
    )  # Load model and optimizer parameters
    parser.add_argument(
        "--save_model", default=True, type=bool
    )  # Save model and optimizer parameters
    parser.add_argument(
        "--save_freq", default=5e5, type=int
    )  # How often it saves the model
    parser.add_argument(
        "--env_name", default="walker2d-medium-v2"
    )  # OpenAI gym environment name
    parser.add_argument(
        "--seed", default=0, type=int
    )  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument(
        "--eval_freq", default=5e3, type=int
    )  # How often (time steps) we evaluate
    parser.add_argument(
        "--max_timesteps", default=1e6, type=int
    )  # Max time steps to run environment for
    parser.add_argument("--batch_size", default=512, type=int)
    parser.add_argument(
        "--vae_lr", default=2e-4, type=float
    )  # action policy (VAE) learning rate
    parser.add_argument(
        "--actor_lr", default=2e-4, type=float
    )  # latent policy learning rate
    parser.add_argument("--critic_lr", default=2e-4, type=float)  # critic learning rate
    parser.add_argument("--tau", default=0.005, type=float)  # delayed learning rate
    parser.add_argument("--discount", default=0.99, type=float)  # discount factor

    parser.add_argument(
        "--expectile", default=0.9, type=float
    )  # expectile to compute weight for samples
    parser.add_argument(
        "--kl_beta", default=1, type=float
    )  # weight for kl loss to train CVAE
    parser.add_argument(
        "--max_latent_action", default=2.0, type=float
    )  # maximum value for the latent policy
    parser.add_argument(
        "--doubleq_min", default=1, type=float
    )  # weight for the minimum Q value
    parser.add_argument(
        "--no_piz", action="store_true"
    )  # using the latent policy or the action policy
    parser.add_argument(
        "--no_noise", action="store_true"
    )  # adding noise to the latent policy or not

    parser.add_argument("--plot", action="store_true")
    parser.add_argument("--device", default="cuda:0", type=str)
    parser.add_argument("--deterministic_actions", type=int)
    parser.add_argument("--std_architecture", type=str)
    parser.add_argument("--num_actions_to_sample", type=int)  # For evaluation only
    parser.add_argument("--num_steps_from_dataset", type=int, default=-1)
    parser.add_argument("--log_sig_max", type=float, default=2)
    parser.add_argument("--log_sig_min", type=float, default=-20)
    parser.add_argument("--constant_std_init", type=float, default=None)
    parser.add_argument("--gradient_clipping", type=float, default=None)
    parser.add_argument("--num_q_functions", type=int)
    parser.add_argument("--not_train_v_func", action="store_true")
    parser.add_argument("--critic_training_iterations", type=int, default=1)

    args = parser.parse_args()

    # Setup Logging
    file_name = f"Exp{args.ExpID:04d}/{args.env_name}/{args.exp_name}/seed_{args.seed}"
    folder_name = os.path.join(args.log_dir, file_name)
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    if args.load_model == 0 and os.path.exists(
        os.path.join(folder_name, "progress.csv")
    ):
        print("exp file already exist")
        raise AssertionError

    variant = vars(args)
    variant.update(node=os.uname()[1])
    setup_logger(os.path.basename(folder_name), variant=variant, log_dir=folder_name)

    # Setup Environment
    env = gym.make(args.env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    print(action_dim)

    # Set seeds
    # env.seed(args.seed)
    # env.action_space.seed(args.seed)
    # torch.manual_seed(args.seed)
    # np.random.seed(args.seed)

    # Load Dataset
    dataset = d4rl.qlearning_dataset(env)  # Load d4rl dataset
    if "antmaze" in args.env_name:
        dataset["rewards"] = dataset["rewards"] * 100
        min_v = 0  # -np.inf
        max_v = 1 * 100
    else:
        dataset["rewards"] = dataset["rewards"] / dataset["rewards"].max()
        min_v = dataset["rewards"].min() / (1 - args.discount)
        max_v = dataset["rewards"].max() / (1 - args.discount)

    # Optionally get a subset of the dataset
    if args.num_steps_from_dataset > 0:
        print("------------ Getting subset of dataset ------------")
        print(f"Getting {args.num_steps_from_dataset} steps from the dataset")
        for dataset_key in dataset.keys():
            dataset[dataset_key] = dataset[dataset_key][: args.num_steps_from_dataset]

    replay_buffer = utils.ReplayBuffer(
        state_dim, action_dim, args.device, max_size=len(dataset["rewards"])
    )
    replay_buffer.load(dataset)
    if args.num_steps_from_dataset > 0:
        assert replay_buffer.size == args.num_steps_from_dataset

    latent_dim = action_dim * 2
    policy = algos.Latent(
        state_dim,
        action_dim,
        latent_dim,
        max_action,
        min_v,
        max_v,
        replay_buffer=replay_buffer,
        device=args.device,
        discount=args.discount,
        tau=args.tau,
        vae_lr=args.vae_lr,
        actor_lr=args.actor_lr,
        critic_lr=args.critic_lr,
        max_latent_action=args.max_latent_action,
        expectile=args.expectile,
        kl_beta=args.kl_beta,
        no_piz=args.no_piz,
        no_noise=args.no_noise,
        doubleq_min=args.doubleq_min,
        deterministic_actions=bool(args.deterministic_actions),
        std_architecture=args.std_architecture,
        log_sig_max=args.log_sig_max,
        log_sig_min=args.log_sig_min,
        constant_std_init=args.constant_std_init,
        gradient_clipping=args.gradient_clipping,
        num_q_functions=args.num_q_functions,
        not_train_v_func=args.not_train_v_func,
    )

    if args.load_model != 0:
        policy.load("model_" + str(args.load_model), folder_name)
        training_iters = int(args.load_model)
    else:
        training_iters = 0

    while training_iters < args.max_timesteps:
        # Train
        pol_vals = policy.train(
            iterations=int(args.eval_freq),
            batch_size=args.batch_size,
            critic_training_iterations=args.critic_training_iterations,
        )
        training_iters += args.eval_freq
        print("Training iterations: " + str(training_iters))
        logger.record_tabular(
            "Training Epochs", int(training_iters // int(args.eval_freq))
        )

        # Save Model
        if training_iters % args.save_freq == 0 and args.save_model:
            policy.save("model_" + str(training_iters), folder_name)

        # Eval
        info = eval_policy(policy, env, plot=args.plot)
        for k, v in info.items():
            logger.record_tabular(k, v)
        stochastic_info = eval_policy(
            policy,
            env,
            plot=args.plot,
            num_actions_to_sample=args.num_actions_to_sample,
        )
        for k, v in stochastic_info.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()

    policy.save("model_" + str(training_iters), folder_name)
