import numpy as np
from functools import reduce
import copy, time
import sys
if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
    sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import cv2 as cv
import scipy.stats as ss
import math

class HashTable(object):
    def __init__(self):
        self.__key_llist = []
        self.__val_list = []

    def clear(self):
        self.__key_llist.clear()
        self.__val_list.clear()

    def add(self, key, val):
        self.__key_llist.append(key)
        self.__val_list.append(val)

    def get(self, key):
        return [self.__val_list[i] for i, k in enumerate(self.__key_llist) if key==k]
        # return self.__val_list[self.__key_llist.index(key)] if key in self.__key_llist else None
    
    def remove(self, val):
        index = self.__val_list.index(val)
        self.__key_llist = self.__key_llist[:index] + self.__key_llist[index+1 :]
        self.__val_list = self.__val_list[:index] + self.__val_list[index+1 :]

    def travel(self):
        return self.__val_list

            
class Box(object):
    def __init__(self, x, y, z, lx, ly, lz, box_index, offset=None):
        self.x = x
        self.y = y
        self.z = z
        self.lx = lx
        self.ly = ly
        self.lz = lz
        self.box_index = box_index
        self.offset = offset
        self.dir = None
    
    def standardize(self):
        return tuple([self.x, self.y, self.z, self.lx, self.ly, self.lz])

class Position(object):
    def __init__(self, lx, ly, lz):
        self.x = lx
        self.y = ly
        self.z = lz


class Space(object):
    def __init__(self, width=20, length=20, height=20, boxlist_len=30, act_len=100):
        self.plain_size = np.array([width*10, length*10, height*10])
        self.plain = np.zeros(shape=(width*10, length*10), dtype = np.uint8)
        self.boxes = []
        self.flags = [] # record rotation information
        self.height = height*10
        self.solidmask = np.ones(shape=(width*10, length*10), dtype = np.uint8)
        self.layout_ = None
        self.layout_ = np.zeros((boxlist_len+1,boxlist_len+1))
        self.layout_[0] = np.ones(boxlist_len+1)
        self.layout_[0][0] = 0
        self.position_choices_set = HashTable()
        self.position_choices_set.add(0,Position(0,0,0))
        self.position_choices_conner = HashTable()
        self.position_choices_conner.add(0,Position(0,0,0))
        
        self.layout = np.zeros((3,boxlist_len+1,boxlist_len+1))
        for l in self.layout:
            l[0]=np.ones(boxlist_len+1)
            l[0][0]=0 
        self.kToolSpongeHeight = 3
        self.kSlideLen = 10
        self.kToolLengthOriginal = 60
        self.kToolWidthOriginal = 40
        self.kSlideMinTolerantLen = 1
        self.offset = None
        self.kHTolerance = 2

    def print_height_graph(self):
        print(self.plain)

    def get_height_graph(self):
        plain = np.zeros(shape=self.plain_size[:2], dtype=np.int32)
        for box in self.boxes:
            plain = self.update_height_graph(plain, box)
        return plain

    
    def update_height_graph(self, plain, box):
        plain = copy.deepcopy(plain)
        le = box.lx
        ri = box.lx + box.x
        up = box.ly
        do = box.ly + box.y
        max_h = np.max(plain[le:ri, up:do])
        max_h = max(max_h, box.lz + box.z)
        plain[le:ri, up:do] = max_h
        return plain

    def get_box_list(self):
        vec = list()
        for box in self.boxes:
            vec += box.standardize()
        return vec

    def get_plain():
        return copy.deepcopy(self.plain)

    # def get_action_space(self):
    #     return self.plain_size[0] * self.plain_size[1]

    def get_corners(self):
        width = self.plain_size[0]
        length = self.plain_size[1]
        guad = [list() for _ in range(4)]

        guad[0].append((width, 0))
        guad[1].append((width, length))
        guad[2].append((0, length))
        guad[3].append((0, 0))

        for i in range(10, width,10):
            if self.plain[i, 0] != self.plain[i-10, 0]:
                guad[0].append((i, 0))
                guad[3].append((i, 0))

        for i in range(10, width,10):
            if self.plain[i, length-10] != self.plain[i-10, length-10]:
                guad[1].append((i, length))
                guad[2].append((i, length))

        for j in range(10, length,10):
            if self.plain[0, j] != self.plain[0, j-10]:
                guad[2].append((0, j))
                guad[3].append((0, j))

        for j in range(10, length,10):
            if self.plain[width-10, j] != self.plain[width-10, j-10]:
                guad[0].append((width, j))
                guad[1].append((width, j))

        for i in range(10, width,10):
            for j in range(10, length,10):
                grid_0 = self.plain[i-10, j]
                grid_1 = self.plain[i-10, j-10]
                grid_2 = self.plain[i, j-10]
                grid_3 = self.plain[i, j]
                if grid_0 == grid_1 and grid_2 == grid_3:
                    continue
                if grid_0 == grid_3 and grid_1 == grid_2:
                    continue
                if grid_0 != grid_3 or grid_0 != grid_1:
                    guad[0].append((i, j))
                if grid_1 != grid_0 or grid_1 != grid_2:
                    guad[1].append((i, j))
                if grid_2 != grid_1 or grid_2 != grid_3:
                    guad[2].append((i, j))
                if grid_3 != grid_2 or grid_3 != grid_0:
                    guad[3].append((i, j))

        return guad
    def UpdateChoices(self, box):
        curr = Position(box.lx,box.ly,box.lz)
        self.BatchDelete(box, curr)
        self.AddUpperChoicePos(box, curr)
        self.AddChoicePos(Position(box.lx+box.x,box.ly,box.lz))
        self.AddChoicePos(Position(box.lx,box.ly+box.y,box.lz))
        self.AddChoicePos(Position(box.lx,box.ly,box.lz+box.z))
        
    def AddUpperChoicePos(self, box, position):
        self.position_choices_conner.add(math.floor(position.x/10)*10+math.floor(position.y/10),Position(position.x, position.y, position.z+box.z))
        for x in range(position.x, position.x+box.x,10):
            for y in range(position.y, position.y+box.y,10):
                self.position_choices_set.add(math.floor(x/10)*10+math.floor(y/10),Position(x, y, position.z+box.z))

    def BatchDelete(self, box, position):
        deleteNum = 0
        for pos in self.position_choices_conner.travel():
            xInRange = False
            yInRange = False
            zInRange = False
            if(pos.x >= position.x and pos.x < position.x + box.x): xInRange = True
            else: continue
            if(pos.y >= position.y and pos.y < position.y + box.y): yInRange = True
            else: continue
            if(pos.z < position.z + box.z): zInRange = True
            else: continue
            if(xInRange and yInRange and zInRange):
                self.position_choices_conner.remove(pos)
                deleteNum +=1
        return deleteNum

    def AddChoicePos(self, position):
        z = position.z
        if(position.x >= self.plain_size[0] or position.y >= self.plain_size[1] or position.z >= self.plain_size[2]): return
        if(self.plain[position.x, position.y]<position.z): z = self.plain[position.x, position.y]
        x_min, y_min = 0, 0
        self.position_choices_set.add(math.floor(position.x/10)*10+math.floor(position.y/10),Position(position.x, position.y, z))
        self.position_choices_conner.add(math.floor(position.x/10)*10+math.floor(position.y/10),Position(position.x, position.y, z))
        # for x in range(position.x,-1,-10):
        #     if(self.plain[x,position.y]>z):
        #         x_min = x+1
        #         break
        # for y in range(position.y,-1,-10):
        #     if(self.plain[position.x,y]>z):
        #         y_min = y+1
        #         break
        
        # for x in range(position.x,x_min-1,-10):
        #     if self.plain[x,position.y]<=z:
        #         for y in range(position.y,y_min-1,-10):
        #             if self.plain[x,y]>z: break
        #             else: 
        #                 self.position_choices_set.add(math.floor(x/10)*10+math.floor(y/10),Position(x, y, z))
    
    def CheckTwoEdges(self, plain, LB, RB, FB, BB, bottom_z, slide_dir):
        RB = min(RB,120)
        BB = min(BB,100)
        LB = max(LB,0)
        FB = max(FB,0)
        if slide_dir == 'kRightBack':
            for x in range(LB, RB):
                if plain[x,BB-1] > bottom_z:
                    return -1
            for y in range(FB,BB):
                if plain[RB-1,y] > bottom_z:
                    return -1
            return 1
        elif slide_dir == 'kRightFront':
            if FB-1 >= 0:
                for x in range(LB,RB):
                    if plain[x,FB] > bottom_z:
                        return -1
            for y in range(FB,BB):
                if plain[RB-1,y] > bottom_z:
                    return -1
            return 1
        elif slide_dir == 'kLeftFront':
            for x in range(LB, RB):
                if plain[x,FB] > bottom_z:
                    return -1
            for y in range(FB,BB):
                if plain[LB,y] > bottom_z:
                    return -1
            return 1
        elif slide_dir == 'kLeftBack':
            for x in range(LB, RB):
                if plain[x,BB-1] > bottom_z:
                    return -1
            for y in range(FB,BB):
                if plain[LB,y] > bottom_z:
                    return -1
            return 1
        else:
            print("Invalid enum: slide_dir")
        

    def CheckSlideCollision(self, plain, LB, RB, FB, BB, bottom_z, type, slide_dir):
        kSlideLen = self.kSlideLen
        # 海绵凹陷
        if (type == 'kTool'):
            bottom_z -= self.kToolSpongeHeight
        # check bottom side

        for i in range(LB,RB):
            for j in range(FB,BB):
                if plain[i,j] > bottom_z:
                    return -1
        if slide_dir == 'kRightBack' or slide_dir == 'kRightFront':
            slide_x_sign = 1
        else:
            slide_x_sign = -1
        if slide_dir == 'kLeftBack' or slide_dir == 'kRightBack':
            slide_y_sign = 1
        else:
            slide_y_sign = -1
        for i in range(0,kSlideLen):
            if(self.CheckTwoEdges(plain, LB+slide_x_sign*i, RB+slide_x_sign*i,FB+slide_y_sign*i,BB+slide_y_sign*i, i+bottom_z,slide_dir) == -1):
                return -1
        return 1
    
    # def slideCollision(self, x, y, lx, ly, z):
    #     if (lx+x) > self.plain_size[0] or (ly+y) > self.plain_size[1]:
    #         return -1
    #     if lx < 0 or ly < 0:
    #         return -1
    #     sub_hmap = self.plain[lx:lx+x, ly:ly+y]
    #     lz = np.max(sub_hmap)
    #     lower_hmap = self.space_cm[lx*10:(lx+1)*10,ly*10:(ly+1)*10]
    #     len_ = 10
    #     # get bottom-left point
    #     i, j = 0, 0
    #     for j in range(0,len_):
    #         if lower_hmap[0,j]==lz :
    #             break
    #     for i in range(0,len_):
    #         if lower_hmap[i,0]==lz:
    #             break
    #     if self.SlideCollisionDir(self.space_cm, x*10+i, y*10+j, lx*10, ly*10, z, True):
    #         return True
    #     else:
    #         return False
    #     # np.random.randint(-5,5)
    
        
    def SlideCollisionDir(self, plain, x, y, lx, ly, z):
        kToolLengthOriginal = self.kToolLengthOriginal
        kToolWidthOriginal = self.kToolLengthOriginal
        plain_size = [self.plain_size[0],self.plain_size[1]]
        if (lx+x) > plain_size[0] or (ly+y) > plain_size[1]:
            return -1
        if lx < 0 or ly < 0:
            return -1
        tool_check = False
        left_border = lx
        right_border = lx + x
        front_border = ly
        back_border = ly + y
        sub_hmap = plain[lx:lx+x,ly:ly+y]
        lz = np.max(sub_hmap)
        # 8 ToolCornerDirection
        arr = ['kLeftFrontOri', 'kLeftFrontZcw', 'kLeftBackOri', 'kLeftBackZcw',
        'kRightBackOri', 'kRightBackZcw', 'kRightFrontOri', 'kRightFrontZcw']
        tool_corner_direction_ = []
        # 12 slide directions.
        dirs = ['kRightBack', 'kRightFront', 'kLeftFront', 'kLeftBack',
         'kRightBack1', 'kRightBack2', 'kRightFront1', 'kRightFront2',
          'kLeftBack1', 'kLeftBack2', 'kLeftFront1', 'kLeftFront2']
        for i in range(0,4):
            if(i<4):# 4个正45度滑动方向
                if self.CheckSlideCollision(plain, left_border,right_border,front_border,back_border,lz, 'kBox', dirs[i]) == -1:
                    continue
                for a in range(0,8):
                    if(a%2 == 1):
                        tool_width_ = kToolLengthOriginal
                        tool_length_ = kToolWidthOriginal
                    else:
                        tool_width_ = kToolWidthOriginal
                        tool_length_ = kToolLengthOriginal
                    if a==0 or a==1 :
                        tool_left_border = lx
                        tool_right_border = min(plain_size[0],lx + tool_length_)
                        tool_front_border = ly
                        tool_back_border = min(plain_size[1],ly + tool_width_)
                    elif a==2 or a==3:
                        tool_left_border = lx
                        tool_right_border = min(plain_size[0],lx + tool_length_)
                        tool_front_border = max(0,ly+y-tool_width_)
                        tool_back_border = ly + y
                    elif a==4 or a==5:
                        tool_left_border = max(0,lx+x-tool_length_)
                        tool_right_border = lx + x
                        tool_front_border = max(0,ly+y-tool_width_)
                        tool_back_border = ly + y
                    elif a==6 or a==7:
                        tool_left_border = max(0,lx+x-tool_length_)
                        tool_right_border = lx + x
                        tool_front_border = ly
                        tool_back_border = min(plain_size[1],ly + tool_width_)
                    if self.CheckSlideCollision(plain, tool_left_border,tool_right_border,tool_front_border,tool_back_border,lz+z,'kTool',dirs[i]) == -1:
                        continue
                    else:
                        tool_check = True
                        return 1
        return tool_check

            
    def TryConvexHull(self, plain, x, y, lx, ly, z):
        solidmask = copy.deepcopy(self.solidmask)
        plain = copy.deepcopy(plain)
        # check boundary
        if (lx+x) >= self.plain_size[0] or (ly+y) >= self.plain_size[1]:
            return -1
        if lx < 0 or ly < 0:
            return -1
        sub_hmap = plain[lx:lx+x, ly:ly+y]
        max_h = np.max(sub_hmap)
        assert max_h >= 0
        if max_h + z > self.plain_size[2]:
            return -1

        sub_hmap = np.where(sub_hmap >= max_h-self.kHTolerance, sub_hmap+1, 0)
        sub_smap = solidmask[lx:lx+x, ly:ly+y]

        sub_eroded_mask = sub_hmap #np.multiply(sub_hmap,sub_smap) #sub_hmap #
        if np.min(sub_eroded_mask) > 0:
            return 1
        contours, _ = cv.findContours(sub_eroded_mask, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
        center = ((y-1)/2, (x-1)/2)
        if len(contours) > 0:
            corners = contours[0]
            for i in range (1,len(contours)):      
                corners = np.concatenate((corners,contours[i]),axis=0)  
            hull_corners = cv.convexHull(corners)
            dist = cv.pointPolygonTest(hull_corners,center,measureDist=True)
            if(dist >0):
                max_sub_eroded_mask = np.max(sub_eroded_mask)
                return np.sum(sub_eroded_mask == max_sub_eroded_mask)/(x*y)
        return -1

    def UpdateSolidMask(self, plain, box, solidmask):
        solidmask = copy.deepcopy(solidmask)
        lx = box.lx
        ly = box.ly
        x = box.x
        y = box.y
        sub_hmap = plain[lx:lx+x, ly:ly+y]
        max_h = np.max(sub_hmap)
        sub_hmap = np.where(sub_hmap >= max_h-self.kHTolerance, sub_hmap+1, 0)
        sub_smap = solidmask[lx:lx+x, ly:ly+y]

        sub_eroded_mask = sub_hmap #np.multiply(sub_hmap,sub_smap) #sub_hmap #
        if np.min(sub_eroded_mask) > 0:
            return None
        contours, _ = cv.findContours(sub_eroded_mask, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
        if len(contours) > 0:
            corners = contours[0]
            for i in range (1,len(contours)):      
                corners = np.concatenate((corners,contours[i]),axis=0)  
            hull_corners = cv.convexHull(corners)
            hull_list = [hull_corners]
            drawing = np.zeros(shape=(x, y), dtype = np.uint8)
            cv.drawContours(drawing, hull_list, -1, 1, cv.FILLED)
            return drawing
        else:
            return None

    def UpdateLayout_z(self, box):
        lx = box.lx
        ly = box.ly
        x = box.x
        y = box.y
        lz = box.lz
        for b in self.boxes:
            if (b.z+b.lz)==lz :
                delta_y = np.min((b.y+b.ly-ly,ly+y-b.ly))
                delta_x = np.min((b.x+b.lx-lx,lx+x-b.lx))
                if delta_x>=x/2 and delta_y>=y/2:
                    self.layout[0][box.box_index][b.box_index] =1
                    self.layout_[b.box_index][box.box_index] =1

    def UpdateLayout_x(self, box):
        lz = box.lz
        ly = box.ly
        z = box.z
        y = box.y
        lx = box.lx
        for b in self.boxes:
            if (b.x+b.lx)==lx :
                delta_y = np.min((b.y+b.ly-ly,ly+y-b.ly))
                delta_z = np.min((b.z+b.lz-lz,lz+z-b.lz))
                if delta_z==z and delta_y>0:
                    self.layout[1][box.box_index][b.box_index] =1
                    self.layout_[b.box_index][box.box_index] =1

    def UpdateLayout_y(self, box):
        lx = box.lx
        lz = box.lz
        x = box.x
        z = box.z
        ly = box.ly
        for b in self.boxes:
            if (b.y+b.ly)==ly :
                delta_z = np.min((b.z+b.lz-lz,lz+z-b.lz))
                delta_x = np.min((b.x+b.lx-lx,lx+x-b.lx))
                if delta_x>0 and delta_z==z:
                    self.layout[2][box.box_index][b.box_index] =1
                    self.layout_[b.box_index][box.box_index] =1
    
    def get_ratio(self):
        vo = reduce(lambda x, y: x+y, [box.x * box.y * box.z for box in self.boxes], 0.0)
        mx = self.plain_size[0] * self.plain_size[1] * self.plain_size[2]
        ratio = vo / mx
        assert ratio <= 1.0
        return ratio

    def idx_to_position(self, idx):
        ly = idx % self.plain_size[1]
        lx = (idx -ly) // self.plain_size[1]
        return lx, ly

    def drop_box(self, box_size, idx, flag, box_index):
        lx, ly = self.idx_to_position(idx)
        if not flag:
            x = box_size[0]
            y = box_size[1]
        else:
            x = box_size[1]
            y = box_size[0]
        z = box_size[2]
        plain = self.plain
        sub_hmap = plain[lx:lx+x, ly:ly+y]
        new_h = np.max(sub_hmap)
        succeed_c = self.TryConvexHull(plain, x, y, lx, ly, z)
        if succeed_c > 0:
            succeed_s = True
            # succeed_s = self.SlideCollisionDir(plain, x, y, lx, ly, z, False)
            # succeed_s = self._slideCollisionDir(x, y, lx, ly, z)
            if succeed_s == True:
                self.boxes.append(Box(x, y, z, lx, ly, new_h, box_index)) # record rotated box
                self.flags.append(flag)
                smap = self.UpdateSolidMask(plain, self.boxes[-1], self.solidmask)
                # self.UpdateLayout_z(Box(x, y, z, lx, ly, new_h, box_index))
                # self.UpdateLayout_x(Box(x, y, z, lx, ly, new_h, box_index))
                # self.UpdateLayout_y(Box(x, y, z, lx, ly, new_h, box_index))
                if smap is not None:
                    self.solidmask[lx:lx+x, ly:ly+y] = smap
                self.plain = self.update_height_graph(plain, self.boxes[-1])
                self.height = max(self.height, new_h + z)
                self.UpdateChoices(self.boxes[-1])
                return True
        return False

    # def get_conner_point(self):
    #     plain = self.plain

    #     width = self.plain_size[0]
    #     length = self.plain_size[1]

    #     guad = []
    #     guad.append((width, 0))
    #     guad.append((width, length))
    #     guad.append((0, length))
    #     guad.append((0, 0))

    #     for i in range(1, width):
    #         if plain[i, 0] != plain[i - 1, 0]:
    #             guad.append((i, 0))

    #     for i in range(1, width):
    #         if plain[i, length - 1] != plain[i - 1, length - 1]:
    #             guad.append((i, length))

    #     for j in range(1, length):
    #         if plain[0, j] != plain[0, j - 1]:
    #             guad.append((0, j))

    #     for j in range(1, length):
    #         if plain[width - 1, j] != plain[width - 1, j]:
    #             guad.append((width, j))

    #     for i in range(1, width):
    #         for j in range(1, length):
    #             grid_0 = plain[i - 1, j]
    #             grid_1 = plain[i - 1, j - 1]
    #             grid_2 = plain[i, j - 1]
    #             grid_3 = plain[i, j]
    #             if grid_0 == grid_1 and grid_2 == grid_3:
    #                 continue
    #             if grid_0 == grid_3 and grid_1 == grid_2:
    #                 continue
    #             if grid_0 != grid_3 or grid_0 != grid_1:
    #                 guad.append((i, j))
    #             if grid_1 != grid_0 or grid_1 != grid_2:
    #                 guad.append((i, j))
    #             if grid_2 != grid_1 or grid_2 != grid_3:
    #                 guad.append((i, j))
    #             if grid_3 != grid_2 or grid_3 != grid_0:
    #                 guad.append((i, j))
    #     guad_new = list(set(guad))
    #     corner_mask = np.zeros(shape=(width, length), dtype=np.int32)
    #     for pos in guad_new:
    #         if pos[0] < width and pos[1] < length :
    #             corner_mask[pos[0],pos[1]] = 1
    #     return corner_mask
