from io import BytesIO
from PIL import Image
import os

import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import datasets
import json


class DavisInference(Dataset):
    def __init__(self, root, split, transform=None, target_transform=None):
        super(DavisInference, self).__init__()
        self.transform = transform
        self.target_transform = target_transform
        self.root = os.path.join(root, f'DAVIS_{split}')
        self.img_paths = []
        self.target_paths = []
        self.categories = []

        with open(os.path.join(root, 'DAVIS/semantics/davis_semantics.json'), 'r') as f:
            self.semantic_labels = json.load(f)

        with open(os.path.join(root, 'DAVIS/semantics/categories.json'), 'r') as f:
            self.categories_map = json.load(f)

        with open(os.path.join(root, 'DAVIS', 'ImageSets', '2017', f'{split}.txt')) as f:
            for category in f:
                category = category.strip()
                img_dir = os.path.join(self.root, 'JPEGImages', '480p', category)
                target_dir = os.path.join(self.root, 'Annotations', '480p', category)

                for img_name in os.listdir(img_dir):
                    img_path = os.path.join(img_dir, img_name)
                    target_path = os.path.join(target_dir, img_name)
                    self.img_paths.append(img_path)
                    self.target_paths.append(target_path.replace('jpg', 'png'))
                    self.categories.append(category)
        print(f'Found {len(self.img_paths)} images in {split} split, and {len(self.categories_map)} categories')

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        target_path = self.target_paths[idx]
        category = self.categories[idx]
        semantic_labels = self.semantic_labels[category]

        img = Image.open(img_path)
        target = Image.open(target_path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        new_target = torch.zeros_like(target)

        for k, v in semantic_labels.items():
            source_id = int(k)
            target_id = self.categories_map[v]['id']

            new_target[target == source_id] = target_id

        return img, new_target

    def __len__(self):
        return len(self.img_paths)
