import numpy as np
import torch

try:
    import pyspng
except ImportError:
    pyspng = None

# from lrw_dataset import LRWDataset
from torch.utils.data import DataLoader

import os
import random

import torch
import torchvision
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

class FaceForensicsDataset(Dataset):
    def __init__(self, 
                 root_dir, 
                 in_channels=3, 
                 mode="train",  
                 seed=42, 
                 augmentations=False,
                 sample_size=8,
                 frame_interval=4,
                 crop_size=(256, 256),
                 reshape_size=(256, 256)):
        self.root_dirs = [root_dir] if isinstance(root_dir, str) else root_dir  
        self.mode = mode
        self.in_channels = in_channels
        self.seed = seed
        self.augmentation = augmentations
        self.sample_size = sample_size
        self.frame_interval = frame_interval
        self.crop_size = crop_size
        self.reshape_size = reshape_size
        self.data_info = self.build_file_list()
        self.transform = self.build_transform()

        random.seed(self.seed)
        random.shuffle(self.data_info)
        

    def build_file_list(self):
        data_info = []

        for root_dir in self.root_dirs:
            for subdir in os.listdir(root_dir):
                subdir_path = os.path.join(root_dir, subdir)
                if not os.path.isdir(subdir_path):
                    continue
                
                frame_files = sorted([f for f in os.listdir(subdir_path) if f.endswith(".png")])
                frame_paths = [os.path.join(subdir_path, f) for f in frame_files]

                total_frames = len(frame_paths)

                if total_frames < self.sample_size:
                    continue

                for start_frame in range(0, total_frames - self.sample_size * self.frame_interval + 1, 4):
                    sample_frames = []
                    sample_filenames = []

                    for i in range(self.sample_size):
                        frame_idx = start_frame + i * self.frame_interval
                        frame_path = frame_paths[frame_idx]

                        sample_frames.append(frame_path)
                        sample_filenames.append(os.path.basename(frame_path))

                    data_info.append({
                        "frame_paths": sample_frames,
                        "filenames": sample_filenames  # Frame filenames
                    })

        return data_info
    

    
    def build_transform(self):
        if self.augmentation:
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(0.5),
            ])
        else:
            augmentations = transforms.Compose([])

        if self.in_channels == 1:
            transform = transforms.Compose([
                transforms.CenterCrop(self.crop_size),
                transforms.Resize(self.reshape_size),
                augmentations,
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                transforms.Normalize([0.5, ], [0.5, ]),
            ])
        elif self.in_channels == 3:
            transform = transforms.Compose([
                # transforms.CenterCrop(self.crop_size),
                transforms.Resize(self.reshape_size),
                augmentations,
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ])
        return transform

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        data = self.data_info[idx]
        frames = []

        for frame_path in data["frame_paths"]:
            image = Image.open(frame_path)
            image = self.transform(image)
            frames.append(image)

        frames = torch.stack(frames)  # (T, C, H, W) 
        frames = frames.transpose(1, 0)  # (T, C, H, W) -> (C, T, H, W)
        
        # null label tensor
        label = torch.tensor([])
        return {
            "frames": frames,
            "label": label,
            "filenames":data["filenames"]
        }

class Dataset(torch.utils.data.Dataset):
    def __init__(self,
        name,                   # Name of the dataset.
        raw_shape,              # Shape of the raw image data (NCHW).
        max_size    = None,     # max_size limit the size of the dataset. None = no limit. Applied before xflip.
        use_labels  = False,    # Enable conditioning labels? False = label dimension is zero.
        xflip       = False,    # Artificially double the size of the dataset via x-flips. Applied after max_size.
        random_seed = 0,        # Random seed to use when applying max_size.
    ):
        self._name = name
        self._raw_shape = list(raw_shape)
        self._use_labels = use_labels
        self._raw_labels = None
        self._label_shape = None

        # Apply max_size.
        self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
        if (max_size is not None) and (self._raw_idx.size > max_size):
            # np.random.RandomState(random_seed).shuffle(self._raw_idx)
            assert(0)
            self._raw_idx = np.sort(self._raw_idx[:max_size])

        # Apply xflip.
        self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
        if xflip:
            assert(0)
            self._raw_idx = np.tile(self._raw_idx, 2)
            self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])

    def close(self): # to be overridden by subclass
        pass

    def _load_raw_image(self, raw_idx): # to be overridden by subclass
        raise NotImplementedError

    def _load_raw_labels(self, raw_idx): # to be overridden by subclass
        raise NotImplementedError

    def __getstate__(self):
        return dict(self.__dict__, _raw_labels=None)

    def __del__(self):
        try:
            self.close()
        except:
            pass

    def __len__(self):
        return self._raw_idx.size

    def __getitem__(self, idx):
        image = self._load_raw_image(idx)
        label = self._load_raw_labels(idx)
        return image, label

    def get_label(self, idx):
        label = self._load_raw_labels(idx)
        return label

    def get_details(self, idx):
        return {}

    @property
    def name(self):
        return self._name

    @property
    def image_shape(self):
        return list(self._raw_shape[1:])

    @property
    def num_channels(self):
        assert len(self.image_shape) == 4 # TCHW
        return self.image_shape[1]

    @property
    def resolution(self):
        assert len(self.image_shape) == 4 # TCHW
        assert self.image_shape[2] == self.image_shape[3]
        return self.image_shape[2]

    @property
    def label_shape(self):
        return list(self._raw_shape[1:])

    @property
    def label_dim(self):
        # assert len(self.label_shape) == 1
        # return self.label_shape[0]
        return 64

    @property
    def has_labels(self):
        return any(x != 0 for x in self.label_shape)

    @property
    def has_onehot_labels(self):
        return False


class FaceForensics(Dataset):
    def __init__(self,
                vid_length = 16,
                path = None,                   # Path to directory or zip.
                resolution = None, # Ensure specific resolution, None = highest available.
                in_channels = 3,
                **super_kwargs,         # Additional arguments for the Dataset base class.
    ):
        self._path = path
        self._re = FaceForensicsDataset('your dataset path', 
                        in_channels=in_channels, 
                        mode="train", 
                        augmentations=False, 
                        seed=42, 
                        crop_size=(256,256),
                        reshape_size=(256,256))
        self.vid_length = vid_length

        name = 'ffs'
        raw_shape = [len(self._re)] + list(self._load_raw_image(0).shape)

        if resolution is not None and (raw_shape[3] != resolution or raw_shape[4] != resolution):
            raise IOError('Image files do not match the specified resolution')
        super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)

    def close(self):
        pass

    def __getstate__(self):
        return dict(super().__getstate__())

    def _load_raw_image(self, raw_idx):
        video = self._re[raw_idx]['frames'] # CTHW
        video = video.transpose(1,0)[:self.vid_length] # CTHW => TCHW
        return video

    def _load_raw_labels(self, raw_idx):
        labels = self._re[raw_idx]['label'][:self.vid_length]
        return labels # [seq_len, c_dim]
    