from __future__ import annotations

import copy
from matplotlib import pyplot as plt
import rowan
import logging
import numpy as np
import robotic as ry
from typing import List

import vtamp.environments.bridge.manipulation as manip
from vtamp.environments.draw.tasks import isolate_red_shapes_from_rgb
from vtamp.environments.utils import Action, Environment, State, Task
from dataclasses import dataclass, field

log = logging.getLogger(__name__)
# Used as imports for the LLM-generated code
__all__ = ["Frame", "DrawState"]


@dataclass
class Frame:
    name: str
    x_pos: float
    y_pos: float
    z_pos: float
    x_size: float
    y_size: float
    z_size: float
    x_rot: float
    y_rot: float
    z_rot: float


    def __str__(self):
        return (
            f'Frame('
            f'name="{self.name}", '
            f'x_pos={round(self.x_pos, 2)}, '
            f'y_pos={round(self.y_pos, 2)}, '
            f'z_pos={round(self.z_pos, 2)}, '
            f'x_rot={round(self.x_rot, 2)}, '
            f'y_rot={round(self.y_rot, 2)}, '
            f'z_rot={round(self.z_rot, 2)}, '
           # f'size={round(self.size, 2) if isinstance(self.size, float) else [round(s, 2) for s in self.size]} '

        )



@dataclass
class DrawState(State):
    config: ry.Config
    frames: List[Frame] = field(default_factory=list)

    def __str__(self):
        return "DrawState(frames=[{}])".format(
            ", ".join([str(o) for o in self.frames])
        )

    def getFrame(self, name: str) -> Frame:
        for f in self.frames:
            if f.name == name:
                return f
        return None
    


class DrawEnv(Environment):
    def __init__(self, task: Task, **kwargs):

        super().__init__(task)

        #self.compute_collisions = False
        state_dict = self.task.setup_env()
        self.C: ry.Config = state_dict["config"]
        self.C.view(False, "Working Config")
        # self.base_config: ry.Config = self.task.setup_env()["config"]
        # self.base_config.view(False, "Base Config")
        # self.C: ry.Config = self.task.setup_env()["config"]
        # self.C.view(False, "Working Config")
        self.initial_state = self.reset()
        self.qHome = self.C.getJointState()
        self.lines = 0

        self.whiteboard_thickness = state_dict["whiteboard_thickness"]
        self.whiteboard_tilt = state_dict["whiteboard_tilt"]
        self.whiteboard_size_x = state_dict["whiteboard_size_x"]
        self.whiteboard_size_y = state_dict["whiteboard_size_y"]

    
        CameraView = ry.CameraView(self.C)
        CameraView.setCamera(self.C.getFrame("cameraTop"))
        self.fx, self.fy, self.cx, self.cy = CameraView.getFxycxy()

        self.number_lines = task.number_lines

    def step(self, action: Action, vis: bool=True):
        
        info = {"constraint_violations": []}

        if not self.feasible:
            self.C.view()
            self.t = self.t + 1
            return self.state, False, 0, info
        
        self.feasible = False
        
        if action.name == "draw_line":
            """
            Draws a line by moving the robots gripper between two points and 
            placing visual markers if the pen is in contact with the whiteboard.

            Args:
                x0, y0, z0 (float): Start coordinates.
                x1, y1, z1 (float): End coordinates.
            """
            p1 = np.array(action.params[:2])
            p2 = np.array(action.params[2:4])

            self.feasible = True
            if p1[0] > self.whiteboard_size_x or p1[0] < 0 or p1[1] > self.whiteboard_size_y or p1[1] < 0:
                print("Warning: Point 1 is outside the whiteboard.")
                self.feasible = False
            if p2[0] > self.whiteboard_size_x or p2[0] < 0 or p2[1] > self.whiteboard_size_y or p2[1] < 0:
                print("Warning: Point 2 is outside the whiteboard.")
                self.feasible = False


            # Get whiteboard info
            whiteboard_frame = self.C.getFrame("whiteboard")
            whiteboard_pos = whiteboard_frame.getPosition() # World position of the whiteboard's center
            # Use table_shape as it was used to define the whiteboard's dimensions
            wb_width = self.whiteboard_size_x
            wb_height = self.whiteboard_size_y

            # Rotation matrix for pitch (rotation around world X-axis)
            cos_tilt = np.cos(self.whiteboard_tilt)
            sin_tilt = np.sin(self.whiteboard_tilt)
            R_wb = np.array([
                [1, 0,         0],
                [0, cos_tilt, -sin_tilt],
                [0, sin_tilt,  cos_tilt]
            ])

            # --- Coordinate Conversion ---
            # Convert bottom-left relative coordinates (input) to center-relative coordinates (local frame)
            # The whiteboard frame's origin is at its center.
            x0_c = p1[0] - wb_width / 2
            y0_c = p1[1] - wb_height / 2
            x1_c = p2[0] - wb_width / 2
            y1_c = p2[1] - wb_height / 2

            # --- Define Points in Local Frame ---
            # Define points relative to the whiteboard's center in its local coordinate system.
            # Add a small offset in z to place the line slightly above the physical surface.
            z_offset = self.whiteboard_thickness / 2 + 0.001
            local_p0 = np.array([x0_c, y0_c, z_offset])
            local_p1 = np.array([x1_c, y1_c, z_offset])

            # --- Transform to World Frame ---
            # Transform local center-relative points to world coordinates.
            world_p0 = whiteboard_pos + R_wb @ local_p0
            world_p1 = whiteboard_pos + R_wb @ local_p1
            
            # --- Calculate Cylinder Properties ---
            center_pos = (world_p0 + world_p1) / 2 # World position of the cylinder's center
            vector = world_p1 - world_p0
            length = np.linalg.norm(vector)
            if length < 1e-6: # Avoid division by zero or tiny cylinders
                print("Warning: Skipping near-zero length line.")
                return

            # --- Calculate Cylinder Orientation ---
            # Default cylinder in ry is aligned with Z-axis
            z_axis = np.array([0., 0., 1.])
            vector_norm = vector / length
            axis = np.cross(z_axis, vector_norm)
            # Use np.clip for numerical stability with arccos
            dot_product = np.clip(np.dot(z_axis, vector_norm), -1.0, 1.0)
            angle = np.arccos(dot_product)

            # Convert axis-angle to quaternion
            if np.linalg.norm(axis) > 1e-6:
                axis = axis / np.linalg.norm(axis) # Normalize axis
                quat = rowan.from_axis_angle(axis, angle)
            elif dot_product < -0.9999: # Check for 180 degree rotation (anti-parallel vectors)
                # Rotate 180 degrees around any perpendicular axis, e.g., X-axis
                quat = np.array([0., 1., 0., 0.]) # rowan uses [w, x, y, z], so cos(pi/2)=0, sin(pi/2)=1
            else: # vector is already aligned with z-axis (dot_product approx 1)
                quat = np.array([1., 0., 0., 0.]) # No rotation needed (identity quaternion)

            vector = np.array([0., 0.])
            starting_point = np.array([0., 0.])

            camera_point = rowan.to_matrix(self.C.getFrame("cameraTop").getQuaternion()) @ world_p0 + self.C.getFrame("cameraTop").getPosition()
            u = self.fx * camera_point[0] / camera_point[2] + self.cx
            v = self.fy * camera_point[1] / camera_point[2] + self.cy

            vector -= np.array([u, v])
            starting_point += np.array([u, v])

            camera_point = rowan.to_matrix(self.C.getFrame("cameraTop").getQuaternion()) @ world_p1 + self.C.getFrame("cameraTop").getPosition()
            u = self.fx * camera_point[0] / camera_point[2] + self.cx
            v = self.fy * camera_point[1] / camera_point[2] + self.cy

            vector += np.array([u, v])

            # --- Add Cylinder Marker ---
            # Add the cylinder frame centered at the calculated midpoint 'center_pos'
            self.C.addFrame(f"line_{self.lines}") \
                .setShape(ry.ST.cylinder, [length, 0.005]) \
                .setPosition(center_pos) \
                .setQuaternion(quat) \
                .setColor([1, 0, 0]) # Red color


            self.lines += 1 # Increment line counter
       
        else:
            raise NotImplementedError
        
        if not self.feasible:
            info["constraint_violations"].append("idk")
        


        self.starting_points.append(starting_point)    
        self.vectors.append(vector)
        self.t = self.t + 1
        self.state = self.getState()
        return self.state, False, 0, info
    
    
    @staticmethod
    def sample_twin(real_env: DrawEnv, obs, task: Task, **kwargs) -> DrawEnv:
        twin = DrawEnv(task)
        twin.C = ry.Config()
        twin.C.addConfigurationCopy(real_env.C)


        return twin

    def reset(self):
        q = self.C.getJointState()
        C_state = self.C.getFrameState()
        self.C.setJointState(q)

        for frame in self.C.getFrameNames():
            if "line_" in frame:
                self.C.delFrame(frame)

        if self.C.getFrame("tmp", warnIfNotExist=False):
            self.C.delFrame("tmp")

        self.starting_points = []
        self.vectors = []
        self.C.setFrameState(C_state)
        self.C.view()
        self.state = self.getState()
        self.t = 0
        self.feasible = True

        return self.state
    
    def getState(self):

        state = DrawState(self.C)
        state.frames = []
        for f in self.task.relevant_frames:
            C_frame = self.C.getFrame(f)
        
            pos = C_frame.getPosition()
            size = C_frame.getSize()

            if "camera" in f:
                size = [np.nan, np.nan, np.nan]

            rot = rowan.to_euler(C_frame.getQuaternion(), convention="xyz") # Rotations need further testing

            frame = Frame(f, *pos, *size, *rot)
            state.frames.append(frame)

        return state

    def render(self):
        # if self.lines == self.number_lines:
        to_stay = ["world", "table", "whiteboard", "cameraTop"]

        C_copy = ry.Config()
        C_copy.addConfigurationCopy(self.C)
        for frame in C_copy.getFrameNames():
            if frame not in to_stay and "line_" not in frame:
                C_copy.delFrame(frame)

        CameraView = ry.CameraView(C_copy)
        CameraView.setCamera(C_copy.getFrame("cameraTop"))
        image, _ = CameraView.computeImageAndDepth(C_copy)
        plt.imshow(image)
        plt.show()

        # self.C.view(True)
        image = isolate_red_shapes_from_rgb(image, background_color=(255, 255, 255))
        # plt.imshow(image)
        # plt.axis("off")
        # plt.savefig("result.png")

        self.image_without_background = isolate_red_shapes_from_rgb(image, background_color=(255, 255, 255))


    def compute_cost(self):
        self.C.view()
        cost = self.task.get_cost(self)

        return cost
    