import numpy as np
import torch

try:
    import pyspng
except ImportError:
    pyspng = None

# from lrw_dataset import LRWDataset
from torch.utils.data import DataLoader

import os
import random

import torch
import torchvision
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

class RealEstateDataset(Dataset):
    def __init__(self, 
                 root_dir, 
                 in_channels=3, 
                 mode="train",  
                 seed=42, 
                 augmentations=False,
                 sample_size=8,
                 frame_interval=8,
                 crop_size=(256, 256),
                 reshape_size=(256, 256)):
        self.root_dirs = [root_dir] if isinstance(root_dir, str) else root_dir  
        self.mode = mode
        self.in_channels = in_channels
        self.seed = seed
        self.augmentation = augmentations
        self.sample_size = sample_size
        self.frame_interval = frame_interval
        self.crop_size = crop_size
        self.reshape_size = reshape_size
        self.data_info = self.build_file_list_simple()
        self.transform = self.build_transform()

        random.seed(self.seed)
        random.shuffle(self.data_info)
        

    def build_file_list(self):
        data_info = []
        
        for root_dir in self.root_dirs:
            annotation_dir = os.path.join(root_dir, "annotation")
            data_dir = os.path.join(root_dir, "data", self.mode)

            for subdir in os.listdir(data_dir):
                subdir_path = os.path.join(data_dir, subdir)
                if not os.path.isdir(subdir_path):
                    continue
                
                # Get frame file paths
                frame_files = sorted([f for f in os.listdir(subdir_path) if f.endswith(".png")])
                frame_paths = [os.path.join(subdir_path, f) for f in frame_files]
                total_frames = len(frame_paths)
                
                if total_frames < self.sample_size:
                    continue
                
                # Find corresponding annotation file
                annotation_file = os.path.join(annotation_dir, f"{subdir}.txt")
                if not os.path.exists(annotation_file):
                    print(f"Annotation file not found for {subdir}")
                    continue
                
                # Parse annotation file
                with open(annotation_file, "r") as f:
                    annotations = [line.strip().split() for line in f.readlines()]
                    annotations = annotations[1:] # exclude url
                
                if len(annotations) != total_frames:
                    print(f"Mismatch between frames and annotations for {subdir}")
                    continue
                
                for start_frame in range(0, total_frames - self.sample_size * self.frame_interval + 1, 3): #14
                    sample_frames = []
                    sample_annotations = []
                    sample_timestamps = []
                    sample_filenames = []
                    
                    for i in range(self.sample_size):
                        frame_idx = start_frame + i * self.frame_interval
                        frame_path = frame_paths[frame_idx]
                        annotation = annotations[frame_idx]
                        
                        sample_frames.append(frame_path)
                        sample_annotations.append(list(map(float, annotation[1:])))  # Exclude timestamp, convert to float
                        sample_timestamps.append(int(annotation[0]))    # Include timestamp as int
                        sample_filenames.append(os.path.basename(frame_path))
                    
                    sample_annotations_tensor = torch.tensor(sample_annotations, dtype=torch.float32)
                    data_info.append({
                        "id": root_dir[-2] + '_' + subdir + '_' + str(start_frame), # unique_id
                        "frame_paths": sample_frames,
                        "camera_params": sample_annotations_tensor,  # 18 columns
                        "timestamps": sample_timestamps,   # All timestamps as strings
                        "filenames": sample_filenames      # Frame filenames
                    })
                    
                     # Augmented (reversed) sample
                    if self.augmentation:
                        reversed_frames = sample_frames[::-1]
                        reversed_annotations = sample_annotations[::-1]
                        reversed_timestamps = sample_timestamps[::-1]
                        reversed_filenames = sample_filenames[::-1]
                        
                        data_info.append({
                            "id": root_dir[-2] + '_' + subdir + '_' + str(start_frame) + '_reversed',
                            "frame_paths": reversed_frames,
                            "camera_params": torch.tensor(reversed_annotations, dtype=torch.float32),
                            "timestamps": reversed_timestamps,
                            "filenames": reversed_filenames
                        })

        return data_info

    
    def build_file_list_simple(self):
        data_info = []
        
        for root_dir in self.root_dirs:
            data_dir = os.path.join(root_dir, "data", self.mode)

            for subdir in os.listdir(data_dir):
                subdir_path = os.path.join(data_dir, subdir)
                if not os.path.isdir(subdir_path):
                    continue
                
                # Get frame file paths
                frame_files = sorted([f for f in os.listdir(subdir_path) if f.endswith(".png")])
                frame_paths = [os.path.join(subdir_path, f) for f in frame_files]
                total_frames = len(frame_paths)
                
                if total_frames < self.sample_size:
                    continue
                
                for start_frame in range(0, total_frames - self.sample_size * self.frame_interval + 1, 3): #14
                    sample_frames = []
                    sample_annotations = []
                    sample_timestamps = []
                    sample_filenames = []
                    
                    for i in range(self.sample_size):
                        frame_idx = start_frame + i * self.frame_interval
                        frame_path = frame_paths[frame_idx]
                        sample_frames.append(frame_path)
                        sample_filenames.append(os.path.basename(frame_path))
                    
                    data_info.append({
                        "id": root_dir[-2] + '_' + subdir + '_' + str(start_frame), # unique_id
                        "frame_paths": sample_frames,
                        "filenames": sample_filenames      # Frame filenames
                    })

        return data_info
    
    
    def build_file_list_augment(self):
        data_info = []
        
        for root_dir in self.root_dirs:
            annotation_dir = os.path.join(root_dir, "annotation")
            data_dir = os.path.join(root_dir, "data", self.mode)

            for subdir in os.listdir(data_dir):
                subdir_path = os.path.join(data_dir, subdir)
                if not os.path.isdir(subdir_path):
                    continue
                
                # Get frame file paths
                frame_files = sorted([f for f in os.listdir(subdir_path) if f.endswith(".png")])
                frame_paths = [os.path.join(subdir_path, f) for f in frame_files]
                total_frames = len(frame_paths)
                
                if total_frames < self.sample_size:
                    continue
                
                # Find corresponding annotation file
                annotation_file = os.path.join(annotation_dir, f"{subdir}.txt")
                if not os.path.exists(annotation_file):
                    print(f"Annotation file not found for {subdir}")
                    continue
                
                # Parse annotation file
                with open(annotation_file, "r") as f:
                    annotations = [line.strip().split() for line in f.readlines()]
                    annotations = annotations[1:]  # exclude url
                
                if len(annotations) != total_frames:
                    print(f"Mismatch between frames and annotations for {subdir}")
                    continue
                
                # Iterate over frames with a custom starting point for reversed samples
                for start_frame in range(self.sample_size - 1, total_frames - self.sample_size + 1, 6):
                    # Normal sample (forward direction)
                    normal_frames = [
                        frame_paths[start_frame + i * self.frame_interval]
                        for i in range(self.sample_size)
                    ]
                    normal_annotations = [
                        list(map(float, annotations[start_frame + i * self.frame_interval][1:]))
                        for i in range(self.sample_size)
                    ]
                    normal_timestamps = [
                        int(annotations[start_frame + i * self.frame_interval][0])
                        for i in range(self.sample_size)
                    ]
                    normal_filenames = [
                        os.path.basename(frame_paths[start_frame + i * self.frame_interval])
                        for i in range(self.sample_size)
                    ]

                    data_info.append({
                        "id": root_dir[-2] + '_' + subdir + '_' + str(start_frame),
                        "frame_paths": normal_frames,
                        "camera_params": torch.tensor(normal_annotations, dtype=torch.float32),
                        "timestamps": normal_timestamps,
                        "filenames": normal_filenames
                    })
                    
                    # Augmented sample (reversed direction)
                    if self.augmentation:
                        reversed_frames = [
                            frame_paths[start_frame - i * self.frame_interval]
                            for i in range(self.sample_size)
                        ]
                        reversed_annotations = [
                            list(map(float, annotations[start_frame - i * self.frame_interval][1:]))
                            for i in range(self.sample_size)
                        ]
                        reversed_timestamps = [
                            int(annotations[start_frame - i * self.frame_interval][0])
                            for i in range(self.sample_size)
                        ]
                        reversed_filenames = [
                            os.path.basename(frame_paths[start_frame - i * self.frame_interval])
                            for i in range(self.sample_size)
                        ]

                        data_info.append({
                            "id": root_dir[-2] + '_' + subdir + '_' + str(start_frame) + '_reversed',
                            "frame_paths": reversed_frames,
                            "camera_params": torch.tensor(reversed_annotations, dtype=torch.float32),
                            "timestamps": reversed_timestamps,
                            "filenames": reversed_filenames
                        })

        return data_info
    
    def find_index_by_name(self, id):
        for idx, sample in enumerate(self.data_info):
            if sample["id"] == id:
                return idx
        return -1
    
    def build_transform(self):
        if self.augmentation:
            augmentations = transforms.Compose([])
            # augmentations = transforms.Compose([
            #     transforms.RandomHorizontalFlip(0.5),
            # ])
        else:
            augmentations = transforms.Compose([])

        if self.in_channels == 1:
            transform = transforms.Compose([
                transforms.CenterCrop(self.crop_size),
                transforms.Resize(self.reshape_size),
                augmentations,
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                transforms.Normalize([0.5, ], [0.5, ]),
            ])
        elif self.in_channels == 3:
            transform = transforms.Compose([
                transforms.CenterCrop(self.crop_size),
                transforms.Resize(self.reshape_size),
                augmentations,
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ])
        return transform

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        data = self.data_info[idx]
        frames = []

        for frame_path in data["frame_paths"]:
            image = Image.open(frame_path)
            image = self.transform(image)
            frames.append(image)

        frames = torch.stack(frames)  # (T, C, H, W) 
        frames = frames.transpose(1, 0)  # (T, C, H, W) -> (C, T, H, W)
        
        # null label tensor
        label = torch.tensor([])
        
        return {
            "label": label,
            "frames": frames,
        }
        
        # return {
        #     "frames": frames,
        #     "label": data["camera_params"],
        #     "timestamps": data["timestamps"],
        #     "id": data["id"],
        #     "filenames":data["filenames"]
        # }

class Dataset(torch.utils.data.Dataset):
    def __init__(self,
        name,                   # Name of the dataset.
        raw_shape,              # Shape of the raw image data (NCHW).
        max_size    = None,     # max_size limit the size of the dataset. None = no limit. Applied before xflip.
        use_labels  = False,    # Enable conditioning labels? False = label dimension is zero.
        xflip       = False,    # Artificially double the size of the dataset via x-flips. Applied after max_size.
        random_seed = 0,        # Random seed to use when applying max_size.
    ):
        self._name = name
        self._raw_shape = list(raw_shape)
        self._use_labels = use_labels
        self._raw_labels = None
        self._label_shape = None

        # Apply max_size.
        self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
        if (max_size is not None) and (self._raw_idx.size > max_size):
            # np.random.RandomState(random_seed).shuffle(self._raw_idx)
            assert(0)
            self._raw_idx = np.sort(self._raw_idx[:max_size])

        # Apply xflip.
        self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
        if xflip:
            assert(0)
            self._raw_idx = np.tile(self._raw_idx, 2)
            self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])

    def close(self): # to be overridden by subclass
        pass

    def _load_raw_image(self, raw_idx): # to be overridden by subclass
        raise NotImplementedError

    def _load_raw_labels(self, raw_idx): # to be overridden by subclass
        raise NotImplementedError

    def __getstate__(self):
        return dict(self.__dict__, _raw_labels=None)

    def __del__(self):
        try:
            self.close()
        except:
            pass

    def __len__(self):
        return self._raw_idx.size

    def __getitem__(self, idx):
        image = self._load_raw_image(idx)
        label = self._load_raw_labels(idx)
        return image, label

    def get_label(self, idx):
        label = self._load_raw_labels(idx)
        return label

    def get_details(self, idx):
        return {}

    @property
    def name(self):
        return self._name

    @property
    def image_shape(self):
        return list(self._raw_shape[1:])

    @property
    def num_channels(self):
        assert len(self.image_shape) == 4 # TCHW
        return self.image_shape[1]

    @property
    def resolution(self):
        assert len(self.image_shape) == 4 # TCHW
        assert self.image_shape[2] == self.image_shape[3]
        return self.image_shape[2]

    @property
    def label_shape(self):
        return list(self._raw_shape[1:])

    @property
    def label_dim(self):
        # assert len(self.label_shape) == 1
        # return self.label_shape[0]
        return 64
        
    @property
    def has_labels(self):
        return any(x != 0 for x in self.label_shape)

    @property
    def has_onehot_labels(self):
        return False


class RealEstate(Dataset):
    def __init__(self,
                vid_length = 16,
                path = 'your dataset path',            # Path to directory or zip.
                resolution = None,      # Ensure specific resolution, None = highest available.
                in_channels = 3,
                **super_kwargs,         # Additional arguments for the Dataset base class.
    ):
        self._path = path
        self._re = RealEstateDataset(self._path, 
                        in_channels=in_channels, 
                        mode="train", 
                        augmentations=False, 
                        seed=42, 
                        crop_size=(256,256),
                        reshape_size=(256,256))
        self.vid_length = vid_length

        name = 're'
        raw_shape = [len(self._re)] + list(self._load_raw_image(0).shape)
        if resolution is not None and (raw_shape[3] != resolution or raw_shape[4] != resolution):
            raise IOError('Image files do not match the specified resolution')
        super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)

    def close(self):
        pass

    def __getstate__(self):
        return dict(super().__getstate__())

    def _load_raw_image(self, raw_idx):
        video = self._re[raw_idx]['frames'] # CTHW
        video = video.transpose(1,0)[:self.vid_length] # CTHW => TCHW
        return video

    def _load_raw_labels(self, raw_idx):
        labels = self._re[raw_idx]['label'][:self.vid_length]
        return labels # [seq_len, c_dim]
