import numpy as np
import copy
import torch
import os
import pandas as pd
import random

class BoxCreator(object):
    def __init__(self):
        self.box_list = []

    def reset(self):
        self.box_list.clear()

    def generate_box_size(self, **kwargs):
        pass

    def preview(self, length):
        while len(self.box_list) < length:
            self.generate_box_size()
        return copy.deepcopy(self.box_list[:length])

    def drop_box(self):
        assert len(self.box_list) >= 0
        self.box_list.pop(0)


class RandomBoxCreator(BoxCreator):
    # default_box_set = []
    # for i in range(5):
    #     for j in range(5):
    #         for k in range(5):
    #             default_box_set.append((2+i, 2+j, 2+k))
    
    default_box_set = [
    [30,25,25],[32,25,39],[59,17,11],[52,25,15],[42,24,11],[40,28,32],
    [34,27,14],[39,27,15],[30,25,32],[59,20,13],[58,15,13],[20,25,33],
    [35,25,37],[59,24,18],[40,28,44],[58,20,15],[30,25,34]] #17

    rate_list = [0.05401756,0.0243079,0.1215395,0.05536799,0.10803511,0.03241053,
    0.08507765,0.08035111,0.04051317,0.06617151,0.09453072,0.03646185,
    0.0243079,0.0486158,0.02160702,0.06617151,0.042]


    def __init__(self, box_size_set=None):
        super().__init__()
        # self.box_set = box_size_set
        self.box_set = RandomBoxCreator.default_box_set
        self.rate_list = RandomBoxCreator.rate_list
        print(self.box_set)

    def generate_box_size(self, **kwargs):
        # idx = np.random.randint(0, len(self.box_set))
        # self.box_list.append(self.box_set[idx])
        box = self._rate_random(self.box_set,self.rate_list)
        self.box_list.append(box)
    
    def _rate_random(self,data_list,rate_list):
        start = 0
        random_num = np.random.rand()
        for idx,score in enumerate(rate_list):
            start += score
            if random_num <= start:
                break
        return data_list[idx]
        

class LoadBoxCreator(BoxCreator):
    def __init__(self, data_name=None):
        super().__init__()
        self.products = pd.read_pickle('./dataset/random_s1.pkl') #39500
        print("load data set successfully!")
        self.box_index = 0

        # random.seed(7)
        # self.start = [item for item in range(500)]
        # random.shuffle(self.start)
        # self.start_idx = -1

    def reset(self, index=None):
        self.box_list.clear()
        self.recorder = []
        if self.box_index >= 1e7:
            self.box_index = 0
        # self.start_idx += 1
        # self.box_index = self.start[self.start_idx]

    def generate_box_size(self, **kwargs):
        box = [self.products.iloc[self.box_index][0],self.products.iloc[self.box_index][1],self.products.iloc[self.box_index][2] ]
        self.box_list.append(box)
        self.recorder.append(box)
        self.box_index += 1
        if self.box_index >= 1e7:
            self.box_index = 0
        


    # def __init__(self, data_name=None):
    #     super().__init__()
    #     self.data_name = data_name
    #     print("load data set successfully!")
    #     self.index = 0
    #     self.box_index = 0
    #     self.traj_nums = len(torch.load(self.data_name))
    #     self.all_box_list = torch.load(self.data_name)

    # def reset(self, index=None):
    #     self.box_list.clear()
    #     box_trajs = torch.load(self.data_name)
    #     self.recorder = []
    #     if index is None:
    #         self.index += 1
    #     else:
    #         self.index = index
    #     self.boxes = box_trajs[self.index]
    #     for i in range(len(self.boxes)):
    #         self.boxes[i] = [b * 10 for b in self.boxes[i]]
    #     self.box_index = 0
    #     self.box_set = self.boxes
    #     # self.box_set.append([100, 100, 100])

    # def generate_box_size(self, **kwargs):
    #     if self.box_index < len(self.box_set):
    #         self.box_list.append(self.box_set[self.box_index])
    #         self.recorder.append(self.box_set[self.box_index])
    #         self.box_index += 1
    #     else:
    #         self.box_list.append((100, 100, 100))
    #         self.recorder.append((100, 100, 100))
    #         self.box_index += 1