from pathlib import Path

import torch
from transformers import GPT2TokenizerFast, GPT2Config

from promptrl.prompts import CLIPEmbedPrompt
from promptrl.clipcap.prompt import ClipCapPrompt
from promptrl.model import GPT2PromptInputLM

def load_pretrained(num_tasks, soft_prompt_dim, deep_prompt_dim, patched=False, checkpoint_path=Path(__file__).parent.parent.parent / 'checkpoints' / 'clipcap'):
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    model = GPT2PromptInputLM(GPT2Config())
    model.load_state_dict(torch.load(str(checkpoint_path / 'conceptual_weights_gpt.pt'), map_location='cpu'))
    model.config.pad_token_id = model.config.eos_token_id

    prompt = CLIPEmbedPrompt(num_tasks, model, soft_prompt_dim, deep_prompt_dim, obs_embed_dim=10, clip_arch='ViT-B/32', patched=patched)
    clip_project = ClipCapPrompt(10)
    clip_project.load_state_dict(torch.load(str(checkpoint_path / 'conceptual_weights_prompt.pt'), map_location='cpu'))
    prompt.obs_embed = clip_project
    return tokenizer, model, prompt
