import os

from collections import namedtuple
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision.transforms import functional

from utils import get_logger, to_float, handle_labels
from torchvision import transforms as T
import torchvision.transforms.functional as TF


Label = namedtuple( 'Label' , [

    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                    # We use them to uniquely name a class

    'id'          , # An integer ID that is associated with this label.
                    # The IDs are used to represent the label in ground truth images
                    # An ID of -1 means that this label does not have an ID and thus
                    # is ignored when creating ground truth images (e.g. license plate).
                    # Do not modify these IDs, since exactly these IDs are expected by the
                    # evaluation server.

    'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                    # ground truth images with train IDs, using the tools provided in the
                    # 'preparation' folder. However, make sure to validate or submit results
                    # to our evaluation server using the regular IDs above!
                    # For trainIds, multiple labels might have the same ID. Then, these labels
                    # are mapped to the same class in the ground truth images. For the inverse
                    # mapping, we use the label that is defined first in the list below.
                    # For example, mapping all void-type classes to the same ID in training,
                    # might make sense for some approaches.
                    # Max value is 255!

    'category'    , # The name of the category that this label belongs to

    'categoryId'  , # The ID of this category. Used to create ground truth images
                    # on category level.

    'hasInstances', # Whether this label distinguishes between single instances or not

    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                    # during evaluations or not

    'color'       , # The color of this label
    'm_color'       , # The color of this label
    ] )

labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color          multiplied color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) , 0      ),
    Label(  'ship'                 ,  1 ,        0 , 'transport'       , 1       , True         , False        , (  0,  0, 63) , 4128768),
    Label(  'storage_tank'         ,  2 ,        1 , 'transport'       , 1       , True         , False        , (  0, 63, 63) , 4144896),
    Label(  'baseball_diamond'     ,  3 ,        2 , 'land'            , 2       , True         , False        , (  0, 63,  0) , 16128  ),
    Label(  'tennis_court'         ,  4 ,        3 , 'land'            , 2       , True         , False        , (  0, 63,127) , 8339200),
    Label(  'basketball_court'     ,  5 ,        4 , 'land'            , 2       , True         , False        , (  0, 63,191) , 12533504),
    Label(  'Ground_Track_Field'   ,  6 ,        5 , 'land'            , 2       , True         , False        , (  0, 63,255) , 16727808),
    Label(  'Bridge'               ,  7 ,        6 , 'land'            , 2       , True         , False        , (  0,127, 63) , 4161280),
    Label(  'Large_Vehicle'        ,  8 ,        7 , 'transport'       , 1       , True         , False        , (  0,127,127) , 8355584),
    Label(  'Small_Vehicle'        ,  9 ,        8 , 'transport'       , 1       , True         , False        , (  0,  0,127) , 8323072),
    Label(  'Helicopter'           , 10 ,        9 , 'transport'       , 1       , True         , False        , (  0,  0,191) , 12517376),
    Label(  'Swimming_pool'        , 11 ,       10 , 'land'            , 2       , True         , False        , (  0,  0,255) , 16711680),
    Label(  'Roundabout'           , 12 ,       11 , 'land'            , 2       , True         , False        , (  0,191,127) , 8371968),
    Label(  'Soccer_ball_field'    , 13 ,       12 , 'land'            , 2       , True         , False        , (  0,127,191) , 12549888),
    Label(  'plane'                , 14 ,       13 , 'transport'       , 1       , True         , False        , (  0,127,255) , 16744192),
    Label(  'Harbor'               , 15 ,       14 , 'transport'       , 1       , True         , False        , (  0,100,155) , 10183680),
]

color2id = {
    l.m_color : l.id
    for l in labels
}

id2name = {
    l.id: l.name
    for l in labels
}

def get_isaid_datasets(args):
    logger = get_logger(__name__, args)
    # get the training and validation datasets

    # get the list of training patches: [(path_to_patch, state), ...]
    training_patches = []
    d = os.path.join(args.isaid_dir, 'train/images')
    for fn in sorted(os.listdir(d)):
        if fn.endswith('_instance_color_RGB.png'):
            training_patches.append(
                (
                    os.path.join(d, fn.replace('_instance_color_RGB', '')),  # images png file
                    os.path.join(d, fn)  # label
                )
        )
    logger.info(f"Added {len(training_patches)} training patches.")

    validation_patches = []
    d = os.path.join(args.isaid_dir, 'val/images')
    for fn in sorted(os.listdir(d)):
        if fn.endswith('_instance_color_RGB.png'):
            validation_patches.append(
                (
                    os.path.join(d, fn.replace('_instance_color_RGB', '')),  # images png file
                    os.path.join(d, fn)  # label
                )
        )
    logger.info(f"Added {len(validation_patches)} validation patches.")
   
    training_dataset = Dataset(args, training_patches,mode = 'train')
    validation_dataset = Dataset(args, validation_patches,mode = 'validation') #seed =  args.seed

    return training_dataset, validation_dataset


class Dataset(torch.utils.data.Dataset):
    def __init__(self, args, patch_list, mode='train'):
        super(Dataset).__init__()
        self.args = args
        self.patches = patch_list
        # mode: train. data augmentation is triggered
        # mode: valid(validation). no data augmentation.
        self.mode = mode


    def __getitem__(self, index):
        fn_img, fn_label = self.patches[index]
        img = Image.open(fn_img)
        label = Image.open(fn_label)

        img = TF.to_tensor(img)
        if img.shape[1] != 896 or img.shape[2] != 896:
            print(fn_img)
            print(img.shape)
            raise ValueError

        label = TF.to_tensor(label) * 255
        label = label[0] + label[1] * 256 + label[2] * 256 * 256
        label.map_(label, lambda i, *y: color2id[i])

        if self.mode == 'train':
            # image preprocessing
            angle = np.random.choice([0, 90, 180, 270])
            vflip = np.random.choice([0, 1])
            hflip = np.random.choice([0, 1])

            if vflip:
                img = TF.vflip(img)
                label = TF.vflip(label)
            
            if hflip:
                img = TF.hflip(img)
                label = TF.hflip(label)
            
            if angle > 0:
                img = TF.rotate(img, angle.item())
                label = TF.rotate(label.unsqueeze(dim=0), angle.item())

        return img, label.squeeze(dim=0).long()



    def __len__(self):
        return len(self.patches)

