import os
import sys
from pathlib import Path
from itertools import cycle

import numpy as np
from PIL import Image
from einops import rearrange
import torch
from torch.utils.data.dataset import ConcatDataset, Dataset
import torchvision
import torchvision.transforms as T
import pytorchvideo.transforms as VT
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.data.clip_sampling import RandomClipSampler
from torchvision import datasets
from torch.utils.data import Dataset, IterableDataset
from data.vox2_preprocess import get_logger

DATA_PATH = './data'
VOXCELEB_PATH = DATA_PATH + '/vox2_mp4'
logging  = get_logger(__name__)

class ImgDataset(Dataset):
    def __init__(self, data, sdf=False):
        self.data = data
        self.sdf = sdf

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        if not self.sdf:
            x = x[0]
        return {
            'imgs': x,
        }


class FFHQ(Dataset):
    def __init__(self, root_path, list_file, transform=None):
        self.root_path = root_path
        self.transform = transform
        self.img_list = []
        with open(list_file, 'r') as f:
            for line in f.readlines():
                self.img_list.append(line.strip())

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_name = self.img_list[idx]
        folder_name = img_name[:2] + '000'
        img_path = self.root_path.joinpath('imgs').joinpath(folder_name).joinpath(img_name)
        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)
        return {
            'imgs': img,
        }


class Imagenette(Dataset):
    def __init__(self, root_path, train=True, transform=None):
        self.root_path = root_path
        self.transform = transform
        self.train = train
        img_path = self.root_path.joinpath('train' if self.train else 'val')
        self.imgs = [Path(folder).joinpath(file) for folder, _, files in os.walk(img_path) for file in files]

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img = Image.open(self.imgs[index]).convert('RGB')

        if self.transform:
            img = self.transform(img)
        return {
            'imgs': img,
        }


# class FrameDataset(IterableDataset):
#     def __init__(self, video_paths, frame_num=4, transform=None):
#         super().__init__()
#         self.video_paths = video_paths
#         self.frame_num = frame_num
#         self.transform = transform

#     def __len__(self):
#         return len(self.video_paths)

#     def __iter__(self):
#         for folder, files in cycle(self.video_paths):
#             try:
#                 files = sorted([file for file in files if file.endswith('.jpg')])
#                 if len(files) > 0:
#                     convert_img = lambda x: np.array(Image.open(x).convert('RGB')).astype(np.float64)
#                     x = [convert_img(os.path.join(folder, files[i])) for i in range(len(files))]

#                     if self.transform:
#                         x = torch.stack([self.transform(x[i]) for i in range(len(x))])
#                     yield {
#                         'videos': x[:self.frame_num],
#                     }
#             except Exception as e:
#                 logging.info(f'Error in {folder}: {e}')
#                 continue

class FrameDataset(Dataset):
    def __init__(self, video_paths, frame_num=4, transform=None):
        super().__init__()
        self.video_paths = video_paths
        self.frame_num = frame_num
        self.transform = transform

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        folder, files = self.video_paths[idx]
        try:
            files = sorted([file for file in files if file.endswith('.jpg')])
            if len(files) <= 16 and len(files) >= self.frame_num:
                convert_img = lambda x: np.array(Image.open(x).convert('RGB')).astype(np.float64)
                x = [convert_img(os.path.join(folder, files[i])) for i in range(len(files))]

                if self.transform:
                    x = torch.stack([self.transform(x[i]) for i in range(self.frame_num)])

                return {
                    'videos': x[:self.frame_num],
                }
            else:
                return None

        except Exception as e:
            logging.info(f'Error in {folder}: {e}')
            return None

## Iterable Dataset Version
# class VidDataset(IterableDataset):
#     def __init__(self, video_paths, clip_sampler, video_sampler, transform=None, decode_audio=False, duration=2):
#         super().__init__()
#         self.video_paths = video_paths
#         self.clip_sampler = clip_sampler
#         self.video_sampler = video_sampler
#         self.transform = transform
#         self.decode_audio = decode_audio
#         self.duration = duration

#     def __len__(self):
#         return len(self.video_paths)

#     def __iter__(self):
#         for folder, files in self.video_paths:
#             video_path = os.path.join(folder, files[0])
#             try:
#                 video = EncodedVideo.from_path(video_path, decode_audio=self.decode_audio)
#                 video_data = video.get_clip(0, self.duration)

#                 if self.transform:
#                     video_data = self.transform(video_data)
#                 x = rearrange(video_data['video'], 'c t h w -> t c h w')
#                 yield {
#                     'videos': x,
#                 }
#             except Exception as e:
#                 logging.info(f'Error in {video_path}: {e}')
#                 continue

## Map-style Dataset Version
# class VidDataset(Dataset):
#     def __init__(self, video_paths, clip_sampler, video_sampler, transform=None, decode_audio=False, duration=2):
#         self.video_paths = video_paths
#         self.clip_sampler = clip_sampler
#         self.video_sampler = video_sampler
#         self.transform = transform
#         self.decode_audio = decode_audio
#         self.duration = duration

#     def __len__(self):
#         return len(self.video_paths)

#     def __getitem__(self, idx):
#         try:
#             folder, files = self.video_paths[idx]
#             # vid = torch.cat([torchvision.io.read_video(os.path.join(folder, file), pts_unit='sec')[0] for file in files], dim=0)
#             video = EncodedVideo.from_path(os.path.join(folder, files[0]), decode_audio=self.decode_audio)
#             video_data = video.get_clip(0, self.duration)

#             if self.transform:
#                 video_data = self.transform(video_data)
#             x = rearrange(video_data['video'], 'c t h w -> t c h w')
#             return {
#                 'videos': x,
#             }
#         except Exception as e:
#             logging.info(e)
#             logging.info(folder)
#             return None


def get_dataset(P, dataset, only_test=False):
    val_set = None
    P.data_size = None

    if dataset == 'celeba':
        T_base = T.Compose([
            T.Resize(P.resolution),
            T.CenterCrop(P.resolution),
            T.Pad(1),
            T.ToTensor(),
        ])
        P.resolution += 2
        train_set = ImgDataset(
            datasets.CelebA(DATA_PATH, split='train',
                            target_type='attr', transform=T_base)
        )
        test_set = ImgDataset(
            datasets.CelebA(DATA_PATH, split='test',
                            target_type='attr', transform=T_base)
        )
        P.data_type = 'img'
        P.dim_in, P.dim_out = 2, 3
        P.data_size = (3, P.resolution, P.resolution)

    elif dataset == 'ffhq':
        root_path = Path(DATA_PATH) / 'ffhq'
        train_list_file = root_path / 'ffhqtrain.txt'
        test_list_file = root_path / 'ffhqvalidation.txt'

        T_base = T.Compose([
            T.Resize(P.resolution),
            T.CenterCrop(P.resolution),
            T.ToTensor(),
        ])
        train_set = FFHQ(root_path, train_list_file, transform=T_base)
        test_set = FFHQ(root_path, test_list_file, transform=T_base)

        P.data_type = 'img'
        P.dim_in, P.dim_out = 2, 3
        P.data_size = (3, P.resolution, P.resolution)

    elif dataset == 'imagenette':
        root_path = Path(DATA_PATH) / 'imagenette'
        T_base = T.Compose([
            T.Resize(P.resolution),
            T.CenterCrop(P.resolution),
            T.ToTensor(),
        ])
        train_set = Imagenette(root_path, train=True, transform=T_base)
        test_set = Imagenette(root_path, train=False, transform=T_base)

        P.data_type = 'img'
        P.dim_in, P.dim_out = 2, 3
        P.data_size = (3, P.resolution, P.resolution)

    elif dataset == 'voxceleb':
        dataset_path = Path(VOXCELEB_PATH)
        train_vids_path = dataset_path / 'dev' / 'mp4'
        test_vids_path = dataset_path / 'test' / 'mp4'

        train_video_list, test_video_list = [], []
        for root, dirs, files in os.walk(train_vids_path):
            if len(files) >= P.inner_step:
                train_video_list.append((root, files))
        logging.info(f'Number of train videos: {len(train_video_list)}')

        for root, dirs, files in os.walk(test_vids_path):
            if len(files) >= P.inner_step:
                test_video_list.append((root, files))
        logging.info(f'Number of test videos: {len(test_video_list)}')

        # transform = VT.ApplyTransformToKey(
        #     key="video", transform=T.Compose([
        #         VT.UniformTemporalSubsample(P.inner_steps),
        #         T.Lambda(lambda x: x/255.0),
        #         T.Resize(P.resolution),
        #         T.RandomCrop(P.resolution)])
        # )
        transform = T.Compose([
            T.ToTensor(),
            T.Lambda(lambda x: x.float()/255.0),
            T.Resize(P.resolution),
            T.CenterCrop(P.resolution),
        ])

        train_set = FrameDataset(
            train_video_list,
            transform=transform,
            frame_num=P.inner_step,
        )
        test_set = FrameDataset(
            test_video_list,
            transform=transform,
            frame_num=P.inner_step,
        )

        # train_set = VidDataset(
        #     train_video_list,
        #     clip_sampler=RandomClipSampler(2),
        #     video_sampler=None,
        #     transform=transform,
        #     decode_audio=False,
        #     duration=P.inner_steps//4,
        # )
        # test_set = VidDataset(
        #     test_video_list,
        #     clip_sampler=RandomClipSampler(2),
        #     video_sampler=None,
        #     transform=transform,
        #     decode_audio=False,
        #     duration=P.inner_steps//4,
        # )

        P.data_type = 'video'
        P.dim_in, P.dim_out = 3, 3
        P.data_size = (P.inner_step, 3, P.resolution, P.resolution)

    else:
        raise NotImplementedError()

    if only_test:
        return test_set

    val_set = test_set if val_set is None else val_set
    return train_set, val_set
