import pathlib
import pickle

from tqdm import tqdm
import blosc2
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import clip
import tree
from PIL import Image

import promptrl.utils

IMG_DATA_SUFF = 'seq2seq_data_ood'
DATA_DIR = pathlib.Path(__file__).parent.parent.resolve() / 'promptrl' / 'data' / ('img_'+IMG_DATA_SUFF)
#BASE_MODEL = 'clip'
#SUB_MODEL = "ViT-B/32"
BASE_MODEL = 'resnet'
SUB_MODEL = "50"
PATCHED = False
PATCH_DIM = (3, 3)
dir_name = f'{BASE_MODEL}_{SUB_MODEL.replace("/", "")}_' + ('patched_' if PATCHED else '') + IMG_DATA_SUFF
WRITE_DIR = pathlib.Path(__file__).parent.parent.resolve() / 'promptrl' / 'data' / dir_name
WRITE_DIR.mkdir(parents=True, exist_ok=True)
BATCH_SIZE = 64

print('Loading model...')
device = "cuda" if torch.cuda.is_available() else "cpu"
if BASE_MODEL == 'clip':
    model, preprocess = clip.load(SUB_MODEL, device=device)
    forward = model.encode_image
elif BASE_MODEL == 'resnet':
    weights = ResNet50_Weights.DEFAULT
    preprocess = weights.transforms()
    model = resnet50(weights=weights)
    model.fc = torch.nn.Identity()
    model.to(device)
    forward = model.forward
else:
    raise NotImplementedError

def _prep(im_arr):
    if PATCHED:
        im_arr = promptrl.utils.patchify_np(im_arr, PATCH_DIM)
    assert im_arr.shape[-1] == 3# channel at end
    im_arr = [Image.fromarray(im_arr[i]) for i in range(im_arr.shape[0])]
    return im_arr

for p in DATA_DIR.glob('*.pkl'):
    print(f'Processing {p.name}...')
    with p.open('rb') as f:
        task_data = pickle.load(f)

    write_data = []
    with torch.no_grad():
        for row in tqdm(task_data):
            new_row = dict(row)
            obs = row['obs']
            obs = [blosc2.unpack_array(o) for o in obs]
            obs_spl = tree.map_structure(_prep, obs)
            obs_spl = tree.map_structure(preprocess, obs_spl)
            with torch.no_grad():
                obs_embeds = promptrl.utils.map_structure_batched(forward, obs_spl, pre_func=lambda x: x.to(device), post_func=lambda x: x.cpu().numpy(), batch_size=BATCH_SIZE)
            obs_embeds = tree.map_structure_up_to(obs, np.stack, obs_embeds)# tree of (frames x obs_embed_dim) x LM hidden size
            new_row['obs'] = obs_embeds
            if BASE_MODEL == 'clip':
                new_row['goal_embed'] = model.encode_text(clip.tokenize(row['goal']).to(device)).squeeze(0).cpu().numpy()
            write_data.append(new_row)

    with (WRITE_DIR / p.name).open('wb') as f:
        pickle.dump(write_data, f)
print('Done.')
