import torch
import os
import tifffile
import numpy as np
from glob import glob
from PIL import Image
import torchvision.transforms.functional as TF

from collections import namedtuple

Label = namedtuple( 'Label' , ['id','name','color','m_color'])
labels = [Label(0,'Impervious surfaces', [255, 255, 255],255+255*256+255*256*256),
        Label(1,'Building',              [0, 0, 255],    255*256*256),
        Label(2,'Low vegetation' ,       [0, 255, 255],  255*256+255*256*256),
        Label(3,'Tree',                  [0, 255, 0],    255*256),
        Label(4,'Car',                   [255, 255, 0],  255+255*256),
        Label(5,'Clutter/background',    [255, 0, 0],    255)]

_color2id = {l.m_color : l.id for l in labels}
def color2id(x):
    if x in _color2id:
        return _color2id[x]
    return 255


def load_tiffasnp(file_path):
    img_arr = tifffile.imread(file_path).astype(np.float32)
    return img_arr #[0:self.img_height, 0:self.img_width]


class ISPRS_Dataset(torch.utils.data.Dataset):
    def __init__(self,img_path,label_path,
                 img_ids,
                patch=(896,896),enhanceIM_pad=0,stride=512,mode='train'):
        super().__init__()

        self.img_path = img_path #+'/top_potsdam_{}_RGB.tif'
        self.label_path = label_path #+'/top_potsdam_{}_label.tif'
        self.img_ids = img_ids
        self.patch = patch
        self.enhanceIM_pad = enhanceIM_pad
        self.stride = stride
        ##images [3*H*W] ~ [0,1]
        ##labels [J*W] ~ long 0,1,2,3,4,5
        self.images,self.labels,self.data = self._load_data()
        
        self.mode = mode
        print(f'The patch size is: {patch}, we use stride: {stride}')
        print(f'In {self.mode} mode, the images number is :{len(self.images)}, The total data is: {len(self.data)}')
        


    def _load_data(self,):
        stride = self.stride
        patch = self.patch
        images,labels,data = [],[],[]
        for idx,img_id in enumerate(self.img_ids):
            im = TF.to_tensor(load_tiffasnp(self.img_path.format(img_id)))/255

            label = TF.to_tensor(load_tiffasnp(self.label_path.format(img_id)))
            if ('4_12' in img_id) or ('6_7' in img_id):
                label[label>128] = 255
                label[label<=128] = 0 ###Potsdam data bug
                print('Potsdam data bug fix',self.label_path.format(img_id))
            label = label[0] + label[1] * 256 + label[2] * 256 * 256
            label.map_(label, lambda i, *y: color2id(i))
            
            print(self.img_path.format(img_id),im.shape,label.shape,torch.sum(label>5)) 

            images.append(im)
            labels.append(label)
        
            for i in range(0,im.shape[1],stride):
                for j in range(0,im.shape[2],stride):
                    data.append([idx,i,i+patch[0],j,j+patch[1]])
        
        return images,labels,data
        
    def __getitem__(self,idx):
        img_idx,i,i_,j,j_ = self.data[idx]
#         print(i,i_,j,j_)
#         print(max(0,i-self.enhanceIM_pad),i_+self.enhanceIM_pad,max(0,j-self.enhanceIM_pad),j_+self.enhanceIM_pad)
        img = self.images[img_idx][:,max(0,i-self.enhanceIM_pad):i_+self.enhanceIM_pad,max(0,j-self.enhanceIM_pad):j_+self.enhanceIM_pad]
        label = self.labels[img_idx][max(0,i-self.enhanceIM_pad):i_,max(0,j-self.enhanceIM_pad):j_].clone()
        label[:min(i,self.enhanceIM_pad)] = 255
        label[:,:min(j,self.enhanceIM_pad)] = 255
                
                
        if img.shape[1]<self.patch[0]+2*self.enhanceIM_pad or img.shape[2]<self.patch[1]+2*self.enhanceIM_pad:
            _img = torch.ones(img.shape[0],self.patch[0]+2*self.enhanceIM_pad,self.patch[1]+2*self.enhanceIM_pad)
            _img[:,:img.shape[1],:img.shape[2]] = img
            img = _img
        if label.shape[0]<self.patch[0]+2*self.enhanceIM_pad or label.shape[1]<self.patch[1]+2*self.enhanceIM_pad:
            _label = 255*torch.ones(self.patch[0]+2*self.enhanceIM_pad,self.patch[1]+2*self.enhanceIM_pad) ##255 ignore idx
            _label[:label.shape[0],:label.shape[1]] = label
            label = _label
            
        if self.mode == 'train':
            # image preprocessing
            angle = np.random.choice([0, 90, 180, 270])
            vflip = np.random.choice([0, 1])
            hflip = np.random.choice([0, 1])

            if vflip:
                img = TF.vflip(img)
                label = TF.vflip(label)
            
            if hflip:
                img = TF.hflip(img)
                label = TF.hflip(label)
            
            if angle > 0:
                img = TF.rotate(img, angle.item())
                label = TF.rotate(label.unsqueeze(dim=0), angle.item())

        return img, label.squeeze(dim=0).long()
    
    def __len__(self,):
        return len(self.data)


class PotsDam_Dataset(ISPRS_Dataset):
    def __init__(self,img_path='/scratch/forest/datasets/PostDam/2_Ortho_RGB/top_potsdam_{}_RGB.tif',label_path='/scratch/forest/datasets/PostDam/5_Labels_all/top_potsdam_{}_label.tif',
                 img_ids = ['2_10','2_11','2_12','3_10','3_11','3_12','4_10','4_11','4_12','5_10','5_11','5_12','6_7','6_8','6_9','6_10','6_11','6_12','7_7','7_8','7_9','7_10','7_11','7_12'],
                patch=(896,896),enhanceIM_pad=0,stride=512,mode='train'):
        '''
        test_image ids: '2_13','2_14','3_13','3_14','4_13','4_14','4_15','5_13','5_14','5_15','6_13','6_14','6_15','7_13'
        '''
        super().__init__(img_path,label_path,img_ids,patch,enhanceIM_pad,stride,mode)

        
class Vaihingen_Dataset(ISPRS_Dataset):
    def __init__(self,img_path='/scratch/forest/datasets/Vaihingen/ISPRS_semantic_labeling_Vaihingen/top/top_mosaic_09cm_area{}.tif',
                 label_path='/scratch/forest/datasets/Vaihingen/ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE/top_mosaic_09cm_area{}.tif',
                 img_ids = ['1','3','5','7','11','13','15','17','21','23','26','28','30','32','34','37'],
                patch=(768,768),enhanceIM_pad=0,stride=512,mode='train'):
        '''
        test_image ids: '2','4','6','8','10','12','14','16','20','22','24','27','29','31','33','35','38'
        '''
        super().__init__(img_path,label_path,img_ids,patch,enhanceIM_pad,stride,mode)
