from pathlib import Path
import torch
import numpy as np
import PIL

from promptrl.clipcap.model import ClipCaptionModel
from promptrl.clipcap import load_pretrained, ClipCapPrompt
from promptrl.envs.alfworld_viz import AlfworldVizDataset

def _untangle(chkpt_path):
    # save clip project and language model separately
    print('Loading clipcap model..')
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    prefix_length = 10
    clipcap = ClipCaptionModel(prefix_length)

    clipcap.load_state_dict(torch.load(str(chkpt_path), map_location='cpu')) 

    model = GPT2PromptInputLM(clipcap.gpt.config)
    model.load_state_dict(clipcap.gpt.state_dict())

    prompt_model = ClipCapPrompt(prefix_length)
    prompt_model.clip_project.load_state_dict(clipcap.clip_project.state_dict())

    print('Saving..')
    torch.save(model.state_dict(), chkpt_path.parent / (chkpt_path.stem + '_gpt.pt'))
    torch.save(prompt_model.state_dict(), chkpt_path.parent / (chkpt_path.stem + '_prompt.pt'))

def _test(im_path, chkpt_path, beam_search=False):
    print('Loading clip..')
    device = 'cuda' if torch.cuda.is_available() else "cpu"
    clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

    print('Loading clipcap model..')
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    prefix_length = 10
    model = ClipCaptionModel(prefix_length)

    model.load_state_dict(torch.load(str(chkpt_path), map_location='cpu')) 

    model = model.eval() 
    model = model.to(device)

    print(f'Captioning {im_path}..')
    with PIL.Image.open(im_path) as im:
        im = preprocess(im).unsqueeze(0).to(device)
        with torch.no_grad():
            prefix = clip_model.encode_image(im).to(device, dtype=torch.float32)
            prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)

            if beam_search:
                generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
            else:
                generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)

    print(f'Result: {generated_text_prefix}')

def _test2(im_path):
    device = 'cuda' if torch.cuda.is_available() else "cpu"
    print('Loading..')
    tokenizer, model, prompt_model = load_pretrained(1, 0, 0, patched=False)
    model.to(device)
    prompt_model.to(device)
    model.eval()
    prompt_model.eval()
    print(f'Captioning {im_path}..')
    im = np.array(PIL.Image.open(im_path))
    im = np.expand_dims(im, 0)
    im = im[:, :, :1278]# even 3x3 patches
    prompt = prompt_model([0], [im], [1], [''])
    prompt['inputs_embeds'] = torch.stack(prompt['inputs_embeds'])
    prompt['attention_mask'] = torch.stack(prompt['attention_mask'])
    outputs = model.generate(**prompt, max_length=20).output_ids
    print(tokenizer.decode(outputs[0]))

    print(f'Captioning alfworld..')
    dataset = AlfworldVizDataset('img', [1], model, tokenizer, num_samples=100, data_mode='clip')
    obs_precomputed = [o.unsqueeze(0).to(device) for o in dataset[0]['obs'][0]]
    N = len(obs_precomputed)
    prompt2 = prompt_model([0] * N, None, [1] * N, [''] * N, obs_precomputed=obs_precomputed)
    prompt2['inputs_embeds'] = torch.stack(prompt2['inputs_embeds'])
    prompt2['attention_mask'] = torch.stack(prompt2['attention_mask'])
    outputs2 = model.generate(**prompt2, max_length=20).output_ids
    print(tokenizer.batch_decode(outputs2))

if __name__ == '__main__':
    model_path = Path(__file__).parent.parent.parent / 'checkpoints' / 'clipcap' / 'conceptual_weights.pt'
    im_path = Path.home() / 'cow.jpg'

    #_untangle(model_path)
    #_test(im_path, model_path, beam_search=False)
    _test2(im_path)
