import numpy as np
import os

base_map_features = ["pot_loc", "counter_loc", "onion_disp_loc", "tomato_disp_loc",
                    "dish_disp_loc", "serve_loc"]
variable_map_features = ["onions_in_pot", "tomatoes_in_pot", "onions_in_soup", "tomatoes_in_soup",
                    "soup_cook_time_remaining", "soup_done", "dishes", "onions", "tomatoes"]
urgency_features = ["urgency"]

all_features = base_map_features + variable_map_features + urgency_features
def feature_idx(name):
    return 10 + all_features.index(name)
shape = (5, 7, 26)

proj_direc = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
data_direc = os.path.join(proj_direc, 'Layout_DistantTomato_Datasets', 'Diatant_Tomato_deliver_soup_use_pot2', 'buffer')

# pot1, pot2: (2, 5), (2, 4)
pot_loc = [2, 4]
# pot1, pot2, pot3: (1, 0), (2, 0), (3, 0)

item_list = list(os.listdir(data_direc))
for item in item_list:
    item_direc = os.path.join(data_direc, item)

    for i in range(15):
        file_types = ['action.npz', 'done.npz', 'info.txt', 'next_state.npz', 'reward.npz', 'state.npz']
        action_ = np.load(os.path.join(item_direc, 'action.npz'))['arr_0']
        done_ = np.load(os.path.join(item_direc, 'done.npz'))['arr_0']
        state_ = np.load(os.path.join(item_direc, 'state.npz'))['arr_0']
        next_state_ = np.load(os.path.join(item_direc, 'next_state.npz'))['arr_0']
        reward_ = np.load(os.path.join(item_direc, 'reward.npz'))['arr_0']
        state_, next_state_ = state_.reshape(400, 2, *shape), next_state_.reshape(400, 2, *shape)

        ### 增广1
        cooking_id = np.where(state_[..., pot_loc[0], pot_loc[1], feature_idx('soup_cook_time_remaining')] > 0)[0]
        next_cooking_id = np.where(next_state_[..., pot_loc[0], pot_loc[1], feature_idx('soup_cook_time_remaining')] > 0)[0]
        cooked_id = np.where(state_[..., pot_loc[0], pot_loc[1], feature_idx('soup_done')] > 0)[0]
        tot_id = np.arange(state_.shape[0])
        no_cooking_id = np.setdiff1d(tot_id, cooking_id)
        no_cooking_id = np.setdiff1d(no_cooking_id, next_cooking_id)
        no_cooking_id = np.setdiff1d(no_cooking_id, cooked_id)

        no_handling_id = np.where(state_[:, 0, :, :, feature_idx('dishes')] - state_[:, 0, :, :, 0] < 0)[0]
        selected_id = np.intersect1d(no_handling_id, no_cooking_id)
        
        new_item = f"{item}_aug{i}"
        if i < 5:
            tomatoes_in_soup_layer = np.zeros(shape[:2])
            tomatoes_in_soup_layer[pot_loc[0], pot_loc[1]] = 255. * 3
            state_[selected_id, ..., feature_idx('tomatoes_in_soup')] = tomatoes_in_soup_layer
            next_state_[selected_id, ..., feature_idx('tomatoes_in_soup')] = tomatoes_in_soup_layer
            if i == 0:
                soup_done_layer = np.zeros(shape[:2])
                soup_done_layer[pot_loc[0], pot_loc[1]] = 255.
                state_[selected_id, ..., feature_idx('soup_done')] = soup_done_layer
                next_state_[selected_id, ..., feature_idx('soup_done')] = soup_done_layer
            else:
                soup_cook_time_layer = np.random.randint(1, 21, size=[len(selected_id), *shape[:2]])
                state_[selected_id, 0, ..., feature_idx('soup_cook_time_remaining')] = soup_cook_time_layer
                state_[selected_id, 1, ..., feature_idx('soup_cook_time_remaining')] = soup_cook_time_layer
                next_state_[selected_id, 0, ..., feature_idx('soup_cook_time_remaining')] = soup_cook_time_layer - 1
                next_state_[selected_id, 1, ..., feature_idx('soup_cook_time_remaining')] = soup_cook_time_layer - 1
        else:
            onions_in_soup_layer = np.zeros(shape[:2])
            onions_in_soup_layer[pot_loc[0], pot_loc[1]] = 255. * 3
            state_[selected_id, ..., feature_idx('onions_in_soup')] = onions_in_soup_layer
            next_state_[selected_id, ..., feature_idx('onions_in_soup')] = onions_in_soup_layer
            if i == 5:
                soup_done_layer = np.zeros(shape[:2])
                soup_done_layer[pot_loc[0], pot_loc[1]] = 255.
                state_[selected_id, ..., feature_idx('soup_done')] = soup_done_layer
                next_state_[selected_id, ..., feature_idx('soup_done')] = soup_done_layer
            else:
                soup_cook_time_layer = np.random.randint(1, 21, size=[len(selected_id), *shape[:2]])
                state_[selected_id, 0, ..., feature_idx('soup_cook_time_remaining')] = soup_cook_time_layer
                state_[selected_id, 1, ..., feature_idx('soup_cook_time_remaining')] = soup_cook_time_layer
                next_state_[selected_id, 0, ..., feature_idx('soup_cook_time_remaining')] = soup_cook_time_layer - 1
                next_state_[selected_id, 1, ..., feature_idx('soup_cook_time_remaining')] = soup_cook_time_layer - 1 

        new_item_direc = os.path.join(data_direc, new_item)
        os.makedirs(new_item_direc, exist_ok=True)
        state_, next_state_ = state_.reshape(400, 2, -1), next_state_.reshape(400, 2, -1)

        # select part of it to save
        action_, done_, state_, next_state_, reward_ = \
            action_[selected_id], done_[selected_id], state_[selected_id], next_state_[selected_id], reward_[selected_id]

        np.savez_compressed(os.path.join(new_item_direc, 'action.npz'), action_)
        np.savez_compressed(os.path.join(new_item_direc, 'done.npz'), done_)
        np.savez_compressed(os.path.join(new_item_direc, 'reward.npz'), reward_)
        np.savez_compressed(os.path.join(new_item_direc, 'state.npz'), state_)
        np.savez_compressed(os.path.join(new_item_direc, 'next_state.npz'), next_state_)
        
        with open(os.path.join(item_direc, 'info.txt'), 'r') as f:
            txt_str = f.read()
        txt_str = txt_str.replace(item, new_item)
        txt_str = txt_str.replace('400', str(len(selected_id)))
        with open(os.path.join(new_item_direc, 'info.txt'), 'w') as f:
            f.write(txt_str)