import numpy as np
import os

base_map_features = ["pot_loc", "counter_loc", "onion_disp_loc", "dish_disp_loc", "serve_loc"]
variable_map_features = ["onions_in_pot", "onions_cook_time", "onion_soup_loc", "dishes", "onions"]
all_features = base_map_features + variable_map_features
def feature_idx(name):
    return 10 + all_features.index(name)
shape = (9, 5, 20)

proj_direc = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# data_direc = os.path.join(proj_direc, 'Layout_Unident_S_Datasets', 'Unident_S_deliver_soup_use_pot2', 'buffer')
data_direc = os.path.join(proj_direc, 'Layout_Unident_S_Datasets_plus', 'Unident_S_deliver_soup_use_pot2', 'buffer')

# pot1: (4,2); pot2: (4,3)

item_list = list(os.listdir(data_direc))
for item in item_list:
    item_direc = os.path.join(data_direc, item)

    for i in range(10):
        file_types = ['action.npz', 'done.npz', 'info.txt', 'next_state.npz', 'reward.npz', 'state.npz']
        action_ = np.load(os.path.join(item_direc, 'action.npz'))['arr_0']
        done_ = np.load(os.path.join(item_direc, 'done.npz'))['arr_0']
        state_ = np.load(os.path.join(item_direc, 'state.npz'))['arr_0']
        next_state_ = np.load(os.path.join(item_direc, 'next_state.npz'))['arr_0']
        reward_ = np.load(os.path.join(item_direc, 'reward.npz'))['arr_0']
        state_, next_state_ = state_.reshape(400, 2, *shape), next_state_.reshape(400, 2, *shape)

        ### 数据增广原则
        # 1/2数量的增广结果来做锅的内容，其中又1/2的增广做全锅，另外1/2的增广随便做锅
        # 1/4数量的增广结果随便搞点其他的内容

        ## 进行增广规则1
        cooking_id = np.where(state_[..., feature_idx('onions_cook_time')] > 0)[0]
        next_cooking_id = np.where(next_state_[..., feature_idx('onions_cook_time')] > 0)[0]
        tot_id = np.arange(state_.shape[0])
        no_cooking_id = np.setdiff1d(tot_id, cooking_id)
        no_cooking_id = np.setdiff1d(no_cooking_id, next_cooking_id)

        no_handling_id = np.where(state_[:, 0, :, :, feature_idx('dishes')] - state_[:, 0, :, :, 0] < 0)[0]
        selected_id = np.intersect1d(no_handling_id, no_cooking_id)

        new_item = f"{item}_aug{i}"
        onions_pot_layer, cook_time_layer = np.zeros(shape[:2]), np.zeros(shape[:2])
        onions_pot_layer[4, 3] = 255. * 3
        if i == 0:
            cook_time_layer[4, 3] = 255. * 20
        else:
            cook_time_layer[4, 3] = 255. * np.random.randint(0, 20)
        state_[selected_id, ..., feature_idx('onions_in_pot')] = onions_pot_layer
        next_state_[selected_id, ..., feature_idx('onions_in_pot')] = onions_pot_layer
        state_[selected_id, ..., feature_idx('onions_cook_time')] = cook_time_layer
        next_state_[selected_id, ..., feature_idx('onions_cook_time')] = cook_time_layer

        new_item_direc = os.path.join(data_direc, new_item)
        os.makedirs(new_item_direc, exist_ok=True)
        state_, next_state_ = state_.reshape(400, 2, -1), next_state_.reshape(400, 2, -1)

        # select part of it to save
        action_, done_, state_, next_state_, reward_ = \
            action_[selected_id], done_[selected_id], state_[selected_id], next_state_[selected_id], reward_[selected_id]

        np.savez_compressed(os.path.join(new_item_direc, 'action.npz'), action_)
        np.savez_compressed(os.path.join(new_item_direc, 'done.npz'), done_)
        np.savez_compressed(os.path.join(new_item_direc, 'reward.npz'), reward_)
        np.savez_compressed(os.path.join(new_item_direc, 'state.npz'), state_)
        np.savez_compressed(os.path.join(new_item_direc, 'next_state.npz'), next_state_)

        import json
        with open(os.path.join(item_direc, 'info.txt'), 'r') as f:
            info = json.load(f)
        info['capacity']  = info['idx'] = info['size'] = len(selected_id)
        with open(os.path.join(new_item_direc, 'info.txt'), 'w') as f:
            output = json.dumps(info, indent=4, sort_keys=True)
            f.write(output)