import numpy as np
from PIL import Image, ImageDraw
import random
import matplotlib.pyplot as plt
import torch


class SyntheticImageDataset:
    def __init__(self, width, height, num_images):
        self.width = width
        self.height = height
        self.num_images = num_images
        self.shapes = ['square', 'circle', 'triangle']
        self.shape_idx = [0,1,2]

    def generate_image(self, shape):
        image = Image.new('RGB', (self.width, self.height), color='white')
        draw = ImageDraw.Draw(image)

        color = tuple(np.random.randint(0, 256, 3))
        size = random.randint(10, min(self.width, self.height) // 2)
        x = random.randint(0, self.width - size)
        y = random.randint(0, self.height - size)

        if shape == 0:
            draw.rectangle([x, y, x + size, y + size], fill=color)
        elif shape == 1:
            draw.ellipse([x, y, x + size, y + size], fill=color)
        elif shape == 2:
            draw.polygon([(x, y + size), (x + size // 2, y), (x + size, y + size)], fill=color)

        return image

    def generate_dataset(self, class_weights=None, test_split=0.5, label_noise=0.0, flatten=True, *args, **kwargs):
        if class_weights is None:
            class_weights = [1 / 3] * len(self.shapes)
        else:
            assert len(class_weights) == len(
                self.shapes), "class_weights must have the same length as the number of shapes"
            assert np.isclose(sum(class_weights), 1), "class_weights must sum to 1"

        images = []
        labels = []
        for _ in range(self.num_images):
            shape = random.choices(self.shape_idx, weights=class_weights, k=1)[0]
            image = self.generate_image(shape)
            images.append(np.array(image))
            labels.append(shape)

        # Convert to numpy arrays
        images = np.array(images)
        labels = np.array(labels)
        if flatten:
            images = images.reshape(images.shape[0], -1)
        # Shuffle the data
        indices = np.arange(self.num_images)
        np.random.shuffle(indices)
        images = images[indices]
        labels = labels[indices]

        # Split the data
        train_split = 1-test_split
        split_idx = int(self.num_images * train_split)
        train_images = images[:split_idx]
        train_labels = labels[:split_idx]
        test_images = images[split_idx:]
        test_labels = labels[split_idx:]

        # Introduce label noise in the training set
        if label_noise > 0:
            num_noisy_labels = int(len(train_labels) * label_noise)
            noisy_indices = np.random.choice(len(train_labels), num_noisy_labels, replace=False)
            for idx in noisy_indices:
                current_label = train_labels[idx]
                new_label = random.choice([shape for shape in self.shapes if shape != current_label])
                train_labels[idx] = new_label
        train_images, test_images = torch.tensor(train_images, dtype=torch.float), torch.tensor(test_images, dtype=torch.float)
        train_labels, test_labels = torch.tensor(train_labels), torch.tensor(test_labels)
        return train_images, test_images, train_labels, test_labels

if __name__=="__main__":
    dataset_generator = SyntheticImageDataset(width=32, height=32, num_images=10000)
    train_images, train_labels, test_images, test_labels = dataset_generator.generate_dataset()

    # Show the first image in the training set
    first_image = train_images[0]
    print(first_image.shape)
    plt.imshow(first_image)
    plt.title(f"Label: {train_labels[0]}")
    plt.axis('off')  # Hide the axis
    plt.show()