import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from skimage.draw import polygon, ellipse
from skimage.transform import rotate
import argparse

from tqdm import tqdm


def draw_shape(shape_type, resolution, scale, orientation_deg):
    """
    Draws a centered shape on a blank canvas (no translation yet).
    Returns a binary mask.
    """
    canvas = np.zeros((resolution, resolution), dtype=np.float32)
    cx, cy = resolution // 2, resolution // 2
    r = resolution // 2 * scale

    if shape_type == "square":
        half = r / np.sqrt(2)
        coords = np.array([
            [cx - half, cy - half],
            [cx - half, cy + half],
            [cx + half, cy + half],
            [cx + half, cy - half]
        ])
        rr, cc = polygon(coords[:, 1], coords[:, 0], shape=canvas.shape)
        canvas[rr, cc] = 1.0

    elif shape_type == "ellipse":
        rr, cc = ellipse(cy, cx, r, r, shape=canvas.shape)
        canvas[rr, cc] = 1.0

    elif shape_type == "heart":
        # Parametric heart shape, rotated
        t = np.linspace(0, 2 * np.pi, 100)
        x = 16 * np.sin(t) ** 3
        y = (13 * np.cos(t) - 5 * np.cos(2 * t) -
             2 * np.cos(3 * t) - np.cos(4 * t))
        x = x / np.max(np.abs(x)) * r + cx
        y = -y / np.max(np.abs(y)) * r + cy
        rr, cc = polygon(y, x, shape=canvas.shape)
        canvas[rr, cc] = 1.0

    else:
        raise ValueError(f"Unknown shape type: {shape_type}")

    # Apply rotation
    if orientation_deg != 0:
        canvas = rotate(canvas, angle=orientation_deg, resize=False, preserve_range=True)

    return canvas


def translate_mask(mask, x, y, resolution):
    """
    Translate the mask to a normalized position (x, y) in [0, 1].
    Handles edge clipping (no wraparound).
    """
    target_x = int(np.round(x * (resolution - 1)))
    target_y = int(np.round(y * (resolution - 1)))

    shape_coords = np.argwhere(mask > 0)
    if shape_coords.size == 0:
        return np.zeros_like(mask)

    min_y, min_x = shape_coords.min(axis=0)
    max_y, max_x = shape_coords.max(axis=0)
    h, w = max_y - min_y, max_x - min_x

    top = target_y - h // 2
    left = target_x - w // 2
    bottom = top + h + 1
    right = left + w + 1

    canvas = np.zeros_like(mask)

    # Compute valid bounds
    mask_top = max(0, -top)
    mask_left = max(0, -left)
    canvas_top = max(0, top)
    canvas_left = max(0, left)
    mask_bottom = mask_top + min(resolution - canvas_top, h + 1 - mask_top)
    mask_right = mask_left + min(resolution - canvas_left, w + 1 - mask_left)

    canvas_bottom = canvas_top + (mask_bottom - mask_top)
    canvas_right = canvas_left + (mask_right - mask_left)

    # Copy only valid region
    cropped_mask = mask[min_y + mask_top:min_y + mask_bottom,
                        min_x + mask_left:min_x + mask_right]
    canvas[canvas_top:canvas_bottom, canvas_left:canvas_right] = cropped_mask
    return canvas


def generate_sprite_image(shape="square", x=0.5, y=0.5, scale=1.0, orientation=0.0,
                          resolution=64, add_noise=False):
    """
    Generates a single dSprite-style image.
    """
    base = draw_shape(shape, resolution, scale, orientation)
    final = translate_mask(base, x, y, resolution)
    if add_noise:
        noise = np.random.normal(0, 0.05, size=final.shape)
        final = np.clip(final + noise, 0, 1)
    return final

def render_from_csv_to_h5(csv_path, output_h5_path, resolution=64, file_name="sprites.h5"):
    df = pd.read_csv(csv_path)

    required = {"index", "shape_im", "x_im", "y_im", "scale_im", "orientation_im"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"CSV is missing required columns: {missing}")

    N = len(df)
    df = df.sort_values("index")  # ensure deterministic ordering

    file_name = os.path.join(output_h5_path, file_name)

    with h5py.File(file_name, "w") as f:
        img_ds = f.create_dataset("images", shape=(N, resolution, resolution), dtype=np.float32)
        meta_grp = f.create_group("metadata")

        for i, row in tqdm(df.iterrows(), total=N, desc="Rendering & writing"):
            img = generate_sprite_image(
                shape=row["shape_im"],
                x=float(row["x_im"]),
                y=float(row["y_im"]),
                scale=float(row["scale_im"]),
                orientation=float(row["orientation_im"]),
                resolution=resolution,
                add_noise=False
            )
            img_ds[i] = img

        for col in df.columns:
            dtype = 'S' if df[col].dtype == object else df[col].dtype
            meta_grp.create_dataset(col, data=df[col].values.astype(dtype))

    print(f"Saved {N} images and metadata to {output_h5_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--csv", type=str, help="CSV file with shape/x/y/scale/orientation")
    parser.add_argument("--output_dir", type=str, default="generated_sprites", help="Directory to save images")
    parser.add_argument("--resolution", type=int, default=64, help="Image resolution")
    args = parser.parse_args()

    if args.csv:
        if not os.path.isfile(args.csv):
            raise FileNotFoundError(f"CSV file not found: {args.csv}")
        print(f"Generating sprites from CSV: {args.csv}")
        render_from_csv_to_h5(args.csv, output_h5_path=args.output_dir, resolution=args.resolution)
    else:
        raise ValueError("--csv must be specified.")
