import pathlib
import pickle

from tqdm import tqdm
import blosc2
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from accelerate import Accelerator
import clip
import tree
from PIL import Image

import promptrl.utils as utils
from scripts.load_args import CaptionArgs
from scripts.load import load_checkpoint_agent

IMG_DATA_SUFF = 'seq2seq_data_ood'
DATA_DIR = pathlib.Path(__file__).parent.parent.resolve() / 'promptrl' / 'data' / ('img_'+IMG_DATA_SUFF)
dir_name = f'captions_' + IMG_DATA_SUFF
WRITE_DIR = pathlib.Path(__file__).parent.parent.resolve() / 'promptrl' / 'data' / dir_name
WRITE_DIR.mkdir(parents=True, exist_ok=True)
BATCH_SIZE = 1

logger = utils.Logger(None, use_wandb=False)
accelerator = Accelerator()
args = CaptionArgs()

agent = load_checkpoint_agent(accelerator, logger, args)
agent.eval()

def _decode_action(action):
    gen_action_s = agent.tokenizer.decode(action[0], skip_special_tokens=True)
    gen_action_clean = gen_action_s.split('[', 2)[1]
    gen_action_clean = gen_action_clean.replace('SEP]', '').strip()
    return gen_action_clean

caption_goal = 'Your task is to: caption the following observation'
caption_goal_ids = agent.tokenizer(caption_goal, return_tensors='pt').input_ids.squeeze(0).to(accelerator.device)

for p in DATA_DIR.glob('*.pkl'):
    print(f'Processing {p.name}...')
    with p.open('rb') as f:
        task_data = pickle.load(f)

    write_data = []
    with torch.no_grad():
        for row in tqdm(task_data):
            new_row = dict(row)
            obs = row['obs']
            all_caps = []
            for o in obs:
                o = blosc2.unpack_array(o)
                with torch.no_grad():
                    caption_ids = agent.direct_sample([o], [caption_goal_ids], {}, task_id=1, max_length=50)
                caption = _decode_action(caption_ids)
                all_caps.append(caption)

            new_row['obs'] = all_caps
            write_data.append(new_row)

    with (WRITE_DIR / p.name).open('wb') as f:
        pickle.dump(write_data, f)
print('Done.')
