import json
import os
from collections import namedtuple

import torch.utils.data as data
from PIL import Image
import numpy as np

class Nalendovis(data.Dataset):
    """Cityscapes <http://www.cityscapes-dataset.com/> Dataset.
    
    **Parameters:**
        - **root** (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
        - **split** (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
        - **mode** (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
        - **transform** (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
        - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.
    """

    # Based on https://github.com/mcordts/cityscapesScripts
    NalendovisClass = namedtuple('NalendovisClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                     'has_instances', 'ignore_in_eval', 'color'])
    classes = [
        NalendovisClass('background',        0, 0, 'background', 0,False, False, (0, 0, 0)),
        NalendovisClass('Bipolar Forceps',        1, 1, 'Bipolar Forceps', 1,False, False, (255, 55, 0)),
        NalendovisClass('Prograsp Forceps',        2, 2, 'Prograsp Forceps', 2,False, False, (0, 255, 0)),
        NalendovisClass('Large Needle Driver',        3, 3, 'Large Needle Driver', 3,False, False, (24, 55, 125)),
        NalendovisClass('Monopolar Curved Scissors',        4, 4, 'Monopolar Curved Scissors', 4,False, False, (255, 255, 125)),
        NalendovisClass('Unknow Ultrasound Probe',        5, 5, 'Unknow Ultrasound Probe', 5,False, False, (187, 155, 25)),
        NalendovisClass('Suction Instrument',        6, 6, 'Suction Instrument', 6,False, False, (0, 255, 255)),
        NalendovisClass('Clip Applier',        7, 7, 'Clip Applier', 7,False, False, (255, 128, 0)),
    ]
    #[
    #   (128, 64, 128),  # train_id = 0
    #   (244, 35, 232),  # train_id = 1
    #   ...
    #   (0, 0, 142)      # train_id = 13
    # ]
    train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
    # train_id_to_color.append([0, 0, 0])
    train_id_to_color = np.array(train_id_to_color)
    train_id_to_name = np.array([c.name for c in classes if c.train_id !=255])
    #建立从 原始灰度值 id 到 训练用 train_id 的映射
    id_to_train_id = np.array([c.train_id for c in classes])


    def __init__(self, root,  dataset='nalendovis', split='train', mode='fine', target_type='semantic', transform=None, ACDC_sub = 'night'
               ):
        self.root = os.path.expanduser(root)
        self.dataset = dataset
        self.ACDC_sub = ACDC_sub

        if self.dataset == 'ACDC':
            self.mode = 'gt'
        else:    
            self.mode = 'labels'

        self.target_type = target_type

        if self.dataset == 'ACDC':
            self.images_dir = os.path.join(self.root, 'rgb_anon',self.ACDC_sub,split)
            print(self.images_dir)
        else:
            self.images_dir = os.path.join(self.root, 'images', split)
            
        if self.dataset == 'ACDC':
            self.targets_dir = os.path.join(self.root, self.mode, self.ACDC_sub, split)
        else:
            self.targets_dir = os.path.join(self.root, self.mode, split)

        self.transform = transform

        self.split = split
        self.images = []
        self.targets = []

        if split not in ['train', 'test', 'val']:
            raise ValueError('Invalid split for mode! Please use split="train", split="test"'
                             ' or split="val"')

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
            raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                               ' specified "split" and "mode" are inside the "root" directory')
        
        #遍历每一个城市文件夹
        for city in os.listdir(self.images_dir):
            city_path = os.path.join(self.images_dir, city)
            # 跳过非文件夹或名为 'train' 的文件夹
            if not os.path.isdir(city_path) or city.lower() == 'train':
                continue
            # if not os.path.isdir(os.path.join(self.images_dir, city)):
            #     continue
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)

            for file_name in os.listdir(img_dir):
                if self.ACDC_sub == "fog":
                    if Image.open(os.path.join(img_dir, file_name)).convert('RGB').size == (1920,1080):
                        self.images.append(os.path.join(img_dir, file_name))
                else:
                    self.images.append(os.path.join(img_dir, file_name))

                if self.dataset == "ACDC":
                    target_name = '{}_{}'.format(file_name.split('_rgb_anon')[0],
                                             self._get_target_suffix(self.mode, self.target_type))
                else:
                    target_name = file_name.replace('.png', '_labelTrainIds.png')

                if self.ACDC_sub == "fog":
                    if Image.open(os.path.join(img_dir, file_name)).convert('RGB').size == (1920,1080):
                        self.targets.append(os.path.join(target_dir, target_name))
                else:
                     self.targets.append(os.path.join(target_dir, target_name))

    @classmethod
    def encode_target(cls, target):
        return cls.id_to_train_id[np.array(target)]

    @classmethod
    def decode_target(cls, target):
        target[target == 255] = 19
        #target = target.astype('uint8') + 1
        return cls.train_id_to_color[target]
    
    @classmethod
    def name(cls,target):
        return cls.train_id_to_name[target]

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
        """
            
        image = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.targets[index])
        
        image, target = self.transform(image, target)
        #得到id到train_id的映射表
        target = self.encode_target(target)
        #返回图像,标签,和类别id映射表
        return self.images[index],self.targets[index],image, target

    def __len__(self):
        return len(self.images)

    def _load_json(self, path):
        with open(path, 'r') as file:
            data = json.load(file)
        return data

    def _get_target_suffix(self, mode, target_type):
        if target_type == 'instance':
            return '{}_instanceIds.png'.format(mode)
        elif target_type == 'semantic':
            return '{}_labelIds.png'.format(mode)
        elif target_type == 'color':
            return '{}_color.png'.format(mode)
        elif target_type == 'polygon':
            return '{}_polygons.json'.format(mode)
        elif target_type == 'depth':
            return '{}_disparity.png'.format(mode)
