import functools as ft
import pathlib
from abc import ABC, abstractmethod, abstractproperty
from typing import Callable, NamedTuple, Optional, Tuple

import cv2
import ipdb
import jax
import jax.lax as lax
import jax.numpy as jnp
import mediapy as media
import mujoco
import numpy as np
import tqdm
from mujoco import mjx

from cmarl.trainer.data import Rollout
from cmarl.trainer.utils import rollout
from cmarl.utils.graph import EdgeBlock, GetGraph, GraphsTuple
from cmarl.utils.typing import Action, Array, Cost, Done, Info, PRNGKey, Reward, State
from cmarl.utils.utils import jax2np, jax_jit_np, tree_concat_at_front, tree_index, tree_stack

from ..base import MultiAgentEnv, StepResult
from ..utils import get_node_goal_rng
from .reverse_transport import MuJoCoEnvState


class ReverseTransportCPU:
    def __init__(self):
        xml_path = pathlib.Path(__file__).parent / "assets" / "reverse_transport.xml"
        self.model_mj: mujoco.MjModel = mujoco.MjModel.from_xml_path(filename=str(xml_path.absolute()))
        self.data_mj = mujoco.MjData(self.model_mj)

    def set_initial_state(self, qpos, qvel):
        self.data_mj.qpos[:] = qpos
        self.data_mj.qvel[:] = qvel

    def step(self, action: Action):
        action = action.clip(-1, 1).flatten() / 3.0
        self.data_mj.ctrl[:] = action

        for kk in range(4):
            mujoco.mj_step(self.model_mj, self.data_mj)

        agent1_geom_id = 11
        agent2_geom_id = 12
        agent3_geom_id = 13
        thresh = 4e-3

        agent1_incontact, agent2_incontact, agent3_incontact = False, False, False
        for contact in self.data_mj.contact:
            if contact.geom1 == agent1_geom_id and contact.dist <= thresh:
                agent1_incontact = True
            if contact.geom1 == agent2_geom_id and contact.dist <= thresh:
                agent2_incontact = True
            if contact.geom1 == agent3_geom_id and contact.dist <= thresh:
                agent3_incontact = True

        a_incontact = np.array([agent1_incontact, agent2_incontact, agent3_incontact])

        next_env_state = MuJoCoEnvState(self.data_mj.qpos, self.data_mj.qvel, a_incontact)
        return next_env_state
