import numpy as np
import torch

try:
    import pyspng
except ImportError:
    pyspng = None

# from lrw_dataset import LRWDataset
from torch.utils.data import DataLoader

import os
import random

import torch
import torchvision
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import re

class KTHDataset(Dataset):
    def __init__(self, 
                root_dir, 
                in_channels=1, 
                mode="train",  
                sample_size=16,
                frame_interval=8,
                augmentations=False,
                resolution = None,
                seed=42, 
                query=None,
                crop_size=(112,112),
                reshape_size=(112,112)):
        self.root_dir = root_dir
        self.seed = seed
        self.resolution = 128
        self.sample_size = sample_size
        self.frame_interval = frame_interval
        self.in_channels = in_channels
        self.mode = mode
        self.augmentation = augmentations if mode == 'train' else False
        self.data_info = self.build_file_list()
        self.transform = self.build_transform()

    def build_file_list(self):
        data_info = []
        pattern = re.compile(r'person(\d+)_(\w+)_d(\d+)')  

        for root, _, files in os.walk(self.root_dir):
            match = pattern.search(root)
            if match:
                person_id, action, scene_id = match.groups()
                frame_files = sorted([f for f in files if f.startswith("image") and f.endswith(".png")])
                frame_paths = [os.path.join(root, f) for f in frame_files]
                total_frames = len(frame_paths)
                for start_frame in range(0, total_frames - self.sample_size * self.frame_interval + 1):
                    sample_frames = [frame_paths[start_frame + i * self.frame_interval] for i in range(self.sample_size)]
                    data_info.append({
                        "person_id": int(person_id),
                        "action": action,
                        "scene_id": int(scene_id),
                        "frame_paths": sample_frames
                    })

        return data_info

    def build_transform(self):
        if self.augmentation:
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(0.5),
            ])
        else:
            augmentations = transforms.Compose([])

        if self.in_channels == 1:
            transform = transforms.Compose([
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5]),
            ])
        elif self.in_channels == 3:
            transform = transforms.Compose([
                augmentations,
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
        return transform

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        data = self.data_info[idx]
        frames = []

        for frame_path in data["frame_paths"]:
            image = Image.open(frame_path) 
            image = self.transform(image)
            frames.append(image)

        frames = torch.stack(frames) 
        frames = frames.transpose(1, 0)
        label = torch.zeros(100) 
        label[(data["person_id"] - 1) * 4 + (data["scene_id"] - 1)] = 1
        sample = {
            "frames": frames,
            "label": label,
            "person_id": torch.tensor(data["person_id"]),
            "scene_id": torch.tensor(data["scene_id"]),
            "action": data["action"],
        }
        return sample

class Dataset(torch.utils.data.Dataset):
    def __init__(self,
        name,                   # Name of the dataset.
        raw_shape,              # Shape of the raw image data (NCHW).
        max_size    = None,     # max_size limit the size of the dataset. None = no limit. Applied before xflip.
        use_labels  = False,    # Enable conditioning labels? False = label dimension is zero.
        xflip       = False,    # Artificially double the size of the dataset via x-flips. Applied after max_size.
        random_seed = 0,        # Random seed to use when applying max_size.
    ):
        self._name = name
        self._raw_shape = list(raw_shape)
        self._use_labels = use_labels
        self._raw_labels = None
        self._label_shape = None

        # Apply max_size.
        self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
        if (max_size is not None) and (self._raw_idx.size > max_size):
            # np.random.RandomState(random_seed).shuffle(self._raw_idx)
            assert(0)
            self._raw_idx = np.sort(self._raw_idx[:max_size])

        # Apply xflip.
        self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
        if xflip:
            assert(0)
            self._raw_idx = np.tile(self._raw_idx, 2)
            self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])

    def close(self): # to be overridden by subclass
        pass

    def _load_raw_image(self, raw_idx): # to be overridden by subclass
        raise NotImplementedError

    def _load_raw_labels(self, raw_idx): # to be overridden by subclass
        raise NotImplementedError

    def __getstate__(self):
        return dict(self.__dict__, _raw_labels=None)

    def __del__(self):
        try:
            self.close()
        except:
            pass

    def __len__(self):
        return self._raw_idx.size

    def __getitem__(self, idx):
        image = self._load_raw_image(idx)
        label = self._load_raw_labels(idx)
        return image, label

    def get_label(self, idx):
        label = self._load_raw_labels(idx)
        return label

    def get_details(self, idx):
        return {}

    @property
    def name(self):
        return self._name

    @property
    def image_shape(self):
        return list(self._raw_shape[1:])

    @property
    def num_channels(self):
        assert len(self.image_shape) == 4 # TCHW
        return self.image_shape[1]

    @property
    def resolution(self):
        assert len(self.image_shape) == 4 # TCHW
        assert self.image_shape[2] == self.image_shape[3]
        return self.image_shape[2]

    @property
    def label_shape(self):
        return list(self._raw_shape[1:])

    @property
    def label_dim(self):
        # assert len(self.label_shape) == 1
        return self.label_shape[0]

    @property
    def has_labels(self):
        return any(x != 0 for x in self.label_shape)

    @property
    def has_onehot_labels(self):
        return False


class KTH(Dataset):
    def __init__(self,
                vid_length = 16,
                path = None,            # Path to directory or zip.
                resolution = None,      # Ensure specific resolution, None = highest available.
                in_channels = 3,
                **super_kwargs,         # Additional arguments for the Dataset base class.
    ):
        self._path = path
        self._kth = KTHDataset('your dataset path', 
                        in_channels=in_channels, 
                        mode="train", 
                        augmentations=False, 
                        seed=42, 
                        query=None,
                        crop_size=(128,128),
                        reshape_size=(128,128))
        self.vid_length = vid_length

        name = 'kth_walking'
        raw_shape = [len(self._kth)] + list(self._load_raw_image(0).shape)

        if resolution is not None and (raw_shape[3] != resolution or raw_shape[4] != resolution):
            raise IOError('Image files do not match the specified resolution')
        super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)

    def close(self):
        pass

    def __getstate__(self):
        return dict(super().__getstate__())

    def _load_raw_image(self, raw_idx):
        video = self._kth[raw_idx]['frames'] # CTHW
        video = video.transpose(1,0)[:self.vid_length] # CTHW => TCHW
        return video

    def _load_raw_labels(self, raw_idx):
        labels = self._kth[raw_idx]['label']
        return labels