
#code adapted from https://github.com/W-zx-Y/SM-PPM/blob/master/dataset/gta5_dataset.py
import os.path as osp
import numpy as np
from torch.utils import data
from PIL import Image
import torch
from collections import namedtuple


class Endovis18DataSet(data.Dataset):

    NalendovisClass = namedtuple('NalendovisClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                     'has_instances', 'ignore_in_eval', 'color'])
    classes = [
        NalendovisClass('Background Tissue',        0, 0, 'surgical background tissue', 0,False, False, (0, 0, 0)),
        NalendovisClass('Bipolar Forceps',        1, 1, 'bipolar forceps', 1,False, False, (255, 55, 0)),
        NalendovisClass('Grasper',        2, 2, 'grasper', 2,False, False, (0, 255, 0)),
        NalendovisClass('Large Needle Driver',        3, 3, 'large needle driver', 3,False, False, (24, 55, 125)),
        NalendovisClass('Monopolar Curved Scissors',        4, 4, 'large needle driver', 4,False, False, (255, 255, 125)),
        NalendovisClass('Ultrasound Probe',        5, 5, 'ultrasound probe', 5,False, False, (187, 155, 25)),
        NalendovisClass('Suction Irrigator',        6, 6, 'suction irrigator', 6,False, False, (0, 255, 255)),
        NalendovisClass('Clip Applier',        7, 7, 'clip applier', 7,False, False, (255, 128, 0)),
    ]
    train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
    # train_id_to_color.append([0, 0, 0])
    train_id_to_color = np.array(train_id_to_color)
    train_id_to_name = np.array([c.name for c in classes if c.train_id !=255])
    def __init__(self, root, list_path, ignore_label=255,transform=None):
        self.root = root
        self.list_path = list_path
        self.ignore_label = ignore_label
        self.transform = transform

        self.img_ids = [i_id.strip() for i_id in open(list_path)]
      
        self.files = []

        self.id_to_trainid = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5,
                              6: 6, 7: 7}

        # for split in ["train", "trainval", "val"]:
        for name in self.img_ids:
            img_file = osp.join(self.root, "images/%s" % name)
            label_file = osp.join(self.root, "gtlabels/%s" % name)
            self.files.append({
                "img": img_file,
                "label": label_file,
                "name": name
            })
    @classmethod
    def decode_target(cls, target):
        target[target == 255] = 19
        #target = target.astype('uint8') + 1
        return cls.train_id_to_color[target]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        datafiles = self.files[index]

        image = Image.open(datafiles["img"]).convert('RGB')
        label = Image.open(datafiles["label"])
        
        image, label = self.transform(image, label)
    
        # re-assign labels to match the format of Cityscapes
        label_copy = 255 * torch.ones(label.shape, dtype=torch.float32)
     
        for k, v in self.id_to_trainid.items():
            label_copy[label == k] = v
        #图像路径 标签路径 处理后的图像张量（例如 torch.Size([3, 512, 512])） 处理后的标签张量（例如 torch.Size([512, 512])，取值 0～7）
        return datafiles["img"],datafiles["label"], image, label_copy
    
    @classmethod
    def name(cls,target):
        return cls.train_id_to_name[target]