import pathlib
import pickle

from tqdm import tqdm
import blosc2
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from accelerate import Accelerator
import clip
import tree
from PIL import Image

import promptrl.utils as utils
from scripts.load_args import VHomeCaptionArgs
from scripts.load import load_checkpoint_agent

IMG_DATA_SUFF = 'data_v4.pkl'
DATA_FILE = pathlib.Path(__file__).parent.parent.resolve() / 'promptrl' / 'data' / 'virtualhome' / IMG_DATA_SUFF
file_name = f'captions_' + IMG_DATA_SUFF
WRITE_FILE = pathlib.Path(__file__).parent.parent.resolve() / 'promptrl' / 'data' / 'virtualhome' / file_name
#WRITE_DIR.mkdir(parents=True, exist_ok=True)
BATCH_SIZE = 64

logger = utils.Logger(None, use_wandb=False)
accelerator = Accelerator()
args = VHomeCaptionArgs()

agent = load_checkpoint_agent(accelerator, logger, args)
agent.eval()

def _decode_action(action):
    gen_action_s = agent.tokenizer.decode(action[0], skip_special_tokens=True)
    gen_action_clean = gen_action_s.split('[', 2)[1]
    gen_action_clean = gen_action_clean.replace('SEP]', '').strip()
    return gen_action_clean

caption_goal = 'Your task is to: caption the following observation'
caption_goal_ids = agent.tokenizer(caption_goal, return_tensors='pt').input_ids.squeeze(0).to(accelerator.device)

print(f'Processing {DATA_FILE.name}...')
with DATA_FILE.open('rb') as f:
    task_data = pickle.load(f)

write_data = []
with torch.no_grad():
    for row in tqdm(task_data):
        new_row = dict(row)
        obs = row['obs']
        all_caps = []
        for o in obs:
            o = np.expand_dims(o, 0)
            with torch.no_grad():
                caption_ids = agent.direct_sample([o], [caption_goal_ids], {}, task_id=1, max_length=50)
            caption = _decode_action(caption_ids)
            all_caps.append(caption)

        new_row['obs'] = all_caps
        write_data.append(new_row)

with WRITE_FILE.open('wb') as f:
    pickle.dump(write_data, f)
print('Done.')
