#!/usr/bin/python
# -*- coding: utf-8 -*-
import os
import torch
import torch.utils.data as data
import torch
import torchvision.transforms as transforms
import random
from PIL import Image, ImageOps
import numpy as np
import preprocess
import cv2

IMG_EXTENSIONS = [
    '.jpg',
    '.JPG',
    '.jpeg',
    '.JPEG',
    '.png',
    '.PNG',
    '.ppm',
    '.PPM',
    '.bmp',
    '.BMP',
    ]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in
               IMG_EXTENSIONS)


def default_loader(path):
    return Image.open(path).convert('RGB')


def disparity_loader(path):
    return Image.open(path)


def load_seg(x):
    return torch.from_numpy(cv2.imread(x, cv2.IMREAD_UNCHANGED).T).long()

def load_mask( x ):
    return torch.from_numpy(cv2.imread(x, cv2.IMREAD_UNCHANGED).T > 0 ).long()

# rgb_imgs - list of lists [ left_imgs , right_imgs , center_imgs  ]

class BEVDataLoader(data.Dataset):

    def __init__(
        self,
        rgb_imgs,
        training,
        ipm_imgs=None,
        hmap=None,
        mask=None,
        loader=default_loader,
        hmap_loader=None,
        seg_imgs=None,
        seg_imgs2=None,
        th = 288,
        tw = 512 ,
        mask_imgs=False , 
        cam_confs=None, # if each image has different camera conf
        imp_m=None, 
        return_dict=False ,
        hmap_max = None 
        ):

        self.rgb_imgs = rgb_imgs

#         self.left = left
#         self.right = right

        self.hmap = hmap
        self.mask = mask
        self.loader = loader
        self.hmap_loader = hmap_loader
        self.training = training
        self.seg_imgs = seg_imgs # for top view
        self.seg_imgs2 = seg_imgs2 # for front view 
        self.ipm_imgs = ipm_imgs
        
        self.tw = tw # witth at training time 
        self.mask_imgs = mask_imgs
        self.th = th # height at trainig time
        self.return_dict = return_dict 
        self.cam_confs = cam_confs
        self.imp_m = imp_m

        self.hmap_max = hmap_max 
        

    def __getitem__(self, index):

        rgb = []
        for rr in self.rgb_imgs:
            rgb.append(rr[index])

        rgb_imgs = list(map(lambda x: self.loader(x), rgb))
        
        
        if not self.ipm_imgs is None:
            imp_img  = self.loader( self.ipm_imgs[ index ] )
        
        if not self.mask is None:
            mask = load_mask(self.mask[index])
            
            
        if not self.cam_confs is None:
            cam_conf = np.load( self.cam_confs[ index] )
            
        if not self.imp_m is None:
            imp_m = np.load( self.imp_m[ index] )

        if not self.hmap is None:
            hmap = self.hmap[index]
            hmap = cv2.imread(  hmap , cv2.IMREAD_UNCHANGED ) 
            hmap = hmap.astype(np.float32)
            hmap = np.clip(hmap , 0 ,  self.hmap_max )
            hmap = hmap/self.hmap_max
            hmap = hmap.T

        if not self.seg_imgs is None:
            seg_img = load_seg(self.seg_imgs[index])
            
        if not self.seg_imgs2 is None:
            seg_img2 = load_seg(self.seg_imgs2[index])
            
        if self.mask_imgs:
            seg_img[mask<0.5] = -100
            
        

        if True or self.training:
            (w, h) = rgb_imgs[0].size
            th = self.th
            tw = self.tw 
            
            assert w/h == tw / th , (w/h , tw/th) # ratio shall not change 

            rr = []
            ret_dict = {}

            rgb_imgs = list(map(lambda x: x.resize(( tw , th )), rgb_imgs))
            
            
            
                

            if not self.seg_imgs is None:
                rr.append(seg_img)
                ret_dict['seg_img'] = seg_img 
                
            if not self.seg_imgs2 is None:
                rr.append(seg_img2)
                ret_dict['seg_img_front'] = seg_img2
                
            if not self.mask is None:
                rr.append(mask)
                ret_dict['mask'] = mask 
                
            if not self.hmap is None:

                rr.append( hmap  )
                ret_dict['hmap'] = hmap 
                

            processed = preprocess.get_transform(augment=False)

            rgb_imgs = list(map(lambda x: processed(x), rgb_imgs))
            ret_dict['rgb_imgs'] = rgb_imgs 
            
            if not self.ipm_imgs is None:
                imp_img = processed( imp_img ).permute(0 , 2 , 1)
                ret_dict['img_ipm'] = imp_img 
                rgb_imgs_ret = rgb_imgs + [ imp_img  ]
            else:
                rgb_imgs_ret = rgb_imgs
                
                
            if not self.cam_confs is None:
                ret_dict['cam_conf'] = cam_conf
                rgb_imgs_ret.append( cam_conf )
                
                
            if not self.imp_m is None:
                
                ret_dict['imp_m'] = imp_m
                rgb_imgs_ret.append( imp_m )
            
            
            
            
            if self.return_dict:
                return ret_dict 
                
            return (rgb_imgs_ret, rr)
            


    def __len__(self):
        return len(self.rgb_imgs[0])



			