import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm


def rotation_matrix(angle):
    mat = np.array(
        [
            [np.cos(angle), -np.sin(angle)],
            [np.sin(angle), np.cos(angle)],
        ]
    )

    return mat


def main(args):
    data_dir = Path(args.data_path)

    np.random.seed(args.seed)

    fig, ax = plt.subplots(1, 1)
    scatter = ax.scatter([], [], color="k", marker="o")
    if args.group == "GL2":
        ax.set(xlim=[-2.5, 2.5], ylim=[-2.5, 2.5])
    elif args.group.startswith("C"):
        ax.set(xlim=[-1.2, 1.2], ylim=[-1.2, 1.2])
    ax.axis("equal")

    if args.group.startswith("C"):
        order = args.group.split("C")[1]
        angle = 2 * np.pi / int(order)

        group_elements = [rotation_matrix(i * angle) for i in range(int(order))]

    name = "all_g" if args.all_transforms else "single_g"
    name = f"{name}_seed{args.seed}_subsample{args.subsample}"
    root_dir = data_dir.parent / args.group / name

    for split in ["train", "test"]:
        save_dir = root_dir / split
        png_dir = save_dir / "png"
        png_dir.mkdir(parents=True, exist_ok=True)

        # Temporary bug fix, revert later
        dataset = np.load(
            data_dir / f"{split}/orig_mnist_{split}.npz", allow_pickle=True
        )
        data = dataset["data"]
        labels = dataset["labels"]

        transformed_data = []
        transforms = []
        new_labels = []

        for i, x in enumerate(tqdm(data, desc=split)):
            if i % args.subsample != 0:
                continue

            if args.all_transforms:
                for elem, m in enumerate(group_elements):
                    transformed_x = x @ m.T

                    # Save png
                    scatter.set_offsets(transformed_x)
                    plt.savefig(
                        png_dir / f"{i}_{elem}.png", dpi=100, bbox_inches="tight"
                    )

                    transformed_data.append(transformed_x)
                    transforms.append(m)

                    new_labels.append(labels[i])

            else:
                idx = np.random.choice(len(group_elements))
                matrix = group_elements[idx]

                transformed_x = x @ matrix.T

                # Save png
                scatter.set_offsets(transformed_x)
                plt.savefig(png_dir / f"{i}.png", dpi=100, bbox_inches="tight")

                transformed_data.append(transformed_x)
                transforms.append(matrix)

                new_labels.append(labels[i])

        transformed_data = np.stack(transformed_data, axis=0)
        transforms = np.stack(transforms, axis=0)
        new_labels = np.stack(new_labels, axis=0)

        np.savez(
            save_dir / f"{args.group}_mnist_{split}.npz",
            data=transformed_data,
            transform=transforms,
            labels=new_labels,
        )

    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--data_path", type=str, help="Path to the dataset")
    parser.add_argument(
        "--group",
        type=str,
        help="Symmetry group to use for random augmentations",
    )
    parser.add_argument("--seed", type=int, default=1234, help="Random seed")
    parser.add_argument(
        "--all_transforms",
        action="store_true",
        help=(
            "Perform all transformations for each data sample, only for discrete groups"
        ),
    )
    parser.add_argument(
        "--subsample",
        type=int,
        default=1,
        help=(
            "Transform every (subsample) samples in the dataset, makes dataset smaller."
        ),
    )

    args = parser.parse_args()

    main(args)
