from open_flamingo import create_model_and_transforms
import torch
from peft import LoraConfig,get_peft_model,prepare_model_for_int8_training

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="./.cache/clip/ViT-L-14.pt",
    lang_encoder_path="./.cache/huggingface/hub/mpt-1b-redpajama-200b/",
    tokenizer_path="./.cache/huggingface/hub/mpt-1b-redpajama-200b/",
    cross_attn_every_n_layers=1,
    #cache_dir="~/.cache"  # Defaults to ~/.cache
)

model.load_state_dict(torch.load("../OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt"), strict=False)

# Preprocess Dataset



# train

lora_config = LoraConfig (
    r=16,
    lora_alpha =16,
    lora_dropout=0.05,
    bias="none",
    target_modules=["to_q","to_kv","to_out","Wqkv"]
)

model = get_peft_model(model,lora_config)
model.print_trainable_parameters()


