from .space import Space
import numpy as np
import copy
import gym
from .cutCreator import CuttingBoxCreator
from .mdCreator  import MDlayerBoxCreator
from .binCreator import RandomBoxCreator, LoadBoxCreator, BoxCreator
import math

class PackingGame(gym.Env):
    def __init__(self, box_creator=None, container_size = (12, 10, 15),
                 location_masks = None, box_set = None, data_name = None, test = False,
                 data_type = 'load', boxlist_len = [20,20], enable_rotation=False, **kwags):
        self.box_creator = box_creator
        self.bin_size = container_size
        self.area = int(self.bin_size[0] * self.bin_size[1])
        self.space = Space()
        self.can_rotate = enable_rotation
        self.box_index = 0
        self.boxlist_len = boxlist_len

        if not test and box_creator is None:
            # assert box_set is not None
            if data_type == 'sample':
                print('using random data')
                self.box_creator = RandomBoxCreator(box_set)
            elif data_type == 'depen':
                low = list(box_set[0])
                up = list(box_set[-1])
                low.extend(up)
                print(low)
                self.box_creator = CuttingBoxCreator(container_size, low, self.can_rotate)
            elif data_type == 'md':
                print('using md data')
                self.box_creator = MDlayerBoxCreator(container_size, [box_set[0][0], box_set[-1][0]])
            elif data_type == 'load':
                print('using load data')
                self.box_creator = LoadBoxCreator(box_set)
            assert isinstance(self.box_creator, BoxCreator)

        if test:
            self.box_creator = LoadBoxCreator(box_set)

        self.act_len = self.area * (1+self.can_rotate)
        # ############################
        self.obs_len = self.area * (1+self.can_rotate)
        self.action_space = gym.spaces.Discrete(self.act_len)
        self.observation_space = gym.spaces.Box(low=0.0, high=self.space.height, shape=(self.obs_len,))
        

    def get_box_ratio(self):
        coming_box = self.next_box
        return (coming_box[0] * coming_box[1] * coming_box[2]) / (self.space.plain_size[0] * self.space.plain_size[1] * self.space.plain_size[2])


    def reset(self):
        self.box_creator.reset()
        self.space = Space(*self.bin_size, self.boxlist_len[1],self.act_len)
        self.box_creator.generate_box_size()
        box = self.box_creator.box_list[0]
        self.volumes = box[0]*box[1]*box[2]
        self.yx = box[1]/box[0]
        self.zx = box[2]/box[0]

        return self.box_creator.box_list
        # self.box_creator.reset()
        # self.space = Space(*self.bin_size)
        # self.box_creator.generate_box_size()
        # return self.cur_observation

    def get_obses(self):
        mask1 = self.get_possible_position(rotate=False)
        if self.can_rotate:
            mask2 = self.get_possible_position(rotate=True)
            return np.reshape(np.stack((mask1, mask2)), newshape=(-1,))
        else:
            return np.reshape(np.stack((mask1)), newshape=(-1,))

    def get_obses_sub(self, action):
        flag_rotate = math.floor(action/(self.bin_size[0]*self.bin_size[1]))
        if flag_rotate == 0:
            y = action%10
            x = (action-y)//10
            mask = self.get_possible_position(rotate=False,sub=(x,y))
        else:
            y = action%10
            x = (action-self.bin_size[0]*self.bin_size[1]-y)//10
            mask = self.get_possible_position(rotate=True,sub=(x,y))
        return np.reshape(mask, newshape=(-1,))

    def get_hmap(self):
        x_plain = np.ones(self.space.plain_size[:2], dtype=np.int32) * self.next_box[0]
        y_plain = np.ones(self.space.plain_size[:2], dtype=np.int32) * self.next_box[1]
        z_plain = np.ones(self.space.plain_size[:2], dtype=np.int32) * self.next_box[2]
        # self.volumes = self.volumes*0.9 + self.next_box[0]*self.next_box[1]*self.next_box[2]*0.1
        self.volumes = self.next_box[0]*self.next_box[1]*self.next_box[2]
        volumes_plain = np.ones(self.space.plain_size[:2], dtype=np.int32) * np.ceil(self.next_box[0]*self.next_box[1]*self.next_box[2]/1000)
        volumes_plain = volumes_plain.astype(np.int32)
        # self.yx = self.yx*0.9 + self.next_box[1]/self.next_box[0]*0.1
        # self.zx = self.zx*0.9 + self.next_box[2]/self.next_box[0] *0.1
        self.yx = self.next_box[1]/self.next_box[0]
        self.zx = self.next_box[2]/self.next_box[0]
        yx_plain = np.ones(self.space.plain_size[:2], dtype=np.int32) * np.ceil(self.yx *10)
        zx_plain = np.ones(self.space.plain_size[:2], dtype=np.int32) * np.ceil(self.zx *10)
        size = (x_plain, y_plain, z_plain, volumes_plain, yx_plain, zx_plain)#volumes_plain, yx_plain, zx_plain)
        hmap = self.space.plain
        return np.stack((hmap, *size))

    # @property
    # def cur_observation(self):
    #     # hmap = self.bin_size[2] - self.space.plain
    #     size = self.get_box_plain()
    #     mask1 = self.get_possible_position(rotate=False)
    #     if self.can_rotate:
    #         mask2 = self.get_possible_position(rotate=True)
    #         return np.reshape(np.stack((*size, mask1, mask2)), newshape=(-1,))
    #     else:
    #         return np.reshape(np.stack((*size, mask1)), newshape=(-1,))

    @property
    def next_box(self):
        return self.box_creator.preview(1)[0]

    def get_possible_position(self, rotate, plain=None, sub=None):
        if rotate:
            y = self.next_box[0]
            x = self.next_box[1]
        else:
            x = self.next_box[0]
            y = self.next_box[1]
        z = self.next_box[2]

        if plain is None:
            plain = self.space.plain
        
        width = self.bin_size[0]
        length = self.bin_size[1]
        # self.position_choices_set_cur.clear()
        
        if sub is None:
            action_mask = np.ones(shape=(width, length), dtype=np.int32)*(1e-16)
            for key_i in range(self.bin_size[0]):
                for key_j in range(self.bin_size[1]):
                    dist_list = []

                    # for pos in self.space.position_choices_conner.get(key_i*self.bin_size[1]+key_j):
                    #     i, j = pos.x, pos.y
                    #     dist = self.space.TryConvexHull(plain, x, y, i, j, z)
                    #     if dist > 0:
                    #         dist_list.append(dist)
                    #         if dist == 1:
                    #             break
                    # if len(dist_list)>0:
                    #     action_mask[key_i, key_j]=max(dist_list)*100  

                    # for pos in self.space.position_choices_conner.get(key_i*10+key_j):
                    #     i, j = pos.x, pos.y
                    #     dist = self.space.TryConvexHull(plain, x, y, i, j, z)
                    #     if dist > 0:
                    #         action_mask[key_i, key_j] = 100
                    #         break

                    for pos in self.space.position_choices_conner.get(key_i*10+key_j):
                        i, j = pos.x, pos.y
                        dist = self.space.TryConvexHull(plain, x, y, i, j, z)
                        if dist > 0:
                            action_mask[key_i, key_j] = 100
                            break
                    
                    if action_mask[key_i, key_j]<1:
                        dist = self.space.TryConvexHull(plain, x, y, key_i*10, key_j*10, z)
                        if dist>0:
                            action_mask[key_i, key_j] = dist

            if action_mask.sum() < 1:
                action_mask[:, :] = 100
        else:
            action_mask = np.ones(shape=(10, 10), dtype=np.int32)*(1e-16)
            key_i, key_j = sub
            for pos in self.space.position_choices_conner.get(key_i*10+key_j):
                i, j = pos.x, pos.y
                dist = self.space.TryConvexHull(plain, x, y, i, j, z)
                if dist > 0:
                    action_mask[i%10, j%10] = 100 #dist*100
            # dist = self.space.TryConvexHull(plain, x, y, key_i*10, key_j*10, z)
            # if(dist>0):
            #     action_mask[0, 0] = 100
            # for i in range(10):
            #     for j in range(10):
            #         dist = self.space.TryConvexHull(plain, x, y,key_i*10+i, key_j*10+j, z)
            #         if dist > 0:
            #             action_mask[i, j] = 100
            if action_mask.sum() < 1:
                action_mask[:, :] = 100
        return action_mask

    def step(self, action):
        h_max = np.max(self.space.plain)
        idx = action[0]
        # dist = 0
        flag = False
        # check whether rotate the box
        if idx >= self.area*100:
            assert self.can_rotate
            idx = idx - self.area*100
            flag = True
        # idx = self.adjust(idx)
        succeeded = self.space.drop_box(self.next_box, idx, flag, self.box_index)

        # print(action,self.next_box,'succeeded',succeeded)
        # for pos in self.space.position_choices_conner.travel():
        #     print(pos.x,pos.y,pos.z)
        
        if not succeeded:
            reward = 0.0
            done = True
            info = {'counter':len(self.space.boxes), 'ratio':self.space.get_ratio(), 'mask':np.ones(shape=self.act_len)}
            return self.space.layout_, reward, done, info
        box_ratio = self.get_box_ratio()
        box_z = self.next_box[2]
        self.box_creator.drop_box() # remove current box from the list
        # self.box_creator.generate_box_size() # add a new box to the list

        plain = self.space.plain
        reward = box_ratio *10
        # reward = box_ratio * 10
        done = False
        info = dict()
        info['counter'] = len(self.space.boxes)
        info['ratio'] = self.space.get_ratio()
        # info['mask'] = self.space.plain
        # info['mask'] = self.get_possible_position().reshape((-1,))
        return self.space.layout_, reward, done, info

    def _get_dis(self, mov):
        return int(np.linalg.norm(mov, ord=1))

    def _min_mov(self, point, targets, lx, ly):
        min_dis = 1000
        min_vec = None
        for target in targets:
            target = np.array(target, dtype=np.int32)
            cur_vec = target - point
            cur_dis = self._get_dis(cur_vec)
            if cur_dis < min_dis:
                plain = self.space.plain
                x = self.next_box[0]
                y = self.next_box[1]
                z = self.next_box[2]
                adj_lx = lx + cur_vec[0]
                adj_ly = ly + cur_vec[1]
                min_dis = cur_dis
                min_vec = cur_vec
        return min_vec, min_dis