import copy
from PIL import Image
from functools import partial
import json
import logging
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

import open3d
import yaml
from furniture_bench.envs.furniture_sim_env import ASSET_ROOT

# import uuid

# isort: off
import cv2
import furniture_bench
from furniture_bench.sim_config import sim_config
import torch

from furniture_bench_api.EnvState import EnvState
from furniture_bench_api.EnvStates import EnvStates
from furniture_bench_api.api.api_schema import Point3d, RigidBody, RobotArm, StateModel, Table
from furniture_bench_api.utils.image_utils import get_image_with_labels
from furniture_bench_api.utils.other_utils import image_to_base64


# isort: on

print(furniture_bench.__name__)
import gymnasium as gym
from isaacgym import gymtorch
from python_utils.transformations import absolute_to_relative, affine_transform, pose_to_affine, quaternion_to_euler_zyx

from furniture_bench_api.utils.pose_utils import (
    get_part_poses_from_obs,
    get_pose_from_obs,
    normalize_pose_to_z,
    relative_to_base,
    transform_pose_in_local_coords,
    transform_pose_in_world_coords,
)

logger = logging.getLogger(__name__)


class FurnitureBenchEnvironment:

    def get_transformed_pose(self, part: str, pose: str) -> torch.Tensor:
        part_name = part.split(f"{self.furniture}_", 1)[1]

        part_info = self.config["parts"][part_name]
        if "part" in part_info:
            part_name = part_info["part"]

        raw_part_pose = self.get_object_origin(object_name=f"{self.furniture}_{part_name}")
        part_pose = self.get_transform(part=part, pose=pose)(raw_part_pose)
        return part_pose

    def get_transform(self, part: str, pose: str):
        part_name = part.split(f"{self.furniture}_", 1)[1]
        config = self.config["parts"][part_name][pose]

        def trans(pose: torch.Tensor, part: str, config: List[Dict[str, Any]]):
            if config is None:
                return pose

            for operation in config:
                if "parent" in operation:
                    # apply different transform
                    parent = operation["parent"]
                    if ":" in parent:
                        part, parent = parent.split(":")
                        part = f"{self.furniture}_{part}"
                    pose = self.get_transform(part=part, pose=parent)(pose)
                    continue

                op_name = operation["op"]
                op_args = operation.get("args", {})
                assert isinstance(op_args, dict), "args must be a dict"
                if op_name == "tf_local":
                    for tf in op_args["tfs"]:
                        pose = transform_pose_in_local_coords(pose, **tf)
                if op_name == "tf_world":
                    for tf in op_args["tfs"]:
                        pose = transform_pose_in_world_coords(pose, **tf)
                elif op_name == "normalize_to_z":
                    pose = normalize_pose_to_z(pose)
            return pose

        return partial(trans, part=part, config=config)

    def get_parts(self) -> List[str]:
        # return [p.name for p in self.env.env.env.furniture.parts]
        return [f"{self.furniture}_{part}" for part in self.config["parts"].keys()]

    def robot_arm(self) -> RobotArm:
        gripper_pose = self.get_current_pose(at_flange=True)[:7]
        gripper_position = gripper_pose[:3].tolist()
        gripper_closed = self.gripper_closed

        gripper_position = Point3d(x=gripper_position[0], y=gripper_position[1], z=gripper_position[2])

        return RobotArm(
            gripper_position=gripper_position,
            gripper_closed=gripper_closed,
        )

    def table(self) -> Table:
        surface = torch.as_tensor([0, 0, 0.4 + 0.014, 1, 0, 0, 0])
        surface = relative_to_base(surface, self).flatten()
        return Table(surface_z=surface[2].item())

    def rigid_bodies(self) -> Dict[str, RigidBody]:
        object_names = self.get_objects()

        bodies = {}
        for object_name in object_names:
            min_bound, max_bound = self.get_object_bounding_box(object_name=object_name)
            min_bound = Point3d(x=min_bound[0], y=min_bound[1], z=min_bound[2])
            max_bound = Point3d(x=max_bound[0], y=max_bound[1], z=max_bound[2])
            pose = self.get_transformed_pose(part=object_name, pose="center")
            orientation = quaternion_to_euler_zyx(pose[3:7]).tolist()
            pos = pose[:3].tolist()
            orientation = Point3d(x=orientation[0], y=orientation[1], z=orientation[2])
            bodies[object_name] = RigidBody(
                min_bound=min_bound,
                max_bound=max_bound,
                grasp_center=Point3d(x=pos[0], y=pos[1], z=pos[2]),
                orientation=orientation,
            )

        return bodies

    def add_sample_to_data(self):
        image = get_image_with_labels(self)
        self.data.append(
            StateModel(
                robot_arm=self.robot_arm(),
                rigid_bodies=self.rigid_bodies(),
                table=self.table(),
                image=image_to_base64(image),
            )
        )

    def __init__(self, furniture: str = "lamp"):
        self.sampling_interval = sim_config["sim_params"].dt  # seconds per step
        self.device = torch.device("cuda")
        # Create Env and Run
        self.T_ee_to_tool = torch.as_tensor([0, 0, 0, 1, 0, 0, 0], device=self.device, dtype=torch.float32)
        self.furniture = furniture
        self.config = yaml.safe_load(Path(f"scripts/teaching-configs/{furniture}.yaml").read_text())

        self.data = []

        obs_keys = [
            "robot_state/ee_pos",
            "robot_state/ee_quat",
            "robot_state/ee_pos_vel",
            "robot_state/ee_ori_vel",
            "robot_state/gripper_width",
            "robot_state/joint_positions",
            "color_image1",
            "color_image2",
            "parts_poses",
        ]
        self.env = gym.make(
            "FurnitureSim-v0",
            furniture=furniture,  # Specifies the type of furniture [lamp | square_table | desk | drawer | cabinet | round_table | stool | chair | one_leg].
            num_envs=1,  # Number of parallel environments.
            obs_keys=obs_keys,  # List of observations.
            concat_robot_state=False,  # Whether to return robot_state in a vector or dictionary.
            use_april_tag_coords=False,  # Whether to use AprilTag coordinates for parts
            resize_img=False,  # If true, images are resized to 224 x 224.
            headless=True,  # If true, simulation runs without GUI.
            compute_device_id=0,  # GPU device ID for simulation.
            graphics_device_id=0,  # GPU device ID for rendering.
            init_assembled=False,  # If true, the environment is initialized with assembled furniture.
            np_step_out=False,  # If true, env.step() returns Numpy arrays.
            channel_first=False,  # If true, images are returned in channel first format.
            randomness="low",  # Level of randomness in the environment [low | med | high].
            high_random_idx=-1,  # Index of the high randomness level (range: [0-2]). Default -1 will randomly select the index within the range.
            save_camera_input=False,  # If true, the initial camera inputs are saved.
            record=True,  # If true, videos of the wrist and front cameras' RGB inputs are recorded.
            max_env_steps=1000,  # Maximum number of steps per episode.
            act_rot_repr="quat",  # Representation of rotation for action space. Options are 'quat' and 'axis'.
        )

        self.objects = {
            k: open3d.io.read_triangle_mesh(
                (Path(ASSET_ROOT) / ("furniture/mesh/%s/%s.obj" % (furniture, k))).as_posix()
            )
            for k in self.get_parts()
        }

        self.gripper_closed = False

        cache_dir = Path("cache")
        cache_dir.mkdir(exist_ok=True)
        self.cache_file = cache_dir / "states.json"

    def stop_recording(self):
        env = self.env.env.env
        if env.record:
            env.video_writer.release()

    def restart_recording(self):
        self.stop_recording()
        env = self.env.env.env
        if env.record:
            self.record_dir = Path("sim_record") / datetime.now().strftime("%Y%m%d-%H%M%S")
            self.record_dir.mkdir(parents=True, exist_ok=True)
            env.video_writer = cv2.VideoWriter(
                str(self.record_dir / "video.mp4"),
                cv2.VideoWriter_fourcc(*"MP4V"),
                30,
                (env.img_size[0] * 2, env.img_size[1]),  # Wrist and front cameras.
            )

    def get_observation(self) -> Dict:
        return self.env.env.env.get_observation()

    def get_objects(self) -> str:
        furniture_env = self.env.env.env  # type: FurnitureSimEnv
        parts = furniture_env.furniture.parts
        part_names = [p.name for p in parts]
        return part_names

    def get_current_pose(self, *, at_flange: bool = False) -> torch.Tensor:
        current_pose = get_pose_from_obs(self.get_observation())
        if not at_flange:
            tool_pose = affine_transform(current_pose[:7], self.T_ee_to_tool)
        else:
            tool_pose = current_pose[:7]
        tool_pose = torch.concat((tool_pose, current_pose[7:]))
        return tool_pose

    def get_object_bounding_box(self, object_name: str):
        obj_info = self.config["parts"][object_name.split(f"{self.furniture}_", 1)[1]]
        if "part" in obj_info:
            object_name = f"{self.furniture}_{obj_info['part']}"

        mesh = copy.deepcopy(self.objects[object_name])
        obj_center = self.get_object_origin(object_name=object_name)
        obj_center_affine = pose_to_affine(obj_center)
        mesh = mesh.transform(obj_center_affine.cpu().numpy())
        oobb_object = mesh.get_axis_aligned_bounding_box()

        min_bound = list(oobb_object.min_bound)
        max_bound = list(oobb_object.max_bound)

        return min_bound, max_bound

    def get_object_origin(self, object_name: str):
        obs = self.get_observation()
        part_poses = get_part_poses_from_obs(obs=obs, env=self)

        for k, v in part_poses.items():
            if object_name in k:
                return v[0]
        raise RuntimeError("object %s unknown" % object_name)

    def grasps_object(self, obj_name: str, xy_tolerance: float = 3e-2) -> bool:
        # 2 cm diff
        ee_pose = self.get_current_pose(at_flange=True)[:7]
        object_pose = self.get_transformed_pose(part=obj_name, pose="center")

        rel_transform = absolute_to_relative(object_pose, ee_pose)

        within_xyz = rel_transform[:3].norm() < xy_tolerance

        return bool(within_xyz.item()) and self.gripper_closed

    def _gen_hash(self) -> str:
        return str(uuid.uuid4())

    def get_hash(self) -> str:
        return self.env_state.curr_hash

    def set_hash(self, hash: str, *, force: bool = False):
        if hash == self.env_state.curr_hash and not force:
            # nothing to do
            return None

        state = self.env_state.get_state(hash=hash)
        assert state is not None, "state not present"

        self.set_state(new_state=state)
        self.env_state.curr_hash = hash

    def add_state(self, hash: Optional[str] = None) -> str:
        n_hash = self.env_state.set_state(self.get_state(), hash=hash)
        self.cache_file.write_text(json.dumps(self.env_state.to_dict()))
        return n_hash

    def reset_env(self, *, seed: Optional[int] = None, new_hash: bool = True, new_seed: bool = True):
        if not new_seed:
            assert self.env_state is not None
            seed = self.env_state.seed
        else:
            seed = seed
        self.env.reset(seed=seed)
        f_env = self.env.env.env
        f_env.isaac_gym.refresh_net_contact_force_tensor(f_env.sim)
        gymtorch.wrap_tensor(f_env.isaac_gym.acquire_net_contact_force_tensor(f_env.sim))

        self.T_ee_to_tool = torch.as_tensor([0, 0, 0, 1, 0, 0, 0], device=self.device, dtype=torch.float32)

        if not new_hash:
            assert self.env_state is not None
            hash = self.env_state.init_hash
        else:
            hash = self._gen_hash()
        self.env_state = EnvStates(seed=seed, init_hash=hash, curr_hash=hash, states={})

        if self.cache_file.is_file():
            self.env_state.from_dict(json.loads(self.cache_file.read_text()))

        self.gripper_closed = False
        self.env.env.env.last_grasp[:] = -1

        self.add_state()

    def solve_scripted(self):
        comp = False
        while not comp:
            action, _ = self.env.env.env.get_assembly_action()
            _, _, comp, _ = self.env.step(action)

    def get_state(self):
        fb_env = self.env.env.env  # type: FurnitureSimEnv
        sim = fb_env.sim
        gym = fb_env.isaac_gym

        # https://docs.robotsfan.com/isaacgym/programming/tensors.html#actor-root-state-tensor

        gym.refresh_actor_root_state_tensor(sim)
        _rb_states = gym.acquire_actor_root_state_tensor(sim)
        rb_states = gymtorch.wrap_tensor(_rb_states).clone()

        gym.refresh_dof_state_tensor(sim)
        _dof_states = gym.acquire_dof_state_tensor(sim)
        dof_states = gymtorch.wrap_tensor(_dof_states).clone()


        from furniture_bench_api.api.api_predicates_validator import supported_predicates

        predicates_state = {
            name: copy.deepcopy(pred.get_state())
            for name, pred in supported_predicates.items()
        }

        return EnvState(
            rb=rb_states,
            dof=dof_states,
            gripper_closed=self.gripper_closed,
            last_grasp=fb_env.last_grasp.clone(),
            T_ee_to_tool=self.T_ee_to_tool.clone(),
            predicates=predicates_state,
        )

    def set_state(self, new_state: EnvState):
        # https://docs.robotsfan.com/isaacgym/programming/tensors.html#actor-root-state-tensor
        fb_env = self.env.env.env  # type: FurnitureSimEnv
        sim = fb_env.sim
        gym = fb_env.isaac_gym

        gym.refresh_actor_root_state_tensor(sim)
        _rb_states = gym.acquire_actor_root_state_tensor(sim)
        rb_states = gymtorch.wrap_tensor(_rb_states)
        rb_states[:] = new_state.rb[:]
        gym.set_actor_root_state_tensor(sim, _rb_states)

        gym.refresh_dof_state_tensor(sim)
        _dof_states = gym.acquire_dof_state_tensor(sim)
        dof_states = gymtorch.wrap_tensor(_dof_states)
        dof_states[:] = new_state.dof[:]
        gym.set_dof_state_tensor(sim, _dof_states)

        self.gripper_closed = new_state.gripper_closed
        fb_env.last_grasp = new_state.last_grasp.clone()

        self.T_ee_to_tool = new_state.T_ee_to_tool.to(self.device).clone()

        from furniture_bench_api.api.api_predicates_validator import supported_predicates
        for pred_name, pred_state in new_state.predicates.items():
            supported_predicates[pred_name].set_state(pred_state)

        fb_env.refresh()

        from PIL import Image

        Image.fromarray(self.get_observation()["color_image2"][0].cpu().numpy()).save("image2.png")
