import torch
from torchvision.datasets import CelebA
from torchvision import datasets, transforms
from src.neural_nets.model_data.make_dataset import NumpyLoader
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd

def load_hieararchy_celeba(batch_size=8, include_headnode=False, flat=True, orthogonalise=False, normalise=False):
    
    if normalise:
        celeba_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((64, 64)),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    else:
        celeba_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((64, 64))
        ])
    
    # Load the full cifar training dataset
    celeba_dataset = CelebA(root='./data', split='train', target_type='attr', download=False, transform=celeba_transform)
    data_loader = DataLoader(celeba_dataset, batch_size=batch_size, shuffle=True)
    attr_path = './data/celeba/list_attr_celeba.txt'
    attr_data = pd.read_csv(attr_path, delim_whitespace=True, header=1)
    # Convert -1 to 0 for easier computation
    attr_data = (attr_data + 1) // 2
    # Calculate marginal distributions
    mean_labels = attr_data.mean()
    # convert to numpy array
    mean_labels = mean_labels.to_numpy()

    num_steps_per_epoch = len(celeba_dataset) // batch_size
    if len(celeba_dataset) % batch_size != 0:
        num_steps_per_epoch += 1  # Account for the last partial batch if any

    return mean_labels,num_steps_per_epoch, torch_to_numpy(data_loader)

def torch_to_numpy(dataloader):
    for images, attributes in dataloader:
        # Convert images from Torch Tensors to Numpy Arrays
        images_np = images.numpy()  # Shape: [batch_size, C, H, W]
        images_np = np.transpose(images_np, (0, 2, 3, 1)) # Reorder to (H, W, C), channel last in jax convention

        # Convert attributes to binary format and then to Numpy Array
        attributes_np = ((attributes.numpy() + 1) // 2)  # Shape: [batch_size, num_attributes]

        yield images_np, attributes_np

# if __name__ == "__main__":
