import pickle
import pathlib
import re
from tqdm import tqdm

import numpy as np
import blosc2
import torch
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

from promptrl.envs.alfworld_task import AlfworldFillDataset
from promptrl.envs.alfworldviz.action_lookup import ActionLookup

DATA_DIR = pathlib.Path(__file__).parent.parent.resolve() / 'data'
PATHS = {
    'img': DATA_DIR / 'img_seq2seq_data',
    'imgdummy': DATA_DIR / 'clip_ViT-B32_seq2seq_data',# placeholder, doesn't use
    'resnet': DATA_DIR / 'img_seq2seq_data',
    'resnetpt': DATA_DIR / 'img_seq2seq_data',
    'resnetfz': DATA_DIR / 'resnet_50_seq2seq_data',
    'clip': DATA_DIR / 'clip_ViT-B32_seq2seq_data',
    'clipgoal': DATA_DIR / 'clip_ViT-B32_seq2seq_data',
    'clippatch': DATA_DIR / 'clip_ViT-B32_patched_seq2seq_data',
    'clipcap': DATA_DIR / 'clip_ViT-B32_seq2seq_data',
    'unit': DATA_DIR / 'allcaps/captions_seq2seq_data',
}
CONFIG_PATH = pathlib.Path(__file__).parent.parent.resolve() / 'data' / 'alfworld_configs' / 'base_config.yaml'

def _load_tasks(task_type_ids, mode='train', parent_dir=PATHS['clip']):
    # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place
    datas = []
    if mode == 'eval_out_of_distribution':
        parent_dir = parent_dir.parent / (parent_dir.stem + '_ood')
    elif mode == 'eval_in_distribution':
        parent_dir = parent_dir.parent / (parent_dir.stem + '_id')
    for task_type_id in task_type_ids:
        with (parent_dir / f'tw_alfred_seq2seq_{mode}_task{task_type_id}_hc.pkl').open('rb') as f:
            task_data = pickle.load(f)
            datas.extend(task_data)
    return datas

class AlfworldVizDataset(AlfworldFillDataset):
    def __init__(self, obs_type, tasks, model, tokenizer, num_samples=float('inf'), seed=42, action_burn_in=5, mode='train', n_positions=1024, per_obs_tokens=lambda *_: 0, data_mode='img', task_types={'forward': (0, 0, float('inf'))}, limit_frames=None):
        assert obs_type.startswith('img')
        self.tasks = tasks
        self.model = model
        self.tokenizer = tokenizer
        self.rng = np.random.default_rng(seed)
        self.action_burn_in = action_burn_in
        self.n_positions = n_positions
        self.per_obs_tokens = per_obs_tokens
        self.data_mode = data_mode
        self.task_types = task_types
        self.limit_frames = limit_frames

        raw_data = _load_tasks(tasks, mode, parent_dir=PATHS[data_mode])
        self.num_samples = int(np.clip(num_samples, 0, len(raw_data)))
        raw_data = self.rng.choice(raw_data, self.num_samples, replace=False)

        raw_data = self._parse_data(raw_data)
        forward_data = self.preprocess(raw_data)
        self.data = []
        for task_type, (task_i, task_log_i, task_num) in task_types.items():
            if task_type == 'forward':
                self.data.extend(forward_data)
            elif task_type == 'caption':
                #caption_data = self.rng.choice(self._captions(forward_data, task_i, log_task_idx=task_log_i), task_num, replace=False)
                caption_data = self._captions(forward_data, task_i, log_task_idx=task_log_i, num_samples=task_num)
                self.data.extend(caption_data)
            elif task_type == 'invdyn':
                if mode == 'train':
                    inv_dyn_data = self._invdyn(forward_data, task_i, log_task_idx=task_log_i, num_samples=task_num)
                    self.data.extend(inv_dyn_data)
            elif task_type == 'goalp':
                if mode == 'train':
                    goalp_data = self._goalp(forward_data, task_i, log_task_idx=task_log_i, num_samples=task_num)
                    self.data.extend(goalp_data)
            elif task_type == 'admissible':
                adm_data = self._admissible(forward_data, task_i, log_task_idx=task_log_i, num_samples=task_num)
                self.data.extend(adm_data)
            else:
                raise NotImplementedError(f'Auxilliary task {task_type} not implemented.')

    def _parse_data(self, raw):
        data = []
        for raw_row in raw:
            row = dict(raw_row)
            if len(raw_row['actions']) < 2:
                continue
            row['goal'] = self.tokenizer(raw_row['goal'], return_tensors='pt').input_ids.squeeze(0)
            parse_action = lambda a: re.sub(r' \d+', '', a)
            row['actions'] = [self.tokenizer(' [SEP] ' + parse_action(a) + ' [SEP]', return_tensors='pt').input_ids.squeeze(0) for a in raw_row['actions']]
            data.append(row)
        return data

    def preprocess(self, data):
        new_data = []
        for idx in range(len(data)):
            row = dict(data[idx])
            if 'img' in PATHS[self.data_mode].stem:
                row['obs'] = [blosc2.unpack_array(o) for o in data[idx]['obs']]
                row['obs_precomputed'] = False
                if self.limit_frames is not None:
                    row['obs'] = [o[-self.limit_frames:] for o in row['obs']]
            elif 'captions' in PATHS[self.data_mode].stem:
                row['obs'] = [self.tokenizer(o, return_tensors='pt').input_ids.squeeze(0) for o in data[idx]['obs']]
                row['obs_precomputed'] = False
            else:
                row['obs'] = [torch.from_numpy(o).float() for o in data[idx]['obs']]
                row['obs_precomputed'] = True
            if 'goal_embed' in row:
                row['goal_embed'] = torch.from_numpy(row['goal_embed']).float()
            row = self._preprocess_row(row)

            new_data.append(row)

        return new_data

    def get_collator(self):
        def _collator(batch):
            if 'img' in PATHS[self.data_mode].stem or 'captions' in PATHS[self.data_mode].stem:
                obs = [row['obs'] for row in batch]
                obs_precomputed = None
            else:
                obs = None
                obs_precomputed = [row['obs'] for row in batch]
            #goals_precomputed = [row['goal_embed'] for row in batch] if ('goal_embed' in batch[0]) else None
            data = {
                'tasks': [row['task'] for row in batch],
                'log_tasks': [row['log_task'] for row in batch],
                'obs': obs,
                'obs_precomputed': obs_precomputed,
                'goals': [row['goal'] for row in batch],
                'goals_precomputed': None,
                'actions': [row['actions'] for row in batch],
                'sep': self.tokenizer(' [SEP]', return_tensors='pt').input_ids,
                'total_length': [row['total_length'] for row in batch],
                'fill_arr': pad_sequence([row['fill_arr'] for row in batch], batch_first=True),
                'fill_locs': [row['fill_locs'] for row in batch],
                'attention_mask': pad_sequence([row['attention_mask'] for row in batch], batch_first=True),
                'loss_mask': pad_sequence([row['loss_mask'] for row in batch], batch_first=True),
                'output_ids': pad_sequence([row['output_ids'] for row in batch], batch_first=True),
            }
            return data
        return _collator

    def _strip_obs(self, obs):
        ## remove ids
        obs = re.sub(r' \d+', '', obs)
        ## remove dupes
        uniq_chunks = []
        for chunk in obs.split(' a '):
            if chunk in uniq_chunks:
                continue
            uniq_chunks.append(chunk)
        obs = ' a '.join(uniq_chunks)
        return obs

    def _process_caption(self, task_idx, row, log_task_idx=None, sample=False):
        if log_task_idx is None:
            log_task_idx = task_idx
        new_rows = []
        if sample:
            idx = self.rng.choice(len(row['obs']) - 1)
            obs, action, oracle = row['obs'][idx+1], row['actions'][idx], row['oracle'][idx+1]
            new_row = self._make_caption_row(obs, action, oracle, task_idx, log_task_idx)
            if new_row is not None:
                new_rows.append(new_row)
        else:
            for obs, action, oracle in zip(row['obs'][1:], row['actions'], row['oracle'][1:]):
                new_row = self._make_caption_row(obs, action, oracle, task_idx, log_task_idx)
                if new_row is not None:
                    new_rows.append(new_row)
        return new_rows

    def _make_caption_row(self, obs, action, oracle, task_idx, log_task_idx=None):
        action = self.tokenizer.decode(action)
        action_type = action.split()[1]
        if action_type != 'go' and action_type != 'open':
            return None
        if oracle == 'Nothing happens.':
            return None
        obs_text = self._strip_obs(oracle)
        if obs_text.startswith('You arrive at loc.'):
            obs_text = obs_text[19:]# remove "You arrive at loc." text
        if obs_text.startswith('You open the'):
            obs_text = obs_text.split('. ', 1)[-1]
        obs_proced = self.tokenizer(' [SEP] ' + obs_text + ' [SEP]', return_tensors='pt').input_ids.squeeze(0)
        goal = self.tokenizer('Your task is to: caption the following observation', return_tensors='pt').input_ids.squeeze(0)

        total_length = self.per_obs_tokens(len(obs), len(goal), 0) + len(obs_proced)

        input_ids = torch.zeros((total_length,), dtype=torch.int64)
        attention_mask = torch.ones((total_length,))
        loss_mask = torch.zeros((total_length,))
        output_ids = torch.zeros((total_length,), dtype=torch.int64)

        start_idx = self.per_obs_tokens(len(obs), len(goal), 0)
        fill_locs = [(0, start_idx)]
        input_ids[start_idx:] = obs_proced
        output_ids[start_idx:] = obs_proced
        loss_mask[start_idx:] = 1.

        new_row = {
            'task': task_idx,
            'log_task': log_task_idx,
            'task_name': 'caption',
            'obs': [obs],
            'goal': goal,
            'actions': [obs_proced],
            'total_length': total_length,
            'fill_arr': input_ids,
            'fill_locs': fill_locs,
            'attention_mask': attention_mask,
            'loss_mask': loss_mask,
            'output_ids': output_ids,
            'no_truncate': True
        }
        return new_row

    def _captions(self, data, task_idx, log_task_idx=None, num_samples=None):
        new_rows = []
        if num_samples is None:
            for row in data:
                new_rows.extend(self._process_caption(task_idx, row, log_task_idx))
        else:
            while len(new_rows) < num_samples:
                row = self.rng.choice(data)
                new_rows.extend(self._process_caption(task_idx, row, log_task_idx, sample=True))
        return new_rows

    def _invdyn(self, data, task_idx, log_task_idx=None, num_samples=None):
        new_rows = []
        if num_samples is None:
            for row in data:
                new_rows.extend(self._process_invdyn(task_idx, row, log_task_idx))
        else:
            while len(new_rows) < num_samples:
                row = self.rng.choice(data)
                new_rows.extend(self._process_invdyn(task_idx, row, log_task_idx, sample=True))
        return new_rows

    def _process_invdyn(self, task_idx, row, log_task_idx=None, sample=False):
        if log_task_idx is None:
            log_task_idx = task_idx
        new_rows = []
        if sample:
            idx = 1 + self.rng.choice(len(row['obs']) - 2)
            #obs, action, oracle = row['obs'][idx+1], row['actions'][idx], row['oracle'][idx+1]
            obs, next_obs, action, oracle = row['obs'][idx], row['obs'][idx+1], row['actions'][idx], row['oracle'][idx+1]
            new_row = self._make_invdyn_row(obs, next_obs, action, oracle, task_idx, log_task_idx)
            if new_row is not None:
                new_rows.append(new_row)
        else:
            for obs, next_obs, action, oracle in zip(row['obs'][1:], row['obs'][2:], row['actions'][1:], row['oracle'][2:]):
                new_row = self._make_invdyn_row(obs, next_obs, action, oracle, task_idx, log_task_idx)
                if new_row is not None:
                    new_rows.append(new_row)
        return new_rows

    def _make_invdyn_row(self, obs, next_obs, action, oracle, task_idx, log_task_idx=None):
        action_str = self.tokenizer.decode(action)
        action_type = action_str.split()[1]
        if action_type in ['look', 'examine', 'inventory']:
            return None
        if oracle == 'Nothing happens.':
            return None
        #obs_text = self._strip_obs(oracle)
        goal = self.tokenizer('Your task is to: predict the action that ocurred.', return_tensors='pt').input_ids.squeeze(0)

        total_length = self.per_obs_tokens(len(obs), len(goal), 0) + self.per_obs_tokens(len(next_obs), len(goal), 1) + len(action)

        input_ids = torch.zeros((total_length,), dtype=torch.int64)
        attention_mask = torch.ones((total_length,))
        loss_mask = torch.zeros((total_length,))
        output_ids = torch.zeros((total_length,), dtype=torch.int64)

        end_first_idx = self.per_obs_tokens(len(obs), len(goal), 0)
        end_second_idx = end_first_idx + self.per_obs_tokens(len(next_obs), len(goal), 1)
        fill_locs = [(0, end_first_idx), (end_first_idx, end_second_idx)]
        input_ids[end_second_idx:] = action
        output_ids[end_second_idx:] = action
        loss_mask[end_second_idx:] = 1.

        new_row = {
            'task': task_idx,
            'log_task': log_task_idx,
            'task_name': 'invdyn',
            'obs': [obs, next_obs],
            'goal': goal,
            'actions': [None, action],
            'total_length': total_length,
            'fill_arr': input_ids,
            'fill_locs': fill_locs,
            'attention_mask': attention_mask,
            'loss_mask': loss_mask,
            'output_ids': output_ids,
            'no_truncate': True
        }
        return new_row

    def _goalp(self, data, task_idx, log_task_idx=None, num_samples=None):
        new_rows = []
        if num_samples is None:
            num_samples = len(data)
        for row in data:
            if len(new_rows) >= num_samples:
                break
            new_rows.append(self._make_goalp_row(row['obs'], row['goal'], task_idx, log_task_idx))
        return new_rows

    def _make_goalp_row(self, obs, goal, task_idx, log_task_idx=None):
        goal_str = self.tokenizer.decode(goal)
        goal_str = ' '.join(goal_str.split(' ')[4:])
        goal_str = 'Your task was to: ' + goal_str
        goal = self.tokenizer(' [SEP] ' + goal_str + ' [SEP]', return_tensors='pt').input_ids.squeeze(0)
        task_goal = self.tokenizer('Your task is to: predict the task goal.', return_tensors='pt').input_ids.squeeze(0)

        obs_len = 0
        for start_idx in range(len(obs) - 1, 0, -1):# do not include 0 index
            obs_len += len(obs[start_idx])
            next_obs_len = obs_len + len(obs[start_idx - 1])
            next_length = len(goal) + self.per_obs_tokens(next_obs_len, len(task_goal), 0)
            if next_length > self.n_positions:
                break
        total_length = len(goal) + self.per_obs_tokens(obs_len, len(task_goal), 0)

        input_ids = torch.zeros((total_length,), dtype=torch.int64)
        attention_mask = torch.ones((total_length,))
        loss_mask = torch.zeros((total_length,))
        output_ids = torch.zeros((total_length,), dtype=torch.int64)

        obs_cat = torch.cat(obs[start_idx:], dim=0)
        assert obs_cat.shape[0] == obs_len

        start_idx = self.per_obs_tokens(obs_len, len(task_goal), 0)
        fill_locs = [(0, start_idx)]
        input_ids[start_idx:] = goal
        output_ids[start_idx:] = goal
        loss_mask[start_idx:] = 1.

        new_row = {
            'task': task_idx,
            'log_task': log_task_idx,
            'task_name': 'goalp',
            'obs': [obs_cat],
            'goal': task_goal,
            'actions': [goal],
            'total_length': total_length,
            'fill_arr': input_ids,
            'fill_locs': fill_locs,
            'attention_mask': attention_mask,
            'loss_mask': loss_mask,
            'output_ids': output_ids,
            'no_truncate': True
        }
        return new_row

    def _admissible(self, data, task_idx, log_task_idx=None, num_samples=None):
        action_lookup = ActionLookup()
        new_rows = []
        if num_samples <= 0 or num_samples is None:
            num_samples = float('inf')
        for row in tqdm(data):
            if len(new_rows) >= num_samples:
                break
            for i in range(1, len(row['obs']) - 1):
                actions_s = self.tokenizer.batch_decode(row['actions'])
                actions_s = [s.strip().strip('[SEP]').strip() for s in actions_s]
                adm = list(set([re.sub(r' \d+', '', s) for s in row['admissible'][i]]))
                adm = [a for a in adm if 'nothing' not in a]# buggy admissible sometimes gives nothing
                new_rows.extend(self._make_admissible_rows(action_lookup, row['obs'][:i+1], adm, actions_s[i], actions_s[1:], task_idx, log_task_idx))
        return new_rows

    def _make_admissible_rows(self, action_lookup, obs, admissible, action, action_seq, task_idx, log_task_idx=None):
        task_goal = self.tokenizer('Your task is to: predict whether the following action is valid.', return_tensors='pt').input_ids.squeeze(0)

        valid_toks = self.tokenizer(' [SEP] valid [SEP]', return_tensors='pt').input_ids.squeeze(0)
        invalid_toks = self.tokenizer(' [SEP] invalid [SEP]', return_tensors='pt').input_ids.squeeze(0)
        assert len(valid_toks) == len(invalid_toks)

        obs_len = 0
        for start_idx in range(len(obs) - 1, 0, -1):# do not include 0 index
            obs_len += len(obs[start_idx])
            next_obs_len = obs_len + len(obs[start_idx - 1])
            next_length = len(valid_toks) + 20 + self.per_obs_tokens(next_obs_len, len(task_goal), 0)
            if next_length > self.n_positions:
                break

        negative_pool = action_lookup.negative_samples(self.rng, admissible, num_negative=len(admissible), candidates=action_seq)
        pool = admissible + negative_pool
        rows = []
        for candidate in pool:
            if self.rng.random() > 0.1:# trim training size (too large)
                continue
            positive = candidate in admissible
            class_toks = valid_toks if positive else invalid_toks
            candidate_ids = self.tokenizer(candidate, return_tensors='pt').input_ids.squeeze(0)

            total_length = len(valid_toks) + len(candidate_ids) + self.per_obs_tokens(obs_len, len(task_goal), 0)
            input_ids = torch.zeros((total_length,), dtype=torch.int64)
            attention_mask = torch.ones((total_length,))
            loss_mask = torch.zeros((total_length,))
            output_ids = torch.zeros((total_length,), dtype=torch.int64)

            obs_cat = torch.cat(obs[start_idx:], dim=0)
            assert obs_cat.shape[0] == obs_len

            end_obs_idx = self.per_obs_tokens(obs_len, len(task_goal), 0)
            fill_locs = [(0, end_obs_idx)]
            end_cand_idx = end_obs_idx + len(candidate_ids)
            input_ids[end_obs_idx:end_cand_idx] = candidate_ids
            output_ids[end_obs_idx:end_cand_idx] = candidate_ids

            input_ids[end_cand_idx:] = class_toks
            output_ids[end_cand_idx:] = class_toks
            loss_mask[end_cand_idx:] = 1.

            new_row = {
                'task': task_idx,
                'log_task': log_task_idx,
                'task_name': 'admissibility',
                'obs': [obs_cat],
                'goal': task_goal,
                'actions': [class_toks],
                'total_length': total_length,
                'fill_arr': input_ids,
                'fill_locs': fill_locs,
                'attention_mask': attention_mask,
                'loss_mask': loss_mask,
                'output_ids': output_ids,
                'no_truncate': True
            }
            rows.append(new_row)
        return rows
