import matplotlib.pyplot as plt
from gymnasium.spaces import Box
import numpy as np
from gymnasium.envs.registration import register
import argparse
from pathlib import Path
import time
from functools import partial

import jax, jax.numpy as jnp

from relax.algorithm.sac import SAC
from relax.algorithm.dsact import DSACT
from relax.algorithm.dacer import DACER
from relax.algorithm.qsm import QSM
from relax.algorithm.ddsq import DDSQ
from relax.algorithm.dipo import DIPO
from relax.algorithm.qvpo import QVPO
from relax.buffer import TreeBuffer
from relax.network.sac import create_sac_net
from relax.network.dsact import create_dsact_net
from relax.network.dacer import create_dacer_net
from relax.network.qvpo import create_qvpo_net
from relax.network.qsm import create_qsm_net
from relax.network.ddsq import create_ddsq_net
from relax.network.dipo import create_dipo_net
from relax.trainer.train_2d import OffPolicyTrainer
from relax.env import create_env, create_vector_env
from relax.utils.experience import Experience, ObsActionPair
from relax.utils.fs import PROJECT_ROOT
from relax.utils.random_utils import seeding


class MultiGoalEnv:
    """
    Move a 2D point mass to one of the goal positions. Cost is the distance to
    the closest goal.

    State: position.
    Action: velocity.
    """

    def __init__(self, goal_reward=10, actuation_cost_coeff=30, distance_cost_coeff=1, init_sigma=0.1, seed=None):

        self.dynamics = PointDynamics(dim=2, sigma=0)
        self.init_mu = np.zeros(2, dtype=np.float32)
        self.init_sigma = init_sigma
        self.goal_positions = np.array([[5, 0], [-5, 0], [0, 5], [0, -5]], dtype=np.float32)
        self.goal_threshold = 1.0
        self.goal_reward = goal_reward
        self.action_cost_coeff = actuation_cost_coeff
        self.distance_cost_coeff = distance_cost_coeff
        self.xlim = (-7, 7)
        self.ylim = (-7, 7)
        self.vel_bound = 1.0
        self.reset(seed=seed)
        self.observation = None

        self.reward_range = (-float("inf"), float("inf"))
        self.metadata = {"render.modes": []}
        self.spec = None

        self._ax = None
        self._env_lines = []
        self.fixed_plots = None
        self.dynamic_plots = []

        super().__init__()

    def reset(self, seed=None):
        generator = np.random.default_rng(seed=seed) if seed is not None else np.random
        unclipped_observation = self.init_mu + self.init_sigma * generator.normal(size=self.dynamics.s_dim)
        o_lb, o_ub = self.observation_space.low, self.observation_space.high
        self.observation = np.clip(unclipped_observation, o_lb, o_ub)
        return self.observation.astype(np.float32), {}

    @property
    def observation_space(self):
        return Box(
            low=np.array((self.xlim[0], self.ylim[0])),
            high=np.array((self.xlim[1], self.ylim[1])),
            shape=None,
            dtype=np.float32,
        )

    @property
    def action_space(self):
        return Box(low=-self.vel_bound, high=self.vel_bound, shape=(self.dynamics.a_dim,), dtype=np.float32)

    def get_current_obs(self):
        return np.copy(self.observation).astype(np.float32)

    def step(self, action):
        action = action.ravel()

        a_lb, a_ub = self.action_space.low, self.action_space.high
        action = np.clip(action, a_lb, a_ub).ravel()

        next_obs = self.dynamics.forward(self.observation, action)
        o_lb, o_ub = self.observation_space.low, self.observation_space.high
        next_obs = np.clip(next_obs, o_lb, o_ub)

        self.observation = np.copy(next_obs)

        reward = self.compute_reward(self.observation, action)
        cur_position = self.observation
        dist_to_goal = np.amin([np.linalg.norm(cur_position - goal_position) for goal_position in self.goal_positions])
        done = dist_to_goal < self.goal_threshold
        if done:
            reward += self.goal_reward

        return next_obs.astype(np.float32), reward, done, done, {"pos": next_obs}

    def _init_plot(self):
        fig_env = plt.figure(figsize=(7, 7))
        self._ax = fig_env.add_subplot(111)
        self._ax.axis("equal")

        self._env_lines = []
        self._ax.set_xlim((-7, 7))
        self._ax.set_ylim((-7, 7))

        self._ax.set_title("Multigoal Environment")
        self._ax.set_xlabel("x")
        self._ax.set_ylabel("y")

        self._plot_position_cost(self._ax)

    def render(self, paths):
        if self._ax is None:
            self._init_plot()

        # noinspection PyArgumentList
        [line.remove() for line in self._env_lines]
        self._env_lines = []

        for path in paths:
            positions = np.stack([info["pos"] for info in path["env_infos"]])
            xx = positions[:, 0]
            yy = positions[:, 1]
            self._env_lines += self._ax.plot(xx, yy, "b")

        plt.draw()
        plt.pause(0.01)

    def compute_reward(self, observation, action):
        # penalize the L2 norm of acceleration
        # noinspection PyTypeChecker
        action_cost = np.sum(action**2) * self.action_cost_coeff

        # penalize squared dist to goal
        cur_position = observation
        # noinspection PyTypeChecker
        goal_cost = self.distance_cost_coeff * np.amin(
            [np.sum((cur_position - goal_position) ** 2) for goal_position in self.goal_positions]
        )

        # penalize staying with the log barriers
        costs = [action_cost, goal_cost]
        reward = -np.sum(costs)
        return reward

    def _plot_position_cost(self, ax):
        delta = 0.01
        x_min, x_max = tuple(1.1 * np.array(self.xlim))
        y_min, y_max = tuple(1.1 * np.array(self.ylim))
        X, Y = np.meshgrid(np.arange(x_min, x_max, delta), np.arange(y_min, y_max, delta))
        goal_costs = np.amin([(X - goal_x) ** 2 + (Y - goal_y) ** 2 for goal_x, goal_y in self.goal_positions], axis=0)
        costs = goal_costs

        contours = ax.contour(X, Y, costs, 20)
        ax.clabel(contours, inline=1, fontsize=10, fmt="%.0f")
        ax.set_xlim([x_min, x_max])
        ax.set_ylim([y_min, y_max])
        goal = ax.plot(self.goal_positions[:, 0], self.goal_positions[:, 1], "ro")
        return [contours, goal]

    def get_param_values(self):
        return None

    def close(self):
        return

    def set_param_values(self, params):
        pass

    def horizon(self):
        return None

    @property
    def unwrapped(self):
        return self


class PointDynamics(object):
    """
    State: position.
    Action: velocity.
    """

    def __init__(self, dim, sigma):
        self.dim = dim
        self.sigma = sigma
        self.s_dim = dim
        self.a_dim = dim

    def forward(self, state, action):
        mu_next = state + action
        state_next = mu_next + self.sigma * np.random.normal(size=self.s_dim)
        return state_next


if __name__ == "__main__":
    multi_goal_env = "MultiGoal-v0"
    register(
        id=multi_goal_env,
        entry_point="train_2d:MultiGoalEnv",
    )

    parser = argparse.ArgumentParser()
    parser.add_argument("--alg", type=str, default="dacer")
    parser.add_argument("--num_vec_envs", type=int, default=20)
    parser.add_argument("--hidden_num", type=int, default=3)
    parser.add_argument("--hidden_dim", type=int, default=256)
    parser.add_argument("--diffusion_steps", type=int, default=20)
    parser.add_argument("--diffusion_hidden_dim", type=int, default=256)
    parser.add_argument("--start_step", type=int, default=int(2e5))  # other envs 3e4
    parser.add_argument("--total_step", type=int, default=int(3e6))
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--seed", type=int, default=100)
    args = parser.parse_args()

    master_seed = args.seed
    master_rng, _ = seeding(master_seed)
    env_seed, env_action_seed, eval_env_seed, buffer_seed, init_network_seed, train_seed = map(
        int, master_rng.integers(0, 2**32 - 1, 6)
    )
    init_network_key = jax.random.key(init_network_seed)
    train_key = jax.random.key(train_seed)
    del init_network_seed, train_seed

    if args.num_vec_envs > 0:
        env, obs_dim, act_dim = create_vector_env(
            multi_goal_env, args.num_vec_envs, env_seed, env_action_seed, mode="futex"
        )
    else:
        env, obs_dim, act_dim = create_env(multi_goal_env, env_seed, env_action_seed)
    eval_env = None

    hidden_sizes = [args.hidden_dim] * args.hidden_num
    diffusion_hidden_sizes = [args.diffusion_hidden_dim] * args.hidden_num

    buffer = TreeBuffer.from_experience(obs_dim, act_dim, size=int(1e6), seed=buffer_seed)

    gelu = partial(jax.nn.gelu, approximate=False)

    def mish(x: jax.Array):
        return x * jnp.tanh(jax.nn.softplus(x))

    if args.alg == "qsm":
        agent, params = create_qsm_net(
            init_network_key, obs_dim, act_dim, hidden_sizes, num_timesteps=args.diffusion_steps, num_particles=64
        )
        algorithm = QSM(agent, params, lr=args.lr)
    elif args.alg == "ddsq":
        agent, params = create_ddsq_net(
            init_network_key,
            obs_dim,
            act_dim,
            args.diffusion_steps,
            with_reflect=True,
            hidden_sizes=hidden_sizes,
            activation=mish,
        )
        algorithm = DDSQ(agent, params, lr=args.lr)
    elif args.alg == "ddsq_init":
        agent, params = create_ddsq_net(
            init_network_key,
            obs_dim,
            act_dim,
            args.diffusion_steps,
            with_reflect=False,
            hidden_sizes=hidden_sizes,
            activation=mish,
        )
        algorithm = DDSQ(agent, params, lr=args.lr)
    elif args.alg == "ddsq_clip":
        agent, params = create_ddsq_net(
            init_network_key,
            obs_dim,
            act_dim,
            args.diffusion_steps,
            with_reflect=False,
            hidden_sizes=hidden_sizes,
            activation=mish,
        )
        algorithm = DDSQ(agent, params, lr=args.lr)
    elif args.alg == "sac":
        agent, params = create_sac_net(init_network_key, obs_dim, act_dim, hidden_sizes, gelu)
        algorithm = SAC(agent, params, lr=args.lr)
    elif args.alg == "dsact":
        agent, params = create_dsact_net(init_network_key, obs_dim, act_dim, hidden_sizes, gelu)
        algorithm = DSACT(agent, params, lr=args.lr)
    elif args.alg == "dacer":

        agent, params = create_dacer_net(
            init_network_key,
            obs_dim,
            act_dim,
            hidden_sizes,
            diffusion_hidden_sizes,
            mish,
            num_timesteps=args.diffusion_steps,
        )
        algorithm = DACER(agent, params, lr=args.lr)
    elif args.alg == "dipo":
        diffusion_buffer = TreeBuffer.from_example(
            ObsActionPair.create_example(obs_dim, act_dim),
            args.total_step,
            int(master_rng.integers(0, 2**32 - 1)),
            remove_batch_dim=False,
        )
        TreeBuffer.connect(buffer, diffusion_buffer, lambda exp: ObsActionPair(exp.obs, exp.action))

        def mish(x: jax.Array):
            return x * jnp.tanh(jax.nn.softplus(x))

        agent, params = create_dipo_net(init_network_key, obs_dim, act_dim, hidden_sizes, num_timesteps=100)
        algorithm = DIPO(
            agent,
            params,
            diffusion_buffer,
            lr=args.lr,
            action_gradient_steps=30,
            policy_target_delay=2,
            action_grad_norm=0.16,
        )
    elif args.alg == "qvpo":

        def mish(x: jax.Array):
            return x * jnp.tanh(jax.nn.softplus(x))

        agent, params = create_qvpo_net(
            init_network_key,
            obs_dim,
            act_dim,
            hidden_sizes,
            diffusion_hidden_sizes,
            mish,
            num_timesteps=args.diffusion_steps,
            num_particles=4,
            noise_scale=0.05,
        )
        algorithm = QVPO(agent, params, lr=args.lr, alpha_lr=7e-3, delay_alpha_update=250)
    else:
        raise ValueError(f"Invalid algorithm {args.alg}!")

    trainer = OffPolicyTrainer(
        env=env,
        alg_name=args.alg,
        algorithm=algorithm,
        buffer=buffer,
        start_step=args.start_step,
        total_step=args.total_step,
        sample_per_iteration=1,
        evaluate_env=eval_env,
        save_policy_every=30000,
        warmup_with="random",
        log_path=PROJECT_ROOT
        / "logs"
        / multi_goal_env
        / (args.alg + "_" + time.strftime("%Y-%m-%d_%H-%M-%S") + f"_s{args.seed}"),
    )

    trainer.setup(Experience.create_example(obs_dim, act_dim, trainer.batch_size))
    trainer.run(train_key)
