import pathlib
import pickle

from tqdm import tqdm
import blosc2
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import clip
import tree
from PIL import Image

import promptrl.utils

IMG_DATA_SUFF = 'data_v4_ns.pkl'
DATA_FILE = pathlib.Path(__file__).parent.parent.resolve() / 'promptrl' / 'data' / 'virtualhome' / IMG_DATA_SUFF
BASE_MODEL = 'clip'
SUB_MODEL = "ViT-B/32"
#BASE_MODEL = 'resnet'
#SUB_MODEL = "50"
PATCHED = False
PATCH_DIM = (3, 3)
file_name = f'{BASE_MODEL}_{SUB_MODEL.replace("/", "")}_' + ('patched_' if PATCHED else '') + IMG_DATA_SUFF
WRITE_FILE = pathlib.Path(__file__).parent.parent.resolve() / 'promptrl' / 'data' / 'virtualhome' / file_name
#WRITE_DIR.mkdir(parents=True, exist_ok=True)
BATCH_SIZE = 64

print('Loading model...')
device = "cuda" if torch.cuda.is_available() else "cpu"
if BASE_MODEL == 'clip':
    model, preprocess = clip.load(SUB_MODEL, device=device)
    forward = model.encode_image
elif BASE_MODEL == 'resnet':
    weights = ResNet50_Weights.DEFAULT
    preprocess = weights.transforms()
    model = resnet50(weights=weights)
    model.fc = torch.nn.Identity()
    model.to(device)
    forward = model.forward
else:
    raise NotImplementedError

def _prep(im_arr):
    if PATCHED:
        im_arr = promptrl.utils.patchify_np(im_arr, PATCH_DIM)
    assert im_arr.shape[-1] == 3# channel at end
    im_arr = [Image.fromarray(im_arr[i]) for i in range(im_arr.shape[0])]
    return im_arr

print(f'Processing {DATA_FILE.name}...')
with DATA_FILE.open('rb') as f:
    task_data = pickle.load(f)

write_data = []
with torch.no_grad():
    for row in tqdm(task_data):
        new_row = dict(row)
        assert len(row['obs']) == len(row['actions']) + 1
        obs = [np.expand_dims(o, 0) for o in row['obs'][:-1]]
        obs_spl = tree.map_structure(_prep, obs)
        obs_spl = tree.map_structure(preprocess, obs_spl)
        with torch.no_grad():
            obs_embeds = promptrl.utils.map_structure_batched(forward, obs_spl, pre_func=lambda x: x.to(device), post_func=lambda x: x.cpu().numpy(), batch_size=BATCH_SIZE)
        obs_embeds = tree.map_structure_up_to(obs, np.stack, obs_embeds)# tree of (frames x obs_embed_dim) x LM hidden size
        new_row['obs'] = obs_embeds
        if BASE_MODEL == 'clip':
            new_row['goal_embed'] = model.encode_text(clip.tokenize(row['goal']).to(device)).squeeze(0).cpu().numpy()
        write_data.append(new_row)

with WRITE_FILE.open('wb') as f:
    pickle.dump(write_data, f)
print('Done.')
