from typing import Optional
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Polygon
import numpy as np
import pandas as pd

from src.utils import reconstruction


def visualize_movement(
    IK_data: pd.DataFrame,
    segment_positions: np.ndarray,
    imu_data: Optional[pd.DataFrame] = None,
    cop_data: Optional[pd.DataFrame] = None,
    output_file: str = "animation.gif",
):
    if isinstance(imu_data, pd.DataFrame):
        global required_imu_data_columns
        required_imu_data_columns = [
            "imu_pelvis_global_x",
            "imu_pelvis_global_y",
            "imu_thigh_r_global_x",
            "imu_thigh_r_global_y",
            "imu_shank_r_global_x",
            "imu_shank_r_global_y",
            "imu_foot_r_global_x",
            "imu_foot_r_global_y",
            "imu_thigh_l_global_x",
            "imu_thigh_l_global_y",
            "imu_shank_l_global_x",
            "imu_shank_l_global_y",
            "imu_foot_l_global_x",
            "imu_foot_l_global_y",
        ]
        assert all(
            [col in imu_data.columns for col in required_imu_data_columns]
        ), imu_data.columns

    if isinstance(cop_data, pd.DataFrame):
        global required_cop_data_columns
        required_cop_data_columns = ["cop_rx", "cop_ry", "cop_lx", "cop_ly"]
        assert all(
            [col in cop_data.columns for col in required_cop_data_columns]
        ), cop_data.columns

    # Segment Reconstruction
    seg_starts, seg_ends = reconstruction._reconstruct_segments_koelewijn(
        IK_data, segment_positions
    )

    if cop_data is not None:
        cop_data["cop_rx"] = cop_data["cop_rx"].clip(
            seg_ends[:, 0, 4], seg_ends[:, 0, 5]
        )
        cop_data["cop_lx"] = cop_data["cop_lx"].clip(
            seg_ends[:, 0, 10], seg_ends[:, 0, 11]
        )

    blue_col = "#1f77b4"
    red_col = "#d62728"
    grey_col = "#7f7f7f"

    line_colors = [
        grey_col,
        grey_col,
        grey_col,
        red_col,
        red_col,
        red_col,
        red_col,
        grey_col,
        blue_col,
        blue_col,
        blue_col,
        blue_col,
    ]

    imu_point_colors = [line_colors[i] for i in [1, 3, 4, 6, 8, 9, 11]]

    triangle_colors = [line_colors[i] for i in [0, 6, 11]]

    # Set up the figure and axis
    fig, ax = plt.subplots()
    lines = [
        ax.plot([], [], "-", lw=2, color=line_colors[i])[0] for i in range(12)
    ]  # Line objects for the segments
    imu_points = [ax.plot([], [], "o", color=imu_point_colors[i])[0] for i in range(7)]
    cop_points = [ax.plot([], [], "x", color="black")[0] for i in range(2)]
    joint_triangles = [
        Polygon(
            np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]),
            closed=True,
            color=triangle_colors[i],
        )
        for i in range(3)
    ]
    ax.set_xlim(-0.5, 2.5)
    ax.set_ylim(-0.5, 2.5)
    ax.set_aspect("equal")

    # Initialize the animation by creating an empty plot
    def init_anim():
        for line in lines:
            line.set_data([], [])
        for imu_point in imu_points:
            imu_point.set_data([], [])
        for cop_point in cop_points:
            cop_point.set_data([], [])
        for triangle in joint_triangles:
            ax.add_patch(triangle)
        if imu_data is not None and cop_data is not None:
            return lines + joint_triangles + imu_points + cop_points
        elif imu_data is not None:
            return lines + joint_triangles + imu_points
        elif cop_data is not None:
            return lines + joint_triangles + cop_points
        else:
            return lines + joint_triangles

    # Update function for the animation
    def update(frame):
        # Draw the segments and IMUs based on the computed positions
        for i, line in enumerate(lines):
            start_pos = seg_starts[frame, :, i]
            end_pos = seg_ends[frame, :, i]
            line.set_data([start_pos[0], end_pos[0]], [start_pos[1], end_pos[1]])

        # Draw the joint triangles
        joint_triangles[0].set_xy(
            np.array(
                [
                    seg_starts[frame, :, 0],
                    seg_ends[frame, :, 0],
                    seg_ends[frame, :, 2],
                ]
            )
        )
        joint_triangles[1].set_xy(
            np.array(
                [
                    seg_starts[frame, :, 5],
                    seg_ends[frame, :, 5],
                    seg_ends[frame, :, 6],
                ]
            )
        )
        joint_triangles[2].set_xy(
            np.array(
                [
                    seg_starts[frame, :, 10],
                    seg_ends[frame, :, 10],
                    seg_ends[frame, :, 11],
                ]
            )
        )

        if imu_data is not None:
            for i, imu_point in enumerate(imu_points):
                imu_x = imu_data[required_imu_data_columns[i * 2]].iloc[frame]
                imu_y = imu_data[required_imu_data_columns[i * 2 + 1]].iloc[frame]
                imu_point.set_data([imu_x, imu_x], [imu_y, imu_y])

        if cop_data is not None:
            for i, cop_point in enumerate(cop_points):
                cop_x = cop_data[required_cop_data_columns[i * 2]].iloc[frame]
                cop_y = cop_data[required_cop_data_columns[i * 2 + 1]].iloc[frame]
                cop_point.set_data([cop_x, cop_x], [cop_y, cop_y])

        if imu_data is not None and cop_data is not None:
            return lines + joint_triangles + imu_points + cop_points
        elif imu_data is not None:
            return lines + joint_triangles + imu_points
        elif cop_data is not None:
            return lines + joint_triangles + cop_points
        else:
            return lines + joint_triangles

    # Create the animation object
    ani = FuncAnimation(
        fig, update, frames=len(IK_data), init_func=init_anim, blit=True, repeat=True
    )

    # Save the animation as a GIF
    ani.save(
        f"{output_file}", writer="pillow", fps=100
    )  # You can adjust the fps (frames per second) as needed

    return ani
