import torch
import torch.nn as nn
import os
import numpy as np
import random
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms as T
from torchvision.datasets.celeba import CelebA
from PIL import Image


class TwoCropTransform:
    def __init__(self, transform):
        self.transform = transform
    def __call__(self, x):
        return (self.transform(x), self.transform(x))

class CelebAsplit:
    def __init__(self, split, transform, conditional, root = 'data/CelebA/'):
        self.transform = transform
        self.conditional = conditional
        self.root = root + split +'/'

        self.celeba = CelebA(
            root = self.root, 
            split = split, 
            target_type = 'attr',
            transform = transform,
            download = False
        )


        self.target_idx = 2 # target corresponding to 'Attractive'
        #self.bias_idx = [6,7,8] # indices corresponding to 'Big_Lips, Big_Nose, Black_Hair'
        self.bias_idx = [6,8,21] # indices corresponding to 'Big_Lips, Black_Hair, 'Mouth_Slightly_Open'

        self.attr = self.celeba.attr
        self.targets = self.attr[:, self.target_idx]
        self.biases = self.attr[:, self.bias_idx]

        #self.indices = torch.arange(len(self.celeba))

    def __getitem__(self, index):
        img, _ = self.celeba.__getitem__(index)
        target, bias = self.targets[index], self.biases[index]

        if self.conditional:
            return img, target, bias
        else:
            return img, target
    
    def __len__(self):
        return len(self.targets)
    
class UTKFaceDataset(Dataset):
    def __init__(self, conditional, root = '../UTKFace/', transform=None, split='train', test_size=0.2, random_seed=42, embedding_dim = 8):
        self.root_dir = root
        self.transform = transform
        self.image_paths = []
        self.ethnicities = []
        self.biases = []
        self.conditional = conditional
        self.embedding_dim = embedding_dim

        for filename in os.listdir(self.root_dir):
            if filename.endswith('.jpg'):
                # Parse the filename to extract age and gender
                if len(filename.split('_')) != 4:
                    continue
                age, gender, ethnicity, _ = filename.split('_')[:4]
            
                age = float(age)
                gender = int(gender)
                ethnicity = int(ethnicity)

                # Append the image path, target and bias
                self.image_paths.append(os.path.join(self.root_dir, filename))
                self.biases.append((age, gender))
                self.ethnicities.append(ethnicity)
        
        # create index lookup table for embedding
        self.bias_to_idx_lookup = {bias: i for i, bias in enumerate(set(self.biases))}

        # create embedding layer
        self.embedding_layer = nn.Embedding(len(self.bias_to_idx_lookup), self.embedding_dim).requires_grad_(False)


        # Create train-test split indices with a fixed random seed
        num_samples = len(self.image_paths)
        indices = list(range(num_samples))
        random.seed(random_seed)
        random.shuffle(indices)
        split_index = int(num_samples * (1 - test_size))
        
        if split == 'train':
            self.indices = indices[:split_index]
        else:
            self.indices = indices[split_index:]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        idx = self.indices[idx]
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        bias = self.biases[idx]
        ethnicity = self.ethnicities[idx]

        target = torch.tensor(ethnicity)

        embedding = self.embedding_layer(torch.tensor(self.bias_to_idx_lookup[bias]))

        if self.transform is not None:
            image = self.transform(image)
        
        if self.conditional:
            return image, target, (embedding, torch.tensor(list(bias)))
        else:
            return image, target
    
def get_loader(dataset, batch_size, split, conditional, shuffle, num_workers, aug = True, two_crop = True):

    if dataset == 'CelebA':
        img_size = 128
    if dataset == 'UTKFace':
        img_size = 128
    # set transformations
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]

    if split == 'test':
        transform = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
            T.Normalize(means, stds)
        ])
    
    else:
        if aug:
            transform = T.Compose([
            T.RandomResizedCrop(size=img_size, scale=(0.2, 1.)),
            T.RandomHorizontalFlip(),
            #T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            #T.RandomGrayscale(p=0.2),
            T.ToTensor(),
            T.Normalize(means, stds),
        ])
        else:
            transform = T.Compose(
            [
                T.Resize((img_size, img_size)),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(means, stds),
            ]
        )
    
    if two_crop:
        transform = TwoCropTransform(transform)

    if dataset == 'CelebA':
        dataset = CelebAsplit(split = split, transform = transform, conditional = conditional)


    if dataset == 'UTKFace':
        dataset = UTKFaceDataset(split=split, transform=transform, conditional=conditional )

    #dataset = Subset(dataset, range(100))
    dataloader = DataLoader(dataset = dataset, batch_size = batch_size,
                            shuffle = shuffle, num_workers = num_workers,
                            pin_memory = True)
    
    return dataloader



# trainloader = get_loader(dataset= 'UTKFace', batch_size= 10, split = 'train', conditional = True, 
#                                 shuffle = True, num_workers = 2, two_crop = True)

# testloader = get_loader(dataset='UTKFace', batch_size= 10, split = 'test', conditional = False,
#                                 shuffle = True, num_workers = 2, two_crop = False)


# train_batch = next(iter(trainloader))
# test_batch = next(iter(testloader))









