import numpy as np
from PIL import Image
import os
import torchvision.transforms as transforms
from tqdm import tqdm

imagenet_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


class StripesDataset:
    def __init__(self, folder: str, angles: list):
        assert os.path.exists(folder)
        self.filenames = [
            os.path.join(folder, f"{orientation_angle}.jpg")
            for orientation_angle in angles
        ]

        for f in self.filenames:
            assert os.path.exists(f), f"Invalid filename: {f}"

    def __getitem__(self, idx):
        return Image.open(self.filenames[idx])

    def __len__(self):
        return len(self.filenames)


def generate_sinusoidal_waves(size, wave_length, orientation_angle):
    # Convert orientation angle from degrees to radians
    orientation_rad = np.deg2rad(orientation_angle)

    # Create a grid of coordinates
    x = np.arange(size)
    y = np.arange(size)
    X, Y = np.meshgrid(x, y)

    # Calculate the sinusoidal pattern
    wave = np.sin(
        2
        * np.pi
        * (X * np.cos(orientation_rad) + Y * np.sin(orientation_rad))
        / wave_length
    )

    # Normalize values to range from 0 to 1
    wave_normalized = (wave + 1) / 2

    # Create 3-channel array by replicating the single channel across all channels
    wave_rgb = np.stack((wave_normalized, wave_normalized, wave_normalized), axis=-1)

    # Convert to black and white image (0 corresponds to black, 1 corresponds to white)
    wave_bw = (wave_rgb * 255).astype(np.uint8)

    # Create PIL image from array
    pil_image = Image.fromarray(wave_bw)

    return pil_image


# stripe_width = 30  # Width of each stripe
# num_stripes = 17  # Number of evenly spaced stripes

size = 224
possible_angles = [i * 22.5 for i in range(0, 8)]
print(possible_angles)
# possible_angles = [i for i in range(0, 180)]
# print(possible_angles)

stripes_folder = "./stripes/"

if not os.path.exists(stripes_folder):  # # create a folder if it doesn't already exist
    os.makedirs(stripes_folder)

os.system(f"rm {stripes_folder}*.jpg")

for orientation_angle in tqdm(possible_angles):
    image_filename = os.path.join(stripes_folder, f"{orientation_angle}.jpg")
    stripe_image = generate_sinusoidal_waves(
        size, wave_length=30, orientation_angle=orientation_angle
    )
    stripe_image.save(image_filename)

print("Dataset generation complete!")
