import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import numpy as np
from tqdm import tqdm
from torch import Tensor as T
import json
import pandas as pd


from utils.stable_diffusion import (
    load_text_components,
    compute_text_embedding,
)

device = "cuda"


def get_batches(lst: list, batch_size: int) -> list:
    return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]


@torch.no_grad()
def get_norms() -> T:
    tokenizer, text_encoder = load_text_components("v1-4")
    text_encoder.to(device)
    text_encoder.eval()
    data = json.loads(
        open("/home/datasets/coco2014/annotations/captions_train2014.json").read()
    )
    norms = torch.cat(
        [
            compute_text_embedding(prompts, tokenizer, text_encoder)
            .cpu()
            .reshape(-1, 77 * 768)
            .norm(dim=-1)
            for prompts in tqdm(
                get_batches(
                    pd.DataFrame(data["annotations"]).caption.values.tolist(),
                    8192,
                ),
                desc="Encoding retain set",
            )
        ],
        dim=0,
    )

    np.savez("results/coco_l2_norms.npz", norms=norms)


if __name__ == "__main__":
    get_norms()
