import os
import numpy as np
import torch
from torch.utils import data
from glob import glob
from PIL import Image

import random
import torchvision.transforms as T
import torch.nn.functional as F
from datasets.transforms import load_image_in_PIL, To_One_Hot

class DAVIS17_train(data.Dataset):
    def __init__(self, args, max_obj_n=11):
        self.root = args.root
        self.train_splits = args.train_splits

        self.N = args.N
        self.relative_orders = list(range(-self.N, self.N + 1))
        
        self.resize_to = args.resize_to

        self.patch_size = args.patch_size
        self.token_num = (self.resize_to[0] * self.resize_to[1]) // (self.patch_size * self.patch_size)

        # === Get Video Names ===
        self.dataset_list = []

        # === Train Set ===
        if "train" in self.train_splits:
            dataset_path = os.path.join(self.root, "ImageSets", "2017/train.txt")
            with open(os.path.join(dataset_path), 'r') as lines:
                for line in lines:
                    dataset_name = line.strip()
                    if len(dataset_name) > 0:
                        self.dataset_list.append(dataset_name)

        # === Val Set ===
        if "valid" in self.train_splits:
            dataset_path = os.path.join(self.root, "ImageSets", "2017/val.txt")

            with open(os.path.join(dataset_path), 'r') as lines:
                for line in lines:
                    dataset_name = line.strip()
                    if len(dataset_name) > 0:
                        self.dataset_list.append(dataset_name)

        # === Get Video Lengths ===
        self.video_lengths = []
        for video_name in self.dataset_list:
            frames_dir = os.path.join(
                self.root, 'JPEGImages', '480p', video_name)
            frame_list = sorted(glob(os.path.join(frames_dir, '*.jpg')))
            frame_num = len(frame_list)
            self.video_lengths.append(frame_num)

        self.create_idx_frame_mapping()

        # === Transformations ===
        self.resize = T.Resize(self.resize_to)
        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    def __len__(self):
        return sum(self.video_lengths)

    def transform(self, image):
        image = self.resize(image)
        image = self.to_tensor(image)
        image = self.normalize(image)

        return image

    def create_idx_frame_mapping(self):
        self.mapping = []

        for video_idx, video_length in enumerate(self.video_lengths):
            video_name = self.dataset_list[video_idx]
            for video_frame_idx in range(video_length):
                self.mapping.append((video_name, video_frame_idx))
    
    def get_rgb(self, video_name, frame_idx):
        img_dir = os.path.join(self.root, "JPEGImages", "480p", video_name)
        img_list = sorted(glob(os.path.join(img_dir, "*.jpg")), key=lambda x: int(x.split("/")[-1].split(".")[0]))
        frame_num = len(img_list)

        input_frames = torch.zeros((2 * self.N + 1, 3, self.resize_to[0], self.resize_to[1]), dtype=torch.float)
        mask = torch.ones(2 * self.N + 1)

        for i, frame_order in enumerate(self.relative_orders):
            frame_idx_real = frame_idx + frame_order

            if frame_idx_real < 0 or frame_idx_real >= frame_num:
                mask[i] = 0
                continue

            frame = load_image_in_PIL(img_list[frame_idx_real], 'RGB')
            frame = self.transform(frame)
            input_frames[i] = frame
        
        return input_frames, mask

    def __getitem__(self, idx):
        """
        :return:
            input_features: RGB frames [t-N, ..., t+N]
                            in shape (2*N + 1, 3, H, W)

            mask: Mask for input_features indicating if frame is available
                            in shape (2*N + 1)

        """
        video_name, frame_idx = self.mapping[idx]
        # === Frame inputs ===
        input_frames, frame_masks = self.get_rgb(video_name, frame_idx)             # (2N + 1, 3, H, W), (2N + 1)

        return input_frames, frame_masks


class DAVIS17_val(data.Dataset):
    def __init__(self, args, max_obj_n=11):
        self.root = args.root

        self.N = args.N
        self.relative_orders = list(range(-self.N, self.N + 1))
        self.resize_to = args.resize_to
        self.patch_size = args.patch_size
        self.token_num = (self.resize_to[0] * self.resize_to[1]) // (self.patch_size * self.patch_size)
        self.max_obj_n = max_obj_n

        dataset_path = os.path.join(self.root, "ImageSets", "2017/val.txt")
        self.dataset_list = list()

        with open(os.path.join(dataset_path), 'r') as lines:
            for line in lines:
                dataset_name = line.strip()
                if len(dataset_name) > 0:
                    self.dataset_list.append(dataset_name)

        # === Transformations ===
        self.resize = T.Resize(self.resize_to)
        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

        self.to_onehot = To_One_Hot(max_obj_n, shuffle=False)

    def __len__(self):
        return len(self.dataset_list)

    def transform(self, image):

        image = self.resize(image)
        image = self.to_tensor(image)
        image = self.normalize(image)

        return image

    def get_rgb(self, video_name):
        img_dir = os.path.join(self.root, "JPEGImages", "480p", video_name)
        img_list = sorted(glob(os.path.join(img_dir, "*.jpg")), key=lambda x: int(x.split("/")[-1].split(".")[0]))
        frame_num = len(img_list)
        
        input_frames = torch.zeros(frame_num, 3, self.resize_to[0], self.resize_to[1], dtype=torch.float)
        for i in range(frame_num):
            frame = load_image_in_PIL(img_list[i], 'RGB')
            frame = self.transform(frame)
            input_frames[i] = frame

        model_input = torch.zeros(frame_num + 2 * self.N, 3, self.resize_to[0], self.resize_to[1], dtype=torch.float)
        input_masks = torch.ones(frame_num + 2 * self.N)
        
        for frame_idx in range(frame_num + 2 * self.N):
            
            frame_idx_real = frame_idx - self.N

            if frame_idx_real < 0 or frame_idx_real >= frame_num:
                input_masks[frame_idx] = 0
                continue

            model_input[frame_idx] = input_frames[frame_idx_real]

        return model_input, input_masks

    def get_gt_masks(self, video_name):
        mask_path = os.path.join(self.root, "Annotations_unsupervised", "480p", video_name)
        mask_list = sorted(glob(os.path.join(mask_path, "*.png")), key=lambda x: int(x.split("/")[-1].split(".")[0]))
        frame_num = len(mask_list)

        first_mask = load_image_in_PIL(mask_list[0], 'P')
        first_mask_np = np.array(first_mask, np.uint8)
        H, W = first_mask_np.shape
        obj_n = first_mask_np.max() + 1
        
        masks = torch.zeros(frame_num, self.max_obj_n, H, W, dtype=torch.float)
        for i in range(frame_num):
            mask = load_image_in_PIL(mask_list[i], 'P')
            mask = np.array(mask, np.uint8)
            if i == 0:
                mask, obj_list = self.to_onehot(mask)
                obj_n = len(obj_list) + 1
            else:
                mask, _ = self.to_onehot(mask, obj_list)

            masks[i] = mask

        return masks, obj_n

    def __getitem__(self, idx):
        """
        :return:
            model_input: (#frames, 2N + 1, 3, H, W)
            input_masks: (#frames, 2*N + 1)
            masks: (#frames, #objects, H, W)
        """

        video_name = self.dataset_list[idx]

        # === DINO Features of Frames ===
        input_frames, frame_masks = self.get_rgb(video_name)             # (#frames + 2N, 3, H, W), (#frames + 2N)
        mask, obj_n = self.get_gt_masks(video_name)
        mask = mask[:, :obj_n]


        return input_frames, frame_masks, mask
