import torch
import json
import os
import cv2
import numpy as np
import pandas as pd
import math
import random
from PIL import Image
import torchvision.transforms as transforms
from decord import VideoReader, cpu
from torch.utils.data import Dataset
import jsonlines

from .label2text import label2text


class ucf(Dataset):
    def __init__(self, dataset, H, W, T, train, lim=None):
        self.datasetroot = dataset
        self.dataset = []
        if train:
            with open(os.path.join(self.datasetroot, "train.txt"), 'r') as file:
                for line in file:
                    self.dataset.append(line.split(' ')[0])
        else:
            with open(os.path.join(self.datasetroot, "test.txt"), 'r') as file:
                for line in file:
                    self.dataset.append(line.split('\n')[0])
        if lim is not None:
            self.dataset = self.dataset[::len(self.dataset)//lim]
        
        self.video_length = T
        self.frame_stride = 4
        self.st = int(math.sqrt(T))
        self.resolution = [H, W]
        print(f"dataset size: {len(self.dataset)}")

    def _make_dataset(self):
        return
    
    def video(self, video_path, index):
        while True:
            try:
                video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0])
                if len(video_reader) < self.video_length:
                    index += 1
                    continue
                else:
                    break
            except:
                if index < len(self.dataset):
                    index += 1
                else:
                    index = 0
                print(f"Load video failed! path = {video_path}")
    
        all_frames = list(range(0, len(video_reader), self.frame_stride))
        if len(all_frames) < self.video_length:
            all_frames = list(range(0, len(video_reader), 1))

        # select random clip
        rand_idx = random.randint(0, len(all_frames) - self.video_length)
        frame_indices = list(range(rand_idx, rand_idx+self.video_length))
        frames = video_reader.get_batch(frame_indices)
        assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'

        frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
        assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
        frames = (frames / 255 - 0.5) * 2
        return frames

    def catvideo(self, video_tensor):
        channels, num_frames, height, width = video_tensor.shape
        output_height = height * self.st  # 4 x 256 = 1024
        output_width = width * self.st    # 4 x 256 = 1024

        # 逐批处理
            # 用于存储单个拼接后的图片
        stitched_image = torch.zeros((channels, output_height, output_width))
        for i in range(self.st):  # 行
            for j in range(self.st):  # 列
                # 计算出帧的索引
                frame_idx = i * self.st + j
                # 放置帧的位置
                start_y = i * height
                start_x = j * width
                stitched_image[:, start_y:start_y + height, start_x:start_x + width] = video_tensor[ :, frame_idx, :, :]
            
        return stitched_image

    def __getitem__(self, index):
        p = self.dataset[index]
        text = p.split('/')[0]
        video_path = os.path.join(self.datasetroot, p)

        video = self.video(video_path, index)
        img = self.catvideo(video)
        d = {
            "image": img,
            "text": text,
            "fps": 30
        }
        return d

    def __len__(self):
        return len(self.dataset)


class msrvtt(Dataset):
    def __init__(self, dataset, H, W, T, train, lim=None):
        self.datasetroot = dataset
        if train:
            with open(os.path.join(self.datasetroot, "train_val_videodatainfo.json"), 'r') as file:
                self.dataset = json.load(file)['sentences']
            self.dataset_suffix = "TrainValVideo"
        else:
            with open(os.path.join(self.datasetroot, "test_videodatainfo.json"), 'r') as file:
                self.dataset = json.load(file)['sentences']
            self.dataset_suffix = "TestVideo"
        if lim is not None:
            self.dataset = self.dataset[::len(self.dataset)//lim]
        
        self.video_length = T
        self.frame_stride = 1
        self.H = H
        self.W = W
        self.st = int(math.sqrt(T))
        self.resolution = [H, W]

        print(f"dataset size: {len(self.dataset)}")

    def _make_dataset(self):
        return
    
    def video(self, video_path, index):
        while True:
            try:
                video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0])
                if len(video_reader) < self.video_length:
                    index += 1
                    continue
                else:
                    break
            except:
                if index < len(self.dataset):
                    index += 1
                else:
                    index = 0
                print(f"Load video failed! path = {video_path}")
    
        all_frames = list(range(0, len(video_reader), self.frame_stride))
        if len(all_frames) < self.video_length:
            all_frames = list(range(0, len(video_reader), 1))

        # select random clip
        rand_idx = random.randint(0, len(all_frames) - self.video_length)
        frame_indices = list(range(rand_idx, rand_idx+self.video_length))
        frames = video_reader.get_batch(frame_indices)
        assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'

        frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
        assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
        frames = (frames / 255 - 0.5) * 2
        return frames

    def catvideo(self, video_tensor):
        channels, num_frames, height, width = video_tensor.shape
        output_height = height * self.st  # 4 x 256 = 1024
        output_width = width * self.st    # 4 x 256 = 1024

        # 逐批处理
            # 用于存储单个拼接后的图片
        stitched_image = torch.zeros((channels, output_height, output_width))
        for i in range(self.st):  # 行
            for j in range(self.st):  # 列
                # 计算出帧的索引
                frame_idx = i * self.st + j
                # 放置帧的位置
                start_y = i * height
                start_x = j * width
                stitched_image[:, start_y:start_y + height, start_x:start_x + width] = video_tensor[ :, frame_idx, :, :]
            
        return stitched_image

    def __getitem__(self, index):
        p = self.dataset[index]
        caption = p['caption']
        video_id = p['video_id']
        sen_id = p['sen_id']
        video_path = os.path.join(self.datasetroot, self.dataset_suffix, video_id + '.mp4')

        video = self.video(video_path, index)
        img = self.catvideo(video)
        d = {
            "image": img,
            "text": caption,
            "fps": 30
        }
        return d

    def __len__(self):
        return len(self.dataset)
    



class webvid(Dataset):
    def __init__(self, dataset, H, W, T, train, lim=None):
        self.datasetroot = dataset
        df = pd.read_csv(os.path.join(self.datasetroot, "data", "top_videos.csv"))
        self.dataset = df
        if lim is not None:
            self.dataset = self.dataset[::len(self.dataset)//lim]
        
        self.video_length = T
        self.st = int(math.sqrt(T))
        self.frame_rate = 30
        self.H = H
        self.W = W
        # self.frame_stride = 4
        self.resolution = [H, W]
        self.is_train = train
        print(f"dataset size: {len(self.dataset)}")
        # self.image_dataset = 
        self.image_dataset = []
        for i in range(1000):
            self.image_dataset.append([])
        with open("datasets/ImageNet/train.txt", 'r') as file:
            for line in file:
                line = line.strip()
                path, index = line.split(' ')
                self.image_dataset[int(index)].append(path)
        
        # print(line)

    def _make_dataset(self):
        return
    
    def getvideopath(self, index):
        try:
            videoid, contentUrl, duration, page_dir, name, clipscore = self.dataset.iloc[index]
        except:
            videoid, contentUrl, duration, page_dir, name = self.dataset.iloc[index]
        video_path = os.path.join(self.datasetroot, "data", "videos", page_dir, str(videoid) + '.mp4')
        return video_path


    def video(self, index, frame_stride):
        while True:
            try:
                video_reader = VideoReader(self.getvideopath(index), ctx=cpu(0), width=self.resolution[1], height=self.resolution[0])
                if len(video_reader) < self.video_length * frame_stride:
                    index += 1
                    continue
                else:
                    break
            except:
                if index < len(self.dataset):
                    index += 1
                else:
                    index = 0
                print(f"Load video failed! path = {self.getvideopath(index)}")

        all_frames = list(range(0, len(video_reader), frame_stride))
        if len(all_frames) < self.video_length:
            all_frames = list(range(0, len(video_reader), 1))

        # select random clip
        rand_idx = random.randint(0, len(all_frames) - self.video_length)
        frame_indices = all_frames[rand_idx:rand_idx + self.video_length]
        frames = video_reader.get_batch(frame_indices)
        assert(frames.shape[0] == self.video_length), f'{len(frames)}, self.video_length={self.video_length}'

        frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
        assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
        frames = (frames / 255 - 0.5) * 2
        return frames

    def catvideo(self, video_tensor):
        channels, num_frames, height, width = video_tensor.shape
        output_height = height * self.st  # 4 x 256 = 1024
        output_width = width * self.st    # 4 x 256 = 1024

        # 逐批处理
            # 用于存储单个拼接后的图片
        stitched_image = torch.zeros((channels, output_height, output_width))
        for i in range(self.st):  # 行
            for j in range(self.st):  # 列
                # 计算出帧的索引
                frame_idx = i * self.st + j
                # 放置帧的位置
                start_y = i * height
                start_x = j * width
                stitched_image[:, start_y:start_y + height, start_x:start_x + width] = video_tensor[ :, frame_idx, :, :]
            
        return stitched_image

    def __getitem__(self, index):
        try:
            videoid, contentUrl, duration, page_dir, name, clipscore = self.dataset.iloc[index]
        except:
            videoid, contentUrl, duration, page_dir, name = self.dataset.iloc[index]
        if self.is_train:
            fps = np.random.choice([0, 1, 2, 4, 8, 15, 30])
        else:
            fps = 8
        
        if fps == 0:
            index = np.random.randint(100)
            list = np.random.choice(self.image_dataset[index], 20, replace=False)
            images = []
            for each in list:
                path = os.path.join('datasets/ImageNet', each)
                img = cv2.imread(path)
                if img is not None:
                    img_resized = cv2.resize(img, (self.W, self.H))
                    images.append(img_resized)
                else:
                    continue
                if len(images) == self.video_length:
                    grid_img = np.zeros((self.H * self.st, self.W * self.st, 3), dtype=np.uint8)
                    for i in range(self.st):
                        for j in range(self.st):
                            grid_img[i*self.H:(i+1)*self.H, j*self.W:(j+1)*self.W, :] = images[i * self.st + j]
                    grid_img = torch.tensor(grid_img).float()
                    grid_img = (grid_img / 255 - 0.5) * 2
                    grid_img = grid_img.permute(2, 0, 1)
                    d = {
                        "image": grid_img,
                        "text": label2text[int(index)],
                        "fps": 0
                    }
                    return d
        else:
            # videoid, contentUrl, duration, page_dir, name, clipscore = self.dataset.iloc[index]
            # video_path = os.path.join(self.datasetroot, "data", "videos", page_dir, str(videoid) + '.mp4')

            video = self.video(index, self.frame_rate // fps)
            img = self.catvideo(video)
            d = {
                "image": img,
                "text": name,
                "fps": fps
            }
            return d

    def __len__(self):
        return len(self.dataset)


class text(Dataset):
    def __init__(self, dataset, H, W, T, train, lim=None):
        self.datasetroot = dataset

        with open(dataset, 'r', encoding='utf-8') as file:
            lines = [line.strip() for line in file]
        self.dataset = lines

    def __getitem__(self, index):
        name = self.dataset[index]
        fps = 60
        d = {
            "image": None,
            "text": name,
            "fps": fps
        }
        return d

    def __len__(self):
        return len(self.dataset)

class jdb(Dataset):
    def __init__(self, dataset, H, W, T, train, lim=None):
        self.datasetroot = dataset
        with open(os.path.join(self.datasetroot, 'train_anno_realease_repath.jsonl'), 'r') as f:
            self.dataset = list(jsonlines.Reader(f))

        # df = pd.read_csv(os.path.join(self.datasetroot, "train_anno.json"))
        # self.dataset = df
        self.lim = lim
        if lim is not None:
            self.dataset = self.dataset[::len(self.dataset)//lim]
        
        self.video_length = T
        self.st = int(math.sqrt(T))
        self.frame_rate = 30
        self.H = H
        self.W = W
        # self.frame_stride = 4
        self.resolution = [H, W]
        self.is_train = train
        # print(f"dataset size: {len(self.dataset)}")
        
        
        # print(line)

    def _make_dataset(self):
        return
    

    def catimage(self, image_tensor):
        channels, height, width = image_tensor.shape
        output_height = height * self.st  # 4 x 256 = 1024
        output_width = width * self.st    # 4 x 256 = 1024

        stitched_image = torch.zeros((channels, output_height, output_width))
        for i in range(self.st):  # 行
            for j in range(self.st):  # 列
                frame_idx = i * self.st + j
                start_y = i * height
                start_x = j * width
                stitched_image[:, start_y:start_y + height, start_x:start_x + width] = image_tensor
            
        return stitched_image

    def __getitem__(self, index):
        image = None
        while image is None:
            try:
                d = self.dataset[index]
                img_path = os.path.join(self.datasetroot,'imgs', d['img_path'][2:])
                image = Image.open(img_path).convert('RGB').resize((self.H, self.W))
            except:
                print('error!')
                index += 1
        transform = transforms.ToTensor()  # 定义转换器，将图片转为 tensor
        image_tensor = transform(image)
        img = self.catimage(image_tensor)
        d = {
            "image": img,
            "text": d['prompt'],
            "fps": 120
        }
        return d


    def __len__(self):
        return len(self.dataset)