import gym
from gym import spaces
import numpy as np
from os import path
from scipy.stats import triang
import jax.numpy as jnp
from functools import partial
import jax
import matplotlib.pyplot as plt
import os
from rsm_utils import (
    triangular,
    make_unsafe_spaces,
    contained_in_any,
    clip_and_filter_spaces,
)
import seaborn as sns


class RoomMetaEnv:
    name = "rooms"

    def __init__(self):
        self.edges = [
            (0, 1),
            (0, 3),
            (0, 6),
            (1, 7),
            (6, 7),
            (6, 8),
            (4, 5),
            (5, 8),
            (7, 8),
            (0, 4),
            (4, 8),
            (0, 8),
        ]

    @property
    def num_nodes(self):
        return 9

    @property
    def num_edges(self):
        return len(self.edges)

    def get_subtask(self, a, b):
        env = RoomEnv(a, b)
        return env


class RoomEnv(gym.Env):
    def __init__(self, a, b, init_radius=0.1, noise_radius=0.1):
        self.steps = None
        self.state = None
        self.name = f"rooms{a}_{b}"
        assert 0 <= a <= 8
        assert 0 <= b <= 8
        self.a = a
        self.b = b
        ax, ay = a // 3, a % 3
        bx, by = b // 3, b % 3
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.observation_space = spaces.Box(
            low=np.array([min(ax, bx), min(ay, by)]),
            high=np.array([max(ax, bx) + 1, max(ay, by) + 1]),
            dtype=np.float32,
        )

        # very small for debugging
        # self.noise = np.array([0.001, 0.001])  # low noise
        # self.noise = np.array([0.02, 0.02])  # mid noise
        # self.noise = np.array([0.1, 0.1])  # was 0.02 before
        self.noise = np.array([noise_radius, noise_radius])  # was 0.02 before

        target_radius = init_radius
        self.target_spaces = [
            spaces.Box(
                low=np.array([bx + 0.5 - target_radius, by + 0.5 - target_radius]),
                high=np.array([bx + 0.5 + target_radius, by + 0.5 + target_radius]),
                dtype=np.float32,
            )
        ]
        init_radius = init_radius
        self.init_spaces = [
            spaces.Box(
                low=np.array([ax + 0.5 - init_radius, ay + 0.5 - init_radius]),
                high=np.array([ax + 0.5 + init_radius, ay + 0.5 + init_radius]),
                dtype=np.float32,
            )
        ]
        init_radius_train = 0.4
        self.init_spaces_train = [self.observation_space]
        # self.init_spaces_train = [
        #     spaces.Box(
        #         low=np.array(
        #             [ax + 0.5 - init_radius_train, ay + 0.5 - init_radius_train]
        #         ),
        #         high=np.array(
        #             [ax + 0.5 + init_radius_train, ay + 0.5 + init_radius_train]
        #         ),
        #         dtype=np.float32,
        #     ),
        #     spaces.Box(
        #         low=np.array(
        #             [bx + 0.5 - init_radius_train, by + 0.5 - init_radius_train]
        #         ),
        #         high=np.array(
        #             [bx + 0.5 + init_radius_train, by + 0.5 + init_radius_train]
        #         ),
        #         dtype=np.float32,
        #     ),
        # ]

        self._v_lines = [
            (3, [(0, 3)]),
            # (2, [(0, 1.2), (1.8, 2.2), (2.8, 3)]),
            # (1, [(0, 0.2), (0.8, 2.2), (1.8, 2.2), (2.8, 3)]),
            (2, [(0, 1.1), (1.9, 2.1), (2.9, 3)]),
            (1, [(0, 0.1), (0.9, 2.1), (1.9, 2.1), (2.9, 3)]),
            # (2, [(0, 0.2), (0.8, 1.2), (1.8, 2.2), (2.8, 3)]),
            # (1, [(0, 0.2), (0.8, 1.2), (1.8, 2.2), (2.8, 3)]),
            (0, [(0, 3)]),
        ]
        self._h_lines = [
            (3, [(0, 3)]),
            # (2, [(0, 0.2), (0.8, 1.2), (1.8, 2.2), (2.8, 3)]),
            # (1, [(0, 0.2), (0.8, 1.2), (1.8, 2.2), (2.8, 3)]),
            (2, [(0, 0.1), (0.9, 1.1), (1.9, 2.1), (2.9, 3)]),
            (1, [(0, 0.1), (0.9, 1.1), (1.9, 2.1), (2.9, 3)]),
            # (2, [(0, 0.2), (0.8, 1.2), (1.8, 2.2), (2.8, 3)]),
            # (1, [(0, 0.2), (0.8, 1.2), (1.8, 2.2), (2.8, 3)]),
            (0, [(0, 3)]),
        ]
        lw = 0.05
        self.unsafe_spaces = []
        for y, xlines in self._v_lines:
            for x1, x2 in xlines:
                self.unsafe_spaces.append(
                    spaces.Box(
                        low=np.array([x1 - lw, y - lw]),
                        high=np.array([x2 + lw, y + lw]),
                        dtype=np.float32,
                    )
                )
        for x, ylines in self._h_lines:
            for y1, y2 in ylines:
                self.unsafe_spaces.append(
                    spaces.Box(
                        low=np.array([x - lw, y1 - lw]),
                        high=np.array([x + lw, y2 + lw]),
                        dtype=np.float32,
                    )
                )
        self.unsafe_spaces = clip_and_filter_spaces(
            self.observation_space, self.unsafe_spaces
        )
        self._jax_rng = jax.random.PRNGKey(777)
        self.v_next = jax.vmap(self.next, in_axes=(0, 0), out_axes=0)
        self.v_step = jax.jit(jax.vmap(self.step))
        self.v_reset = jax.jit(jax.vmap(self.reset))
        self.v_reset_init = jax.jit(jax.vmap(self.reset_init))

    @property
    def noise_bounds(self):
        return -self.noise, self.noise

    @partial(jax.jit, static_argnums=(0,))
    def next(self, state, action):
        action = jnp.clip(action, -1, 1)

        next_state = state + action * 0.1
        next_state = jnp.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )
        for interp in [0, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
            for unsafe_space in self.unsafe_spaces:
                interp_state = state * interp + (1 - interp) * next_state
                contain = jnp.all(
                    jnp.logical_and(
                        interp_state >= unsafe_space.low,
                        interp_state <= unsafe_space.high,
                    )
                )
                action = action * (1.0 - jnp.float32(contain))
        next_state = state + action * 0.1
        next_state = jnp.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        state = next_state
        return state

    def add_noise(self, state):
        self._jax_rng, rng = jax.random.split(self._jax_rng, 2)
        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        next_state = state + noise
        for interp in [0, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
            for unsafe_space in self.unsafe_spaces:
                interp_state = state * interp + (1 - interp) * next_state
                contain = jnp.all(
                    jnp.logical_and(
                        interp_state >= unsafe_space.low,
                        interp_state <= unsafe_space.high,
                    )
                )
                noise = noise * (1.0 - jnp.float32(contain))

        next_state = state + noise
        return next_state

    @partial(jax.jit, static_argnums=(0,))
    def step(self, state, action, rng):
        step = state[0]
        state = state[1:3]
        next_state = self.next(state, action)

        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        # next_state = next_state + noise
        state = next_state
        next_state = state + noise
        for interp in [0, 0.1, 0.2, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
            for unsafe_space in self.unsafe_spaces:
                interp_state = state * interp + (1 - interp) * next_state
                contain = jnp.all(
                    jnp.logical_and(
                        interp_state >= unsafe_space.low,
                        interp_state <= unsafe_space.high,
                    )
                )
                noise = noise * (1.0 - jnp.float32(contain))

        next_state = state + noise

        next_state = jnp.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        next_step = step + 1

        reward = 0
        for unsafe in self.unsafe_spaces:
            # close to the border
            contain = jnp.all(
                jnp.logical_and(state >= unsafe.low - 0.1, state <= unsafe.high + 0.1)
            )
            reward += -0.2 * jnp.float32(contain)
            # next_step += 200 * jnp.int32(contain)
        for target in self.target_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= target.low, state <= target.high)
            )
            center = 0.5 * (target.low + target.high)
            dist = jnp.sum(jnp.abs(center - next_state))
            dist = jnp.clip(dist, 0, 2)
            reward += 0.1 * (2.0 - dist)
            reward += 0.4 * jnp.float32(contain)

        done = next_step > 200
        next_step = jnp.minimum(next_step, 200)
        next_packed = jnp.array([next_step, next_state[0], next_state[1]])
        return next_packed, next_state, reward, done

    @property
    def observation_dim(self):
        return self.observation_space.shape[0]

    @property
    def action_dim(self):
        return self.action_space.shape[0]

    @property
    def lipschitz_constant(self):
        return 0.5

    @property
    def lipschitz_constant_linf(self):
        return 0.5

    @property
    def delta(self):
        return 0.1 + self.noise[0]

    def integrate_noise(self, a: list, b: list):
        dims = 2
        pmass = np.ones(a[0].shape[0])
        for i in range(dims):
            loc = self.noise_bounds[0][i]
            scale = self.noise_bounds[1][i] - self.noise_bounds[0][i]
            marginal_pmass = triang.cdf(b[i], c=0.5, loc=loc, scale=scale) - triang.cdf(
                a[i], c=0.5, loc=loc, scale=scale
            )
            pmass *= marginal_pmass
        return pmass

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, rng):
        lowers = jnp.stack([init.low for init in self.init_spaces_train], 0)
        high = jnp.stack([init.high for init in self.init_spaces_train], 0)
        rng1, rng2 = jax.random.split(rng, 2)
        index = jax.random.randint(
            rng1, shape=(), minval=0, maxval=len(self.init_spaces_train)
        )
        obs = jax.random.uniform(
            rng2, shape=(lowers.shape[1],), minval=lowers[index], maxval=high[index]
        )
        state = jnp.array([0, obs[0], obs[1]])
        return state, obs

    @partial(jax.jit, static_argnums=(0,))
    def reset_init(self, rng):
        lowers = jnp.stack([init.low for init in self.init_spaces], 0)
        high = jnp.stack([init.high for init in self.init_spaces], 0)
        rng1, rng2 = jax.random.split(rng, 2)
        index = jax.random.randint(
            rng1, shape=(), minval=0, maxval=len(self.init_spaces)
        )
        obs = jax.random.uniform(
            rng2, shape=(lowers.shape[1],), minval=lowers[index], maxval=high[index]
        )
        state = jnp.array([0, obs[0], obs[1]])
        return state, obs

    def plot(self, filename):
        sns.set()
        fig, ax = plt.subplots(figsize=(6, 6))
        for init in self.init_spaces:
            x = [
                init.low[0],
                init.high[0],
                init.high[0],
                init.low[0],
                init.low[0],
            ]
            y = [
                init.low[1],
                init.low[1],
                init.high[1],
                init.high[1],
                init.low[1],
            ]
            ax.fill(x, y, color="blue", alpha=0.5, zorder=7)
        for unsafe in self.unsafe_spaces:
            x = [
                unsafe.low[0],
                unsafe.high[0],
                unsafe.high[0],
                unsafe.low[0],
                unsafe.low[0],
            ]
            y = [
                unsafe.low[1],
                unsafe.low[1],
                unsafe.high[1],
                unsafe.high[1],
                unsafe.low[1],
            ]
            ax.fill(x, y, color="red", alpha=0.5, zorder=7)
        for target_space in self.target_spaces:
            x = [
                target_space.low[0],
                target_space.high[0],
                target_space.high[0],
                target_space.low[0],
                target_space.low[0],
            ]
            y = [
                target_space.low[1],
                target_space.low[1],
                target_space.high[1],
                target_space.high[1],
                target_space.low[1],
            ]
            ax.fill(x, y, color="green", alpha=0.5, zorder=7)

        x = [
            self.observation_space.low[0],
            self.observation_space.high[0],
            self.observation_space.high[0],
            self.observation_space.low[0],
            self.observation_space.low[0],
        ]
        y = [
            self.observation_space.low[1],
            self.observation_space.low[1],
            self.observation_space.high[1],
            self.observation_space.high[1],
            self.observation_space.low[1],
        ]
        ax.plot(x, y, color="black", alpha=0.9, zorder=1)
        for i in range(9):
            x, y = i // 3, i % 3
            ax.text(x + 0.45, y + 0.45, i)
        # ax.set_xlim(
        #     [self.env.observation_space.low[0], self.env.observation_space.high[0]]
        # )
        # ax.set_ylim(
        #     [self.env.observation_space.low[1], self.env.observation_space.high[1]]
        # )
        ax.set_title(f"({self.a}) -> ({self.b})")
        fig.tight_layout()
        fig.savefig(filename)
        plt.close(fig)

    def plot_fancy(self, filename):
        sns.set()
        sns.set_style("whitegrid")
        fig, ax = plt.subplots(figsize=(6, 6))
        for init in self.init_spaces:
            x = [
                init.low[0],
                init.high[0],
                init.high[0],
                init.low[0],
                init.low[0],
            ]
            y = [
                init.low[1],
                init.low[1],
                init.high[1],
                init.high[1],
                init.low[1],
            ]
            ax.fill(x, y, color=sns.color_palette()[2], alpha=0.8, zorder=7)
        for unsafe in self.unsafe_spaces:
            x = [
                unsafe.low[0],
                unsafe.high[0],
                unsafe.high[0],
                unsafe.low[0],
                unsafe.low[0],
            ]
            y = [
                unsafe.low[1],
                unsafe.low[1],
                unsafe.high[1],
                unsafe.high[1],
                unsafe.low[1],
            ]
            ax.fill(x, y, color=sns.color_palette()[3], alpha=1, zorder=7)
        for target_space in self.target_spaces:
            x = [
                target_space.low[0],
                target_space.high[0],
                target_space.high[0],
                target_space.low[0],
                target_space.low[0],
            ]
            y = [
                target_space.low[1],
                target_space.low[1],
                target_space.high[1],
                target_space.high[1],
                target_space.low[1],
            ]
            ax.fill(x, y, color=sns.color_palette()[4], alpha=0.8, zorder=7)

        # x = [
        #     self.observation_space.low[0],
        #     self.observation_space.high[0],
        #     self.observation_space.high[0],
        #     self.observation_space.low[0],
        #     self.observation_space.low[0],
        # ]
        # y = [
        #     self.observation_space.low[1],
        #     self.observation_space.low[1],$
        #     self.observation_space.high[1],
        #     self.observation_space.high[1],
        #     self.observation_space.low[1],
        # ]
        # ax.plot(x, y, color="black", alpha=0.9, zorder=1)
        # for i in range(9):
        #     x, y = i // 3, i % 3
        #     ax.text(x + 0.45, y + 0.45, i)
        # ax.set_xlim(
        #     [self.env.observation_space.low[0], self.env.observation_space.high[0]]
        # )
        # ax.set_ylim(
        #     [self.env.observation_space.low[1], self.env.observation_space.high[1]]
        # )
        ax.set_xticks([])
        ax.set_yticks([])
        sns.despine()
        # ax.set_title(f"({self.a}) -> ({self.b})")
        plt.axis("off")

        if filename is None:
            return fig, ax
        plt.savefig(
            "test.png",
        )
        fig.savefig(filename, bbox_inches="tight")
        plt.close(fig)


if __name__ == "__main__":
    meta_env = RoomMetaEnv()
    os.makedirs("plots_comp", exist_ok=True)
    for i in range(meta_env.num_edges):
        a, b = meta_env.edges[i]
        env = meta_env.get_subtask(a, b)
        env.plot(f"plots_comp/subtask_{i:02d}.png")

    env = RoomEnv(0, 8)
    env.plot_fancy(f"plots_comp/task.png")
    env0 = RoomEnv(0, 1)
    env1 = RoomEnv(1, 7)
    env2 = RoomEnv(7, 8)
