import pickle
import pathlib
import re
from tqdm import tqdm

import numpy as np
import blosc2
import torch
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

from promptrl.envs.alfworld_viz import AlfworldVizDataset

DATA_DIR = pathlib.Path(__file__).parent.parent.parent.resolve() / 'data' / 'virtualhome'
PATHS = {
    'img': DATA_DIR / 'data_v4.pkl',
    'clip': DATA_DIR / 'clip_ViT-B32_data_v4.pkl',
    'imgdummy': DATA_DIR / 'clip_ViT-B32_data_v4.pkl',
    'unit': DATA_DIR / 'captions_data_v4.pkl',
}

def _load_tasks(mode='train', file_path=PATHS['img']):
    if mode == 'novel_tasks':
        file_path = file_path.parent / (file_path.stem + '_nt.pkl')
    elif mode == 'novel_scenes':
        file_path = file_path.parent / (file_path.stem + '_ns.pkl')
    with file_path.open('rb') as f:
        data = pickle.load(f)
    return data

class VirtualHomeDataset(AlfworldVizDataset):
    def __init__(self, obs_type, model, tokenizer, num_samples=float('inf'), seed=42, action_burn_in=5, mode='train', n_positions=1024, per_obs_tokens=lambda *_: 0, data_mode='img', task_types={'forward': (0, 0, float('inf'))}, limit_frames=None):
        self.model = model
        self.tokenizer = tokenizer
        self.rng = np.random.default_rng(seed)
        self.action_burn_in = action_burn_in
        self.n_positions = n_positions
        self.per_obs_tokens = per_obs_tokens
        self.data_mode = data_mode
        self.task_types = task_types
        self.limit_frames = limit_frames

        raw_data = _load_tasks(mode, file_path=PATHS[data_mode])
        raw_data = self._parse_data(raw_data)
        forward_data = self.preprocess(raw_data)
        self.data = []
        for task_type, (task_i, task_log_i, task_num) in task_types.items():
            if task_type == 'forward':
                self.data.extend(forward_data)
            elif task_type == 'caption':
                #caption_data = self.rng.choice(self._captions(forward_data, task_i, log_task_idx=task_log_i), task_num, replace=False)
                caption_data = self._captions(forward_data, task_i, log_task_idx=task_log_i, num_samples=task_num)
                self.data.extend(caption_data)
            elif task_type == 'invdyn':
                inv_dyn_data = self._invdyn(forward_data, task_i, log_task_idx=task_log_i, num_samples=task_num)
                self.data.extend(inv_dyn_data)

    def _parse_data(self, raw):
        data = []
        for raw_row in raw:
            # fix typos in dataset:
            if 'admissible' not in raw_row and 'admissible_actions' in raw_row:
                raw_row['admissible'] = raw_row.pop('admissible_actions')
            if 'visible_objects_raw' not in raw_row and 'visible_obects_raw' in raw_row:
                raw_row['visible_objects_raw'] = raw_row.pop('visible_obects_raw')

            row = dict(raw_row)
            assert isinstance(raw_row['goal'], list)
            assert len(raw_row['goal']) > 0
            if len(raw_row['goal']) == 1:
                goal = f'Your task is to: {raw_row["goal"][0]}.'
            else:
                goal = 'Your task is to: ' + ', '.join(raw_row['goal'][:-1]) + ', and ' + raw_row['goal'][-1] + '.'
            row['goal'] = self.tokenizer(goal, return_tensors='pt').input_ids.squeeze(0)
            parse_action = lambda a: re.sub(r' \(\d+\)', '', a)
            row['actions'] = [self.tokenizer(' [SEP] ' + parse_action(a) + ' [SEP]', return_tensors='pt').input_ids.squeeze(0) for a in raw_row['actions']]
            row['admissible'] = [[parse_action(a) for a in timestep] for timestep in raw_row['admissible']]
            oracle = []
            for t in range(len(row['actions'])):
                #visible_interactable = raw_row['interactable_objects'][t] & set(raw_row['visible_objects_raw'][t].values())
                visible_interactable = raw_row['interactable_objects'][t]
                visible_interactable.discard('character')
                objs_list = ['a ' + o for o in sorted(visible_interactable)]
                caption = 'In front of you, you see '
                if len(objs_list) == 0:
                    caption += 'nothing.'
                elif len(objs_list) == 1:
                    caption += objs_list[0] + '.'
                else:
                    caption += ', '.join(objs_list[:-1]) + ', and ' + objs_list[-1] + '.'
                oracle.append(caption)
            row['oracle'] = oracle
            data.append(row)
        return data

    def preprocess(self, data):
        new_data = []
        for idx in range(len(data)):
            row = dict(data[idx])
            if 'img' in PATHS[self.data_mode].stem:
                row['obs'] = [np.expand_dims(o, 0) for o in data[idx]['obs'][:-1]]
                row['obs_precomputed'] = False
                if self.limit_frames is not None:
                    row['obs'] = [o[-self.limit_frames:] for o in row['obs']]
            elif 'captions' in PATHS[self.data_mode].stem:
                row['obs'] = [self.tokenizer(o, return_tensors='pt').input_ids.squeeze(0) for o in data[idx]['obs']]
                row['obs'] = row['obs'][:-1]
                row['obs_precomputed'] = False
            else:
                row['obs'] = [torch.from_numpy(o).float() for o in data[idx]['obs']]
                row['obs_precomputed'] = True
            assert len(row['obs']) == len(row['actions'])
            if 'goal_embed' in row:
                row['goal_embed'] = torch.from_numpy(row['goal_embed']).float()
            row = self._preprocess_row(row)

            new_data.append(row)

        return new_data

    def _make_caption_row(self, obs, action, oracle, task_idx, log_task_idx=None):
        action = self.tokenizer.decode(action)
        action_type = action.split()[1]
        # same as alfworld but remove action type restrictions
        obs_text = oracle
        obs_proced = self.tokenizer(' [SEP] ' + obs_text + ' [SEP]', return_tensors='pt').input_ids.squeeze(0)
        goal = self.tokenizer('Your task is to: caption the following observation', return_tensors='pt').input_ids.squeeze(0)

        total_length = self.per_obs_tokens(len(obs), len(goal), 0) + len(obs_proced)

        input_ids = torch.zeros((total_length,), dtype=torch.int64)
        attention_mask = torch.ones((total_length,))
        loss_mask = torch.zeros((total_length,))
        output_ids = torch.zeros((total_length,), dtype=torch.int64)

        start_idx = self.per_obs_tokens(len(obs), len(goal), 0)
        fill_locs = [(0, start_idx)]
        input_ids[start_idx:] = obs_proced
        output_ids[start_idx:] = obs_proced
        loss_mask[start_idx:] = 1.

        new_row = {
            'task': task_idx,
            'log_task': log_task_idx,
            'task_name': 'caption',
            'obs': [obs],
            'goal': goal,
            'actions': [obs_proced],
            'total_length': total_length,
            'fill_arr': input_ids,
            'fill_locs': fill_locs,
            'attention_mask': attention_mask,
            'loss_mask': loss_mask,
            'output_ids': output_ids,
            'no_truncate': True
        }
        return new_row
