import torch
import numpy as np
import os
import pickle
import argparse
import matplotlib.pyplot as plt
from copy import deepcopy
from tqdm import tqdm
from einops import rearrange
import sys
from std_msgs.msg import String, Float32MultiArray

path_to_suturebot = os.getenv("PATH_TO_SUTUREBOT")

if path_to_suturebot:
    sys.path.append(os.path.join(path_to_suturebot, 'src'))

import rospy
from rostopics import ros_topics
from policy import ACTPolicy
from utils import set_seed # helper functions
from sklearn.preprocessing import normalize
from sensor_msgs.msg import Image
from std_msgs.msg import Bool, Float32, Int16

import cv2
import crtk
from mpl_toolkits import mplot3d
from scipy.spatial.transform import Rotation as R
from pytransform3d import rotations, batch_rotations, transformations, trajectories
from dvrk_scripts.dvrk_control import example_application
from dvrk_scripts.constants_dvrk import TASK_CONFIGS
from aloha_pro.aloha_scripts.utils import initialize_model_and_tokenizer, encode_text, is_multi_gpu_checkpoint
import time
import IPython
from cv_bridge import CvBridge, CvBridgeError

e = IPython.embed
set_seed(0)

class LowLevelPolicy:

    ## ----------------- initializations ----------------
    def __init__(self, args):
        self.args = args
        self.temporal_agg = False
        
        self.initialize_parameters()
        self.initialize_ros()
        
        self.setup_policy()
        self.setup_language_model()
        self.language_embedding = None
        self.language_instruction = None

        self.avail_commands = ["needle pickup", "needle throw", "knot tying"  ]
        command_idx = 0
        self.command = self.avail_commands[command_idx] ## can change
        self.debugging = False
        
    def initialize_parameters(self):
        self.num_inferences = 4000
        self.action_execution_horizon = 20
        
        self.sleep_rate = 0.1
        self.language_encoder = "distilbert"
        self.max_timesteps = 400 
        self.state_dim = 16
        self.iter = 0
        self.pause = False
        self.fps = 30
        self.correction = None
        self.user_correction = None
        self.is_correction = False
        self.user_correction_start_t = None
        self.use_preprogrammed_correction = False

            
    def initialize_ros(self):
        self.rt = ros_topics()
        self.ral = crtk.ral('dvrk_arm_test')
        self.bridge = CvBridge()
        self.psm1_app = example_application(self.ral, "PSM1", 1)
        self.psm2_app = example_application(self.ral, "PSM2", 1)
        self.instruction_sub = rospy.Subscriber("/instructor_prediction", String, self.language_instruction_callback, queue_size=10)
        self.pause_sub = rospy.Subscriber("/pause_robot", Bool, self.pause_robot_callback, queue_size=10)
        self.action_horizon_sub = rospy.Subscriber("/action_horizon", Int16, self.action_horizon_callback, queue_size=10)
        rospy.Subscriber('/direction_instruction_user', String, self.user_correction_callback, queue_size=1)
        rospy.Subscriber('/sketch_points', Float32MultiArray, self.sketch_point_callback, queue_size=1)

        
    def setup_policy(self):
        self.task_config = TASK_CONFIGS[self.args.task_name]
        self.mean = self.task_config['action_mode'][1]['mean']
        self.std = self.task_config['action_mode'][1]['std']
        self.max_ = self.task_config['action_mode'][1]['max_']
        self.min_ = self.task_config['action_mode'][1]['min_']
        self.action_mode = self.task_config['action_mode'][0]

        self.use_sketch_insertion = True if self.task_config.get('goal_condition_style') and self.task_config['goal_condition_style'] == 'plot' else False   
        print("\n\nuse sketch insertion: ", self.use_sketch_insertion)
        self.num_queries = self.chunk_size
        
        self.left_img = None
        self.sketch_points = None

        if self.args.policy_class == "ACT":
            policy_config = {
                'lr': 1e-5,
                'num_queries': self.chunk_size,
                'action_dim': 20,
                'kl_weight': 10,
                'hidden_dim': 512,
                'dim_feedforward': 3200,
                'lr_backbone': 1e-5,
                'backbone': 'efficientnet_b3' if not self.args.use_language else "efficientnet_b3film",
                'enc_layers': 4,
                'num_epochs': self.args.num_epochs,
                'dec_layers': 7,
                'nheads': 8,
                'camera_names': self.task_config['camera_names'],
                "multi_gpu": None,
            }
            self.policy = ACTPolicy(policy_config)
            
        
        model_state_dict = torch.load(self.args.ckpt_dir)["model_state_dict"]
        if is_multi_gpu_checkpoint(model_state_dict):
            print("The checkpoint was trained on multiple GPUs.")
            model_state_dict = {
                k.replace("module.", "", 1): v for k, v in model_state_dict.items()
            }
        loading_status = self.policy.deserialize(model_state_dict)
        print(loading_status)
        self.policy.cuda()
        self.policy.eval()
        
        # print(f"Loaded: {self.args.ckpt_dir}")
    
    def setup_language_model(self):
        if self.args.use_language:
            self.tokenizer, self.model = initialize_model_and_tokenizer("distilbert")
            assert self.tokenizer is not None and self.model is not None
            print("language model and tokenizer set up completed")

    ## ------------ helper functions for action -------------

 
    def unnormalize_action(self, naction, norm_scheme):
        action = None
        if norm_scheme == "min_max":
            action = (naction + 1) / 2 * (self.max_ - self.min_) + self.min_
            action[:, 3:9] = naction[:, 3:9]
            action[:, 13:19] = naction[:, 13:19]
        elif norm_scheme == "std":
            action = self.unnormalize_positions_only_std(naction)
        else:
            raise NotImplementedError
        return action

    def unnormalize_positions_only_std(self, diffs):
        unnormalized = diffs * self.std + self.mean
        unnormalized[:, 3:9] = diffs[:, 3:9]
        unnormalized[:, 13:19] = diffs[:, 13:19]
        return unnormalized

    def convert_delta_6d_to_taskspace_quat(self, all_actions, all_actions_converted, qpos):
        '''
        convert delta rot into task-space quaternion rot
        '''
        # Gram-schmidt
        c1 = all_actions[:, 3:6] # t x 3
        c2 = all_actions[:, 6:9] # t x 3 
        c1 = normalize(c1, axis = 1) # t x 3
        dot_product = np.sum(c1 * c2, axis = 1).reshape(-1, 1)
        c2 = normalize(c2 - dot_product*c1, axis = 1)
        c3 = np.cross(c1, c2)
        r_mat = np.dstack((c1, c2, c3)) # t x 3 x 3
        # transform delta rot into task space
        rots = R.from_matrix(r_mat)
        rot_init = R.from_quat(qpos[3:7])
        rots = (rot_init * rots).as_quat()
        all_actions_converted[:, 3:7] = rots
        return all_actions_converted
    
    

    
    ## --------------------- callbacks -----------------------
    def sketch_point_callback(self, msg):
        self.sketch_points = np.array(msg.data).reshape(-1, 2)
        if self.sketch_points.shape[0] > 0:
            self.use_sketch = True
        else:
            self.use_sketch = False

    def language_instruction_callback(self, msg):
        self.language_instruction = msg.data

    
    def pause_robot_callback(self, msg):
        self.pause = msg.data
        
        if self.pause:
            print("Robot paused. Waiting for the robot to be unpaused...")
        else:
            print("Robot unpaused. Resuming the low level policy...")
        
    def action_horizon_callback(self, msg):
        self.action_execution_horizon = msg.data
        print("action horizon changed to: ", self.action_execution_horizon)
        
    def user_correction_callback(self, msg):
        self.user_correction = msg.data
        self.user_correction_start_t = time.time()
        self.correction = self.user_correction
        self.is_correction = True  # Set the correction flag immediately when user issues a correction
        print("User correction issued: ", self.correction)

    ## ---------------------- helpers ------------------------
    
    def process_image(self, image_data):
        # Process the image from buffer to RGB and rearrange
        img = np.frombuffer(image_data, np.uint8)
        img = cv2.imdecode(img, cv2.IMREAD_COLOR)
        img = cv2.resize(img, (480, 360))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return rearrange(img, 'h w c -> c h w')


    def create_offset_map_with_gradient(self, image_shape, insert_point, exit_point, normalize_size=224.0, device='cpu', eps=1e-6):
        """
        Returns a 3-channel offset map:
        - Channel 0: dx to insertion point
        - Channel 1: dy to insertion point
        - Channel 2: scalar heatmap (1 at insertion, 0 at exit)

        Args:
            image_shape: (H, W)
            insert_point: (x, y)
            exit_point: (x, y)
            normalize_size: reference image size for normalization
            device: 'cpu' or 'cuda'
        """
        H, W = image_shape
        normalizing_constant = 250.0 * (min(H, W) / normalize_size)

        y_coords = torch.arange(H, device=device)
        x_coords = torch.arange(W, device=device)
        y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')

        # Offsets to insertion point (dy, dx)
        dx = (x_grid - insert_point[0]) / normalizing_constant
        dy = (y_grid - insert_point[1]) / normalizing_constant

        # Gradient heatmap: insertion → 1.0, exit → 0.0
        d_insert = torch.sqrt((x_grid - insert_point[0]) ** 2 + (y_grid - insert_point[1]) ** 2)
        d_exit = torch.sqrt((x_grid - exit_point[0]) ** 2 + (y_grid - exit_point[1]) ** 2)
        heat = d_exit / (d_insert + d_exit + eps)  # in [0, 1]

        # Stack to shape (3, H, W)
        offset_map = torch.stack([dx, dy, heat], dim=0)
        return offset_map.clamp(-1.0, 1.0)  # Optional clamp


    def offset_map_to_rgb_visual(self, offset_map):
        """
        Converts a (3, H, W) offset map (dx, dy, heat) to a uint8 RGB image for visualization.
        - Red = dx
        - Green = dy
        - Blue = heat
        """
        if torch.is_tensor(offset_map):
            offset_map = offset_map.detach().cpu().numpy()

        # Normalize each channel to [0, 1]
        def normalize(x):
            x = x - np.min(x)
            x = x / (np.max(x) + 1e-6)
            return x

        dx_norm = normalize(offset_map[0])
        dy_norm = normalize(offset_map[1])
        heat_norm = normalize(offset_map[2])

        rgb_image = np.stack([
            dx_norm,     # R
            dy_norm,     # G
            heat_norm    # B
        ], axis=-1)  # (H, W, 3)

        rgb_uint8 = (rgb_image * 255).astype(np.uint8)
        return rgb_uint8

    def process_image(self, image_data):
        # Process the image from buffer to RGB and rearrange
        img = np.frombuffer(image_data, np.uint8)
        img = cv2.imdecode(img, cv2.IMREAD_COLOR)
        img = cv2.resize(img, (480, 360))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img

    def get_image_dvrk(self):

        self.left_img = np.fromstring(self.rt.usb_image_left.data, np.uint8)
        self.left_img = cv2.imdecode(self.left_img, cv2.IMREAD_COLOR)
        # if self.use_sketch_insertion and self.language_instruction == "needle throw" and self.sketch_points is not None and self.sketch_points.shape[0] == 2:
        self.left_img = cv2.resize(self.left_img, (960, 540))

        if self.goal_condition_style == "dot":
            if  self.language_instruction == "needle throw" and self.sketch_points is not None and self.sketch_points.shape[0] == 2:

                h, w = self.left_img.shape[:2]
                clicked_points_mask = np.zeros((h, w, 3), dtype=np.uint8)
                # Draw insertion point (first point) as red
                insert_x, insert_y = int(self.sketch_points[0][0]), int(self.sketch_points[0][1])
                print("insert point: ", insert_x, insert_y)
                cv2.circle(clicked_points_mask, (insert_x, insert_y), radius=10, color=(255, 0, 0), thickness=-1)  # Red in BGR

                # Draw exit point (second point) as green
                exit_x, exit_y = int(self.sketch_points[1][0]), int(self.sketch_points[1][1])
                print("exit point: ", exit_x, exit_y)
                cv2.circle(clicked_points_mask, (exit_x, exit_y), radius=10, color=(0, 255, 0), thickness=-1)  # Green in BGR
                # Only blend where the mask has non-zero content
                nonzero_mask = np.any(clicked_points_mask != 0, axis=-1)
                overlay = self.left_img.copy()
                overlay[nonzero_mask] = cv2.addWeighted(
                    self.left_img, 0.5, clicked_points_mask, 0.5, 0
                )[nonzero_mask]
                self.left_img = overlay

        elif self.goal_condition_style == "mask":

            if  self.language_instruction == "needle throw" and self.sketch_points is not None and self.sketch_points.shape[0] == 2:
                # Create a 3-channel mask (H, W, 3) with all zeros
                h, w = self.left_img.shape[:2]
                clicked_points_mask = np.zeros((h, w, 3), dtype=np.uint8)
                # Draw insertion point (first point) as red
                insert_x, insert_y = int(self.sketch_points[0][0]), int(self.sketch_points[0][1])
                cv2.circle(clicked_points_mask, (insert_x, insert_y), radius=10, color=(255, 0, 0), thickness=-1)  # Red in BGR

                # Draw exit point (second point) as green
                exit_x, exit_y = int(self.sketch_points[1][0]), int(self.sketch_points[1][1])
                cv2.circle(clicked_points_mask, (exit_x, exit_y), radius=10, color=(0, 255, 0), thickness=-1)  # Green in BGR
                mask_img = clicked_points_mask
                    
            else:
                mask_img = np.zeros_like(self.left_img)

        elif self.goal_condition_style == "map":

            if  self.language_instruction == "needle throw" and self.sketch_points is not None and self.sketch_points.shape[0] == 2:
                print("sketch points: ", self.sketch_points)
                h, w = self.left_img.shape[:2]
                insert_x, insert_y = int(self.sketch_points[0][0]), int(self.sketch_points[0][1])
                exit_x, exit_y = int(self.sketch_points[1][0]), int(self.sketch_points[1][1])

                # Create offset map
                offset_map = self.create_offset_map_with_gradient(
                    image_shape=(h, w),
                    insert_point=(insert_x, insert_y),
                    exit_point=(exit_x, exit_y),
                    device='cpu'
                )

                mask_img = self.offset_map_to_rgb_visual(offset_map)

            else:
                mask_img = np.zeros_like(self.left_img)

        self.left_img = cv2.resize(self.left_img, (480, 360))
        self.left_img = cv2.cvtColor(self.left_img, cv2.COLOR_BGR2RGB)
        

        lw_img = self.process_image(self.rt.endo_cam_psm2.data)
        rw_img = self.process_image(self.rt.endo_cam_psm1.data)

        self.left_img = rearrange(self.left_img, 'h w c -> c h w')

        lw_img = rearrange(lw_img, 'h w c -> c h w')

        rw_img = rearrange(rw_img, 'h w c -> c h w')
        
        if self.goal_condition_style == "mask" or self.goal_condition_style == "map":
            mask_img = cv2.resize(mask_img, (480, 360))
            mask_img = cv2.cvtColor(mask_img, cv2.COLOR_BGR2RGB)
            mask_img = rearrange(mask_img, 'h w c -> c h w')

            curr_image = np.stack([self.left_img, lw_img, mask_img, rw_img], axis=0)
            curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
            
        else:
            curr_image = np.stack([self.left_img, lw_img, rw_img], axis=0)
            curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
            
        return curr_image
    
    
    
    def generate_command_embedding(self, command, correction):
        ## use language instructions from high level policy rostopic
        if self.args.use_language and self.language_instruction is not None:
            command = self.language_instruction
            command_embedding = encode_text(command, self.language_encoder, self.tokenizer, self.model)
            command_embedding = torch.tensor(command_embedding).cuda()
            print(f"\n---------------------------------------\nusing high level policy command: {command}\n---------------------------------------\n")

            return command_embedding
        
        ## use language command set in the low level policy
        else:
            command_embedding = encode_text(command, self.language_encoder, self.tokenizer, self.model)
            command_embedding = torch.tensor(command_embedding).cuda()
            print(f"\n---------------------------------------\nusing command: {command}\n---------------------------------------\n")
            return command_embedding
    
    def plot_actions(self, qpos_psm1, qpos_psm2, actions_psm1, actions_psm2):
        factor = 1000
        fig = plt.figure()
        ax = plt.axes(projection='3d')
        ax.scatter(actions_psm1[:, 0] * factor, actions_psm1[:, 1]* factor, actions_psm1[:, 2]* factor, c ='r')
        ax.scatter(actions_psm2[:, 0]*factor, actions_psm2[:, 1]*factor, actions_psm2[:, 2]*factor, c ='r', label = 'Generated trajectory')
        ax.scatter(qpos_psm1[0]* factor, qpos_psm1[1]* factor, qpos_psm1[2]* factor, c = 'g')
        ax.scatter(qpos_psm2[0]*factor, qpos_psm2[1]*factor, qpos_psm2[2]*factor, c = 'b', label = 'Current end-effector position')
        ax.set_xlabel('X (mm)')
        ax.set_ylabel('Y (mm)')
        ax.set_zlabel('Z (mm)')
        n_bins = 7
        ax.legend()
        ax.xaxis.set_major_locator(plt.MaxNLocator(n_bins))
        ax.yaxis.set_major_locator(plt.MaxNLocator(n_bins))
        ax.zaxis.set_major_locator(plt.MaxNLocator(n_bins))
        plt.show()

    def execute_actions(self, actions_psm1, actions_psm2):

        for jj in range(self.action_execution_horizon):
            # print("actions_psm2: ", actions_psm2[jj], "\nactions_psm2_temp: ", actions_psm2_temp[jj])
            # if not self.pause and not self.is_correction:
            if not self.pause:
                if self.use_preprogrammed_correction and self.is_correction:
                    break

                self.ral.spin_and_execute(self.psm1_app.run_full_pose_goal, actions_psm1[jj])
                self.ral.spin_and_execute(self.psm2_app.run_full_pose_goal, actions_psm2[jj])
                time.sleep(self.sleep_rate)
            else:
                break
                

    ## --------------------- main loop -----------------------

    def run(self):
        print("-------------starting low level policy------------------\n")
        time.sleep(1)
        if self.temporal_agg:
            all_time_actions = torch.zeros(
                [self.max_timesteps, self.max_timesteps + self.num_queries, self.state_dim]
            ).cuda()

        with torch.inference_mode():
            t = 0
            
            while t < self.num_inferences:
                try:
                    if rospy.is_shutdown():
                        print("ROS shutdown signal received. Exiting...")
                        break
                    
                    if self.pause or (self.use_preprogrammed_correction and self.is_correction):
                        try:
                            time.sleep(0.2)
                        except KeyboardInterrupt:
                            print("Exiting...")
                            break
                        continue
                    
                    if self.args.use_language:
                        command_embedding = self.generate_command_embedding(self.command, self.correction)
                    else:
                        command_embedding = None
                    qpos_zero = torch.zeros(1, 20).float().cuda()
                    
                    ### use this if testing with real endoscope image
                    curr_image = self.get_image_dvrk()

                    action = self.policy(qpos_zero, curr_image, command_embedding=command_embedding).cpu().numpy().squeeze()
                    action = self.unnormalize_action(action, self.task_config['norm_scheme'])

                    ## remove offset by substracting the first action
                    action[:,0:3] = action[:,0:3] - action[0, 0:3]
                    action[:,10:13] = action[:,10:13] - action[0, 10:13]

                    qpos_psm1 = np.array((self.rt.psm1_pose.position.x, self.rt.psm1_pose.position.y, self.rt.psm1_pose.position.z,
                                        self.rt.psm1_pose.orientation.x, self.rt.psm1_pose.orientation.y, self.rt.psm1_pose.orientation.z, self.rt.psm1_pose.orientation.w,
                                        self.rt.psm1_jaw))

                    qpos_psm2 = np.array((self.rt.psm2_pose.position.x, self.rt.psm2_pose.position.y, self.rt.psm2_pose.position.z,
                                        self.rt.psm2_pose.orientation.x, self.rt.psm2_pose.orientation.y, self.rt.psm2_pose.orientation.z, self.rt.psm2_pose.orientation.w,
                                        self.rt.psm2_jaw))

                    if self.action_mode == 'hybrid':
                        actions_psm1 = np.zeros((self.chunk_size, 8)) # pos, quat, jaw
                        
                        actions_psm1[:, 0:3] = qpos_psm1[0:3] + action[:, 0:3] # convert to current translation
                        actions_psm1 = self.convert_delta_6d_to_taskspace_quat(action[:, 0:10], actions_psm1, qpos_psm1)

                        actions_psm1[:, 7] = np.clip(action[:, 9], -0.698, 0.698)  # copy over gripper angles
                        
                        actions_psm2 = np.zeros((self.chunk_size, 8)) # pos, quat, jaw
                        actions_psm2[:, 0:3] = qpos_psm2[0:3] + action[:, 10:13] # convert to current translation
                        actions_psm2 = self.convert_delta_6d_to_taskspace_quat(action[:, 10:], actions_psm2, qpos_psm2)
                        actions_psm2[:, 7] = np.clip(action[:, 19], -0.698, 0.698)  # copy over gripper angles  


                    self.execute_actions(actions_psm1, actions_psm2)
                    
                    if self.debugging:
                        key = input("press enter to continue...")
                        if key == "q":
                            exit
                    t += 1
                    
                except KeyboardInterrupt:
                    print("low level policy interrupted")
                    break


## --------------------- main function -----------------------

if __name__ == "__main__":
    os.environ["PYTHONWARNINGS"] = "ignore"
    parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt_dir', action='store', type=str, 
                        help='specify ckpt file path', 
                        required=True)
    # needed to avoid error for detr
    parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
    parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
    parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
    parser.add_argument('--use_language', action='store_true')
    parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)    
    args = parser.parse_args()
        
    system = LowLevelPolicy(args)
    system.run()
