import os
import json
import random
import glob
import torch
from torch.utils.data import Dataset
import numpy as np

from datasets import load_from_disk
import torchvision.transforms as transforms
from io import BytesIO

from PIL import Image
import PIL.Image
try:
    import pyspng
except ImportError:
    pyspng = None


class CustomDataset(Dataset):
    def __init__(self, data_dir):
        PIL.Image.init()
        supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}

        self.images_dir = os.path.join(data_dir, 'images')
        self.features_dir = os.path.join(data_dir, 'vae-sd')

        # images
        self._image_fnames = {
            os.path.relpath(os.path.join(root, fname), start=self.images_dir)
            for root, _dirs, files in os.walk(self.images_dir) for fname in files
            }
        self.image_fnames = sorted(
            fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
            )
        # features
        self._feature_fnames = {
            os.path.relpath(os.path.join(root, fname), start=self.features_dir)
            for root, _dirs, files in os.walk(self.features_dir) for fname in files
            }
        self.feature_fnames = sorted(
            fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
            )
        # labels
        fname = 'dataset.json'
        with open(os.path.join(self.features_dir, fname), 'rb') as f:
            labels = json.load(f)['labels']
        labels = dict(labels)
        labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames]
        labels = np.array(labels)
        self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])


    def _file_ext(self, fname):
        return os.path.splitext(fname)[1].lower()

    def __len__(self):
        assert len(self.image_fnames) == len(self.feature_fnames), \
            "Number of feature files and label files should be same"
        return len(self.feature_fnames)

    def __getitem__(self, idx):
        image_fname = self.image_fnames[idx]
        feature_fname = self.feature_fnames[idx]
        image_ext = self._file_ext(image_fname)
        with open(os.path.join(self.images_dir, image_fname), 'rb') as f:
            if image_ext == '.npy':
                image = np.load(f)
                image = image.reshape(-1, *image.shape[-2:])
            elif image_ext == '.png' and pyspng is not None:
                image = pyspng.load(f.read())
                image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
            else:
                image = np.array(PIL.Image.open(f))
                image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)

        features = np.load(os.path.join(self.features_dir, feature_fname))
        return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])

def get_feature_dir_info(root):
    files = glob.glob(os.path.join(root, '*.npy'))
    files_caption = glob.glob(os.path.join(root, '*_*.npy'))
    num_data = len(files) - len(files_caption)
    n_captions = {k: 0 for k in range(num_data)}
    for f in files_caption:
        name = os.path.split(f)[-1]
        k1, k2 = os.path.splitext(name)[0].split('_')
        n_captions[int(k1)] += 1
    return num_data, n_captions


class DatasetFactory(object):

    def __init__(self):
        self.train = None
        self.test = None

    def get_split(self, split, labeled=False):
        if split == "train":
            dataset = self.train
        elif split == "test":
            dataset = self.test
        else:
            raise ValueError

        if self.has_label:
            return dataset #if labeled else UnlabeledDataset(dataset)
        else:
            assert not labeled
            return dataset

    def unpreprocess(self, v):  # to B C H W and [0, 1]
        v = 0.5 * (v + 1.)
        v.clamp_(0., 1.)
        return v

    @property
    def has_label(self):
        return True

    @property
    def data_shape(self):
        raise NotImplementedError

    @property
    def data_dim(self):
        return int(np.prod(self.data_shape))

    @property
    def fid_stat(self):
        return None

    def sample_label(self, n_samples, device):
        raise NotImplementedError

    def label_prob(self, k):
        raise NotImplementedError

class MSCOCOFeatureDataset(Dataset):
    # the image features are got through sample
    def __init__(self, root):
        self.root = root
        self.num_data, self.n_captions = get_feature_dir_info(root)

    def __len__(self):
        return self.num_data

    def __getitem__(self, index):
        with open(os.path.join(self.root, f'{index}.png'), 'rb') as f:
            x = np.array(PIL.Image.open(f))
            x = x.reshape(*x.shape[:2], -1).transpose(2, 0, 1)

        z = np.load(os.path.join(self.root, f'{index}.npy'))
        k = random.randint(0, self.n_captions[index] - 1)
        c = np.load(os.path.join(self.root, f'{index}_{k}.npy'))
        return x, z, c


class CFGDataset(Dataset):  # for classifier free guidance
    def __init__(self, dataset, p_uncond, empty_token):
        self.dataset = dataset
        self.p_uncond = p_uncond
        self.empty_token = empty_token

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        x, z, y = self.dataset[item]
        if random.random() < self.p_uncond:
            y = self.empty_token
        return x, z, y

class MSCOCO256Features(DatasetFactory):  # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
    def __init__(self, path, cfg=True, p_uncond=0.1, mode='train'):
        super().__init__()
        print('Prepare dataset...')
        if mode == 'val':
            self.test = MSCOCOFeatureDataset(os.path.join(path, 'val'))
            #assert len(self.test) == 40504
            self.empty_context = np.load(os.path.join(path, 'empty_context.npy'))
        else:
            self.train = MSCOCOFeatureDataset(os.path.join(path, 'train'))
            #assert len(self.train) == 82783
            self.empty_context = np.load(os.path.join(path, 'empty_context.npy'))

            if cfg:  # classifier free guidance
                assert p_uncond is not None
                print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
                self.train = CFGDataset(self.train, p_uncond, self.empty_context)

    @property
    def data_shape(self):
        return 4, 32, 32

    @property
    def fid_stat(self):
        return f'assets/fid_stats/fid_stats_mscoco256_val.npz'
    
# ---------------------------------
# CC3M 
# ---------------------------------
class CC3MImageCaptionDataset(Dataset):
    """
    Dataset class for image-caption data stored as Hugging Face `.arrow` format
    """

    def __init__(self, root_dir, size=256, augmentation=False, image_key='jpg', caption_key='txt'):
        """
        Args:
            root_dir (str): Path to directory with .arrow files (e.g., saved via Hugging Face Datasets)
            size (int): Image resize target size
            augmentation (bool): Apply random horizontal flip
        """
        self.dataset = load_from_disk(root_dir)
        self.augmentation = augmentation
        self.image_key = image_key
        self.caption_key = caption_key

        if augmentation:
            self.transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
            ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if hasattr(idx, 'item'):
            idx = idx.item()
        item = self.dataset[idx]

        # Handle different image storage formats
        if isinstance(item[self.image_key], dict) and 'bytes' in item[self.image_key]:
            # Case 1: image stored as bytes
            image = Image.open(BytesIO(item[self.image_key]['bytes'])).convert("RGB")
        elif isinstance(item[self.image_key], str) and os.path.exists(item[self.image_key]):
            # Case 2: image is a path string
            image = Image.open(item[self.image_key]).convert("RGB")
        elif hasattr(item[self.image_key], 'convert'):
            # Case 3: already PIL.Image (less common)
            image = item[self.image_key].convert("RGB")
        else:
            raise ValueError(f"Unsupported image format: {type(item[self.image_key])}")

        image = self.transform(image)
        caption = item[self.caption_key]

        return image, caption
    
class CC3MCFGDataset(Dataset): 
    def __init__(self, dataset, p_uncond, empty_token):
        self.dataset = dataset
        self.p_uncond = p_uncond
        self.empty_token = empty_token

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        x, y = self.dataset[item]
        if random.random() < self.p_uncond:
            y = self.empty_token
        return x, y

class CC3M256Dataset(DatasetFactory): 
    def __init__(self, path, cfg=True, p_uncond=0.1, mode='train'):
        super().__init__()
        print('Prepare dataset...')
        if mode == 'val':
            self.test = CC3MImageCaptionDataset(os.path.join(path, 'val'))
            self.empty_context = np.load(os.path.join(path, 'empty_context.npy'))
        else:
            self.train = CC3MImageCaptionDataset(os.path.join(path, 'train'))
            self.empty_context = np.load(os.path.join("/gpfs/projects/bsc70/bsc193242/Data/coco256_features", 'empty_context.npy'))

            #if cfg:  # classifier free guidance
            #    assert p_uncond is not None
            #    print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
            #    self.train = CC3MCFGDataset(self.train, p_uncond, self.empty_context)