import os
import torch
import yaml
import argparse
import importlib
from torch.utils.data import DataLoader
from model_packs import get_model_and_processors


os.environ["TOKENIZERS_PARALLELISM"] = "true"


def get_collate_fn(processor, max_length):
    def collate_fn(batch):
        images = [item[0].convert("RGB") for item in batch]
        texts = [item[1] for item in batch]
        
        inputs = processor(
            images=images,
            text=texts,
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        return inputs
    return collate_fn


@torch.no_grad()
def main(dataset_name, vlm_key):
    device = "cuda"
    batch_size = 512 # Reduced slightly for stability, increase if VRAM allows
    
    # Load Model
    model_pack = get_model_and_processors(vlm_key)
    model = model_pack["model"].to(device)
    processor = model_pack["processor"]
    
    # Prep Paths
    save_dir = f'embeddings/{dataset_name}'
    os.makedirs(save_dir, exist_ok=True)

    # Dataset & Loader
    module = importlib.import_module(f'custom_datasets.proxy.{dataset_name}')
    dataset = module.streaming_dataset()
    
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        num_workers=4, 
        collate_fn=get_collate_fn(processor, model_pack["max_length"]),
        prefetch_factor=2,
    )

    img_embs, txt_embs = [], []
    print(f"Starting extraction for {vlm_key} on {dataset_name}...")

    model.eval()
    batch_idx = 0
    for inputs in dataloader:
        print(f"Batch index: {batch_idx}", flush=True)
        batch_idx += 1

        inputs = {
            k: v.to(device) for k, v in inputs.items()
        }

        img_features = model.get_image_features(
            pixel_values=inputs['pixel_values']
        ).pooler_output
        txt_features = model.get_text_features(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask']
        ).pooler_output

        # Normalize (Optional: CLIP usually expects normalized features)
        img_features /= img_features.norm(p=2, dim=-1, keepdim=True)
        txt_features /= txt_features.norm(p=2, dim=-1, keepdim=True)

        img_embs.append(img_features.cpu())
        txt_embs.append(txt_features.cpu())

    img_embs = torch.cat(img_embs)
    txt_embs = torch.cat(txt_embs)

    print(img_embs.shape, txt_embs.shape)

    torch.save(img_embs, f"{save_dir}/image.pth")
    torch.save(txt_embs, f"{save_dir}/text.pth")

    print("Extraction Complete.")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('dataset', type=str)
    args = parser.parse_args()
    
    with open('configs.yaml', 'r') as f:
        vlm_key = yaml.safe_load(f)['base_model']
    
    main(args.dataset, vlm_key)
