import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import Dataset, Subset
from src.neural_nets.model_data.make_dataset import NumpyLoader
import numpy as np
from src.neural_nets.model_data.make_dataset import  HierarchyDatasetGenerator

class HierarchyImageDataset(Dataset):
    def __init__(self, dataset, labels_array, label_mapping, flat=True, orthogonalise=False):
        self.dataset = dataset
        self.labels_array = labels_array
        self.label_mapping = label_mapping  # Add a mapping to transform original labels to new labels
        self.flat = flat
        self.orthogonalise = orthogonalise

        # load the image size
        first_image, _ = self.dataset[0]
        self.image_height, self.image_width = first_image.size()[1], first_image.size()[2]
        
        # For a 3x3 grid, multiply the dimensions by 3
        self.large_image_height = self.image_height * 3
        self.large_image_width = self.image_width * 3

        # Define positions in the 3x3 grid for the 8 classes (leaving one position empty)
        # thses are only used if we orthogonalise the input images
        self.positions = {
            0: (0, 0),
            1: (0, 1),
            2: (0, 2),
            3: (1, 0),
            4: (1, 1),
            # Skipping the middle of the grid (1,2) to leave one position empty
            5: (1, 2),
            6: (2, 0),
            7: (2, 1),
            8: (2, 2)
        }

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        # Remap the original label to new label
        new_label = self.label_mapping[label]
        three_hot_vector = self.labels_array[new_label]
        # Convert to NumPy arrays
        image_np = image.numpy() # Convert image to NumPy array
        image_np = np.transpose(image_np, (1, 2, 0))  # Reorder to (H, W, C), channel last in jax convention

        # optional: Orthogonalisation of input images
        if self.orthogonalise:
            # Create a blank 3xheight x 3xwidth image
            large_image = np.zeros((self.large_image_height, self.large_image_width, image_np.shape[2]), dtype=np.float32)
            # Determine the position for the current image based on its new label
            position = self.positions[new_label]
            # Calculate where to place the image on the large canvas
            start_y = position[0] * self.image_height
            start_x = position[1] * self.image_width
            large_image[start_y:start_y+self.image_height, start_x:start_x+self.image_width, 0] = image_np[:,:,0]
            image_np = large_image

        # flatten if we need to
        if self.flat:
            image_np = image_np.reshape(-1)  # Flatten the image
        # else:
        #     # add a channel dimension to the end of the image
        #     image_np = np.expand_dims(image_np, axis=-1)

        return image_np, three_hot_vector

def load_hieararchy_mnist(batch_size=8, include_headnode=False, flat=True, orthogonalise=False, normalise=False):
    # Create the hierarchy dataset generator for "3-hot" labels
    hierarchy_generator = HierarchyDatasetGenerator(include_headnode=include_headnode, 
                                                    include_bias_input=False)
    _, labels = hierarchy_generator.create_dataset()
    # get the mean over labels for OCS
    mean_labels = np.mean(labels, axis=0)
    
    # shuffle the labels
    np.random.shuffle(labels)

    # Define transforms for the MNIST dataset
    if normalise:
        mnist_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    else:
        mnist_transform = transforms.Compose([
            transforms.ToTensor()
        ])

    # Load the full MNIST training dataset
    full_train_dataset = MNIST(root='./data', train=True, download=True, transform=mnist_transform)

    # # Filter the dataset to keep only labels 0 to 7
    # subset_indices = [i for i, (image, label) in enumerate(full_train_dataset) if label < 8]
    # sample 8 random classes
    random_classes = np.random.choice(range(10), 8, replace=False)
    subset_indices = [i for i, (image, label) in enumerate(full_train_dataset) if label in random_classes]

    # Create a label mapping from original labels to new labels (0-7)
    label_mapping = {original: new for new, original in enumerate(random_classes)}
    # Create a subset
    filtered_train_dataset = Subset(full_train_dataset, subset_indices)

    # Create the custom dataset with "3-hot" vectors
    custom_dataset = HierarchyImageDataset(filtered_train_dataset, 
                                    labels,
                                    label_mapping,
                                    flat=flat,
                                    orthogonalise=orthogonalise)
    # return the train loader
    return mean_labels, NumpyLoader(custom_dataset, batch_size=batch_size)

# a dataset for the imbalanced mnist
class ImbalancedMNISTDataset(Dataset):
    def __init__(self, root, train=True, normalise=False):
        if normalise:
            # Define transforms for the MNIST dataset
            mnist_transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
        else:
            mnist_transform = transforms.Compose([
                transforms.ToTensor()
            ])
        self.mnist_dataset = MNIST(root=root, train=train, download=False, transform=mnist_transform)
        self.random_classes = np.random.choice(10, 2, replace=False)  # Randomly pick 2 distinct classes
        self.filter_and_balance_classes()

    def filter_and_balance_classes(self):
        # Gather indices for the randomly selected classes
        indices = [(i, label) for i, (image, label) in enumerate(self.mnist_dataset) if label in self.random_classes]
        
        # Split indices by class and balance them
        class_indices = [[i for i, label in indices if label == cls] for cls in self.random_classes]
        balanced_indices = class_indices[0] * 2 + class_indices[1]  # Class 0 indices appear twice as often as Class 1
        np.random.shuffle(balanced_indices)  # Shuffle to mix the order of indices
        
        self.indices = balanced_indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # Get the original index
        original_idx = self.indices[idx]
        image, label = self.mnist_dataset[original_idx]
        image_np = image.numpy() # / 255.0  # Convert image to NumPy array and normalise in range [0, 1]
        image_np = np.transpose(image_np, (1, 2, 0))  # Reorder to (H, W, C), channel last in jax conv
        
        # Convert label to one-hot encoding (mapping the original labels to 0 and 1)
        one_hot_label = np.zeros(2, dtype=np.float32)
        one_hot_label[self.random_classes.tolist().index(label)] = 1.0
        
        return image_np, one_hot_label
    
def load_imbalanced_mnist(batch_size=8, normalise=False):
    # Create the imbalanced MNIST dataset
    train_dataset = ImbalancedMNISTDataset(root='./data', train=True, normalise=normalise)
    mean_labels = np.array([2/3, 1/3])
    return mean_labels, NumpyLoader(train_dataset, batch_size=batch_size, shuffle=True)
    

if __name__ == "__main__":
    # Load the hierarchy MNIST dataset
    train_loader = load_hieararchy_mnist()

