import random
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as TT
from accelerate.logging import get_logger
from torch.utils.data import Dataset, Sampler
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
import os
from PIL import Image
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord  # isort:skip
from torchvision.io import read_image, write_png

decord.bridge.set_bridge("torch")

logger = get_logger(__name__)

HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
FRAME_BUCKETS = [16, 24, 32, 48, 64, 80]

# from bs_ergb
def convert_and_fix_event_pixels(data, upper_limit, fix_overflows=True):
    data = data.astype(np.int32)
    overflow_indices = np.where(data > upper_limit*32)
    num_overflows = overflow_indices[0].shape[0]
    if fix_overflows and num_overflows > 0:
        data[overflow_indices] = data[overflow_indices] - 65536
    data = data / 32.0
    data = np.rint(data)
    data = data.astype(np.int16)
    data = np.clip(data, 0, upper_limit)
    return data

class VideoDatasetWithResizingEvent(Dataset):
    def __init__(
        self,
        data_root: str,
        max_num_frames: int = 49,
        voxel_grid_channel: int = 3,
        id_token: Optional[str] = None,
        height_buckets: List[int] = None,
        width_buckets: List[int] = None,
        frame_buckets: List[int] = None,
        load_tensors: bool = False,
        random_flip: Optional[float] = None,
        image_to_video: bool = False,
        eval: bool = False,
        brightness_transform: bool = True
    ) -> None:
        super().__init__()

        self.data_root = data_root.split("+")
        self.max_num_frames = max_num_frames
        self.id_token = id_token or ""
        self.height_buckets = height_buckets or HEIGHT_BUCKETS
        self.width_buckets = width_buckets or WIDTH_BUCKETS
        self.frame_buckets = frame_buckets or FRAME_BUCKETS
        self.load_tensors = load_tensors
        self.random_flip = random_flip
        self.image_to_video = image_to_video
        self.voxel_grid_channel = voxel_grid_channel
        self.resolutions = [
            (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
        ]
        self.eval = eval
        self.video_transforms = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(random_flip)
                if random_flip
                else transforms.Lambda(self.identity_transform),
                transforms.Lambda(self.scale_transform),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
            ]
        )
        
        if brightness_transform:
            brightness_factor = [0, 1.5]
            self.brightness_transform = transforms.ColorJitter(brightness=brightness_factor) 
        else:
            self.brightness_transform = transforms.Lambda(self.identity_transform)


        seq_list_all = []
        for root in self.data_root:
            seq_list = sorted(os.listdir(root))
            seq_list = [os.path.join(root, seq) for seq in seq_list if os.path.isdir(os.path.join(root, seq))]
            # filter our the seq less than self.max_num_frames
            if "BS-ERGB" in root:
                seq_list = [seq for seq in seq_list if len(os.listdir(seq)) >= self.max_num_frames]
            if "MVSEC" in root:
                seq_list = [seq for seq in seq_list if seq.split('/')[-1] in ['indoor_flying1_data', 'indoor_flying2_data', 'indoor_flying3_data', 'indoor_flying4_data', 'outdoor_day1_data', 'outdoor_day2_data']]
            seq_list_all.append(seq_list)
            print(f"Video Dataset: {root}, {len(seq_list)}")

        # balance the dataset based on the number of videos by repeating
        lengths = [len(seq_list) for seq_list in seq_list_all]
        max_length = max(lengths)
        self.video_paths = []
        for i in range(len(seq_list_all)):
            # Calculate the ratio needed to balance this dataset
            ratio = max_length / lengths[i]
            # Add full repeats of the sequence list
            self.video_paths.extend(seq_list_all[i] * int(ratio))
            # Add remaining fraction of the sequence list
            remaining = int(round(lengths[i] * (ratio - int(ratio))))
            self.video_paths.extend(seq_list_all[i][:remaining])
        self.video_paths = [Path(seq) for seq in self.video_paths]
        print(f"Total number of videos: {len(self.video_paths)}")

    def _generate_voxel_grid(self, event, height, width):
        """obtain voxel grid """
        event_start = event[0, 0]
        event_end = event[-1, 0]
        
        ch = (
            event[:, 0].to(torch.float32)
            / (event_end - event_start)
            * self.voxel_grid_channel
        ).long()
        torch.clamp_(ch, 0, self.voxel_grid_channel - 1)
        ex = event[:, 1].long()
        ey = event[:, 2].long()
        ep = event[:, 3].to(torch.float32)
        ep[ep == 0] = -1

        voxel_grid = torch.zeros(
            (self.voxel_grid_channel, height, width), dtype=torch.float32
        )
        voxel_grid.index_put_((ch, ey, ex), ep, accumulate=True)

        return voxel_grid


    def _find_nearest_resolution(self, height, width):
        nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
        return nearest_res[1], nearest_res[2]
    
    def convert_event_to_rgb(self, events, height, width):
        voxel_grid = self._generate_voxel_grid(events, height, width)
        assert voxel_grid.shape[0] == 3, "voxel_grid should be [3, H, W]"
        # normalization to 0-255
        mean = voxel_grid.mean()
        std = voxel_grid.std()
        voxel_grid = (voxel_grid - mean) / std
        voxel_grid = torch.clamp(voxel_grid, min=-3, max=3)
        voxel_grid = ((voxel_grid + 3) / 6 * 255).round()
        return voxel_grid

    def _preprocess_video(self, path: Path) -> torch.Tensor:
        video_normal_path = path / "normal"
        if self.eval:
            video_low_path = path / "evaluation_extreme_005-135"
            video_ev_path = path / "evaluation_extreme_005-135"
        else:
            video_low_path = path / "low"
            video_ev_path = path / "low"
        video_normal_path = sorted(list(video_normal_path.glob('*.png')))
        video_low_path = sorted(list(video_low_path.glob('*.png')))
        video_ev_path = sorted(list(video_ev_path.glob('*.npz')))
        video_num_frames = len(video_normal_path)
        if self.eval:
            frame_indices = list(range(video_num_frames))
        else:
            nearest_frame_bucket = min(
                self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
            )

            start_idx = random.randint(0, video_num_frames - nearest_frame_bucket)
            frame_indices = list(range(start_idx, start_idx + nearest_frame_bucket))

        video_normal = []
        video_low = []
        video_event = []

        for i in frame_indices:
            frame_normal = read_image(str(video_normal_path[i]))
            video_normal.append(frame_normal)
            
            # random brightness change
            if self.eval or random.random() < 0.6:
                frame_low = read_image(str(video_low_path[i]))
            else:
                frame_low = read_image(str(video_normal_path[i]))
                ## apply brightness change
                frame_low = self.brightness_transform(frame_low)
                if random.random() < 0.3:
                    frame_low = torch.zeros_like(frame_low)
            video_low.append(frame_low)
            
            frame_ev = np.load(video_ev_path[i])
            try:
                event_input = frame_ev["arr_0"]
                if event_input.ndim == 1:
                    et = event_input["timestamp"] # 10e10
                    ex = event_input["x"]
                    ey = event_input["y"]
                    ep = event_input["polarity"] # 0 1
                    event_input = np.stack([et, ex, ey, ep], axis=1)
            except:
                print(f"loading event error @: {video_ev_path[i]}")
            event_input = torch.from_numpy(event_input)
            voxel_grid = self.convert_event_to_rgb(event_input, frame_normal.shape[1], frame_normal.shape[2]) # [C, H, W]
            video_event.append(voxel_grid)

        video_normal = torch.stack(video_normal, dim=0)
        video_low = torch.stack(video_low, dim=0)
        video_event = torch.stack(video_event, dim=0)

        seed = random.randint(0, 1000000)
        nearest_res = self._find_nearest_resolution(video_normal.shape[2], video_normal.shape[3])
        video_normal = torch.stack([resize(frame, nearest_res) for frame in video_normal], dim=0)
        random.seed(seed)
        video_normal = torch.stack([self.video_transforms(frame) for frame in video_normal], dim=0)
        video_low = torch.stack([resize(frame, nearest_res) for frame in video_low], dim=0)
        random.seed(seed)
        video_low = torch.stack([self.video_transforms(frame) for frame in video_low], dim=0)
        video_event = torch.stack([resize(frame, nearest_res) for frame in video_event], dim=0)
        random.seed(seed)
        video_event = torch.stack([self.video_transforms(frame) for frame in video_event], dim=0)

        image = video_normal[:1].clone()
        return image, video_normal, video_low, video_event

    def _preprocess_video_SDSD(self, path: Path) -> torch.Tensor:
        video_normal_path = path / "GT"
        if self.eval:
            video_low_path = path / "evaluation_extreme_005-135"
        else:
            video_low_path = path / "input"
        video_normal_path = sorted(list(video_normal_path.glob('*.png')))
        video_low_path = sorted(list(video_low_path.glob('*.png')))
        video_num_frames = len(video_normal_path)
        if self.eval:
            frame_indices = list(range(video_num_frames))
        else:
            nearest_frame_bucket = min(
                self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
            )

            start_idx = random.randint(0, video_num_frames - nearest_frame_bucket)
            frame_indices = list(range(start_idx, start_idx + nearest_frame_bucket))

        video_normal = []
        video_low = []

        for i in frame_indices:
            frame_normal = read_image(str(video_normal_path[i]))
            video_normal.append(frame_normal)
            
            # random brightness change
            if self.eval or random.random() < 0.6:
                frame_low = read_image(str(video_low_path[i]))
            else:
                frame_low = read_image(str(video_normal_path[i]))
                ## apply brightness change
                frame_low = self.brightness_transform(frame_low)
                if random.random() < 0.3:
                    frame_low = torch.zeros_like(frame_low)
            video_low.append(frame_low)

        video_normal = torch.stack(video_normal, dim=0)
        video_low = torch.stack(video_low, dim=0)

        seed = random.randint(0, 1000000)
        nearest_res = self._find_nearest_resolution(video_normal.shape[2], video_normal.shape[3])
        video_normal = torch.stack([resize(frame, nearest_res) for frame in video_normal], dim=0)
        random.seed(seed)
        video_normal = torch.stack([self.video_transforms(frame) for frame in video_normal], dim=0)
        video_low = torch.stack([resize(frame, nearest_res) for frame in video_low], dim=0)
        random.seed(seed)
        video_low = torch.stack([self.video_transforms(frame) for frame in video_low], dim=0)
        
        image = video_normal[:1].clone()

        # dummy event
        video_event = torch.zeros_like(video_low)

        return image, video_normal, video_low, video_event

    def _preprocess_video_COCO(self, path: Path, width=240, height=180) -> torch.Tensor:
        video_normal_path = path / "frames"
        video_ev_path = path / "events"
        video_normal_path = sorted(list(video_normal_path.glob('*.png')))
        video_ev_path = sorted(list(video_ev_path.glob('*.npz')))
        video_num_frames = len(video_normal_path) - 1
        if self.eval:
            id_path = path / "selected_ids.txt"
            if os.path.exists(id_path):
                frame_indices = open(id_path, 'r').readlines()
                frame_indices = [int(idx.strip()) for idx in frame_indices]
                frame_indices = frame_indices[:-1]
            else:
                frame_indices = list(range(video_num_frames))
        else:
            nearest_frame_bucket = min(
                self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
            )

            start_idx = random.randint(0, video_num_frames - nearest_frame_bucket)
            frame_indices = list(range(start_idx, start_idx + nearest_frame_bucket))

        video_normal = []
        video_low = []
        video_event = []

        event_window = np.empty((0,4),dtype=np.float32)
        for i in frame_indices:
            frame_normal = read_image(str(video_normal_path[i]))
            frame_normal = frame_normal.repeat(3, 1, 1)
            video_normal.append(frame_normal)
            
            if self.eval:
                # frame_low = frame_normal.clone()
                # zeros
                frame_low = torch.zeros_like(frame_normal)
            else:
                frame_low = self.brightness_transform(frame_normal.clone())
                if random.random() < 0.3:
                    frame_low = torch.zeros_like(frame_low)
            video_low.append(frame_low)

            ## load event
            cur_event_window = np.load(video_ev_path[i])
            timestamp = cur_event_window["t"] # 0...
            x = cur_event_window['x']
            y = cur_event_window['y']
            p = cur_event_window['p'] # 0 1
            cur_event_window = np.stack((timestamp,x,y,p), axis=1)
            # event_window = np.concatenate((event_window, cur_event_window), 0)
            # event_patch = e2_voxelgrid(cur_event_window, width, height, bin=3)
            cur_event_window = torch.from_numpy(cur_event_window)
            if len(cur_event_window) > 0:
                event_patch = self.convert_event_to_rgb(cur_event_window, height, width)
            else:
                event_patch = torch.full((3, height, width), 127.5)
            # event_patch = add_noise_to_voxel(event_patch, noise_std=0.1, noise_fraction=0.1)
            video_event.append(event_patch)

        video_normal = torch.stack(video_normal, dim=0)
        video_low = torch.stack(video_low, dim=0)
        video_event = torch.stack(video_event, dim=0)

        seed = random.randint(0, 1000000)
        nearest_res = self._find_nearest_resolution(video_normal.shape[2], video_normal.shape[3])
        video_normal = torch.stack([resize(frame, nearest_res) for frame in video_normal], dim=0)
        random.seed(seed)
        video_normal = torch.stack([self.video_transforms(frame) for frame in video_normal], dim=0)
        video_low = torch.stack([resize(frame, nearest_res) for frame in video_low], dim=0)
        random.seed(seed)
        video_low = torch.stack([self.video_transforms(frame) for frame in video_low], dim=0)
        video_event = torch.stack([resize(frame, nearest_res) for frame in video_event], dim=0)
        random.seed(seed)
        video_event = torch.stack([self.video_transforms(frame) for frame in video_event], dim=0)

        image = video_normal[:1].clone()

        return image, video_normal, video_low, video_event
    
    def _preprocess_video_BS(self, path: Path) -> torch.Tensor:
        video_ev_path = path
        path_str = str(path)
        video_normal_path = path_str.replace("clip_events", "clip_videos") + '.mp4'
        video_ev_path = sorted(list(video_ev_path.glob('*.npz')))
        # load video
        video_reader = decord.VideoReader(video_normal_path)
        video_num_frames = len(video_reader) - 1
        width, height = video_reader[0].shape[1], video_reader[0].shape[0] 
        nearest_frame_bucket = min(
            self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
        )
        start_idx = random.randint(0, video_num_frames - nearest_frame_bucket)
        frame_indices = list(range(start_idx, start_idx + nearest_frame_bucket))


        video_normal = []
        video_low = []
        video_event = []

        brightness_factor = random.uniform(0.2, 0.8)
        for i in frame_indices:
            frame_normal = video_reader[i] # [624, 970, 3]
            frame_normal = frame_normal.permute(2, 0, 1) # [3, 624, 970]
            video_normal.append(frame_normal)
            
            frame_low = self.brightness_transform(frame_normal.clone())
            if random.random() < 0.3:
                frame_low = torch.zeros_like(frame_low)
            video_low.append(frame_low)

            ## load event
            cur_event_window = np.load(video_ev_path[i])
            timestamp = cur_event_window["timestamp"] / 1e6
            x = convert_and_fix_event_pixels(cur_event_window['x'], width - 1) 
            y = convert_and_fix_event_pixels(cur_event_window['y'], height - 1)
            p = cur_event_window['polarity']

            cur_event_window = np.stack((timestamp,x,y,p), axis=1)
            cur_event_window = torch.from_numpy(cur_event_window)
            if len(cur_event_window) > 0: 
                event_patch = self.convert_event_to_rgb(cur_event_window, height, width)
            else:
                event_patch = torch.full((3, height, width), 127.5)
            # event_patch = add_noise_to_voxel(event_patch, noise_std=0.1, noise_fraction=0.1) # [3, 624, 970]
            video_event.append(event_patch)

        video_normal = torch.stack(video_normal, dim=0)
        video_low = torch.stack(video_low, dim=0)
        video_event = torch.stack(video_event, dim=0)

        seed = random.randint(0, 1000000)
        nearest_res = self._find_nearest_resolution(video_normal.shape[2], video_normal.shape[3])
        video_normal = torch.stack([resize(frame, nearest_res) for frame in video_normal], dim=0)
        random.seed(seed)
        video_normal = torch.stack([self.video_transforms(frame) for frame in video_normal], dim=0)
        video_low = torch.stack([resize(frame, nearest_res) for frame in video_low], dim=0)
        random.seed(seed)
        video_low = torch.stack([self.video_transforms(frame) for frame in video_low], dim=0)
        video_event = torch.stack([resize(frame, nearest_res) for frame in video_event], dim=0)
        random.seed(seed)
        video_event = torch.stack([self.video_transforms(frame) for frame in video_event], dim=0)

        image = video_normal[:1].clone()

        return image, video_normal, video_low, video_event

    def __getitem__(self, index: int) -> Dict[str, Any]:
        if isinstance(index, list):
            return index
        seq_name = ""
        if 'SDE/event_out_' in self.video_paths[index].as_posix():
            ## SDE Dataset
            image, video_normal, video_low, video_event = self._preprocess_video(self.video_paths[index])
            seq_name = 'SDE_out_extreme_005-135/' + self.video_paths[index].as_posix().split('/')[-1]
        elif 'SDE/event_in_' in self.video_paths[index].as_posix():
            ## SDE Dataset
            image, video_normal, video_low, video_event = self._preprocess_video(self.video_paths[index])
            seq_name = 'SDE_in_extreme_005-135/' + self.video_paths[index].as_posix().split('/')[-1]
        elif 'sdsd' in self.video_paths[index].as_posix():
            ## SDSD Dataset
            image, video_normal, video_low, video_event = self._preprocess_video_SDSD(self.video_paths[index])
            seq_name = 'SDSD/' + self.video_paths[index].as_posix().split('/')[-1]
        elif 'COCO' in self.video_paths[index].as_posix():
            ## COCO Dataset
            try:
                image, video_normal, video_low, video_event = self._preprocess_video_COCO(self.video_paths[index])
            except:
                print(f"Error in COCO Dataset: {self.video_paths[index]}, index: {index}")
                image, video_normal, video_low, video_event = self._preprocess_video(self.video_paths[0])
        # only for testing
        elif 'BS-ERGB' in self.video_paths[index].as_posix():
            ## BS-ERGB Dataset
            try:
                image, video_normal, video_low, video_event = self._preprocess_video_BS(self.video_paths[index])
            except:
                print(f"Error in BS-ERGB Dataset: {self.video_paths[index]}, index: {index}")
                image, video_normal, video_low, video_event = self._preprocess_video(self.video_paths[0])
        elif 'ECD' in self.video_paths[index].as_posix():
            ## ECD Dataset
            image, video_normal, video_low, video_event = self._preprocess_video_COCO(self.video_paths[index])
            seq_name = 'ECD/' + self.video_paths[index].as_posix().split('/')[-1]
        elif 'MVSEC' in self.video_paths[index].as_posix():
            ## MVSEC Dataset
            image, video_normal, video_low, video_event = self._preprocess_video_COCO(self.video_paths[index], width=346, height=260)
            seq_name = 'MVSEC/' + self.video_paths[index].as_posix().split('/')[-1]
        elif 'HQF' in self.video_paths[index].as_posix():
            ## HQF Dataset
            image, video_normal, video_low, video_event = self._preprocess_video_COCO(self.video_paths[index])
            seq_name = 'HQF/' + self.video_paths[index].as_posix().split('/')[-1]
        else:
            raise ValueError(f"Unknown dataset: {self.video_paths[index]}")

        return {
            "prompt": self.id_token + "high quality video",
            "image": image,
            "video_normal": video_normal,
            "video_low": video_low,
            "video_event": video_event,
            "video_metadata": {
                "num_frames": video_normal.shape[0],
                "height": video_normal.shape[2],
                "width": video_normal.shape[3],
            },
            "seq_name": seq_name
        }
    
    @staticmethod
    def identity_transform(x):
        return x

    @staticmethod
    def scale_transform(x):
        return x / 255.0

    def __len__(self) -> int:
        return len(self.video_paths)


class BucketSampler(Sampler):
    r"""
    PyTorch Sampler that groups 3D data by height, width and frames.

    Args:
        data_source (`VideoDataset`):
            A PyTorch dataset object that is an instance of `VideoDataset`.
        batch_size (`int`, defaults to `8`):
            The batch size to use for training.
        shuffle (`bool`, defaults to `True`):
            Whether or not to shuffle the data in each batch before dispatching to dataloader.
        drop_last (`bool`, defaults to `False`):
            Whether or not to drop incomplete buckets of data after completely iterating over all data
            in the dataset. If set to True, only batches that have `batch_size` number of entries will
            be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
            and batches that do not have `batch_size` number of entries will also be yielded.
    """

    def __init__(
        self, data_source, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
    ) -> None:
        self.data_source = data_source
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        self.buckets = {resolution: [] for resolution in data_source.resolutions}

        self._raised_warning_for_drop_last = False

    def __len__(self):
        if self.drop_last and not self._raised_warning_for_drop_last:
            self._raised_warning_for_drop_last = True
            logger.warning(
                "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
            )
        return (len(self.data_source) + self.batch_size - 1) // self.batch_size

    def __iter__(self):
        for index, data in enumerate(self.data_source):
            video_metadata = data["video_metadata"]
            f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]

            self.buckets[(f, h, w)].append(data)
            if len(self.buckets[(f, h, w)]) == self.batch_size:
                if self.shuffle:
                    random.shuffle(self.buckets[(f, h, w)])
                yield self.buckets[(f, h, w)]
                del self.buckets[(f, h, w)]
                self.buckets[(f, h, w)] = []

        if self.drop_last:
            return

        for fhw, bucket in list(self.buckets.items()):
            if len(bucket) == 0:
                continue
            if self.shuffle:
                random.shuffle(bucket)
                yield bucket
                del self.buckets[fhw]
                self.buckets[fhw] = []
