import numpy as np
import torch
import gym
import argparse
import os
import gym_platform

from common import ClickPythonLiteralOption
from common.platform_domain import PlatformFlattenedActionWrapper
from common.goal_domain import GoalFlattenedActionWrapper, GoalObservationWrapper
from common.wrappers import ScaledStateWrapper, ScaledParameterisedActionWrapper

from mpc_model.mpc_buffer import ReplayBuffer
# from mpc_model.mpc_model import DynamicModel, RewardModel
from mpc_model.better_mpc_model import nDynamicModel
from mpc_model.hierachical_model import hmodel, hreplayBuffer

from mpc_model.mpc_planning import MPC
import matplotlib.pyplot as plt

from mpc_model.utils import train
from HyAR_RL.utils import ReplayBuffer as embedbuffer
from embedding import ActionRepresentation_vae
from mpc_model.hyar_model import vae_train

import inspect
import wandb
import time

action_type = ["run", "hop", "leap"]


def save_points(args):
    run = wandb.init(
        project="pamdp-mpc",
        config=args,
        dir="../scratch/wandb"
    )


def pad_action(act, act_param):
    params = [np.zeros((2,), dtype=np.float32), np.zeros((1,), dtype=np.float32), np.zeros((1,), dtype=np.float32)]
    params[act][:] = act_param
    return (act, params)


def unwrap_action(mpc, dm, rm, cm, obs, training=True, debug=False,
                  action_parameter_sizes=[2, 1, 1],
                  discrete_action_dim=3,
                  parameter_action_dim=4,
                  max_action=1.0):
    with torch.no_grad():
        all_discrete_action, all_parameter_action, pred_s = mpc.act(dm, rm, cm, obs, training, debug)

    if training:
        all_discrete_action = (
                all_discrete_action + np.random.normal(0, max_action * args.expl_noise, size=discrete_action_dim)
        ).clip(-max_action, max_action)
        all_parameter_action = (
                all_parameter_action + np.random.normal(0, max_action * args.expl_noise, size=parameter_action_dim)
        ).clip(-max_action, max_action)

    discrete_action = np.argmax(all_discrete_action)

    offset = np.array([action_parameter_sizes[i] for i in range(discrete_action)], dtype=int).sum()
    parameter_action = all_parameter_action[offset:offset + action_parameter_sizes[discrete_action]]

    action = pad_action(discrete_action, parameter_action)

    return all_discrete_action, all_parameter_action, action, pred_s


# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
def evaluate(env, mpc, dm, rm, cm, episodes=100, vis=False):
    returns = []
    epioside_steps = []
    debug = True

    for epi in range(episodes):
        print(f"test: {epi}")
        state, _ = env.reset()

        if vis:
            env.render()

        terminal = False
        t = 0
        total_reward = 0.
        mpc.reset()
        while not terminal:
            t += 1
            state = np.array(state, dtype=np.float32, copy=False)
            discrete_action, parameter_action, action, _ = unwrap_action(mpc, dm, rm, cm, state, False, False)
            debug = False
            (state, _), reward, terminal, _ = env.step(action)
            print(t, reward)

            if vis:
                env.render()
                time.sleep(1)

            total_reward += reward
        epioside_steps.append(t)
        returns.append(total_reward)
    print("---------------------------------------")
    print(
        f"Evaluation over {episodes} episodes_rewards: {np.array(returns).mean():.3f} epioside_steps: {np.array(epioside_steps).mean():.3f}")
    print("---------------------------------------")
    return np.array(returns).mean(), np.array(epioside_steps).mean()


def run(args):
    print("---------------------------------------")
    print(f"Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    if args.save_points:
        save_points(vars(args))

    if not os.path.exists("./results"):
        os.makedirs("./results")

    if args.env == "Platform-v0":
        env = gym.make(args.env)
        env = ScaledStateWrapper(env)
        env = PlatformFlattenedActionWrapper(env)
        env = ScaledParameterisedActionWrapper(env)
    elif args.env == "Goal-v0":
        env = gym.make('Goal-v0')
        env = GoalObservationWrapper(env)
        env = GoalFlattenedActionWrapper(env)
        env = ScaledParameterisedActionWrapper(env)
        env = ScaledStateWrapper(env)

    # Set seeds
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    state_dim = env.observation_space.spaces[0].shape[0]

    discrete_action_dim = env.action_space.spaces[0].n
    action_parameter_sizes = np.array(
        [env.action_space.spaces[i].shape[0] for i in range(1, discrete_action_dim + 1)])
    parameter_action_dim = int(action_parameter_sizes.sum())
    max_action = 1.0
    print("state_dim", state_dim)
    print("discrete_action_dim", discrete_action_dim)
    print("parameter_action_dim", parameter_action_dim)
    # print(action_parameter_sizes)
    # exit()

    args.state_dim = state_dim
    args.discrete_action_dim = discrete_action_dim
    args.parameter_action_dim = parameter_action_dim
    args.action_parameter_sizes = action_parameter_sizes
    args.max_action = max_action

    state, terminal = env.reset(), False

    if args.which_model == 'normal':
        model = nDynamicModel(args, label='state')
        reward_model = nDynamicModel(args, label='reward')
        continuous = nDynamicModel(args, label='continuous')

        replay_buffer = ReplayBuffer(state_dim=state_dim,
                                     discrete_action_dim=discrete_action_dim,
                                     all_parameter_action_dim=parameter_action_dim,
                                     max_size=int(1e5))

    elif args.which_model == 'h':
        model = hmodel(args, label='state')
        reward_model = hmodel(args, label='reward')
        continuous = hmodel(args, label='continuous')

        replay_buffer = hreplayBuffer(state_dim=state_dim,
                                      discrete_action_dim=discrete_action_dim,
                                      all_parameter_action_dim=parameter_action_dim,
                                      max_size=int(1e5))
    else:
        assert f"not implemented with {args.which_model}, only h or normal"

    if args.embed:
        # elif args.which_model == 'hyar':
        discrete_emb_dim = discrete_action_dim * 2
        parameter_emb_dim = parameter_action_dim * 2
        action_rep = ActionRepresentation_vae.Action_representation(state_dim=state_dim,
                                                                    action_dim=discrete_action_dim,
                                                                    parameter_action_dim=1,
                                                                    reduced_action_dim=discrete_emb_dim,
                                                                    reduce_parameter_action_dim=parameter_emb_dim
                                                                    )

    Test_Reward_100 = []

    print('load mpc')
    if args.mpc_type == "CEM" or args.mpc_type == "Random":
        mpc = MPC(args)
    elif args.mpc_type == "Enumeration":
        mpc = rand(args)
    else:
        raise "NOOO"

    Test_Reward, Test_epioside_step = evaluate(env, mpc, model, reward_model, continuous, episodes=5, vis=args.visualise)
    Test_Reward_100.append(Test_Reward)

    if args.save_points:
        wandb.log({"Test_Reward_100": Test_Reward,
                    "Test_epioside_step_100": Test_epioside_step})

    print("---------- START PRE-TRAIN ----------")
    for epi in range(args.pretrain_episodes):
        obs, _ = env.reset()
        terminal = False
        while not terminal:
            raw_action = env.action_space.sample()
            action = pad_action(raw_action[0], raw_action[1:][raw_action[0]])
            da = np.zeros(3)
            da[action[0]] = 1
            pa = action[1]
            (obs_next, _), reward, terminal, _ = env.step(action)  # terminal=1: episode terminated
            replay_buffer.add(obs,
                              discrete_action=da,
                              all_parameter_action=np.concatenate(pa),
                              next_state=obs_next - obs,
                              reward=reward,
                              terminal=terminal)

            obs = obs_next

    if args.embed:
        recon_s_loss = []
        VAE_batch_size = 64
        save_dir = "result/platform_model/mix/1.0/0526"
        save_dir = os.path.join(save_dir, "{}".format(str(66)))
        print("save_dir", save_dir)
        os.makedirs(save_dir, exist_ok=True)
        vae_save_model = True

        c_rate, recon_s = vae_train(action_rep=action_rep, train_step=5000,
                                    replay_buffer=replay_buffer,
                                    batch_size=VAE_batch_size,
                                    save_dir=save_dir, vae_save_model=vae_save_model, embed_lr=1e-4)

    train_models = [['state', model], ['reward', reward_model], ['terminal', continuous]] if args.use_terminal else [
        ['state', model], ['reward', reward_model]]

    train_data = {
        'state': replay_buffer.sample(args.dm_datasetlen),
        'reward': replay_buffer.sample(args.r_datasetlen),
        'terminal': replay_buffer.sample(args.c_datasetlen),
    } if args.use_terminal else {
        'state': replay_buffer.sample(args.r_datasetlen),
        'reward': replay_buffer.sample(args.c_datasetlen),
    }

    train(models=train_models,
          data=train_data,
          logger=args.save_points,
          model_type=args.which_model)

    print("---------- DONE PRE-TRAIN ----------")

    Test_Reward, Test_epioside_step = evaluate(env, mpc, model, reward_model, continuous, episodes=5,
                                               vis=args.visualise)
    Test_Reward_100.append(Test_Reward)
    if args.save_points:
        wandb.log({"Test_Reward_100": Test_Reward,
                   "Test_epioside_step_100": Test_epioside_step})

    returns = []
    Reward_100 = []
    max_steps = 250

    Test_epioside_step_100 = []
    total_timesteps = 0
    epi = 0

    while total_timesteps < args.max_timesteps:
        cur_step = 0
        epi += 1
        state, _ = env.reset()
        state = np.array(state, dtype=np.float32, copy=False)

        discrete_action, parameter_action, action, pred_s = unwrap_action(mpc, model, reward_model, continuous, state)
        episode_reward = 0.

        mpc.reset()  # reset mean, var
        for i in range(max_steps):
            cur_step += 1
            total_timesteps += 1
            print(f"total_timesteps: {total_timesteps}, episode: {epi}")
            ret = env.step(action)
            (next_state, steps), reward, terminal, _ = ret

            if args.change_r:
                reward = -1 if terminal else reward

            next_state = np.array(next_state, dtype=np.float32, copy=False)
            episode_reward += reward
            replay_buffer.add(state,
                              discrete_action=discrete_action,
                              all_parameter_action=parameter_action,
                              next_state=next_state - state,
                              reward=reward, terminal=terminal)

            discrete_action, parameter_action, action, pred_s = unwrap_action(mpc, model, reward_model, continuous,
                                                                              next_state)
            state = next_state

            if total_timesteps % args.train_every == 0:
                train_models = [['state', model], ['reward', reward_model],
                                ['terminal', continuous]] if args.use_terminal else [['state', model],
                                                                                     ['reward', reward_model]]

                train_data = {
                    'state': replay_buffer.sample(args.dm_datasetlen),
                    'reward': replay_buffer.sample(args.r_datasetlen),
                    'terminal': replay_buffer.sample(args.c_datasetlen),
                } if args.use_terminal else {
                    'state': replay_buffer.sample(args.r_datasetlen),
                    'reward': replay_buffer.sample(args.c_datasetlen),
                }

                train(models=train_models,
                      data=train_data,
                      logger=args.save_points,
                      model_type=args.which_model)

            if args.embed and (total_timesteps % 10 == 0) and (total_timesteps >= 1000):
                c_rate, recon_s = vae_train(action_rep=action_rep, train_step=1,
                                            replay_buffer=replay_buffer,
                                            batch_size=VAE_batch_size, save_dir=save_dir, vae_save_model=vae_save_model,
                                            embed_lr=1e-4)

                recon_s_loss.append(recon_s)

            if total_timesteps % args.eval_freq == 0:
                while not terminal:
                    raw_action = env.action_space.sample()
                    action = pad_action(raw_action[0], raw_action[1:][raw_action[0]])
                    (state, _), reward, terminal, _ = env.step(action)

                Reward_100.append(np.array(returns[-100:]).mean())
                Test_Reward, Test_epioside_step = evaluate(env, mpc, model, reward_model, continuous, episodes=5,
                                                           vis=args.visualise)
                Test_Reward_100.append(Test_Reward)
                Test_epioside_step_100.append(Test_epioside_step)

                print(f"Step: {total_timesteps} ||R100: {np.array(returns[-100:]).mean()} || "
                      f"Test_Reward_100: {Test_Reward} || Test_epioside_step_100 {Test_epioside_step}")

                if args.save_points:
                    wandb.log({"Test_Reward_100": Test_Reward,
                               "Test_epioside_step_100": Test_epioside_step})

            if terminal:
                break

        returns.append(episode_reward)

    print("save txt")
    dir = "result/MPC/goal"
    data = args.save_dir
    redir = os.path.join(dir, data)
    if not os.path.exists(redir):
        os.makedirs(redir)
    print("redir", redir)
    title2 = "Reward_100_td3_platform_"
    title3 = "Test_Reward_100_td3_platform_"
    title4 = "Test_epioside_step_100_td3_platform_"
    np.savetxt(os.path.join(redir, title2 + "{}".format(str(args.seed) + ".csv")), Reward_100, delimiter=',')
    np.savetxt(os.path.join(redir, title3 + "{}".format(str(args.seed) + ".csv")), Test_Reward_100, delimiter=',')
    np.savetxt(os.path.join(redir, title4 + "{}".format(str(args.seed) + ".csv")), Test_epioside_step_100,
               delimiter=',')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default='Goal-v0')  # platform goal HFO
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds

    parser.add_argument("--max_timesteps", default=50_000, type=float)  # Max time steps to run environment for
    parser.add_argument("--eval_freq", default=50, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--expl_noise", default=0.1)  # Std of Gaussian exploration noise 0.1

    parser.add_argument("--train_every", default=50, type=int)  # How often (time steps) we evaluate

    parser.add_argument("--pretrain_episodes", default=40, type=int)  # How many steps to pre-train

    parser.add_argument("--dm_hard", default=0, type=int)
    parser.add_argument("--dm_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--dm_lr", default=5e-4, type=float)
    parser.add_argument("--dm_batchsize", default=256, type=int)
    parser.add_argument("--dm_saveflag", default=0, type=int)
    parser.add_argument("--dm_savepath", default='mpc_model/models/dm.pth', type=str)
    parser.add_argument("--dm_loadpath", default='mpc_model/models/dm.pth', type=str)
    parser.add_argument("--dm_valflag", default=0, type=int)
    parser.add_argument("--dm_valfreq", default=500, type=int)  # freq
    parser.add_argument("--dm_valrati", default=0., type=float)
    parser.add_argument("--dm_loadmodel", default=0, type=int)
    parser.add_argument("--dm_layers", default=[64, 64], type=list)
    parser.add_argument("--dm_datasetlen", default=1e5, type=int)

    # parser.add_argument("--dm_continue", default=1, type=int)
    parser.add_argument("--dm_onepa", default=1, type=int)
    parser.add_argument("--use_terminal", default=1, type=int)
    parser.add_argument("--change_r", default=0, type=int)

    parser.add_argument("--r_hard", default=0, type=int)
    parser.add_argument("--r_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--r_lr", default=5e-4, type=float)
    parser.add_argument("--r_batchsize", default=256, type=int)
    parser.add_argument("--r_saveflag", default=0, type=int)
    parser.add_argument("--r_savepath", default='mpc_model/models/r.pth', type=str)
    parser.add_argument("--r_loadpath", default='mpc_model/models/r.pth', type=str)
    parser.add_argument("--r_valflag", default=0, type=int)
    parser.add_argument("--r_valfreq", default=500, type=int)  # freq
    parser.add_argument("--r_valrati", default=0., type=float)
    parser.add_argument("--r_loadmodel", default=0, type=int)
    parser.add_argument("--r_layers", default=[64, 64], type=list)
    parser.add_argument("--r_datasetlen", default=1e5, type=int)

    parser.add_argument("--c_hard", default=0, type=int)
    parser.add_argument("--c_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--c_lr", default=5e-4, type=float)
    parser.add_argument("--c_batchsize", default=256, type=int)
    parser.add_argument("--c_saveflag", default=0, type=int)
    parser.add_argument("--c_savepath", default='mpc_model/models/c.pth', type=str)
    parser.add_argument("--c_loadpath", default='mpc_model/models/c.pth', type=str)
    parser.add_argument("--c_valflag", default=0, type=int)
    parser.add_argument("--c_valfreq", default=500, type=int)  # freq
    parser.add_argument("--c_valrati", default=0., type=float)
    parser.add_argument("--c_loadmodel", default=0, type=int)
    parser.add_argument("--c_layers", default=[64, 64], type=list)
    parser.add_argument("--c_datasetlen", default=1e5, type=int)

    parser.add_argument("--mpc_horizon", default=25, type=int)
    parser.add_argument("--mpc_gamma", default=1., type=float)
    parser.add_argument("--mpc_popsize", default=2000, type=int)
    parser.add_argument("--mpc_num_elites", default=40, type=int)
    parser.add_argument("--mpc_patrical", default=1, type=int)
    parser.add_argument("--mpc_init_mean", default=0., type=float)
    parser.add_argument("--mpc_init_var", default=1., type=float)
    parser.add_argument("--mpc_epsilon", default=0.001, type=float)
    parser.add_argument("--mpc_alpha", default=0.1, type=float)
    parser.add_argument("--mpc_max_iters", default=1e3, type=int)
    parser.add_argument("--mpc_type", default="Random", type=str)  # CEM, Random, Enumeration
    # parser.add_argument("--mpc_mode", default="hard", type=str)  # hard, dl

    parser.add_argument('--which_model', default="normal", type=str)  # normal, h
    parser.add_argument('--embed', default=0, type=int)  # normal, h

    parser.add_argument('--save_dir', default="070901", type=str)
    # parser.add_argument('--save-frames', default=1, type=int)
    parser.add_argument('--visualise', default=1, type=int)
    parser.add_argument("--save_points", default=0, type=int)

    args = parser.parse_args()
    # run(args)
    for i in range(0, 3):
        args.seed = i
        run(args)
