import skvideo
import torch
from torchvision import transforms, models
import torch.nn as nn
from torch.utils.data import Dataset
import skvideo.io
from PIL import Image
import os
import h5py
import numpy as np
import random
import time
import clip

class VideoDataset(Dataset):
    """Read data from the original dataset for feature extraction"""

    def __init__(self, videos_dir, video_names):
        super(VideoDataset, self).__init__()
        self.videos_dir = videos_dir
        self.video_names = video_names
    def __len__(self):
        return len(self.video_names)

    def __getitem__(self, idx):
        video_name = self.video_names[idx]

        video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name))
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        video_length = video_data.shape[0]
        video_channel = video_data.shape[3]
        video_height = video_data.shape[1]
        video_width = video_data.shape[2]
        transformed_video = torch.zeros([video_length, video_channel, video_height, video_width])
        for frame_idx in range(video_length):
            frame = video_data[frame_idx]
            frame = Image.fromarray(frame)
            frame = transform(frame)
            transformed_video[frame_idx] = frame

        sample = {'name': video_name,
                  'video1': video_data,
                  'video2': transformed_video,
               }
        return sample


def clip_feature(batch, text):
    print(batch.shape)
    output = torch.Tensor().to(device)
    i=0
    while (i < batch.shape[0]):
        output1 = torch.Tensor().to(device)
        b = batch[i, 0:224, 0:224, :]  # todo:第1-1块（第1行第1个）
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)

        b = batch[i, batch.shape[1]//2-112:batch.shape[1]//2+112, 0:224, :]  # todo:第2-1块（第2行第1个）
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)

        b = batch[i, batch.shape[1] - 224:batch.shape[1], 0:224, :]  # todo:第3-1块
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)

        b = batch[i, 0:224, batch.shape[2]//2-112:batch.shape[2]//2+112, :]  # todo:第1-2块
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)

        b = batch[i, batch.shape[1]//2-112:batch.shape[1]//2+112, batch.shape[2]//2-112:batch.shape[2]//2+112, :]  # todo:第2-2块
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)

        b = batch[i, batch.shape[1] - 224:batch.shape[1], batch.shape[2]//2-112:batch.shape[2]//2+112, :]  # todo:第3-2块
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)

        b = batch[i, 0:224, batch.shape[2]-224:batch.shape[2], :]  # todo:第1-3块
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)

        b = batch[i, batch.shape[1]//2-112:batch.shape[1]//2+112, batch.shape[2]-224:batch.shape[2], :]  # todo:第2-3块
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)

        b = batch[i, batch.shape[1]-224:batch.shape[1], batch.shape[2]-224:batch.shape[2], :]  # todo:第3-3块
        b = Image.fromarray(b)
        image = preprocess(b).unsqueeze(0).to(device)
        output1 = torch.cat((output1, image), 0)
        logits_per_image, logits_per_text = clip_model(output1, text)
        o = logits_per_image.softmax(dim=-1)

        frame_feature = o.contiguous().view(-1).unsqueeze(0)
        output = torch.cat((output, frame_feature), 0)
        #  print("[960,540,3]")
        i = i + 1
    return output




def get_features(text, current_video1, current_video2, frame_batch_size=64, device= 'cuda:0'):
    """feature extraction"""
    video_length = current_video2.shape[0]
    frame_start = 0
    frame_end = frame_start + frame_batch_size
    output3 = torch.Tensor().to(device)

    with torch.no_grad():
        while frame_end < video_length:
            batch1 = current_video1[frame_start:frame_end]
            clip_features = clip_feature(batch1, text)
            output3 = torch.cat((output3, clip_features), 0)
            frame_end += frame_batch_size
            frame_start += frame_batch_size

        last_batch1 = current_video1[frame_start:video_length]
        clip_features = clip_feature(last_batch1, text)
        output3 = torch.cat((output3, clip_features), 0)
        output = output3
        print(output.shape)
    return output


if __name__ == "__main__":
    frame_batch_size = 64
    seed = 19920517
    torch.manual_seed(seed)  #
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    torch.utils.backcompat.broadcast_warning.enabled = True
    videos_dir = '/home/datasets/LSVQ/'  # videos dir
    features_dir = '/home/datasets/LSVQ/features/'  # features dir
    if not os.path.exists(features_dir):
        os.makedirs(features_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    with open('/home/datasets/examplar_data_labels/LSVQ/labels_test.txt', 'r') as f:
        data = f.readlines()

    video_name = []
    for n in data:
        b = n.split(',')[0]
        video_name.append(b)

    dataset = VideoDataset(videos_dir, video_name)
    clip_model, preprocess = clip.load("./models/ViT-B-32.pt", device)
    text_list = ["good image block", "bad image block", "noisy image block", "hazy image block", "dark image block", "bright image block",
        "blurry image block", "over exposure image block", "sharp image block", "colorful image block", 'dull image block',
        "high contrast image block", "low contrast image block", "image block without noise", "image block without blur",
        "image block with additive gaussian noise", "image block with noise in color compression", "image block with spatially correlated noise",
        "image block with masked noise", "image block with high frequency noise", "image block with impulse noise",
        "image block with quantization noise", "image block with gaussian blur", "image block with motion blur", "image block with bokeh blur",
        "uniform color image block", "uneven color image block", "image block with chromatic aberration", "image block without chromatic aberration",
        "image block with distortions", "image block without distortions", "uniform illumination image block", "unevenly illuminated image block",
        "image block with sharpness loss", "image block without sharpness loss",
        "Light-hearted image block", "Depressing image block", "Comfortable image block", "Uncomfortable image block", "Sad image block",
        "Sentimental image block", "Fearful image block", "Exciting image block", "Satisfactory image block", "Calming image block",
        "Fascinating image block", "Interesting image block", "Impatient image block", "Tense image block", "Puzzling image block",
        "Delightful image block", "Outrageous image block", "Disgusting image block",
                 ]
    text = clip.tokenize(text_list).to(device)
    for i in range(0, len(dataset)):
        current_data = dataset[i]
        name = current_data['name']
        current_video1 = current_data['video1']
        current_video2 = current_data['video2']
        print('Video {}: length {}'.format(i, current_video1.shape[0]))
        start_time = time.time()
        features = get_features(text, current_video1, current_video2, frame_batch_size, device)
        end_time = time.time()
        run_time = end_time - start_time
        print(run_time)
        np.save(features_dir + name.split('/')[1].split('.')[0], features.to('cpu').numpy())
