import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from svgpathtools import parse_path, wsvg
from tqdm import tqdm


def path_length(path):
    """
    Calculate the total length of a path.

    Parameters:
    - path: svgpathtools.Path object

    Returns:
    - Length of the path
    """
    return sum(segment.length() for segment in path)


def approximate_path_to_points(path, num_points):
    """
    Approximate an SVG path into a specified number of points.

    Parameters:
    - path: svgpathtools.Path object
    - num_points: Number of points to approximate the path

    Returns:
    - points: List of points along the path
    """
    points = np.array([path.point(t) for t in np.linspace(0, 1, num_points)])
    return np.stack([points.real, -points.imag], axis=-1)


def get_paths(attributes):
    paths = []
    for attribute in attributes:
        path = parse_path(attribute["d"]).scaled(1, -1)
        paths.append(path)
    return paths


def main(args):
    mnist = np.load(args.data_path, encoding="latin1", allow_pickle=True)

    data_dir = Path(args.data_path).parent / f"{args.num_points}_vectors" / "orig"

    fig, ax = plt.subplots(1, 1)
    scatter = ax.scatter([], [], color="k", marker="o")
    ax.set(xlim=[-1.2, 1.2], ylim=[-1.2, 1.2])
    ax.axis("equal")

    # Train
    for split in ["train", "test"]:
        cur_dir = data_dir / split
        svg_dir = cur_dir / "svg"
        svg_dir.mkdir(parents=True, exist_ok=True)
        png_dir = cur_dir / "png"
        png_dir.mkdir(parents=True, exist_ok=True)
        data = []

        for i, x in enumerate(tqdm(mnist[f"{split}_data"], desc=split)):
            # Parse d-string
            paths = get_paths(x)

            # Save SVG file
            wsvg(paths=paths, filename=f"{svg_dir}/{i}.svg", openinbrowser=False)

            # Calculate the total length of all paths
            total_length = sum(path_length(path) for path in paths)

            # Calculate the number of points for each path
            num_points_per_path = [
                int(args.num_points * path_length(path) / total_length)
                for path in paths
            ]

            # Adjust the total number of points to exactly args.num_points by distributing any remainder
            remainder = args.num_points - sum(num_points_per_path)
            for i in range(remainder):
                num_points_per_path[i % len(paths)] += 1

            # Approximate each path with the calculated number of points
            all_points = []
            for path, num_points in zip(paths, num_points_per_path):
                points = approximate_path_to_points(path, num_points)
                all_points.append(points)

            all_points = np.concatenate(all_points)
            data.append(all_points)

        data = np.stack(data, axis=0)

        bounds = [data.min(), data.max()]

        # Normalize data to [-1, 1]
        data = 2 * (data - bounds[0]) / (bounds[1] - bounds[0]) - 1

        # Save png file of points
        for i, x in enumerate(tqdm(data, desc="png")):
            if (png_dir / f"{i}.png").exists():
                continue
            scatter.set_offsets(x)
            plt.savefig(png_dir / f"{i}.png", dpi=100, bbox_inches="tight")

        np.savez(
            cur_dir / f"orig_mnist_{split}.npz",
            data=data,
            labels=mnist[f"{split}_labels"],
        )

    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--data_path",
        type=str,
        default="data/mnist.svg.npz",
        help="Path to the dataset",
    )
    parser.add_argument(
        "--num_points",
        type=int,
        default=200,
        help="Number of points to sample from SVG",
    )

    args = parser.parse_args()

    main(args)
