from pathlib import Path

import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
from group_discovery.geometry_2d import polar_decomposition
from group_discovery.geometry_3d import matrix_to_xyz_angles
from group_discovery.logger import get_logger
from group_discovery.utils import (
    fig_to_img,
    sample_random_batch,
    seed_all,
)
from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import trange

expm = torch.linalg.matrix_exp

from mpl_toolkits.mplot3d.art3d import Poly3DCollection


def plot_progression_grid(x_t, tf_t, x_1, titles=None):
    """
    Plot progression of 5 transformations in 3D.
    Shows tetrahedron faces with fixed colors and black edges.
    """
    n_matrices = 5
    nr, nc = n_matrices, x_t.shape[0]  # 5 rows, T columns
    fig = plt.figure(figsize=(1.5 * nc, 1.5 * nr), dpi=300)  # Reduced from 2.0

    # Move tensors to CPU
    x_t = x_t.cpu()
    tf_t = tf_t.cpu()
    x_1 = x_1.cpu()

    max_val = torch.abs(x_1).max().item()

    # Define tetrahedron faces (indices into vertices)
    faces = [
        [0, 1, 2],
        [0, 1, 3],
        [0, 2, 3],
        [1, 2, 3],
    ]
    face_colors = ["red", "green", "blue", "yellow"]
    vertex_colors = ["red", "green", "blue", "yellow"]

    for r in range(n_matrices):
        for c in range(nc):
            ax = fig.add_subplot(nr, nc, r * nc + c + 1, projection="3d")

            # Get vertices and convert to numpy
            verts = x_t[c, r].detach().numpy().astype(float)

            # Create all faces as a single collection
            all_face_verts = []
            all_face_colors = []

            for face, fcolor in zip(faces, face_colors):
                face_verts = verts[face]
                all_face_verts.append(face_verts)
                all_face_colors.append(fcolor)

            # Create the main poly collection with faces but no edges initially
            poly_faces = Poly3DCollection(
                all_face_verts,
                facecolors=all_face_colors,
                alpha=0.6,
                edgecolors="none",  # Disable edges initially
                linewidths=0,
            )
            ax.add_collection3d(poly_faces)

            # Add black edges separately as line segments
            for face in faces:
                face_verts = verts[face]
                # Close the loop by adding the first vertex at the end
                edge_verts = np.vstack([face_verts, face_verts[0:1]])
                ax.plot(
                    edge_verts[:, 0],
                    edge_verts[:, 1],
                    edge_verts[:, 2],
                    color="black",
                    linewidth=1.0,
                )

            # Add colored vertices in order: red, green, blue, yellow
            for v_idx, v_color in enumerate(vertex_colors):
                ax.scatter(
                    verts[v_idx, 0],
                    verts[v_idx, 1],
                    verts[v_idx, 2],
                    color=v_color,
                    s=30,
                    edgecolors="black",
                    linewidths=0.5,
                    zorder=10,
                )

            # Reference tetrahedron
            ref_verts = x_1[r].detach().numpy().astype(float)

            # Create reference faces with very transparent colors
            ref_face_verts = []
            ref_face_colors = []

            for face, fcolor in zip(faces, face_colors):
                face_verts = ref_verts[face]
                ref_face_verts.append(face_verts)
                ref_face_colors.append(fcolor)

            # Draw reference with very transparent colored faces
            ref_poly = Poly3DCollection(
                ref_face_verts,
                facecolors=ref_face_colors,
                alpha=0.1,
                edgecolors="lightgray",
                linewidths=1.2,
                linestyles="--",
            )
            ax.add_collection3d(ref_poly)

            # Add colored vertices for reference tetrahedron
            for v_idx, v_color in enumerate(vertex_colors):
                ax.scatter(
                    ref_verts[v_idx, 0],
                    ref_verts[v_idx, 1],
                    ref_verts[v_idx, 2],
                    color=v_color,
                    s=40,
                    alpha=0.3,  # Make reference vertices more transparent
                    edgecolors="lightgray",
                    linewidths=0.5,
                    zorder=5,
                )

            # Set limits and aspect
            lim = 1.0 * max_val
            ax.set_xlim(-lim, lim)
            ax.set_ylim(-lim, lim)
            ax.set_zlim(-lim, lim)
            ax.set_box_aspect([1, 1, 1])

            # Viewing angle
            ax.view_init(elev=20, azim=45)

            # Remove axes
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_zticks([])
            ax.set_axis_off()

            # Column titles with minimal padding
            if r == 0 and titles is not None and c < len(titles):
                ax.set_title(titles[c], fontsize=14, pad=0)

    # Adjust layout with maximum negative spacing
    plt.subplots_adjust(
        left=0.00,
        right=1.00,
        top=0.97,
        bottom=0.00,
        hspace=-0.18,
        wspace=-0.18,
    )

    fig.canvas.draw()
    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_3D_scatter_grid(x, titles=None):
    nr, nc = x.shape[0], x.shape[1]
    fig = plt.figure(figsize=(1.5 * nc, 1.5 * nr), dpi=150)

    # Define tetrahedron faces (indices into vertices)
    faces = [
        [0, 1, 2],
        [0, 1, 3],
        [0, 2, 3],
        [1, 2, 3],
    ]
    face_colors = ["red", "green", "blue", "yellow"]
    vertex_colors = ["red", "green", "blue", "yellow"]

    # Create subplots manually for 3D projection
    for i in range(nr * nc):
        row = i // nc
        col = i % nc
        # Create 3D subplot
        ax = fig.add_subplot(nr, nc, i + 1, projection="3d")

        # Get vertices
        verts = x[row, col].numpy() if hasattr(x[row, col], "numpy") else x[row, col]

        # Create all faces as a single collection
        all_face_verts = []
        all_face_colors = []

        for face, fcolor in zip(faces, face_colors):
            face_verts = verts[face]
            all_face_verts.append(face_verts)
            all_face_colors.append(fcolor)

        # Create the poly collection with colored faces
        poly_faces = Poly3DCollection(
            all_face_verts,
            facecolors=all_face_colors,
            alpha=0.6,
            edgecolors="none",
            linewidths=0,
        )
        ax.add_collection3d(poly_faces)

        # Add black edges separately as line segments
        for face in faces:
            face_verts = verts[face]
            # Close the loop by adding the first vertex at the end
            edge_verts = np.vstack([face_verts, face_verts[0:1]])
            ax.plot(
                edge_verts[:, 0],
                edge_verts[:, 1],
                edge_verts[:, 2],
                color="black",
                linewidth=1.0,
            )

        # Add colored vertices in order: red, green, blue, yellow
        for v_idx, v_color in enumerate(vertex_colors):
            ax.scatter(
                verts[v_idx, 0],
                verts[v_idx, 1],
                verts[v_idx, 2],
                color=v_color,
                s=40,
                edgecolors="black",
                linewidths=0.5,
                zorder=10,
            )

        # Set equal aspect ratio
        ax.set_box_aspect([1, 1, 1])

        # Set axis limits
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_zlim(-1, 1)

        # Remove all axes elements
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        ax.set_axis_off()

        # Set viewing angle
        ax.view_init(elev=20, azim=45)

        # Add title if provided (only for top row) with minimal padding
        if titles is not None and row == 0 and col < len(titles):
            ax.set_title(titles[col], fontsize=12, pad=-10)

    # Adjust layout with maximum negative spacing
    plt.subplots_adjust(
        left=0.00,
        right=1.00,
        top=1.00,
        bottom=0.00,
        hspace=-0.18,
        wspace=-0.18,
    )

    fig.canvas.draw()
    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_so3_distribution(
    rots: torch.Tensor,
    gt_rots=None,
    fig=None,
    ax=None,
    show_color_wheel: bool = True,
    canonical_rotation=torch.eye(3, dtype=torch.float32),
):
    """
    Map intrinsic xyz Euler angles to Mollweide projection.

    - X-axis: yaw angle (Z rotation)
    - Y-axis: pitch angle (Y rotation)
    - Color: roll angle (X rotation)
    """
    if ax is None:
        fig = plt.figure(figsize=(8, 4), dpi=400)
        fig.subplots_adjust(left=0.10, bottom=0.12, right=0.90, top=0.95)
        ax = fig.add_subplot(111, projection="mollweide")

    cmap = plt.cm.hsv
    rots = rots.cpu()
    rots = rots @ canonical_rotation
    scatterpoint_scaling = 3

    # Convert to XYZ Euler angles
    roll, pitch, yaw = matrix_to_xyz_angles(rots)

    # Add small random jitter to prevent exact overlaps
    jitter_scale = 0.02  # Adjust this to control jitter amount
    yaw_jittered = yaw + np.random.rand(len(yaw)) * jitter_scale
    pitch_jittered = pitch + np.random.rand(len(pitch)) * jitter_scale * 0.5

    # Color based on roll angle
    # Normalize roll from [-π, π] to [0, 1] for colormap
    colors = cmap((roll + np.pi) / (2.0 * np.pi))
    if colors.ndim == 1:
        colors = colors.reshape(1, -1)
    colors[:, -1] = 0.6  # Set alpha to 0.6 for regular points

    # Display the distribution
    ax.scatter(
        yaw_jittered,
        pitch_jittered,
        s=scatterpoint_scaling,
        c=colors,
    )

    # Handle ground truth rotations if provided
    if gt_rots is not None:
        if gt_rots.dim() == 2:
            gt_rots = gt_rots.unsqueeze(0)
        gt_rots = gt_rots @ canonical_rotation

        # Get XYZ angles for ground truth
        roll_gt, pitch_gt, yaw_gt = matrix_to_xyz_angles(gt_rots)

        # Add jitter to GT points
        yaw_gt_jittered = yaw_gt + np.random.randn(len(yaw_gt)) * 3 * jitter_scale
        pitch_gt_jittered = (
            pitch_gt + np.random.randn(len(pitch_gt)) * 3 * jitter_scale * 0.5
        )

        # Color based on roll angle
        colors_gt = cmap((roll_gt + np.pi) / (2.0 * np.pi))
        if colors_gt.ndim == 1:
            colors_gt = colors_gt.reshape(1, -1)
        colors_gt[:, -1] = 1.0

        ax.scatter(
            yaw_gt_jittered,
            pitch_gt_jittered,
            s=60,  # Larger for ground truth
            c=colors_gt,
            edgecolors="black",
            marker="o",
            linewidth=1.0,
            zorder=10,
            alpha=0.5,
        )

    ax.grid()
    ax.set_xticklabels([])
    ax.tick_params(axis="both", which="major", labelsize=8)
    ax.tick_params(axis="both", which="minor", labelsize=8)

    # Add labels
    ax.set_xlabel("Yaw (Z rotation)", fontsize=14)
    ax.set_ylabel("Pitch (Y rotation)", fontsize=14)

    if show_color_wheel:
        # Add a color wheel showing the roll angle to color conversion
        ax_cw = fig.add_axes([0.81, 0.15, 0.15, 0.15], projection="polar")
        theta = np.linspace(0, 2 * np.pi, 200)
        radii = np.linspace(0.5, 0.6, 2)
        theta_grid, _ = np.meshgrid(theta, radii)

        # Map the color wheel to show roll angle range [-π, π]
        colormap_val = theta_grid / (2 * np.pi)
        ax_cw.pcolormesh(theta, radii, colormap_val, cmap=cmap, shading="auto")
        ax_cw.set_yticklabels([])

        # Set tick labels for roll angles
        tick_angles = np.linspace(0, 2 * np.pi, 8, endpoint=False)
        ax_cw.set_xticks(tick_angles)
        ax_cw.tick_params(axis="x", pad=-1)
        ax_cw.set_xticklabels(
            [
                "-180°",
                "-135°",
                "-90°",
                "-45°",
                "0°",
                "45°",
                "90°",
                "135°",
            ],
            fontsize=10,
        )
        ax_cw.spines["polar"].set_visible(False)

        # Add label for roll
        plt.text(
            0.5,
            0.5,
            "Roll (X)",
            fontsize=12,
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax_cw.transAxes,
        )

    fig.canvas.draw()
    img = fig_to_img(fig)
    return img


def plot_x_t_centroids(x_t_all, t_values, titles=None):
    """Plot trajectories over time for matrix flow.

    Plot only centroids of x_t_all
    """
    fig, ax = plt.subplots(figsize=(12, 6))

    t_vals = t_values.cpu().numpy()
    x_t_all = x_t_all.cpu().numpy()  # [T, B, N, 2]

    _, B = x_t_all.shape[:2]

    x_centroid = x_t_all.mean(axis=2)  # [T, B, 2]

    x_range = x_centroid[:, :, 0].max() - x_centroid[:, :, 0].min()
    shift_scale = x_range * 3

    x_centroid_shifted = x_centroid.copy()
    x_centroid_shifted[:, :, 0] += shift_scale * t_vals.reshape(-1, 1)

    ax.set_aspect("equal")
    ax.axis("off")

    y_max = x_centroid[:, :, 1].max()

    if titles is not None:
        ax.text(0, y_max * 1.3, titles[0], fontsize=16, ha="center", color="black")
        ax.text(
            shift_scale, y_max * 1.3, titles[1], fontsize=16, ha="center", color="black"
        )

    for i in range(B):
        # Plot trajectory as a line
        ax.plot(
            x_centroid_shifted[:, i, 0],
            x_centroid_shifted[:, i, 1],
            "-",
            linewidth=1,
            alpha=0.1,
            color="orange",
        )
        # Start and end points
        ax.plot(
            x_centroid_shifted[0, i, 0],
            x_centroid_shifted[0, i, 1],
            "o",
            color="blue",
            alpha=0.5,
            markersize=4,
        )
        ax.plot(
            x_centroid_shifted[-1, i, 0],
            x_centroid_shifted[-1, i, 1],
            "o",
            color="blue",
            alpha=0.5,
            markersize=4,
        )

    ax.set_xlabel("Time")

    fig.canvas.draw()
    img_rgb = fig_to_img(fig)

    return img_rgb


@hydra.main(version_base=None, config_path="../conf", config_name="train")
def main(cfg):
    log = get_logger(__name__)

    # Seed
    seed_all(cfg.seed)
    torch.backends.cudnn.deterministic = True

    # Set up wandb and make checkpoints dir
    wandb_config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
    logger = instantiate(
        cfg.logger, wandb_config=wandb_config, _convert_="all", _recursive_=False
    )

    # Datasets/Dataloaders
    train_dataset = instantiate(cfg.dataset.train)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=cfg.dataset.batch_size,
        shuffle=True,
        num_workers=cfg.dataset.num_workers,
    )
    if "Random" in cfg.dataset.train.dist.base_dist._target_:
        dist = train_dataset.dist
        test_dataset = instantiate(cfg.dataset.test, dist=dist)
    else:
        test_dataset = instantiate(cfg.dataset.test)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=cfg.dataset.batch_size,
        shuffle=False,
        num_workers=cfg.dataset.num_workers,
    )

    # Save dir
    (Path(cfg.save_dir) / "ckpt").mkdir(parents=True, exist_ok=True)

    # Model
    model = instantiate(cfg.model).to(cfg.device)

    # Training
    optimizer = instantiate(cfg.optimizer, model.parameters())

    # For test progression
    # times = nonlinear_indices(cfg.test.n_steps, 8, power=1.5)
    times = np.linspace(0, cfg.test.n_steps, 11)

    for epoch in trange(cfg.train.epochs, desc="Train"):
        # Train
        train_results = model.train_net(train_dataloader, optimizer)
        logger.log_all("train", train_results, {"epoch": epoch})

        # Test
        test_results = model.eval_net(test_dataloader)
        logger.log_all("test", test_results, {"epoch": epoch})

        if (epoch + 1) % (cfg.test.epoch_interval) == 0:
            # Plot progression
            B = 5
            test_batch = sample_random_batch(test_dataset, B, cfg.device)

            # [T + 1, B, N, D]
            x_t, _, tf_t = model.sample_all(
                test_batch, cfg.test.n_steps, return_transform=True
            )
            x_t = x_t[times]
            tf_t = tf_t[times]

            titles = [f"t={t / cfg.test.n_steps}" for t in times]
            progression_img = plot_progression_grid(
                x_t, tf_t, test_batch, titles + ["Orig"]
            )
            logger.log_image(
                "test/progression",
                progression_img,
                {"epoch": epoch},
            )

            # Visualize random samples
            B = 48
            test_batch = sample_random_batch(test_dataset, B, cfg.device)

            test_samples = model.sample(
                test_batch, cfg.test.n_steps, return_transform=False
            )
            test_samples = test_samples.reshape(6, 8, *test_samples.shape[1:])

            samples_img = plot_3D_scatter_grid(test_samples.cpu().numpy())

            logger.log_image(
                "test/samples",
                samples_img,
                {"epoch": epoch},
            )

            # Plot distribution
            B = 1_000
            test_batch, test_tf = sample_random_batch(
                test_dataset, B, cfg.device, return_transform=True
            )
            _, orig_tf, transforms = model.sample(
                test_batch, cfg.test.n_steps, return_transform=True
            )

            if model.prior_dist.group == "GL(3,C)":
                # Check for reflections (negative determinant)
                dets = torch.linalg.det(transforms)
                n_reflections = (dets.real < 0).sum().item()

                logger.log_all(
                    "test", {"reflection_ratio": n_reflections / B}, {"epoch": epoch}
                )

            # Extract angles
            canonicalized_transforms = (
                test_tf.transpose(-2, -1) @ orig_tf.transpose(-2, -1) @ transforms
            ).transpose(-2, -1)
            R, P = polar_decomposition(canonicalized_transforms)
            gt_rots = test_dataset.dist.base_dist.locs

            angles_hist_img = plot_so3_distribution(R, gt_rots=gt_rots)
            logger.log_image(
                "test/learned_transforms_distribution",
                angles_hist_img,
                {"epoch": epoch},
            )

            # Plot the Frobenius norm of (P - I)
            I = torch.eye(P.shape[-1], device=cfg.device).unsqueeze(0)  # [1, D, D]
            P_norms = torch.linalg.norm(P - I, ord="fro", dim=(1, 2)).mean()
            logger.log_all(
                "test",
                {"norm(rem - I)": P_norms.item()},
                {"epoch": epoch},
            )

            # Plot trajectories over time
            B = 300
            test_batch = sample_random_batch(test_dataset, B, cfg.device)
            x_t_all = model.sample_all(test_batch, cfg.test.n_steps)
            t_values = torch.linspace(0, 1, cfg.test.n_steps + 1).to(cfg.device)

            traj_img = plot_x_t_centroids(
                x_t_all, t_values, [model.prior_dist.group, cfg.dataset.group]
            )
            logger.log_image(
                "test/prob_path",
                traj_img,
                {"epoch": epoch},
            )

            # Save model checkpoint
            torch.save(
                model.state_dict(),
                Path(cfg.save_dir) / "ckpt" / f"model.pt",
            )
            torch.save(
                optimizer.state_dict(),
                Path(cfg.save_dir) / "ckpt" / f"optimizer.pt",
            )


if __name__ == "__main__":
    main()
