
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import proj3d
from matplotlib.patches import FancyArrowPatch
import os


import seaborn as sns

def visualize_pose(pose, lower_bound, upper_bound, material_type, is_ob, color_mask=None):
    # Set seaborn style for better aesthetics
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=1.2)

    if pose.shape[-1] == 2:
        # Create figure with tight layout to minimize whitespace
        fig, ax = plt.subplots(figsize=(8, 8), dpi=100, tight_layout=True)
        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
        # Plot particles with better color scheme
        # breakpoint()
        color = 'dodgerblue' if material_type == 'water' else 'darkorange'
        # Use seaborn's scatterplot for better default markers
        if is_ob:
            sns.scatterplot(x=pose[:300, 0], y=pose[:300, 1], 
                        color='black', s=50, edgecolor='none', ax=ax)
            sns.scatterplot(x=pose[300:, 0], y=pose[300:, 1], 
                        color=color, s=50, alpha=0.7, 
                        edgecolor='none', ax=ax)
        elif color_mask:
            color_map = {
                'gold': 'darkorange',
                'blue': 'dodgerblue'
            }
            for mask, color_name in color_mask:
                sns.scatterplot(
                    x=pose[mask, 0],
                    y=pose[mask, 1],
                    color=color_map.get(color_name, color_name),
                    s=50,
                    alpha=0.7,
                    edgecolor='none',
                    ax=ax
                )
        else:
            sns.scatterplot(x=pose[:, 0], y=pose[:, 1], 
                        color=color, s=50, alpha=0.7, 
                        edgecolor='none', ax=ax)
        

        ax.set_xlim(lower_bound, upper_bound)
        ax.set_ylim(lower_bound, upper_bound)
        ax.set_aspect('equal')
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_color('black')
            spine.set_linewidth(1.0)
        # Remove spines and ticks for cleaner look
        ax.tick_params(
            which='both',      
            bottom=False,      
            left=False,        
            labelbottom=False, 
            labelleft=False    
        )
        
        # Get image data
        fig.canvas.draw()
        img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig)

        return img
    else:
        # For 3D, matplotlib is still better than seaborn
        fig = plt.figure(figsize=(10, 10), dpi=100)
        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
        ax = fig.add_subplot(111, projection='3d')
        
        # Set background color
        fig.patch.set_facecolor('white')
        ax.set_facecolor('white')
        color = 'dodgerblue' if material_type == 'water' else 'darkorange'
        
        ax.scatter(pose[:, 0], pose[:, 2], pose[:, 1], 
                  s=4, c=color, alpha=0.6, depthshade=True)
        # Set limits and labels
        ax.set_xlim(lower_bound, upper_bound)
        ax.set_ylim(lower_bound, upper_bound)
        ax.set_zlim(lower_bound, upper_bound)
        ax.set_aspect('equal')
        ax.set_axis_off()
        
        x = [lower_bound, upper_bound, upper_bound, lower_bound, lower_bound]
        y = [lower_bound, lower_bound, upper_bound, upper_bound, lower_bound]
        z = [lower_bound] * 5
        ax.plot(x, y, z, color='black', linewidth=2.0)

        z = [upper_bound] * 5
        ax.plot(x, y, z, color='black', linewidth=2.0)

        for xi, yi in [(lower_bound, lower_bound), 
                    (upper_bound, lower_bound),
                    (upper_bound, upper_bound),
                    (lower_bound, upper_bound)]:
            ax.plot([xi, xi], [yi, yi], [lower_bound, upper_bound], 
                    color='black', linewidth=2.0)

        x = [lower_bound, upper_bound, upper_bound, lower_bound, lower_bound]
        z = [lower_bound, lower_bound, upper_bound, upper_bound, lower_bound]
        y = [lower_bound] * 5
        ax.plot(x, y, z, color='black', linewidth=2.0)

        y = [upper_bound] * 5
        ax.plot(x, y, z, color='black', linewidth=2.0)

        y = [lower_bound, upper_bound, upper_bound, lower_bound, lower_bound]
        z = [lower_bound, lower_bound, upper_bound, upper_bound, lower_bound]
        x = [lower_bound] * 5
        ax.plot(x, y, z, color='black', linewidth=2.0)

        x = [upper_bound] * 5
        ax.plot(x, y, z, color='black', linewidth=2.0)
        # Adjust viewing angle for better perspective
        # ax.view_init(elev=20, azim=45)
        
        # Get image data
        fig.canvas.draw()
        img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
        img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,))  
        img = img[..., :3]  # Remove alpha channel
        
        plt.close(fig)
        
        return img
def visualize_control_force(sim, frame, lower_bound, upper_bound):
    
    if sim.dim ==2:
        plt.figure(figsize=(8, 8))
        # Plot particles
        pos = sim.x.to_numpy()
        if sim.material_type == 'water':
            plt.scatter(pos[:, 0], pos[:, 1], s=5, c='blue', alpha=0.5)
        else:
            plt.scatter(pos[:, 0], pos[:, 1], s=5, c='orange', alpha=0.5)
        
        # Plot control forces
        # force = sim.control_accel.to_numpy()
        # plt.quiver(pos[:, 0], pos[:, 1], 
        #           force[:, 0], force[:, 1], 
        #           color='red', scale=1000, width=0.002, headwidth=3)
        
        plt.title(f"Frame {frame}")
        plt.xlim(lower_bound, upper_bound)
        plt.ylim(lower_bound, upper_bound)
        plt.gca().set_aspect('equal')
        
        canvas = plt.gcf().canvas
        canvas.draw()
        img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
        img = img.reshape(canvas.get_width_height()[::-1] + (3,))
        plt.close()

        return img
    else:
        # Plot particles
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')
        pos = sim.x.to_numpy()
        if sim.material_type == 'water':
            ax.scatter(pos[:, 0], pos[:, 2], pos[:, 1], s=2, c='b', alpha=0.5)
        else:
            ax.scatter(pos[:, 0], pos[:, 2], pos[:, 1], s=2, c='orange', alpha=0.5)
        ax.set_xlim(lower_bound, upper_bound)
        ax.set_ylim(lower_bound, upper_bound)
        ax.set_zlim(lower_bound, upper_bound)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title(f"Frame {frame}")
        ax.set_aspect('equal')

        
        # ax.view_init(elev=0, azim=270) 
        canvas = fig.canvas
        canvas.draw()
        img = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
        img = img.reshape(canvas.get_width_height()[::-1] + (4,))  
        img = img[..., :3]  

        plt.close(fig)   
        
        return img
    
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0,0), (0,0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
        return min(zs)
    

def create_control_comparison_figure_3d(states, target_pose, initial_pos, lower_bound, upper_bound, output_dir):
    """Create a figure showing the control effect with a single average arrow"""
    fig = plt.figure(figsize=(18, 6))
    ax = fig.add_subplot(111, projection='3d')

    # Get before and after positions
    before_pos = states[-1]['position']
    after_pos = target_pose
    
    # Calculate center of mass for before and after states
    com_before = np.mean(before_pos, axis=0)
    com_after = np.mean(after_pos, axis=0)
    
    # Calculate average displacement vector
    avg_displacement = com_after - com_before
    displacement_length = np.linalg.norm(avg_displacement)  
    
    # Plot initial state (transparent)
    ax.scatter(initial_pos[:, 0], initial_pos[:, 2], initial_pos[:, 1], s=2, c='gray', alpha=0.5, label='Initial State')
    
    # Plot state before control
    ax.scatter(before_pos[:, 0], before_pos[:, 2], before_pos[:, 1], s=2, c='blue', alpha=0.5, label='State Before Control')
    
    # Plot target state (final after control)
    ax.scatter(after_pos[:, 0], after_pos[:, 2], after_pos[:, 1], s=2, c='none', edgecolor='green', 
               linewidths=1, alpha=0.7, label='Target Position')
    

    arrow_scale = 1.0  
    num_segments = 15
    base_min_width = 1  
    base_max_width = 5  
    

    length_factor = np.clip(displacement_length / (upper_bound - lower_bound), 0.5, 2.0)
    min_linewidth = base_min_width * length_factor
    max_linewidth = base_max_width * length_factor

    arrow_start = com_before
    arrow_end = com_before + avg_displacement * arrow_scale

    arrow_depths = np.linspace(arrow_start[2], arrow_end[2], num_segments+1)
    min_depth = min(arrow_depths)
    max_depth = max(arrow_depths)

    segments = np.linspace(0, 1, num_segments+1)
    for i in range(num_segments):
        start_ratio = segments[i]
        end_ratio = segments[i+1]
        start_point = arrow_start + start_ratio * (arrow_end - arrow_start)
        end_point = arrow_start + end_ratio * (arrow_end - arrow_start)
        

        current_z = (start_point[2] + end_point[2]) / 2
        
        normalized_z = (current_z - min_depth) / (max_depth - min_depth)
        normalized_z = np.clip(normalized_z, 0, 1)


        linewidth = min_linewidth + (max_linewidth - min_linewidth) * normalized_z

        if i < num_segments - 1: 
            ax.plot([start_point[0], end_point[0]],
                   [start_point[2], end_point[2]],
                   [start_point[1], end_point[1]],
                   color='r', linewidth=linewidth)


    head_z = arrow_end[2]
    head_normalized_z = (head_z - lower_bound) / (upper_bound - lower_bound)
    head_linewidth = min_linewidth + (max_linewidth - min_linewidth) * head_normalized_z


    arrow_head = Arrow3D(
        [arrow_end[0] - avg_displacement[0]*0.1, arrow_end[0]],
        [arrow_end[2] - avg_displacement[2]*0.1, arrow_end[2]],
        [arrow_end[1] - avg_displacement[1]*0.1, arrow_end[1]],
        mutation_scale=20 * length_factor,
        arrowstyle='-|>',
        color='r',
        linewidth=head_linewidth
    )
    ax.add_artist(arrow_head)

    
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_zlim(lower_bound, upper_bound)
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.legend(loc='upper right')
    
    output_path = os.path.join(output_dir, "control_effect_average.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def create_control_figure_3d(states, target_pose, lower_bound, upper_bound, output_dir):
    """Create a minimal figure showing only the average control arrow"""
    fig = plt.figure(figsize=(5.12, 5.12))
    ax = fig.add_subplot(111, projection='3d')
    # Get before positions
    before_pos = states[-1]['position']
    after_pos = target_pose
    
    # Calculate center of mass for before and after states
    com_before = np.mean(before_pos, axis=0)
    com_after = np.mean(after_pos, axis=0)
    
    # Calculate average displacement vector
    avg_displacement = com_after - com_before
    displacement_length = np.linalg.norm(avg_displacement) 

    arrow_scale = 1.0  
    num_segments = 15
    base_min_width = 1  
    base_max_width = 5  
    

    length_factor = np.clip(displacement_length / (upper_bound - lower_bound), 0.5, 2.0)
    min_linewidth = base_min_width * length_factor
    max_linewidth = base_max_width * length_factor

    arrow_start = com_before
    arrow_end = com_before + avg_displacement * arrow_scale

    arrow_depths = np.linspace(arrow_start[2], arrow_end[2], num_segments+1)
    min_depth = min(arrow_depths)
    max_depth = max(arrow_depths)

    segments = np.linspace(0, 1, num_segments+1)
    for i in range(num_segments):
        start_ratio = segments[i]
        end_ratio = segments[i+1]
        start_point = arrow_start + start_ratio * (arrow_end - arrow_start)
        end_point = arrow_start + end_ratio * (arrow_end - arrow_start)
        
        current_z = (start_point[2] + end_point[2]) / 2
        
        normalized_z = (current_z - min_depth) / (max_depth - min_depth)
        normalized_z = np.clip(normalized_z, 0, 1)

        linewidth = min_linewidth + (max_linewidth - min_linewidth) * normalized_z
        if i < num_segments - 1: 
            ax.plot([start_point[0], end_point[0]],
                   [start_point[2], end_point[2]],
                   [start_point[1], end_point[1]],
                   color='r', linewidth=linewidth)

    head_z = arrow_end[2]
    head_normalized_z = (head_z - lower_bound) / (upper_bound - lower_bound)
    head_linewidth = min_linewidth + (max_linewidth - min_linewidth) * head_normalized_z


    arrow_head = Arrow3D(
        [arrow_end[0] - avg_displacement[0]*0.1, arrow_end[0]],
        [arrow_end[2] - avg_displacement[2]*0.1, arrow_end[2]],
        [arrow_end[1] - avg_displacement[1]*0.1, arrow_end[1]],
        mutation_scale=20 * length_factor,
        arrowstyle='-|>',
        color='r',
        linewidth=head_linewidth
    )
    ax.add_artist(arrow_head)
    # Set bounds and remove all decorations
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_zlim(lower_bound, upper_bound)

    # ax.view_init(elev=0, azim=90) 
    plt.gca().set_aspect('equal')
    plt.axis('off')  # Turn off all axis decorations
    
    # Tight layout with no padding
    plt.tight_layout(pad=0)
    
    output_path = os.path.join(output_dir, "control_arrow.png")
    plt.savefig(output_path, dpi=100, bbox_inches='tight', pad_inches=0)
    plt.close()


def vis_state_comparison_3d(initial_pos, states, target_pose, lower_bound, upper_bound, output_dir):
    fig = plt.figure(figsize=(15, 8))

    before_pos = states[-1]['position']
    # Initial state
    ax1 = fig.add_subplot(131, projection='3d')
    ax1.scatter(initial_pos[:, 0], initial_pos[:, 2], initial_pos[:, 1], s=2, c='b')
    ax1.set_title("Initial State")
    ax1.set_xlim(lower_bound, upper_bound)
    ax1.set_ylim(lower_bound, upper_bound)
    ax1.set_zlim(lower_bound, upper_bound)
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    
    # State before control
    ax2 = fig.add_subplot(132, projection='3d')
    ax2.scatter(before_pos[:, 0], before_pos[:, 2], before_pos[:, 1], s=2, c='b')
    ax2.set_title("State Before Control")
    ax2.set_xlim(lower_bound, upper_bound)
    ax2.set_ylim(lower_bound, upper_bound)
    ax2.set_zlim(lower_bound, upper_bound)
    ax2.set_xlabel('X')
    ax2.set_ylabel('Y')
    ax2.set_zlabel('Z')
    
    # Final state after control
    ax3 = fig.add_subplot(133, projection='3d')
    ax3.scatter(target_pose[:, 0], target_pose[:, 2], target_pose[:, 1], s=2, c='g')
    ax3.set_title("Final State After Control")
    ax3.set_xlim(lower_bound, upper_bound)
    ax3.set_ylim(lower_bound, upper_bound)
    ax3.set_zlim(lower_bound, upper_bound)
    ax3.set_xlabel('X')
    ax3.set_ylabel('Y')
    ax3.set_zlabel('Z')
    
    plt.tight_layout()

    output_path = os.path.join(output_dir, "state_comparison.png")
    plt.savefig(output_path)
    plt.close()

def create_target_control_figure_3d(target_pose,  lower_bound, upper_bound, output_dir):
    
    fig = plt.figure(figsize=(5.12, 5.12))
    ax = fig.add_subplot(111, projection='3d')

    com = np.mean(target_pose, axis=0)
    base_size = 5000  
    radius = 2 * np.linalg.norm(np.std(target_pose, axis=0))*10

    s = base_size * radius 
    
    ax.scatter(com[0], com[2], com[1],facecolors='none', edgecolors='green',
             alpha=1, s=s, label='Target Points')

    # ax.scatter(target_pose[:, 0], target_pose[:, 2], target_pose[:, 1],
    #           c='green', alpha=0.3, s=20, label='Target Points')

    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_zlim(lower_bound, upper_bound)
    ax.set_aspect('equal')
    ax.axis('off')  # Hide axes
    
    # Remove all padding/margins
    plt.tight_layout(pad=0)
    
    output_path = os.path.join(output_dir, f"target_control.png")
    plt.savefig(output_path, dpi=100, bbox_inches='tight', pad_inches=0)
    plt.close()

def create_target_control_figure_2d(target_pose, lower_bound, upper_bound, output_dir):
    """Create a minimal control figure with target circle and average arrow"""
    after_pos = target_pose
    
    com_after = np.mean(after_pos, axis=0)
    
    # Create figure with no margins
    fig, ax = plt.subplots(figsize=(5.12, 5.12))
    
    # Draw target circle (2 standard deviations radius)
    target_std = np.std(after_pos, axis=0)
    target_radius = np.linalg.norm(target_std)
    target_circle = plt.Circle(com_after, target_radius, 
                             color='green', alpha=0.3, 
                             linewidth=2, fill=False, label='Target Region')
    ax.add_patch(target_circle)
    
    
    
    # Set bounds and formatting
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_aspect('equal')
    ax.axis('off')  # Hide axes
    
    # Remove all padding/margins
    plt.tight_layout(pad=0)
    
    output_path = os.path.join(output_dir, "target_control.png")
    plt.savefig(output_path, dpi=100, 
               bbox_inches='tight', pad_inches=0)
    plt.close()

def create_control_figure_2d(states, target_pose, lower_bound, upper_bound, output_dir):
    """Create a minimal figure showing only the average control arrow"""
    plt.figure(figsize=(5.12, 5.12))
    
    # Get before positions
    before_pos = states[-1]['position']
    after_pos = target_pose
    
    # Calculate center of mass for before state
    com_before = np.mean(before_pos, axis=0)
    
    # Calculate average displacement vector
    avg_displacement = np.mean(after_pos - before_pos, axis=0)
    
    # Draw only the arrow
    arrow_scale = 1.0  # Scale factor to make arrow more visible
    plt.arrow(com_before[0], com_before[1],
             avg_displacement[0] * arrow_scale, avg_displacement[1] * arrow_scale,
             color='red', width=0.005, head_width=0.03, head_length=0.04)
    
    # Set bounds and remove all decorations
    plt.xlim(lower_bound, upper_bound)
    plt.ylim(lower_bound, upper_bound)
    plt.gca().set_aspect('equal')
    plt.axis('off')  # Turn off all axis decorations
    
    # Tight layout with no padding
    plt.tight_layout(pad=0)
    
    output_path = os.path.join(output_dir, "control_arrow.png")
    plt.savefig(output_path, dpi=100, bbox_inches='tight', pad_inches=0)
    plt.close()

def create_control_comparison_figure_2d(states, target_pose, initial_pos, lower_bound, upper_bound, output_dir):
    """Create a figure showing the control effect with a single average arrow"""
    plt.figure(figsize=(10, 10))
    
    # Get before and after positions
    before_pos = states[-1]['position']
    after_pos = target_pose
    
    # Calculate center of mass for before and after states
    com_before = np.mean(before_pos, axis=0)
    com_after = np.mean(after_pos, axis=0)
    
    # Calculate average displacement vector
    avg_displacement = np.mean(after_pos - before_pos, axis=0)
    
    # Plot initial state (transparent)
    plt.scatter(initial_pos[:, 0], initial_pos[:, 1], s=5, c='gray', alpha=0.5, label='Initial State')
    
    # Plot state before control
    plt.scatter(before_pos[:, 0], before_pos[:, 1], s=10, c='blue', alpha=0.5, label='State Before Control')
    
    # Plot target state (final after control)
    plt.scatter(after_pos[:, 0], after_pos[:, 1], s=15, c='none', edgecolor='green', 
                linewidths=1, alpha=0.7, label='Target Position')
    
    # Draw the single average arrow
    arrow_scale = 1.0  # Scale factor to make arrow more visible
    plt.arrow(com_before[0], com_before[1],
             avg_displacement[0] * arrow_scale, avg_displacement[1] * arrow_scale,
             color='red', width=0.005, head_width=0.03, head_length=0.04,
             label='Average Control Direction')
    
    # Add text annotation for the arrow
    arrow_mid_x = com_before[0] + avg_displacement[0] * arrow_scale / 2
    arrow_mid_y = com_before[1] + avg_displacement[1] * arrow_scale / 2
    plt.text(arrow_mid_x, arrow_mid_y, f"Avg Control\n({avg_displacement[0]:.3f}, {avg_displacement[1]:.3f})",
             color='red', ha='center', va='center', fontsize=9,
             bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
    
    plt.xlim(lower_bound, upper_bound)
    plt.ylim(lower_bound, upper_bound)
    plt.gca().set_aspect('equal')
    plt.legend(loc='upper right')
    
    output_path = os.path.join(output_dir, "control_effect_average.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def vis_state_comparison_2d(initial_pos, states, target_pose, lower_bound, upper_bound, output_dir):
    plt.figure(figsize=(15, 8))
    
    plt.subplot(131)
    plt.scatter(initial_pos[:, 0], initial_pos[:, 1], s=2, c='blue')
    plt.title("Initial State")
    plt.xlim(lower_bound, upper_bound)
    plt.ylim(lower_bound, upper_bound)
    
    
    plt.subplot(132)
    plt.scatter(states[-1]['position'][:, 0], states[-1]['position'][:, 1], s=2, c='blue')
    plt.title("State Before Control")
    plt.xlim(lower_bound, upper_bound)
    plt.ylim(lower_bound, upper_bound)
    
    plt.subplot(133)
    plt.scatter(target_pose[:, 0], target_pose[:, 1], s=2, c='green')
    plt.title("Final State After Control")
    plt.xlim(lower_bound, upper_bound)
    plt.ylim(lower_bound, upper_bound)
    
    plt.tight_layout()
    output_path = os.path.join(output_dir, "state_comparison.png")
    plt.savefig(output_path)
    plt.close()