import pathlib
import json
import yaml
from copy import deepcopy
import atexit

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

import alfworld.agents.environment as environment
import alfworld.agents.modules.generic as generic

from promptrl.envs.alfworldviz.alfworld_viz_env import AlfworldVizEnv

DATA_PATH = pathlib.Path(__file__).parent.parent.resolve() / 'data' / 'seq2seq_data'
CONFIG_PATH = pathlib.Path(__file__).parent.parent.resolve() / 'data' / 'alfworld_configs' / 'base_config.yaml'

def _load_tasks(task_type_ids, mode='train'):
    # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place
    datas = []
    for task_type_id in task_type_ids:
        with (DATA_PATH / f'tw_alfred_seq2seq_{mode}_task{task_type_id}_hc.json').open() as f:
            task_data = json.load(f)
            datas.extend(task_data['data'])
    return datas

def _get_alf_env_loader(env_type, controller_type, tokenizer, task_type_ids, batch_size, train_eval='train'):
    # load config
    with CONFIG_PATH.open() as r:
        config = yaml.safe_load(r)
    # env_type = config['env']['type'] # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid'
    config['env']['task_types'] = task_type_ids
    # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar'
    config['controller']['type'] = controller_type

    if env_type != 'AlfworldVizEnv':
        env_cls = getattr(environment, env_type)
    else:
        env_cls = AlfworldVizEnv

    # setup environment
    def _initializer():
        env = env_cls(config, train_eval=train_eval)
        env.init_env(batch_size=batch_size)
        atexit.register(env.close)
        return env
    return _initializer

class AlfworldFillDataset(Dataset):
    def __init__(self, obs_type, tasks, model, tokenizer, num_samples=-1, seed=42, max_obs_length=6, action_burn_in=5, mode='train', n_positions=1024, per_obs_tokens=None, data_mode=None, task_types=None, limit_frames=None):
        super().__init__()
        self.tasks = tasks
        self.tokenizer = tokenizer
        self.rng = np.random.default_rng(seed)
        self.max_obs_length = max_obs_length
        self.action_burn_in = action_burn_in
        self.n_positions = n_positions
        self.per_obs_tokens = per_obs_tokens
        self.data_mode = data_mode

        self.num_samples = num_samples
        raw_data = _load_tasks(tasks, mode)
        if num_samples > 0 and num_samples < len(raw_data):
            # assert num_samples <= len(raw_data)
            raw_data = self.rng.choice(raw_data, num_samples, replace=False)

        self.data = self._parse_data(raw_data)
        self.preprocess()

    def _parse_data(self, raw):
        data = []
        for raw_row in raw:
            row = {'id': raw_row['g'], 'task': raw_row['task'], 'steps': raw_row['steps']}
            row['sep'] = sep = ' [SEP] '
            row['goals'] = 'Your task is to: ' + row['task']
            row['obs'] = row['lang_obs'] = obs = raw_row['steps'][0]['obs']
            #row['all_text'] = raw_row['steps'][-1]['obs'] + sep + raw_row['steps'][-1]['action']
            #row['subtasks'] = row['all_text'][len(row['lang_obs'])+len(sep):]
            step_obs = []
            actions = []
            for i in range(1, len(row['steps'])):
                action = raw_row['steps'][i-1]['action']
                env_out = raw_row['steps'][i]['obs']
                _, obs, env_action = env_out.rsplit(sep, 2)
                if action != env_action:
                    assert obs == step_obs[-1]
                    assert env_action == actions[-1]
                else:
                    step_obs.append(obs)
                    actions.append(env_action)
            row['step_obs'] = step_obs
            row['actions'] = actions

            actions_formatted = [' [SEP] ' + a + ' [SEP]' for a in row['actions']]
            obs_formatted = [' ' + o for o in row['step_obs']]
            obs_formatted.insert(0, row['obs'])
            row['actions'] = [self.tokenizer(a, return_tensors='pt').input_ids.squeeze(0) for a in actions_formatted]
            row['obs'] = [self.tokenizer(o, return_tensors='pt').input_ids.squeeze(0) for o in obs_formatted]
            row['obs'].pop(-1)
            row['goal'] = self.tokenizer(row.pop('goals'), return_tensors='pt').input_ids.squeeze(0)
            data.append(row)

        return data

    def __len__(self):
        return len(self.data)

    def preprocess(self):
        new_data = []
        for idx in range(len(self.data)):
            row = dict(self.data[idx])
            row = self._preprocess_row(row)
            new_data.append(row)

        self.data = new_data

    def _preprocess_row(self, row):
        row['task'] = 0# used to indicate primary / auxilliary tasks
        row['log_task'] = 0# used for indicating which tasks to separate when logging
        row['task_name'] = 'forward'
        row['total_length'] = sum(self.per_obs_tokens(len(o), len(row['goal']), pos) for pos, o in enumerate(row['obs'])) + sum(len(a) for a in row['actions'])

        input_ids = torch.zeros((row['total_length'],), dtype=torch.int64)
        attention_mask = torch.ones((row['total_length'],))
        loss_mask = torch.zeros((row['total_length'],))
        output_ids = torch.zeros((row['total_length'],), dtype=torch.int64)
        fill_locs = []
        chunk_locs = []

        start_idx = 0
        end_idx = 0
        for i in range(len(row['actions'])):
            # make space for obs embeds
            start_idx = end_idx
            chunk_start = start_idx
            end_idx = start_idx + self.per_obs_tokens(len(row['obs'][i]), len(row['goal']), i)
            fill_locs.append((start_idx, end_idx))

            # fill in action embeds
            start_idx = end_idx
            end_idx = start_idx + row['actions'][i].shape[0]
            loss_mask[start_idx:end_idx] = 1.
            input_ids[start_idx:end_idx] = row['actions'][i]
            output_ids[start_idx:end_idx] = row['actions'][i]

            chunk_end = end_idx
            chunk_locs.append((chunk_start, chunk_end))
        assert end_idx == row['total_length']

        row['fill_arr'] = input_ids
        row['fill_locs'] = fill_locs
        row['attention_mask'] = attention_mask
        row['loss_mask'] = loss_mask
        row['output_ids'] = output_ids

        # Calculate all truncation windows for a sequence
        trunc_idxs = []
        trunc_masks = []
        trunc_fill_locs = []
        trunc_loss_idx = []
        chunk_queue = []
        for i, chunk in enumerate(chunk_locs):
            save_window = False
            while sum(c[1] - c[0] for _, c in chunk_queue) + chunk[1] - chunk[0] > self.n_positions:
                chunk_queue.pop(1)# always keep initial obs/action
                save_window = True
            if i == len(chunk_locs) - 1:
                save_window = True
            chunk_queue.append((i, chunk))
            if save_window:
                trunc_idxs.append([i for i, _ in chunk_queue])
                trunc_masks.append(torch.zeros((row['total_length'],), dtype=torch.bool))
                for _, (s, e) in chunk_queue:
                    trunc_masks[-1][s:e] = True

                # calculate trunc fill locs
                trunc_remove = chunk_queue[1][1][0] - chunk_queue[0][1][1]
                new_trunc_fill_locs = [row['fill_locs'][0]]
                for chunk_idx, _ in chunk_queue[1:]:
                    s, e = row['fill_locs'][chunk_idx]
                    new_trunc_fill_locs.append((s - trunc_remove, e - trunc_remove))
                trunc_fill_locs.append(new_trunc_fill_locs)

                # calculate loss index for burn in
                if chunk_queue[1][0] == 1:
                    trunc_loss_idx.append(0)
                else:
                    burn_in_idx = min(len(trunc_fill_locs[-1]) - 1, self.action_burn_in + 1)
                    trunc_loss_idx.append(trunc_fill_locs[-1][burn_in_idx][0])

        row['trunc_idxs'] = trunc_idxs
        row['trunc_masks'] = trunc_masks
        row['trunc_fill_locs'] = trunc_fill_locs
        row['trunc_loss_idx'] = trunc_loss_idx
        row['chunk_locs'] = chunk_locs
        return row

    def __getitem__(self, data_idx):
        row = dict(self.data[data_idx])
        if row.get('no_truncate', False):
            return row
        # calculate which truncation to select
        p = np.ones(len(row['trunc_idxs']))
        p[0] = len(row['trunc_idxs'][0])
        idx = self.rng.choice(len(row['trunc_idxs']), p=p/p.sum())

        row['obs'] = [row['obs'][i] for i in row['trunc_idxs'][idx]]
        row['actions'] = [row['actions'][i] for i in row['trunc_idxs'][idx]]

        row['fill_locs'] = row['trunc_fill_locs'][idx]

        row['total_length'] = int(row['trunc_masks'][idx].sum().item())
        row['fill_arr'] = row['fill_arr'][row['trunc_masks'][idx]]
        row['output_ids'] = row['output_ids'][row['trunc_masks'][idx]]
        row['attention_mask'] = row['attention_mask'][row['trunc_masks'][idx]]

        row['loss_mask'] = row['loss_mask'][row['trunc_masks'][idx]]
        if idx != 0:# case of burn in
            row['loss_mask'] = row['loss_mask'].clone()
            row['loss_mask'][:row['trunc_loss_idx'][idx]] = 0

        '''
        for start, end in row['fill_locs']:
            assert row['fill_arr'][start:end].sum() < 1e-3
        '''

        return row

    def get_collator(self):
        def _collator(batch):
            data = {
                'tasks': [row['task'] for row in batch],
                'log_tasks': [row['log_task'] for row in batch],
                'obs': [row['obs'] for row in batch],
                'goals': [row['goal'] for row in batch],
                'actions': [row['actions'] for row in batch],
                'sep': self.tokenizer(' [SEP]', return_tensors='pt').input_ids,
                'total_length': [row['total_length'] for row in batch],
                'fill_arr': pad_sequence([row['fill_arr'] for row in batch], batch_first=True),
                'fill_locs': [row['fill_locs'] for row in batch],
                'attention_mask': pad_sequence([row['attention_mask'] for row in batch], batch_first=True),
                'loss_mask': pad_sequence([row['loss_mask'] for row in batch], batch_first=True),
                'output_ids': pad_sequence([row['output_ids'] for row in batch], batch_first=True),
            }
            return data
        return _collator

if __name__ == '__main__':
    from transformers import GPT2TokenizerFast
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    #dataset = AlfworldDataset('lang', [1, 2], tokenizer)
    env_cls = _get_alf_env_loader('AlfredThorEnv', 'mrcnn_astar', tokenizer, [1], 1, train_eval='eval_in_distribution')
    env = env_cls()
    print(env.reset())
