import argparse
import os
from collections import defaultdict
import numpy as np
from algos import utils_finetune as utils
from logger import logger, setup_logger
import matplotlib.pyplot as plt
from matplotlib import cm
import algos.algos_vae_finetune as algos
import gym
import multitask_env
import d4rl

# from d3rlpy.torch_utility import hard_sync
import torch
from torch import nn


def hard_sync(targ_model: nn.Module, model: nn.Module) -> None:
    with torch.no_grad():
        params = model.parameters()
        targ_params = targ_model.parameters()
        for p, p_targ in zip(params, targ_params):
            p_targ.data.copy_(p.data)


# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
def eval_policy(
    policy,
    env,
    eval_episodes=10,
    plot=False,
    num_actions_to_sample: int = 1,
    max_std_dev: float = -1,
    latent_noise: float = 0,
    action_space_noise: float = 0,
):
    avg_reward = 0.0
    avg_success = 0

    plt.clf()
    start_states = []
    color_list = cm.rainbow(np.linspace(0, 1, eval_episodes + 2))

    if num_actions_to_sample > 1 and latent_noise == 0 and action_space_noise == 0:
        deterministic = False
    else:
        deterministic = True

    for i in range(eval_episodes):
        state, done = env.reset(), False
        states_list = []
        start_states.append(state)
        success = 0
        episode_return = 0
        while not done:
            # action, action_diagnostics = policy.select_action(np.array(state))
            stuff = policy.select_action(
                np.array(state),
                deterministic=deterministic,
                num_actions_to_sample=num_actions_to_sample,
                max_std_dev=max_std_dev,
                latent_noise=latent_noise,
                action_space_noise=action_space_noise,
            )
            if len(stuff) != 2:
                print("STUFF: ", stuff)
                raise RuntimeError
            else:
                action, action_diagnostics = stuff
            state, reward, done, _ = env.step(action)
            if reward == 1:
                success = 1
            avg_reward += reward
            episode_return += reward
            states_list.append(state)

        avg_success += success
        states_list = np.array(states_list)

        if plot:
            plt.scatter(
                states_list[:, 0], states_list[:, 1], color=color_list[i], alpha=0.1
            )
            plt.scatter(8, 10, color="white", alpha=0.1)
            plt.scatter(2, 0, color="white", alpha=0.1)
    if plot:
        start_states = np.array(start_states)
        plt.scatter(start_states[:, 0], start_states[:, 1], color="red")
        # plt.pause(0.1)
        plt.savefig("./eval_finetune_fig")

    avg_reward /= eval_episodes
    avg_success /= eval_episodes

    normalized_score = env.get_normalized_score(avg_reward)

    deterministic_name = (
        "deterministic" if deterministic else f"stoch_{num_actions_to_sample}_actions"
    )
    if latent_noise != 0 or action_space_noise == 0:
        deterministic_name = f"stoch_fixed_noise_{num_actions_to_sample}_actions"
    info = {
        f"{deterministic_name}_AverageReturn": avg_reward,
        f"{deterministic_name}_NormReturn": normalized_score,
        f"{deterministic_name}_AverageSuccess": avg_success,
    }
    print("---------------------------------------")
    print(
        f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}, {normalized_score:.3f}, {avg_success:.3f}"
    )
    print("---------------------------------------")
    return info


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Additional parameters
    parser.add_argument("--ExpID", default=9999, type=int)  # Experiment ID
    parser.add_argument("--exp_name", type=str)
    parser.add_argument(
        "--log_dir", default="./results/finetune/", type=str
    )  # Logging directory
    parser.add_argument(
        "--load_model", default=0, type=float
    )  # Load model and optimizer parameters
    parser.add_argument(
        "--save_model", default=True, type=bool
    )  # Save model and optimizer parameters
    parser.add_argument(
        "--save_freq", default=5e5, type=int
    )  # How often it saves the model
    parser.add_argument(
        "--env_name", default="walker2d-medium-v2"
    )  # OpenAI gym environment name
    parser.add_argument(
        "--seed", default=0, type=int
    )  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument(
        "--eval_freq", default=5e3, type=int
    )  # How often (time steps) we evaluate
    parser.add_argument(
        "--max_timesteps", default=5e5, type=int
    )  # Max time steps to run environment for
    parser.add_argument("--batch_size", default=512, type=int)
    parser.add_argument("--vae_lr", default=2e-4, type=float)  # policy learning rate
    parser.add_argument("--actor_lr", default=2e-4, type=float)  # policy learning rate
    parser.add_argument("--critic_lr", default=2e-4, type=float)  # policy learning rate
    parser.add_argument("--tau", default=0.005, type=float)  # actor network size
    parser.add_argument("--discount", default=0.99, type=float)  # actor network size

    parser.add_argument(
        "--expectile", default=0.9, type=float
    )  # expectile to compute weight for samples
    parser.add_argument(
        "--kl_beta", default=1, type=float
    )  # weight for kl loss to train CVAE
    parser.add_argument(
        "--max_latent_action", default=2.0, type=float
    )  # maximum value for the latent policy
    parser.add_argument("--doubleq_min", default=1, type=float)
    parser.add_argument("--no_piz", action="store_true")
    parser.add_argument("--no_noise", action="store_true")

    parser.add_argument("--plot", action="store_true")
    parser.add_argument("--device", default="cuda:0", type=str)
    parser.add_argument("--deterministic_model", type=int)
    parser.add_argument("--deterministic_actions", type=int)
    parser.add_argument("--std_architecture", type=str)
    parser.add_argument("--entropy_term_weight", type=float)
    parser.add_argument("--max_entropy_training_steps", type=int)
    parser.add_argument("--load_path", type=str)
    parser.add_argument("--dataset_noise", action="store_true")
    parser.add_argument("--num_actions_to_sample", type=int)
    parser.add_argument("--add_action_noise", action="store_true")
    parser.add_argument("--max_std_dev", type=float)
    parser.add_argument("--log_sig_max", type=float, default=2)
    parser.add_argument("--log_sig_min", type=float, default=-20)
    parser.add_argument("--latent_noise", type=float, default=0)
    parser.add_argument("--action_space_noise", type=float, default=0)
    parser.add_argument("--num_q_functions", type=int, default=2)
    parser.add_argument("--optimism_parameter", type=float, default=0)
    parser.add_argument("--gradient_clipping", type=float, default=None)
    parser.add_argument("--not_train_v_func", action="store_true")
    parser.add_argument("--critic_training_iterations", type=int, default=1)

    args = parser.parse_args()

    # Setup Logging
    file_name = f"Exp{args.ExpID:04d}/{args.env_name}/{args.exp_name}/seed_{args.seed}"
    folder_name = os.path.join(args.log_dir, file_name)
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    if args.load_model == 0 and os.path.exists(
        os.path.join(folder_name, "progress.csv")
    ):
        print("exp file already exist")
        # raise AssertionError

    variant = vars(args)
    variant.update(node=os.uname()[1])
    setup_logger(os.path.basename(folder_name), variant=variant, log_dir=folder_name)

    # Setup Environment
    env = gym.make(args.env_name)
    env_train = gym.make(args.env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    print(action_dim)

    # Set seeds
    # env.seed(args.seed)
    # env.action_space.seed(args.seed)
    # torch.manual_seed(args.seed)
    # np.random.seed(args.seed)

    # dataset = env.get_dataset()
    # for i in range(1005):
    #     # if dataset['timeouts'][i] or dataset['terminals'][i]:
    #     print(i, dataset['rewards'][i], dataset['timeouts'][i], dataset['terminals'][i])
    # input()
    # Load Dataset
    dataset = d4rl.qlearning_dataset(env)  # Load d4rl dataset
    if "antmaze" in args.env_name:
        dataset["rewards"] = dataset["rewards"] * 100
        min_v = 0  # -np.inf
        max_v = 100
    else:
        max_reward = dataset["rewards"].max()
        dataset["rewards"] = dataset["rewards"] / dataset["rewards"].max()
        min_v = dataset["rewards"].min() / (1 - args.discount)
        max_v = dataset["rewards"].max() / (1 - args.discount)

    print("min max r", dataset["rewards"].min(), dataset["rewards"].max())
    replay_buffer = utils.ReplayBuffer(
        state_dim,
        action_dim,
        args.device,
        max_size=max(int(2e6), len(dataset["rewards"])),
    )
    replay_buffer.load(dataset)

    buffer_size = len(dataset["rewards"])

    latent_dim = action_dim * 2  # int(action_dim * 0.5) + 1
    policy = algos.Latent(
        state_dim,
        action_dim,
        latent_dim,
        max_action,
        min_v,
        max_v,
        replay_buffer=replay_buffer,
        device=args.device,
        discount=args.discount,
        tau=args.tau,
        vae_lr=args.vae_lr,
        actor_lr=args.actor_lr,
        critic_lr=args.critic_lr,
        max_latent_action=args.max_latent_action,
        expectile=args.expectile,
        kl_beta=args.kl_beta,
        no_piz=args.no_piz,
        no_noise=args.no_noise,
        doubleq_min=args.doubleq_min,
        deterministic_actions=bool(args.deterministic_model),
        std_architecture=args.std_architecture,
        num_q_functions=args.num_q_functions,
        entropy_term_weight=0,
        dataset_noise=args.dataset_noise,
        log_sig_max=args.log_sig_max,
        log_sig_min=args.log_sig_min,
        optimism_parameter=args.optimism_parameter,
        gradient_clipping=args.gradient_clipping,
        not_train_v_func=args.not_train_v_func,
    )

    if args.load_model != 0:
        policy.load("model_" + str(args.load_model), folder_name)
        training_iters = int(args.load_model)
    else:
        training_iters = 0

    policy.load("model_" + str(1000000.0), args.load_path)

    state, done = env_train.reset(), False
    ep_count, ep_reward = 0, 0
    train_iteras = args.eval_freq
    environment_steps = 0
    while training_iters < args.max_timesteps:
        # Eval
        if training_iters % args.eval_freq == 0:
            print("Training iterations: " + str(training_iters))
            logger.record_tabular(
                "Training Epochs", int(training_iters // int(args.eval_freq))
            )
            info = eval_policy(
                policy, env, plot=args.plot, max_std_dev=args.max_std_dev
            )
            for k, v in info.items():
                logger.record_tabular(k, v)
            stochastic_info = eval_policy(
                policy,
                env,
                plot=args.plot,
                num_actions_to_sample=args.num_actions_to_sample,
                max_std_dev=args.max_std_dev,
                latent_noise=args.latent_noise,
                action_space_noise=args.action_space_noise,
            )
            for k, v in stochastic_info.items():
                logger.record_tabular(k, v)

            logger.dump_tabular()
            print(
                args.env_name,
                "policy min_v, max_v",
                replay_buffer.size,
                policy.min_v,
                policy.max_v,
            )

        # Collect sample
        train_iteras, add_sample = 100, 100
        if training_iters % (train_iteras) == 0:
            count = 0
            actions_diagnostics = defaultdict(float)
            while count < add_sample:  # buffer_size*0.01:
                action, action_diagnostics = policy.select_action(
                    np.array(state),
                    deterministic=bool(args.deterministic_actions),  # False,
                    num_actions_to_sample=args.num_actions_to_sample,
                    max_std_dev=args.max_std_dev,
                    latent_noise=args.latent_noise,
                    action_space_noise=args.action_space_noise,
                )
                if args.add_action_noise:
                    action = action + (np.random.randn(*action.shape) * 0.1).clip(
                        -0.3, 0.3
                    )
                for key, value in action_diagnostics.items():
                    if key != "action_samples":
                        actions_diagnostics[key] += float(value)
                # logger.record_dict(action_diagnostics)
                next_state, reward, done, _ = env_train.step(action)
                environment_steps += 1
                count += 1
                ep_count += 1
                ep_reward += reward
                if "antmaze" in args.env_name:
                    reward = reward * 100
                elif "kitchen" in args.env_name:
                    reward = ep_reward / max_reward
                else:
                    reward = reward / max_reward

                replay_buffer.add(state, action, next_state, reward, done)
                state = next_state.copy()

                if ep_count == env_train._max_episode_steps - 1 or done:
                    # print(training_iters, done, args.env_name, ep_count, env_train._max_episode_steps, ep_reward)
                    state, done = env_train.reset(), False
                    ep_count, ep_reward = 0, 0
            for key, value in actions_diagnostics.items():
                actions_diagnostics[key] = value / count
            logger.record_dict(actions_diagnostics)
            logger.record_tabular(
                "replay_buffer_state_std",
                replay_buffer.storage["state"].std(axis=0).mean(),
            )
            logger.record_tabular(
                "replay_buffer_action_std",
                replay_buffer.storage["action"].std(axis=0).mean(),
            )
            logger.record_tabular("environment_step", environment_steps)

            # replay_buffer.renormalize()
            if "antmaze" in args.env_name:
                policy.min_v = 0
                policy.max_v = 100
            else:
                min_r = replay_buffer.min_r
                max_r = replay_buffer.max_r
                policy.min_v = min_r / (1 - args.discount)
                policy.max_v = max_r / (1 - args.discount)

        policy.train(
            iterations=int(train_iteras),
            step=training_iters,
            batch_size=args.batch_size,
            critic_training_iterations=args.critic_training_iterations,
        )

        training_iters += train_iteras

        # Save Model
        if training_iters % args.save_freq == 0 and args.save_model:
            policy.save("model_" + str(training_iters), folder_name)
            replay_buffer.save(f"replay_buffer_{str(training_iters)}", folder_name)

    policy.save("model_" + str(training_iters), folder_name)
    replay_buffer.save(f"replay_buffer_{str(training_iters)}", folder_name)
