import torch
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import Dataset, Subset
from src.neural_nets.model_data.make_dataset import NumpyLoader
import numpy as np
from src.neural_nets.model_data.make_dataset import  HierarchyDatasetGenerator
from src.neural_nets.model_data.load_mnist import HierarchyImageDataset

def load_hieararchy_cifar(batch_size=8, include_headnode=False, flat=True, orthogonalise=False, normalise=False):
    # Create the hierarchy dataset generator for "3-hot" labels
    hierarchy_generator = HierarchyDatasetGenerator(include_headnode=include_headnode, 
                                                    include_bias_input=False)
    _, labels = hierarchy_generator.create_dataset()
    # get the mean over labels for OCS
    mean_labels = np.mean(labels, axis=0)
    
    # shuffle the labels
    np.random.shuffle(labels)
    if normalise:
        cifar_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    else:
        cifar_transform = transforms.Compose([
            transforms.ToTensor()
        ])
    
    # Load the full cifar training dataset
    full_train_dataset = CIFAR10(root='./data', train=True, download=True, transform=cifar_transform)

    # # Filter the dataset to keep only labels 0 to 7
    # subset_indices = [i for i, (image, label) in enumerate(full_train_dataset) if label < 8]
    # sample 8 random classes
    random_classes = np.random.choice(range(10), 8, replace=False)
    subset_indices = [i for i, (image, label) in enumerate(full_train_dataset) if label in random_classes]
    # Create a label mapping from original labels to new labels (0-7)
    label_mapping = {original: new for new, original in enumerate(random_classes)}
    # Create a subset
    filtered_train_dataset = Subset(full_train_dataset, subset_indices)

    # Create the custom dataset with "3-hot" vectors
    custom_dataset = HierarchyImageDataset(filtered_train_dataset, 
                                    labels,
                                    label_mapping,
                                    flat=flat,
                                    orthogonalise=orthogonalise)
    # return the train loader
    return mean_labels, NumpyLoader(custom_dataset, batch_size=batch_size, shuffle=True)

if __name__ == "__main__":
    # Load the hierarchy cifar dataset
    train_loader = load_hieararchy_cifar()

