from torchvision import transforms
from PIL import Image
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

import os
import numpy as np
import torch
import torch.utils.data as data
import jsonlines
import random
import time
from clip import clip
import numpy as np

class AllDataset(data.Dataset):
    def __init__(self, args):
        #assert args.phase == 'train'
        
        self.data_dir = args.data_dir
        self.seq_len = args.seq_len
        
        self.info = [] 
        
        json_file = args.json_file

        i = 0

        with open(json_file, "r", encoding="utf8") as f:
            for item in jsonlines.Reader(f): 
                self.info.append(item)
                i += 1
        
        print('Load {} {}'.format(json_file, i))

    def __getitem__(self, index):
        item = self.info[index]
        
        shots = []
        
        for shot in item['shots'] : 
            frame_path = os.path.join(self.data_dir, shot, 'fea.npy')
            
            frames_fea = np.load(frame_path)
            frames_fea = torch.from_numpy(frames_fea)

            shots.append(frames_fea)
        
        shot_mask = torch.ones([self.seq_len], dtype = torch.float)
        
        if len(shots) >= self.seq_len : 
            shots = shots[:self.seq_len]
        if len(shots) == 0 : 
            shots.append(torch.zeros((8,512)))
        while len(shots) < self.seq_len : 
            empty = torch.zeros_like(shots[0])
            shots.append(empty)
            shot_mask[len(shots) - 1] = 0

        shots = torch.stack(shots, dim = 0)

        # get random images for false pairs
        random_index = random.randint(0, len(self.info) - 1)
        while random_index == index : 
            random_index = random.randint(0, len(self.info) - 1)
        false_infos = self.info[random_index]
        
        false_shots = []
        
        for shot in false_infos['shots'] : 
            frame_path = os.path.join(self.data_dir, shot, 'fea.npy')
            
            frames_fea = np.load(frame_path)
            frames_fea = torch.from_numpy(frames_fea)
            
            #print(frames.shape)
            false_shots.append(frames_fea)
        
        false_mask = torch.ones([self.seq_len], dtype = torch.float)
        
        if len(false_shots) >= self.seq_len : 
            false_shots = false_shots[:self.seq_len]
        if len(false_shots) == 0 : 
            false_shots.append(torch.zeros((8,512)))
        while len(false_shots) < self.seq_len : 
            empty = torch.zeros_like(false_shots[0])
            false_shots.append(empty)
            false_mask[len(false_shots) - 1] = 0
        
        false_shots = torch.stack(false_shots, dim = 0)

        # get texts 
        text = item['caption']
        sentences = text.strip().split('.')
        
        for sentence in sentences : 
            if sentence == '' : 
                sentences.remove(sentence)
        
        for i in range(len(sentences)) : 
            sentences[i] = sentences[i] + '.'
        
        text_infos = []
        
        for sentence in sentences : 
            text_infos.append(clip.tokenize(sentence).squeeze(0))
        
        text_mask = torch.ones([self.seq_len], dtype = torch.float)
        
        if len(text_infos) >= self.seq_len : 
            text_infos = text_infos[:self.seq_len]
        while len(text_infos) < self.seq_len : 
            empty = torch.zeros_like(text_infos[0])
            text_infos.append(empty)
            text_mask[len(text_infos) - 1] = 0
        
        text_infos = torch.stack(text_infos, dim = 0)
        
        '''
        if shots.shape[0] > 10 or text_infos.shape[0] > 10 : 
            print(shots.shape, text_infos.shape)
            #exit()
        '''
        return shots, false_shots, text_infos, shot_mask, false_mask, text_mask

    def __len__(self):
        return len(self.info)