from clip_interrogator import Config, Interrogator
from PIL import Image
import json, os
import tqdm
from RKME_utils import StaSpec, MMD
import pickle
import torch
from dataset_utils import get_clip_dataloader
import numpy as np
import PIL
import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from tqdm import trange
import packaging
import numpy as np

def get_imgs(path_name):
    imagelist = []
    for parent, dirnames, filenames in os.walk(path_name):
        for filename in filenames:
            if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
                imagelist.append(os.path.join(parent, filename))
        return imagelist

def tokenize(texts, context_length: int = 77, truncate: bool = False):
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
    We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
    """
    if isinstance(texts, str):
        texts = [texts]
    _tokenizer = _Tokenizer()
    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    # print(texts)
    # print(all_tokens)
    # print(np.array(all_tokens).shape)
    if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
    else:
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

def build_vector(mode):
    assert mode in ["test", "train"]
    for i, model_name in enumerate(models):
        model_config = models[model_name]
        model_path = f"./Images/{model_name}"
        spec_path = os.path.join(model_path, mode, "CLIP-prompts.pkl")
        prom_path = os.path.join(model_path, mode, "prompts.json")
        # Read Prompts
        with open(prom_path, "r") as f:
            prompts = json.load(f)
        image_list, prompt_list = [], []
        for item in prompts:
            image_list.append(item["path"])
            prompt_list.append(item["prompt"])
        # Build Clip
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model, preprocess = clip.load("ViT-B/32", device=device)
        # Build Vectors
        image_vec, prompt_vec = [], []
        with torch.no_grad():
            message = trange(len(image_list))
            for i in message:
                image_vec.append(model.encode_image(preprocess(Image.open(image_list[i])).unsqueeze(0).to(device)).cpu().numpy()[0, ...])
                token = tokenize(prompt_list[i], truncate=True).to(device)
                embedding = model.encode_text(token).cpu().numpy()[0, ...]
                prompt_vec.append(embedding)
                message.set_description(model_name)
        # Save PKL
        with open(spec_path, "wb") as fw:
            pickle.dump({
                "images": image_vec,
                "prompts": prompt_vec
            }, fw)

with open("ModelPool.json", "r") as f:
    models = json.load(f)

    
config = Config(clip_model_name="ViT-L-14/openai")
ci_vitl = Interrogator(config)

# Generative pseudo-prompts
for i, model_name in tqdm.tqdm(enumerate(models)):
    model_config = models[model_name]
    images_path = f"./Images/{model_name}/train"
    image_list = get_imgs(images_path)
    res = []
    for file in image_list:
        image = Image.open(file).convert('RGB')
        prompt = ci_vitl.interrogate_fast(image)
        res.append({
            "path": file,
            "prompt": prompt
        })
    with open(os.path.join(images_path, "prompts.json"), "w") as fw:
        json.dump(res, fw)
        
for i, model_name in tqdm.tqdm(enumerate(models)):
    model_config = models[model_name]
    images_path = f"./Images/{model_name}/test"
    image_list = get_imgs(images_path)
    res = []
    for file in image_list:
        image = Image.open(file).convert('RGB')
        prompt = ci_vitl.interrogate_fast(image)
        res.append({
            "path": file,
            "prompt": prompt
        })
    with open(os.path.join(images_path, "prompts.json"), "w") as fw:
        json.dump(res, fw)

# Build vectors
build_vector("train")
build_vector("test")