import os, io, csv, math, random
import numpy as np
from einops import rearrange
from decord import VideoReader
from PIL import Image

import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from animatediff.utils.util import zero_rank_print


class CustomRandomCrop:
    def __init__(self, min_scale=0.9, max_scale=1.0):
        self.min_scale = min_scale
        self.max_scale = max_scale

    def __call__(self, image):
        _, _, height, width = image.size()  # Get dimensions from tensor
        
        scale = random.uniform(self.min_scale, self.max_scale)
        new_width = int(scale * width)
        new_height = int(scale * height)

        if width == new_width and height == new_height:
            cropped_image = image
        else:
            left = random.randint(0, width - new_width)
            top = random.randint(0, height - new_height)
            cropped_image = transforms.functional.crop(image, top, left, new_height, new_width)
        return cropped_image


class WebVid10M(Dataset):
    def __init__(
            self,
            csv_path, video_folder,
            sample_size=256, sample_stride=4, sample_n_frames=16,
            is_image=False,
            mask_folder=None, mask_oracle=False,
            add_rel_video=False,
            add_mix_video=0.0,
        ):
        assert len(csv_path) == len(video_folder)
        if mask_folder is not None:
            assert len(video_folder) == len(mask_folder)
        self.mask_oracle = mask_oracle
        self.add_rel_video = add_rel_video
        self.add_mix_video = add_mix_video
        self.dataset = []
        for ii in range(len(csv_path)):
            zero_rank_print(f"loading annotations from {csv_path[ii]} ...")
            with open(csv_path[ii], 'r') as csvfile:
                tmp_dataset = list(csv.DictReader(csvfile))
            for jj in range(len(tmp_dataset)):
                tmp_dataset[jj]['videopath'] = os.path.join(video_folder[ii],tmp_dataset[jj]['videoid'])
                if not tmp_dataset[jj]['videopath'].endswith('.mp4'):
                    tmp_dataset[jj]['videopath'] += '.mp4'
                tmp_dataset[jj]['maskpath'] = os.path.join(mask_folder[ii],tmp_dataset[jj]['videoid']) if mask_folder is not None else None
                if tmp_dataset[jj]['maskpath'] is not None and tmp_dataset[jj]['maskpath'].endswith('.mp4'):
                    tmp_dataset[jj]['maskpath'] = tmp_dataset[jj]['maskpath'][:-4]
            self.dataset.extend(tmp_dataset)
            
        self.length = len(self.dataset)
        zero_rank_print(f"data scale: {self.length}")

        self.sample_stride   = sample_stride
        self.sample_n_frames = sample_n_frames
        self.is_image        = is_image
        
        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
        # self.pixel_transforms = transforms.Compose([
        #     transforms.RandomHorizontalFlip(),
        #     transforms.Resize(sample_size[0]),
        #     transforms.CenterCrop(sample_size),
        #     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        # ])
        self.pixel_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(sample_size),
            # transforms.ColorJitter(brightness=0.2, contrast=0.2),  
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
        self.mask_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(sample_size),
            # transforms.ColorJitter(brightness=0.2, contrast=0.2),
        ])
        # if self.add_mix_video > 0.0:
        #     import spacy
        #     nlp = spacy.load("en_core_web_sm")
        #     globals()['nlp'] = nlp
    
    def get_batch(self, idx):
        video_dict = self.dataset[idx]
        video_dir, name, page_dir, mask_dir = video_dict['videopath'], video_dict['name'], video_dict['page_dir'], video_dict['maskpath']
        
        video_reader = VideoReader(video_dir)
        video_length = len(video_reader)
        
        if not self.is_image:
            clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
            start_idx   = random.randint(0, video_length - clip_length)
            batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
        else:
            batch_index = [random.randint(0, video_length - 1)]

        pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
        pixel_values = pixel_values / 255.
        del video_reader

        if mask_dir is not None:
            mask_sequence = []
            for i, ind in enumerate(batch_index):
                mask = Image.open(os.path.join(mask_dir,"%05d.jpg"%(ind+1))).convert("L")
                mask = transforms.ToTensor()(mask)
                if not self.mask_oracle and i > 0:
                    mask = torch.zeros_like(mask)
                mask_sequence.append(mask)
            mask = torch.stack(mask_sequence, dim=0)
        else:
            mask = None
        if self.add_rel_video:
            rel_idx = random.sample(range(self.length),1)[0]
            rel_name = self.dataset[rel_idx]["name"]
            
            if ' run' in name or ' ran' in name:
                while ' run' not in rel_name and ' ran' not in rel_name and rel_idx == idx:
                    rel_idx = random.sample(range(self.length),1)[0]
                    rel_name = self.dataset[rel_idx]["name"]
            elif ' hit' in name or ' bunch' in name:
                while ' hit' not in rel_name and ' bunch' not in rel_name and rel_idx == idx:
                    rel_idx = random.sample(range(self.length),1)[0]
                    rel_name = self.dataset[rel_idx]["name"]
            else:
                while rel_idx == idx:
                    rel_idx = random.sample(range(self.length),1)[0]
                    rel_name = self.dataset[rel_idx]["name"]
            rel_video_reader = VideoReader(self.dataset[rel_idx]["videopath"])
            rel_video_length = len(rel_video_reader)
            clip_length = min(rel_video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
            start_idx   = random.randint(0, rel_video_length - clip_length)
            batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
            rel_pixel_values = torch.from_numpy(np.dot(rel_video_reader.get_batch(batch_index).asnumpy(), [0.299, 0.587, 0.114])).unsqueeze(-1).permute(0, 3, 1, 2).contiguous().to(pixel_values.dtype)
            rel_pixel_values = rel_pixel_values / 255.
            del rel_video_reader
        else:
            rel_pixel_values = None
        if self.add_mix_video > 0.0:
            if random.random() < self.add_mix_video:
                mix_video_idx = random.sample(range(self.length),1)[0]
                mix_name = self.dataset[mix_video_idx]["name"]
                while mix_video_idx == idx or ('run ' not in mix_name and 'ran ' not in mix_name and 'runs ' not in mix_name and 'running ' not in mix_name and 'chase ' not in mix_name and 'chases ' not in mix_name and 'chasing ' not in mix_name and 'hit' not in mix_name and  'hitting' not in mix_name):
                    mix_video_idx = random.sample(range(self.length),1)[0]
                    mix_name = self.dataset[mix_video_idx]["name"]
                mix_video_reader = VideoReader(self.dataset[mix_video_idx]["videopath"])
                mix_video_length = len(mix_video_reader)
                mix_video_frame_idx = random.randint(0, mix_video_length - 1)
                mix_pixel_values = torch.from_numpy(mix_video_reader.get_batch([mix_video_frame_idx]).asnumpy()).permute(0, 3, 1, 2).contiguous()
                mix_pixel_values = mix_pixel_values / 255.
                mix_name = remove_verbs(mix_name).replace('sks - ', 'sks-')
            else:
                mix_pixel_values = None
                mix_name = None
        else:
            mix_pixel_values = None
            mix_name = None

        if self.is_image:
            pixel_values = pixel_values[0]
        
        return pixel_values, name, mask, rel_pixel_values, mix_pixel_values, mix_name

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        try_times = 0
        while try_times <= 5:
            try:
                pixel_values, name, mask, rel_pixel_values, mix_pixel_values, mix_name = self.get_batch(idx)
                break
            except Exception as e:
                print("#################", idx, self.dataset[idx])
                idx = random.randint(0, self.length-1)
        if try_times > 5:
            assert False
        seed = np.random.randint(2147483647)
        random.seed(seed)
        torch.manual_seed(seed)
        pixel_values = self.pixel_transforms(pixel_values)
        sample = dict(pixel_values=pixel_values, text=name)
        if mask is not None:
            random.seed(seed)
            torch.manual_seed(seed)
            mask = self.mask_transforms(mask)
            sample.update(mask=mask)
        if rel_pixel_values is not None:
            random.seed(seed)
            torch.manual_seed(seed)
            rel_pixel_values = self.mask_transforms(rel_pixel_values)
            sample.update(rel_pixel_values=rel_pixel_values)
        if mix_pixel_values is not None and mix_name is not None:
            mix_pixel_values = self.pixel_transforms(mix_pixel_values)
            sample["pixel_values"] = 0.75 * sample["pixel_values"] + 0.25 * mix_pixel_values
            sample["text"] = sample["text"] + ' ' + mix_name
        return sample


class ImageDataset(Dataset):
    def __init__(
            self,
            csv_path, folder,
            sample_size=256,
        ):
        zero_rank_print(f"loading annotations from {csv_path} ...")
        with open(csv_path, 'r') as csvfile:
            self.dataset = list(csv.DictReader(csvfile))
        self.length = len(self.dataset)
        zero_rank_print(f"data scale: {self.length}")

        self.folder    = folder
        
        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
        # self.pixel_transforms = transforms.Compose([
        #     transforms.RandomHorizontalFlip(),
        #     transforms.Resize(sample_size[0]),
        #     transforms.CenterCrop(sample_size),
        #     transforms.ToTensor(),
        #     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        # ])
        self.pixel_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(sample_size),
            # transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
    
    def get_batch(self, idx):
        image_dict = self.dataset[idx]
        imageid, name = image_dict['file_name'], image_dict['text']
        
        image_dir    = os.path.join(self.folder, f"{imageid}")
        pixel_values = self.pixel_transforms(Image.open(image_dir).convert("RGB"))
        
        return pixel_values, name

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        try_num = 0
        while try_num < 20:
            try:
                pixel_values, name = self.get_batch(idx)
                break
            except Exception as e:
                idx = random.randint(0, self.length-1)
            try_num += 1

        sample = dict(pixel_values=pixel_values, text=name)
        return sample


class HAAVideo(Dataset):
    def __init__(
            self,
            csv_path, video_folder,
            caption,
            depth_folder = '',
            hed_folder = '',
            sample_size=256, sample_stride=4, sample_n_frames=16,
            random_crop = False,
            random_color = False,
            over_sample_ids = [],
            fix_dataset_scale = None,
        ):
        self.dataset = []

        if csv_path != '':
            if not isinstance(csv_path, str) and len(csv_path) > 1:
                assert len(csv_path) == len(video_folder) and len(csv_path) == len(caption)
                for cp_ind, cp in enumerate(csv_path):
                    with open(cp, 'r') as csvfile:
                        csv_reader = csv.reader(csvfile)
                        id = 0
                        for row in csv_reader:
                            if id < 16 and row[3] == '0' and row[4] == '1':
                                action_category = video_folder[cp_ind].split('/')[-1]
                                if depth_folder != '':
                                    self.dataset.append({'video_path': os.path.join(video_folder[cp_ind], '{}_{:03d}.mp4'.format(action_category, id)), 'depth_path': os.path.join(depth_folder[cp_ind], '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption[cp_ind]})
                                elif hed_folder != '':
                                    self.dataset.append({'video_path': os.path.join(video_folder[cp_ind], '{}_{:03d}.mp4'.format(action_category, id)), 'hed_path': os.path.join(hed_folder[cp_ind], '{}_{:03d}'.format(action_category, id)), 'caption': caption[cp_ind]})
                                else:
                                    self.dataset.append({'video_path': os.path.join(video_folder[cp_ind], '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption[cp_ind]})
                                    if id in over_sample_ids:
                                        for _ in range(5):
                                            self.dataset.append({'video_path': os.path.join(video_folder[cp_ind], '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption[cp_ind]})
                            id += 1
            elif os.path.isfile(csv_path):
                with open(csv_path, 'r') as csvfile:
                    csv_reader = csv.reader(csvfile)
                    id = 0
                    for row in csv_reader:
                        if id < 16 and row[3] == '0' and row[4] == '1':
                            action_category = video_folder.split('/')[-1]
                            if depth_folder != '':
                                self.dataset.append({'video_path': os.path.join(video_folder, '{}_{:03d}.mp4'.format(action_category, id)), 'depth_path': os.path.join(depth_folder, '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption})
                            elif hed_folder != '':
                                self.dataset.append({'video_path': os.path.join(video_folder, '{}_{:03d}.mp4'.format(action_category, id)), 'hed_path': os.path.join(hed_folder, '{}_{:03d}'.format(action_category, id)), 'caption': caption})
                            else:
                                self.dataset.append({'video_path': os.path.join(video_folder, '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption})
                                if id in over_sample_ids:
                                    for _ in range(5):
                                        self.dataset.append({'video_path': os.path.join(video_folder, '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption})
                        id += 1
            elif os.path.isdir(csv_path):
                all_csv_files = os.listdir(csv_path)
                for csv_file in all_csv_files:
                    with open(os.path.join(csv_path, csv_file), 'r') as csvfile:
                        csv_reader = csv.reader(csvfile)
                        id = 0
                        for row in csv_reader:
                            if id < 16 and row[3] == '0' and row[4] == '1':
                                action_category = csv_file.split('.')[0]
                                if depth_folder != '':
                                    self.dataset.append({'video_path': os.path.join(video_folder, action_category, '{}_{:03d}.mp4'.format(action_category, id)), 'depth_path': os.path.join(depth_folder, '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption})
                                elif hed_folder != '':
                                    self.dataset.append({'video_path': os.path.join(video_folder, action_category, '{}_{:03d}.mp4'.format(action_category, id)), 'hed_path': os.path.join(hed_folder, '{}_{:03d}'.format(action_category, id)), 'caption': caption})
                                else:
                                    self.dataset.append({'video_path': os.path.join(video_folder, action_category, '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption})
                                    if id in over_sample_ids:
                                        for _ in range(5):
                                            self.dataset.append({'video_path': os.path.join(video_folder, action_category, '{}_{:03d}.mp4'.format(action_category, id)), 'caption': caption})
                            id += 1
        else:
            all_files = os.listdir(video_folder)
            for f in all_files:
                if f.endswith('mp4') or f.endswith('avi'):
                    if depth_folder != '':
                        self.dataset.append({'video_path': os.path.join(video_folder, f), 'depth_path': os.path.join(depth_folder, f), 'caption': caption})
                    elif hed_folder != '':
                        self.dataset.append({'video_path': os.path.join(video_folder, f), 'hed_path': os.path.join(hed_folder, f), 'caption': caption})
                    else:
                        self.dataset.append({'video_path': os.path.join(video_folder, f), 'caption': caption})
        if fix_dataset_scale is not None:
            assert fix_dataset_scale > 0 and fix_dataset_scale <= len(self.dataset)
            random.seed(fix_dataset_scale)
            self.dataset = random.sample(self.dataset, fix_dataset_scale)
        self.length = len(self.dataset)
        zero_rank_print(f"data scale: {self.length}")

        self.sample_stride   = sample_stride
        self.sample_n_frames = sample_n_frames
        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)

        if not random_crop:
            self.shared_transforms = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.Resize(sample_size, antialias=True),
            ])
        else:
            self.shared_transforms = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                CustomRandomCrop(0.85,1.0),
                transforms.Resize(sample_size, antialias=True),
            ])
        self.pixel_transforms = transforms.Compose([
            transforms.ColorJitter(brightness=0.2, contrast=0.2) if not random_color else transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
    
    def get_batch(self, idx):
        video_dict = self.dataset[idx]
        video_path, video_caption = video_dict['video_path'], video_dict['caption']
        if 'depth_path' in video_dict.keys():
            depth_path = video_dict['depth_path']
        if 'hed_path' in video_dict.keys():
            hed_path = video_dict['hed_path']
        
        video_reader = VideoReader(video_path)
        video_length = len(video_reader)
        
        clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
        start_idx   = random.randint(0, video_length - clip_length)
        batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
        max_stride = int((clip_length - 1) // (self.sample_n_frames - 1) // 2 - 1)
        if max_stride > 0:
            random_ind = np.random.randint(-max_stride, max_stride+1, self.sample_n_frames)
            random_ind[0] = 0
            random_ind[-1] = 0
            batch_index = batch_index + random_ind

        pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
        pixel_values = pixel_values / 255.
        del video_reader
        
        if 'depth_path' in video_dict.keys():
            video_reader = VideoReader(depth_path)
            assert video_length == len(video_reader)
            depth_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
            depth_values = depth_values / 255.
            del video_reader
        else:
            depth_values = None
        
        if 'hed_path' in video_dict.keys():
            hed_sequence = []
            for i, ind in enumerate(batch_index):
                hed = Image.open(os.path.join(hed_path,"%05d.jpg"%(ind+1))).convert("RGB")
                hed = transforms.ToTensor()(hed)
                hed_sequence.append(hed)
            hed_values = torch.stack(hed_sequence, dim=0)
        else:
            hed_values = None
        
        return pixel_values, depth_values, hed_values, video_caption, video_dict['video_path'].split('/')[-1].split('.')[0]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        try_times = 0
        while try_times <= 5:
            try:
                pixel_values, depth_values, hed_values, video_caption, video_id = self.get_batch(idx)
                break
            except Exception as e:
                print("#################", idx, self.dataset[idx])
                idx = random.randint(0, self.length-1)
            try_times += 1
        if try_times > 5:
            assert False
        if depth_values is not None:
            pixel_values, depth_values = torch.chunk(self.shared_transforms(torch.cat([pixel_values, depth_values], dim=0)), 2, dim=0)
        elif hed_values is not None:
            pixel_values, hed_values = torch.chunk(self.shared_transforms(torch.cat([pixel_values, hed_values], dim=0)), 2, dim=0)
            hed_values = (hed_values - 0.5) / 0.5
        else:
            pixel_values = self.shared_transforms(pixel_values)
        pixel_values = self.pixel_transforms(pixel_values)
        if depth_values is not None:
            sample = dict(pixel_values=pixel_values, depth_values=depth_values, text=video_caption, video_id=video_id)
        elif hed_values is not None:
            sample = dict(pixel_values=pixel_values, hed_values=hed_values, text=video_caption, video_id=video_id)
        else:
            sample = dict(pixel_values=pixel_values, text=video_caption, video_id=video_id)
        return sample


class InternVidVideo(Dataset):
    def __init__(
            self,
            video_folder,
            sample_size=512, sample_stride=4, sample_n_frames=16,
        ):
        self.dataset = []

        all_files = os.listdir(video_folder)
        for f in all_files:
            if f.endswith('mp4'):
                self.dataset.append({'video_path': os.path.join(video_folder, f)})
            
        self.length = len(self.dataset)
        zero_rank_print(f"data scale: {self.length}")

        self.sample_stride   = sample_stride
        self.sample_n_frames = sample_n_frames
        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)

        self.pixel_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(sample_size[0], antialias=True),
            transforms.CenterCrop(sample_size),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),  
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
    
    def get_batch(self, idx):
        video_dict = self.dataset[idx]
        video_path = video_dict['video_path']
        
        video_reader = VideoReader(video_path)
        video_length = len(video_reader)
        
        clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
        start_idx   = random.randint(0, video_length - clip_length)
        batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)

        pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
        pixel_values = pixel_values / 255.
        del video_reader
        
        return pixel_values

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        try_times = 0
        while try_times <= 3:
            try:
                pixel_values = self.get_batch(idx)
                break
            except Exception as e:
                print("#################", idx, self.dataset[idx])
                idx = random.randint(0, self.length-1)
            try_times += 1
        if try_times > 3:
            assert False
        pixel_values = self.pixel_transforms(pixel_values)
        sample = dict(pixel_values=pixel_values)
        return sample


if __name__ == "__main__":
    from animatediff.utils.util import save_videos_grid

    dataset = WebVid10M(
        csv_path=["/xxx/xxx/xxx/xxx/data/xxx/xxx/annotations/train_in_category.csv"],
        video_folder=["/xxx/xxx/xxx/xxx/data/xxx/xxx/videos_in_category"],
        sample_size=256,
        sample_stride=4, sample_n_frames=16,
        add_mix_video=0.2
    )
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
    for idx, batch in enumerate(dataloader):
        print(batch["pixel_values"].shape, len(batch["text"]), batch["rel_pixel_values"].shape)
        # for i in range(batch["pixel_values"].shape[0]):
        #     save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
