import os
import pdb
import re
import glob
import pickle
import random
import numpy as np

from torch.utils.data import Dataset

from utils_bc.utils_llm import get_pretrained_tokenizer
from utils_bc.utils_data_process import *
from utils_bc.utils_graph import state_one_hot, filter_redundant_nodes


class Dataloader(Dataset):

    def __init__(self, args, split):
        print('-----------------------------------------------------------------------------------')
        print('loading %s data from %s' % (split, args.data_dir))
        print('-----------------------------------------------------------------------------------')

        self.args = args
        self.data_info = args.data_info
        
        if split=='train':
            data = glob.glob(os.path.join(args.data_dir, 'train_10000', '*.p'))
        elif split=='val':
            data = glob.glob(os.path.join(args.data_dir, 'val_10000', '*.p'))
            
        self.data = data
        self.agent_id = 0
        self.length = len(data)
        print('%s: there are %d data' % (split, self.length))

        self.tokenizer = get_pretrained_tokenizer(model_type=args.model_type, model_name_or_path=args.model_name_or_path)
        

    def __len__(self):
        return self.length

    def __getitem__(self, idx):

        data = pickle.load(open(self.data[idx], 'rb'))
        
        partial_obs = data['partial_obs']
        actions = data['actions']
        env_task_goal = data['env_task_goal']
        init_unity_graph = data['init_unity_graph']

        ## sample a time step 
        input_data = []
        step_range = list(range(len(actions)))
        step_i = random.choice(step_range)
        input_data_tem = self._get_data(partial_obs, actions, env_task_goal, init_unity_graph, step_i, len(actions))
        input_data.append(input_data_tem)
            
        input_data_new = []
        for i in range(len(input_data[0])):
            input_data_new.append(np.array([input_data_tem[i] for input_data_tem in input_data]))
        return input_data_new
        


    def _get_data(self, partial_obs, actions, env_task_goal, init_unity_graph, step_i, len_actions):
        output_act = actions[step_i][self.agent_id]
        input_obs = partial_obs[step_i][self.agent_id]
        input_obs['nodes'] = filter_redundant_nodes(input_obs['nodes'])

        ## current observation
        input_obs_node, input_obs_node_mask, input_obs_node_state, input_obs_node_state_mask, input_obs_node_coords, input_obs_node_coords_mask = get_observation_input(self.args, self.data_info, input_obs, self.agent_id)
        
        ## get history actions
        history_action_index, history_action_index_mask = get_history_action_input(self.args, self.data_info, self.agent_id, actions, step_i, self.tokenizer)

        ## get goal
        goal_index, goal_index_mask = get_goal_input(self.args, self.data_info, self.agent_id, env_task_goal, init_unity_graph, self.tokenizer)

        ## action output
        output_action = get_action_output(self.data_info, input_obs, output_act)


        input_data = [input_obs_node, input_obs_node_mask, input_obs_node_state, input_obs_node_state_mask, input_obs_node_coords, input_obs_node_coords_mask, \
                        history_action_index, history_action_index_mask, goal_index, goal_index_mask, \
                        output_action, len_actions]

        return input_data



def get_observation_input(args, data_info, input_obs, agent_id):
    ## ----------------------------------------------------------------------------
    ## node name
    ## ----------------------------------------------------------------------------
    input_obs_node_gpt2_token = [data_info['vocabulary_node_class_name_word_index_dict_gpt2_padding'][node['class_name']] for node in input_obs['nodes']]
    input_obs_node_gpt2_token_mask = [data_info['vocabulary_node_class_name_word_index_dict_gpt2_padding_mask'][node['class_name']] for node in input_obs['nodes']]

    input_obs_node_gpt2_token = np.stack(input_obs_node_gpt2_token)
    input_obs_node_gpt2_token_mask = np.stack(input_obs_node_gpt2_token_mask)

    input_obs_node_gpt2_token_padding = np.zeros([data_info['max_node_length']-len(input_obs_node_gpt2_token), data_info['max_node_class_name_gpt2_length']]) + data_info['gpt2_eos_token']
    input_obs_node_gpt2_token_mask_padding = np.zeros([data_info['max_node_length']-len(input_obs_node_gpt2_token), data_info['max_node_class_name_gpt2_length']])
    
    input_obs_node_gpt2_token = np.concatenate((input_obs_node_gpt2_token, input_obs_node_gpt2_token_padding), axis=0)
    input_obs_node_gpt2_token_mask = np.concatenate((input_obs_node_gpt2_token_mask, input_obs_node_gpt2_token_mask_padding), axis=0)

    input_obs_node = input_obs_node_gpt2_token
    input_obs_node_mask = input_obs_node_gpt2_token_mask


    ## ----------------------------------------------------------------------------
    ## node state
    ## ----------------------------------------------------------------------------
    input_obs_node_state = np.zeros([data_info['max_node_length'], len(data_info['vocabulary_node_state_word_index_dict'])])
    input_obs_node_state_mask = np.zeros([data_info['max_node_length']])

    input_obs_node_state_tem = [state_one_hot(data_info['vocabulary_node_state_word_index_dict'], node['states']) for node in input_obs['nodes']]
    input_obs_node_state_tem = np.stack(input_obs_node_state_tem)
    input_obs_node_state[:len(input_obs_node_state_tem)] = input_obs_node_state_tem
    input_obs_node_state_mask[:len(input_obs_node_state_tem)] = 1


    ## ----------------------------------------------------------------------------
    ## node coordinate
    ## ----------------------------------------------------------------------------
    agent = [node for node in input_obs['nodes'] if node['id'] == agent_id+1] ## current agent
    assert len(agent)==1 and agent[0]['class_name']=='character'
    agent = agent[0]
    char_coord = np.array(agent['bounding_box']['center'])

    rel_coords = [np.array([0,0,0])[None, :] if 'bounding_box' not in node.keys() else (np.array(node['bounding_box']['center']) - char_coord)[None, :] for node in input_obs['nodes']]
    bounds = [np.array([0,0,0])[None, :] if 'bounding_box' not in node.keys() else np.array(node['bounding_box']['size'])[None, :] for node in input_obs['nodes']]
    rel_coords = np.concatenate([rel_coords, bounds], axis=2)

    input_obs_node_coords = np.zeros([data_info['max_node_length'], 6]) # 6: center, size
    input_obs_node_coords_mask = np.zeros([data_info['max_node_length']])
    input_obs_node_coords[:len(input_obs['nodes'])] = np.concatenate(rel_coords, 0)
    input_obs_node_coords_mask[:len(input_obs['nodes'])] = 1

    return input_obs_node, input_obs_node_mask, input_obs_node_state, input_obs_node_state_mask, input_obs_node_coords, input_obs_node_coords_mask
    


def get_history_action_input(args, data_info, agent_id, acts, step_i, tokenizer):
    
    previous_acts = acts[:step_i]
    
    if len(previous_acts)>0:
        goal_actions = [tem[agent_id] for tem in previous_acts if '[putback]' in tem[agent_id] or '[putin]' in tem[agent_id] or '[close]' in tem[agent_id] or '[switchon]' in tem[agent_id]]
        
        goal_actions_parsed = [parse_language_from_action_script(tem) for tem in goal_actions]
        history_actions = get_history_action_input_language(goal_actions_parsed)
        history_actions = history_actions[-(data_info['max_task_num']-1):]

    history_action_gpt2_token = np.zeros([data_info['max_task_num']-1, data_info['max_history_action_gpt2_length']]) + data_info['gpt2_eos_token']
    history_action_gpt2_token_mask = np.zeros([data_info['max_task_num']-1, data_info['max_history_action_gpt2_length']])

    if len(previous_acts)>0:
        if len(history_actions)>0:
            history_action_tem = [tem for tem in history_actions if tem not in data_info['history_action_gpt2_padding']]
            history_action_gpt2 = {tem: tokenizer(tem)['input_ids'] for tem in history_action_tem}

            for k,v in history_action_gpt2.items():
                index = np.zeros([data_info['max_history_action_gpt2_length']])+data_info['gpt2_eos_token']
                mask = np.zeros([data_info['max_history_action_gpt2_length']])
                index[:len(v)] = v
                mask[:len(v)] = 1
                data_info['history_action_gpt2_padding'][k] = index
                data_info['history_action_gpt2_padding_mask'][k] = mask
            
            history_action_gpt2_padding = [data_info['history_action_gpt2_padding'][tem] for tem in history_actions]
            history_action_gpt2_padding_mask = [data_info['history_action_gpt2_padding_mask'][tem] for tem in history_actions]
            

            history_action_gpt2_padding = np.stack(history_action_gpt2_padding)
            history_action_gpt2_padding_mask = np.stack(history_action_gpt2_padding_mask)
            
            history_action_gpt2_token[:len(history_action_gpt2_padding)] = history_action_gpt2_padding
            history_action_gpt2_token_mask[:len(history_action_gpt2_padding_mask)] = history_action_gpt2_padding_mask
    
    return history_action_gpt2_token, history_action_gpt2_token_mask
        

def get_goal_input(args, data_info, agent_id, env_task_goal, init_unity_graph, tokenizer):
    
    task_goal = env_task_goal[0][agent_id]
    task_goal_languages = get_goal_language(task_goal, init_unity_graph)
    
    task_goal_languages_tem = [task_goal_language for task_goal_language in task_goal_languages if task_goal_language not in data_info['subgoal_gpt2_padding']]

    task_goal_languages_gpt2 = {tem: tokenizer(tem)['input_ids'] for tem in task_goal_languages_tem}
    for k,v in task_goal_languages_gpt2.items():
        index = np.zeros([data_info['max_subgoal_gpt2_length']])+data_info['gpt2_eos_token']
        mask = np.zeros([data_info['max_subgoal_gpt2_length']])
        index[:len(v)] = v
        mask[:len(v)] = 1
        data_info['subgoal_gpt2_padding'][k] = index
        data_info['subgoal_gpt2_padding_mask'][k] = mask

    goal_gpt2_token = np.zeros([data_info['max_task_num'], data_info['max_subgoal_gpt2_length']]) + data_info['gpt2_eos_token']
    goal_gpt2_token_mask = np.zeros([data_info['max_task_num'], data_info['max_subgoal_gpt2_length']])

    subgoal_gpt2_padding = [data_info['subgoal_gpt2_padding'][task_goal_language] for task_goal_language in task_goal_languages]
    subgoal_gpt2_padding_mask = [data_info['subgoal_gpt2_padding_mask'][task_goal_language] for task_goal_language in task_goal_languages]

    subgoal_gpt2_padding = np.stack(subgoal_gpt2_padding)
    subgoal_gpt2_padding_mask = np.stack(subgoal_gpt2_padding_mask)

    goal_gpt2_token[:len(subgoal_gpt2_padding)] = subgoal_gpt2_padding
    goal_gpt2_token_mask[:len(subgoal_gpt2_padding_mask)] = subgoal_gpt2_padding_mask
    
    goal_index = goal_gpt2_token
    goal_index_mask = goal_gpt2_token_mask
        
    return goal_index, goal_index_mask



def get_action_output(data_info, input_obs, output_act):
    action_name = re.findall(r"\[([A-Za-z0-9_]+)\]", output_act)[-1]
    object_name = re.findall(r"\<([A-Za-z0-9_]+)\>", output_act)[-1]
    object_id = re.findall(r"\(([A-Za-z0-9_]+)\)", output_act)[-1]
    action_index = data_info['vocabulary_action_name_word_index_dict'][action_name]

    object_node_index = [tem_idx for tem_idx, node in enumerate(input_obs['nodes']) if node['id']==int(object_id)]
    assert len(object_node_index)==1
    object_node_index = object_node_index[0]
    output_action = np.array([action_index, object_node_index])
    
    return output_action



