import os
# import random
import numpy as np
import torch
import json
import random
from torch.utils import data
from glob import glob
from PIL import Image


import torchvision.transforms as T
import torch.nn.functional as F
from datasets.transforms import load_image_in_PIL, To_One_Hot

class MOVi_train(data.Dataset):
    def __init__(self, args):
        self.root = args.root
        self.train_splits = args.train_splits

        self.N = args.N
        self.relative_orders = list(range(-self.N, self.N + 1))
        
        self.resize_to = args.resize_to
        self.patch_size = args.patch_size
        self.token_num = (self.resize_to[0] * self.resize_to[1]) // (self.patch_size * self.patch_size)

        # === Get Video Names and Lengths ===
        self.dataset_list = []
        self.video_lengths = []
        self.split_name = []

        # === Train Set ===
        if "train" in self.train_splits:
            videos = sorted(glob(self.root + "/train/*/"))
            
            for video_name in videos:
                video_name = video_name.split("/")[-2]
                self.dataset_list.append(video_name)

                frame_num = 24 # fixed
                self.video_lengths.append(frame_num)

                self.split_name.append("train")

        if "valid" in self.train_splits:
            videos = sorted(glob(self.root + "/val/*/"))
            
            for video_name in videos:
                video_name = video_name.split("/")[-2]
                self.dataset_list.append(video_name)

                frame_num = 24 # fixed
                self.video_lengths.append(frame_num)

                self.split_name.append("val")

        self.create_idx_frame_mapping()

        # === Transformations ===
        self.resize = T.Resize(self.resize_to)
        self.resize_nn = T.Resize(self.resize_to, T.InterpolationMode.NEAREST)
        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    def __len__(self):
        return sum(self.video_lengths)

    def transform(self, image):

        image = self.resize(image)
        # image = self.to_tensor(image)
        image = self.normalize(image)

        return image

    def create_idx_frame_mapping(self):
        self.mapping = []

        for video_idx, video_length in enumerate(self.video_lengths):
            video_name = self.dataset_list[video_idx]
            split_name = self.split_name[video_idx]
            for video_frame_idx in range(video_length):
                self.mapping.append((video_name, video_frame_idx, split_name))

    def get_rgb(self, idx):
        video_name, frame_idx, split_name = self.mapping[idx]
        img_dir = os.path.join(self.root, split_name, video_name, "Frames")
        img_list = sorted(glob(os.path.join(img_dir, "*.npy")), key=lambda x: int(x.split("/")[-1].split(".")[0]))
        frame_num = len(img_list)

        input_frames = torch.zeros((2 * self.N + 1, 3, self.resize_to[0], self.resize_to[1]), dtype=torch.float)
        mask = torch.ones(2 * self.N + 1)

        for i, frame_order in enumerate(self.relative_orders):
            frame_idx_real = frame_idx + frame_order

            if frame_idx_real < 0 or frame_idx_real >= frame_num:
                mask[i] = 0
                continue

            frame = np.load(img_list[frame_idx_real])
            frame = torch.from_numpy(frame).permute(2, 0, 1).float()
            frame = self.transform(frame)
            input_frames[i] = frame
        
        return input_frames, mask

    def __getitem__(self, idx):
        """
        :return:
            input_features: RGB frames [t-N, ..., t+N]
                                in shape (2*N + 1, 3, H, W)
            frame_masks: Mask for input_features indicating if frame is available
                            in shape (2*N + 1)

        """
        video_name, frame_idx, split_name = self.mapping[idx]

        # === Frame inputs ===
        input_frames, frame_masks = self.get_rgb(idx)             # (2N + 1, 3, H, W), (2N + 1)

        return input_frames, frame_masks



class MOVi_val(data.Dataset):
    def __init__(self, args, max_obj_n=20):
        self.root = args.root


        self.N = args.N
        self.relative_orders = list(range(-self.N, self.N + 1))
        self.resize_to = args.resize_to
        self.patch_size = args.patch_size
        self.token_num = (self.resize_to[0] * self.resize_to[1]) // (self.patch_size * self.patch_size)
        self.max_obj_n = max_obj_n
        
        # === Get Video Names and Lengths ===
        videos = glob(self.root + "/val/*/")
        self.dataset_list = []
        self.video_lengths = []
        for video_name in videos:
            video_name = video_name.split("/")[-2]
            self.dataset_list.append(video_name)

            frame_num = 24 # fixed
            self.video_lengths.append(frame_num)

        # === Transformations ===
        self.resize = T.Resize(self.resize_to)
        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        self.to_one_hot = To_One_Hot(self.max_obj_n, shuffle=False)

    def __len__(self):
        return len(self.dataset_list)

    def transform(self, image):

        image = self.resize(image)
        # image = self.to_tensor(image)
        image = self.normalize(image)

        return image

    def get_rgb(self, video_name):
        img_dir = os.path.join(self.root, "val", video_name, "Frames")
        img_list = sorted(glob(os.path.join(img_dir, "*.npy")), key=lambda x: int(x.split("/")[-1].split(".")[0]))
        frame_num = len(img_list)
        
        input_frames = torch.zeros(frame_num, 3, self.resize_to[0], self.resize_to[1], dtype=torch.float)
        for i in range(frame_num):
            frame = np.load(img_list[i])
            frame = torch.from_numpy(frame).permute(2, 0, 1).float()
            frame = self.transform(frame)
            input_frames[i] = frame

        model_input = torch.zeros(frame_num + 2 * self.N, 3, self.resize_to[0], self.resize_to[1], dtype=torch.float)
        input_masks = torch.ones(frame_num + 2 * self.N)
        
        for frame_idx in range(frame_num + 2 * self.N):
            
            frame_idx_real = frame_idx - self.N

            if frame_idx_real < 0 or frame_idx_real >= frame_num:
                input_masks[frame_idx] = 0
                continue

            model_input[frame_idx] = input_frames[frame_idx_real]

        return model_input, input_masks

    def get_gt_masks(self, video_name):
        mask_path = os.path.join(self.root, "val", video_name, "Annotations")
        mask_list = sorted(glob(os.path.join(mask_path, "*.npy")), key=lambda x: int(x.split("/")[-1].split(".")[0]))
        frame_num = len(mask_list)

        first_mask = np.load(mask_list[0])
        H, W = first_mask.shape
        obj_n = first_mask.max() + 1
        
        masks = torch.zeros(frame_num, self.max_obj_n, H, W, dtype=torch.float)
        for i in range(frame_num):
            mask = np.load(mask_list[i])
            if i == 0:
                mask, obj_list = self.to_one_hot(mask)
                obj_n = len(obj_list) + 1
            else:
                mask, _ = self.to_one_hot(mask, obj_list)

            masks[i] = mask

        return masks, obj_n

    def __getitem__(self, idx):
        """
        :return:
            model_input: (#frames, 2N + 1, 3, H, W)
            input_masks: (#frames, 2*N + 1)
            masks: (#frames, #objects, H, W)
        """

        video_name = self.dataset_list[idx]

        # === DINO Features of Frames ===
        input_frames, frame_masks = self.get_rgb(video_name)             # (#frames + 2N, 3, H, W), (#frames + 2N)
        mask, obj_n = self.get_gt_masks(video_name)
        mask = mask[:, :obj_n]
        
        return input_frames, frame_masks, mask