import os

import joblib
import pandas as pd

import torch
from clip_interrogator import clip_interrogator
from tqdm import tqdm

from clip_experiments.extract_clip_embeds import root_dataset
from clip_interogator_captions.data_handler import get_dataloaders
from clip_interogator_captions.model import get_configs

class CFG:
    device = "cuda"
    seed = 42
    embedding_length = 384
    sentence_model_path = "./stable-diffusion-image-to-prompts/all-MiniLM-L6-v2"
    blip_model_path = "blip-large"
    ci_clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
    clip_model_name = "ViT-H-14"
    clip_model_path = "./clip_interogator_captions/clip-interrogator-models-X/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    cache_path = "./stable-diffusion-image-to-prompts/clip_interogator_captions/clip-interrogator-models-X"

model_config = get_configs(CFG)
ci = clip_interrogator.Interrogator(model_config)
cos = torch.nn.CosineSimilarity(dim=1)

mediums_features_array = torch.stack([torch.from_numpy(t) for t in ci.mediums.embeds]).to(ci.device)
movements_features_array = torch.stack([torch.from_numpy(t) for t in ci.movements.embeds]).to(ci.device)
flavors_features_array = torch.stack([torch.from_numpy(t) for t in ci.flavors.embeds]).to(ci.device)
df =  pd.read_csv(f"{root_dataset}/prompts.csv")
dataloader = get_dataloaders(df, ci.clip_preprocess, ci.caption_processor)

def interrogate(caption, image_features) -> str:

    medium = [ci.mediums.labels[i] for i in cos(image_features, mediums_features_array).topk(1).indices][0]
    movement = [ci.movements.labels[i] for i in cos(image_features, movements_features_array).topk(1).indices][0]
    flaves = ", ".join([ci.flavors.labels[i] for i in cos(image_features, flavors_features_array).topk(3).indices])

    if caption.startswith(medium):
        prompt = f"{caption}, {movement}, {flaves}"
    else:
        prompt = f"{caption}, {medium}, {movement}, {flaves}"

    return clip_interrogator._truncate_to_fit(prompt, ci.tokenize)

def generate_caption(clip_interogator, inputs) -> str:
        assert clip_interogator.caption_model is not None, "No caption model loaded."
        clip_interogator._prepare_caption()
        # inputs = clip_interogator.caption_processor(images=images, return_tensors="pt").to("cuda")
        if not clip_interogator.config.caption_model_name.startswith('git-'):
            inputs = inputs.to(clip_interogator.dtype)
        tokens = clip_interogator.caption_model.generate(pixel_values = inputs, max_new_tokens=clip_interogator.config.caption_max_length)
        return [text.strip() for text in clip_interogator.caption_processor.batch_decode(tokens, skip_special_tokens=True)]
def image_to_features(self, images) -> torch.Tensor:
        self._prepare_clip()
        # images = self.clip_preprocess(image).to(self.device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = self.clip_model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features


output_df = {"captions": [], "features": [], "filenames":[]}
for images_clip, images_blip, filename in tqdm(dataloader):
    images_clip = images_clip.float().to(CFG.device)
    images_blip = images_blip.to(CFG.device)
    captions = generate_caption(ci, images_blip)
    image_features = image_to_features(ci, images_clip)
    output_df['captions'].append(captions)
    output_df['features'].append(image_features)
    output_df["filenames"].append(filename)

final_output = {}
for i, _ in enumerate(output_df['captions']):
    for j,_ in enumerate(output_df['captions'][i]):
        prompt = interrogate(output_df['captions'][i][j], output_df['features'][i][j])
        final_output[output_df['filenames'][i][j]] = prompt

joblib.dump(final_output, f"{root_dataset}/captions_clip.joblib")
