import functools as ft
import einops as ei
import pathlib
from abc import ABC, abstractmethod, abstractproperty
from typing import Callable, Literal, NamedTuple, Optional, Tuple

import cv2
import equinox as eqx
import ipdb
import jax
import jax.lax as lax
import jax.numpy as jnp
import mediapy as media
import mujoco
import numpy as np
import tqdm
from matplotlib.colors import to_rgb, to_rgba
from mujoco import mjx

from cmarl.trainer.data import Rollout
from cmarl.trainer.utils import rollout
from cmarl.utils.graph import EdgeBlock, GetGraph, GraphsTuple
from cmarl.utils.typing import Action, Array, Cost, Done, Info, PRNGKey, Reward, State
from cmarl.utils.utils import jax2np, jax_jit_np, tree_concat_at_front, tree_index, tree_stack

from ..base import MultiAgentEnv, StepResult
from ..utils import get_node_goal_rng


class MuJoCoEnvState(NamedTuple):
    qpos: Array
    qvel: Array
    a_incontact: Array


class ReverseTransport(MultiAgentEnv):
    AGENT = 0
    GOAL = 1
    BOX = 2

    PARAMS = {
        "comm_radius": 0.4,
        "default_area_size": 0.8,
        "dist2goal": 0.01,
        "agent_radius": 0.03,
    }

    def __init__(
        self,
        num_agents: int,
        area_size: Optional[float] = None,
        max_step: int = 64,
        max_travel: Optional[float] = None,
        dt: float = 0.03,
        params: dict = None,
    ):
        assert num_agents == 3, "ReverseTransport only supports 3 agents."
        assert area_size == 0.8 or area_size is None, "ReverseTransport only supports area_size=0.8."
        assert dt == 0.03, "ReverseTransport only supports dt=0.03."
        super(ReverseTransport, self).__init__(3, area_size, max_step, max_travel, dt, params)

        xml_path = pathlib.Path(__file__).parent / "assets" / "reverse_transport.xml"
        self.model_mj: mujoco.MjModel = mujoco.MjModel.from_xml_path(filename=str(xml_path.absolute()))


        self.model_mjx = mjx.put_model(self.model_mj)

        self.frame_skip = 5
        self.dt_s = self.model_mj.opt.timestep * self.frame_skip

        self.camera = self.model_mj.cam("top").name

        mjx_data = mjx.make_data(self.model_mj)
        mjx_data = jax.jit(mjx.step)(self.model_mjx, mjx_data)
        n_contact = len(mjx_data.contact.geom1)
        agent1_geom_id = 11
        agent2_geom_id = 12
        agent3_geom_id = 13
        self.agent1_contact_idxs = np.array(
            [ii for ii in range(n_contact) if mjx_data.contact.geom1[ii] == agent1_geom_id]
        )
        self.agent2_contact_idxs = np.array(
            [ii for ii in range(n_contact) if mjx_data.contact.geom1[ii] == agent2_geom_id]
        )
        self.agent3_contact_idxs = np.array(
            [ii for ii in range(n_contact) if mjx_data.contact.geom1[ii] == agent3_geom_id]
        )

    @property
    def state_dim(self) -> int:
        return 4

    @property
    def node_dim(self) -> int:
        return 11

    @property
    def edge_dim(self) -> int:
        return 4

    @property
    def action_dim(self) -> int:
        return 2

    @property
    def reward_min(self) -> float:
        return -1

    @property
    def reward_max(self) -> float:
        return 0.5

    @property
    def n_cost(self) -> int:
        return 1

    @property
    def cost_min(self) -> float:
        return -1.0

    @property
    def cost_max(self) -> float:
        return 1.0

    @property
    def cost_components(self) -> Tuple[str, ...]:
        return ("agent collisions",)

    def reset(self, key: Array) -> GraphsTuple:
        box_key, agent_key, goal_key, key = jax.random.split(key, 4)

        box_pos = jax.random.uniform(box_key, (2,), minval=jnp.array([-0.22, -0.22]), maxval=jnp.array([0.22, 0.22]))

        goal_pos = jax.random.uniform(goal_key, (2,), minval=jnp.array([-0.22, -0.22]), maxval=jnp.array([0.22, 0.22]))

        agent_pos, _ = get_node_goal_rng(
            agent_key,
            0.1,
            2,
            self.num_agents,
            2 * self.params["agent_radius"],
            None,
            self.max_travel,
        )
        agent_pos = agent_pos - 0.05 + box_pos

        qpos0 = jnp.concatenate([box_pos, goal_pos, agent_pos.flatten()])
        qvel0 = jnp.zeros_like(qpos0)

        action0 = np.zeros(6)
        mjx_data = mjx.make_data(self.model_mj)
        mjx_data = mjx_data.replace(qpos=qpos0, qvel=qvel0, ctrl=action0)
        a_incontact = self.get_a_incontact(mjx_data)

        env_state = MuJoCoEnvState(qpos=qpos0, qvel=qvel0, a_incontact=a_incontact)

        return self.get_graph(env_state)

    def get_a_incontact(self, mjx_data: mjx.Data):
        qpos = mjx_data.qpos
        nagent = self.num_agents
        box_pos = qpos[:2]
        a_pos = qpos[4:].reshape(nagent, 2)

        a_pos = a_pos - box_pos

        eps = 1e-3
        length = 0.7 - eps

        a_incontact = jnp.any(jnp.abs(a_pos) > length, axis=1)
        return a_incontact

    def step(
        self, graph: GraphsTuple, action: Action, get_eval_info: bool = False
    ) -> Tuple[GraphsTuple, Reward, Cost, Done, Info]:
        qpos = graph.env_states.qpos
        qvel = graph.env_states.qvel

        action = self.clip_action(action).flatten() / 3.0

        mjx_data = mjx.make_data(self.model_mj)
        mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel, ctrl=action)
        mjx_data_new = mjx.step(self.model_mjx, mjx_data)
        mjx_data_new = eqx.error_if(mjx_data_new, ~jnp.isfinite(mjx_data_new.qpos).all(), "[0] qpos not finite")
        mjx_data_new = mjx.step(self.model_mjx, mjx_data_new)
        mjx_data_new = eqx.error_if(mjx_data_new, ~jnp.isfinite(mjx_data_new.qpos).all(), "[1] qpos not finite")
        mjx_data_new = mjx.step(self.model_mjx, mjx_data_new)
        mjx_data_new = eqx.error_if(mjx_data_new, ~jnp.isfinite(mjx_data_new.qpos).all(), "[2] qpos not finite")
        mjx_data_new = mjx.step(self.model_mjx, mjx_data_new)
        mjx_data_new = eqx.error_if(mjx_data_new, ~jnp.isfinite(mjx_data_new.qpos).all(), "[3] qpos not finite")


        a_incontact = self.get_a_incontact(mjx_data_new)
        next_env_state = MuJoCoEnvState(mjx_data_new.qpos, mjx_data_new.qvel, a_incontact)
        info = {}

        done = jnp.array(False)

        reward = self.get_reward(graph, action)
        cost = self.get_cost(graph)

        next_graph = self.get_graph(next_env_state)
        return next_graph, reward, cost, done, info

    def get_reward(self, graph: GraphsTuple, action: Action) -> Reward:
        box_pos = graph.env_states.qpos[:2]
        goal_pos = graph.env_states.qpos[2:4]

        dist2goal = jnp.linalg.norm(goal_pos - box_pos, axis=-1)
        reward = -dist2goal.mean() * 0.01

        reward -= jnp.where(dist2goal > self._params["dist2goal"], 1.0, 0.0).mean() * 0.001


        return reward


    def get_cost(self, graph: GraphsTuple) -> Cost:
        env_state: MuJoCoEnvState = graph.env_states
        agent_pos = env_state.qpos[4:].reshape(self.num_agents, 2)

        dist = jnp.linalg.norm(jnp.expand_dims(agent_pos, 1) - jnp.expand_dims(agent_pos, 0), axis=-1)
        dist += jnp.eye(self.num_agents) * 1e6
        min_dist = jnp.min(dist, axis=1)
        cost: Array = self.params["agent_radius"] * 2 - min_dist
        cost = cost[:, None]

        eps = 0.5
        cost = jnp.where(cost <= 0.0, cost - eps, cost + eps)
        cost = jnp.clip(cost, a_min=-1.0)
        assert cost.shape == (self.num_agents, self.n_cost)

        return cost

    def render_video(
        self,
        rollout: Rollout,
        video_path: pathlib.Path,
        Ta_is_unsafe=None,
        viz_opts: dict = None,
        n_goal: int = None,
        **kwargs,
    ) -> None:
        """Save video to video_path."""
        T_graph = rollout.graph
        anim_T = len(T_graph.n_node)
        renderer = mujoco.Renderer(self.model_mj, height=1080, width=1920)

        T_actions = np.array(rollout.actions)

        T_frame = []
        d = mujoco.MjData(self.model_mj)
        for kk in range(anim_T):
            graph = tree_index(T_graph, kk)
            a_actions = T_actions[kk]
            assert a_actions.shape == (3, 2)
            env_state = graph.env_states
            d.qpos[:] = env_state.qpos
            d.qvel[:] = env_state.qvel
            mujoco.mj_forward(self.model_mj, d)
            renderer.update_scene(d, camera=self.camera)
            scene = renderer.scene

            colors = [to_rgba("C2"), to_rgba("C1"), to_rgba("C5")]
            a_pos = env_state.qpos[4:].reshape(3, 2)
            for aa in range(self.num_agents):
                arr_color = np.array(colors[aa])
                arr_width = 0.005
                arr_len = 0.1
                pos = np.array([*a_pos[aa], 0.05])
                action_vec = np.array([*a_actions[aa], 0.0])
                add_connector_geom(
                    scene, arr_color, mujoco.mjtGeom.mjGEOM_ARROW, arr_width, pos, pos + arr_len * action_vec
                )

            center_color = np.array(to_rgba("C0"))
            size = np.full(3, 0.01)
            pos = np.array([*env_state.qpos[:2], 0.05])
            mat = np.eye(3)
            add_geom(scene, center_color, mujoco.mjtGeom.mjGEOM_SPHERE, size, pos, mat.flatten())



            img = renderer.render()

            rew = rollout.rewards[kk]
            cost_text = f"{kk:4} | {rew:.3f}"
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.5
            thickness = 1
            (label_width, label_height), baseline = cv2.getTextSize(cost_text, font, font_scale, thickness)
            img = cv2.putText(
                img,
                cost_text,
                (5, label_height + baseline + 5),
                font,
                font_scale,
                (0, 0, 0),
                thickness,
                cv2.LINE_AA,
            )

            T_frame.append(img)

        h, w, c = T_frame[0].shape
        dtype = T_frame[0].dtype
        codec = "h264"
        fps = 30

        with media.VideoWriter(video_path, shape=(h, w), dtype=dtype, codec=codec, fps=fps) as writer:
            for frame in T_frame:
                writer.add_image(frame)

    def get_graph(self, env_state: MuJoCoEnvState) -> GraphsTuple:
        state = env_state
        qpos = state.qpos
        qvel = state.qvel

        nagent = self.num_agents
        agent_pos = qpos[4:].reshape(nagent, 2)
        agent_vel = qvel[4:].reshape(nagent, 2)

        goal_pos = qpos[2:4]

        box_pos = qpos[:2]
        box_vel = qvel[:2]

        rel_goal_pos = goal_pos - box_pos

        node_feats = jnp.zeros((self.num_agents, self.node_dim))
        node_feats = node_feats.at[:, :2].set(agent_pos[:, :2])
        node_feats = node_feats.at[:, 2:4].set(agent_vel[:, :2])
        node_feats = node_feats.at[:, 4:6].set(box_pos)
        node_feats = node_feats.at[:, 6:8].set(box_vel)
        node_feats = node_feats.at[:, 8:10].set(rel_goal_pos)
        node_feats = node_feats.at[:, 10].set(state.a_incontact)

        node_type = jnp.full(self.num_agents, ReverseTransport.AGENT)
        edge_blocks = self.edge_blocks(env_state)

        n_state_vec = np.zeros((self.num_agents, 0))
        return GetGraph(node_feats, node_type, edge_blocks, env_state, n_state_vec).to_padded()

    def edge_blocks(self, env_state: MuJoCoEnvState) -> list[EdgeBlock]:
        state = env_state
        qpos = state.qpos
        qvel = state.qvel

        nagent = self.num_agents
        agent_pos = qpos[4:].reshape(nagent, 2)
        agent_vel = qvel[4:].reshape(nagent, 2)
        agent_states = jnp.concatenate([agent_pos[:, :2], agent_vel[:, :2]], axis=-1)

        state_diff = agent_states[:, None, :] - agent_states[None, :, :]
        agent_agent_mask = jnp.eye(nagent) == 0
        id_agent = jnp.arange(self.num_agents)
        agent_agent_edges = EdgeBlock(state_diff, agent_agent_mask, id_agent, id_agent)

        return [agent_agent_edges]

    def state_lim(self, state: Optional[State] = None) -> Tuple[State, State]:
        pass

    def action_lim(self) -> Tuple[Action, Action]:
        lower_lim = jnp.ones(2) * -1.0
        upper_lim = jnp.ones(2)
        return lower_lim, upper_lim

    @ft.partial(jax.jit, static_argnums=(0,))
    def unsafe_mask(self, graph: GraphsTuple) -> Array:
        cost = self.get_cost(graph)
        return jnp.any(cost >= 0.0, axis=-1)


GeomType = Literal[
    mujoco.mjtGeom.mjGEOM_ARROW,
    mujoco.mjtGeom.mjGEOM_ARROW1,
    mujoco.mjtGeom.mjGEOM_ARROW2,
    mujoco.mjtGeom.mjGEOM_BOX,
    mujoco.mjtGeom.mjGEOM_CAPSULE,
    mujoco.mjtGeom.mjGEOM_CYLINDER,
    mujoco.mjtGeom.mjGEOM_ELLIPSOID,
    mujoco.mjtGeom.mjGEOM_FLEX,
    mujoco.mjtGeom.mjGEOM_HFIELD,
    mujoco.mjtGeom.mjGEOM_LABEL,
    mujoco.mjtGeom.mjGEOM_LINE,
    mujoco.mjtGeom.mjGEOM_LINEBOX,
    mujoco.mjtGeom.mjGEOM_MESH,
    mujoco.mjtGeom.mjGEOM_SPHERE,
    mujoco.mjtGeom.mjGEOM_TRIANGLE,
]


def add_geom(
    scene: mujoco.MjvScene, rgba: np.ndarray, geom_type: GeomType, size: np.array, pos: np.array, mat: np.array
):
    if scene.ngeom >= scene.maxgeom:
        return
    scene.ngeom += 1

    mujoco.mjv_initGeom(
        scene.geoms[scene.ngeom - 1],
        geom_type,
        size,
        pos,
        mat,
        rgba.astype(np.float32),
    )
    return scene.geoms[scene.ngeom - 1]


def add_connector_geom(
    scene: mujoco.MjvScene, rgba: np.ndarray, geom_type: GeomType, width: float, pt_from: np.ndarray, pt_to: np.ndarray
):
    if scene.ngeom >= scene.maxgeom:
        return
    scene.ngeom += 1

    size = np.zeros(3)
    pos = np.zeros(3)
    mat = np.zeros(9)
    mujoco.mjv_initGeom(
        scene.geoms[scene.ngeom - 1],
        geom_type,
        size,
        pos,
        mat,
        rgba.astype(np.float32),
    )
    mujoco.mjv_connector(scene.geoms[scene.ngeom - 1], geom_type, width, pt_from, pt_to)
