import base64
import os
import requests
import shutil

import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet
from PIL import Image
from sentence_transformers import SentenceTransformer
import torch
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast


# OpenAI API Key
api_key = 'YOUR_OPENAI_API_KEY'
headers = {
  "Content-Type": "application/json",
  "Authorization": f"Bearer {api_key}"
}


# Function to encode the image
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

# Function to retrieve tokens that are common across multiple tokenizers
def get_candidate_tokens(tokenizers):
    candidate_tokens_list = []
    for i, (tokenizer, tokenizer_type) in enumerate(tokenizers):
        candidate_tokens = set()
        for j in range(len(tokenizer)):
            token = tokenizer.convert_ids_to_tokens(j)
            if tokenizer_type == 'clip' and token[-4:] == '</w>':
                token = token[:-4]
            elif tokenizer_type == 't5' and token[0] == '▁':
                token = token[1:]

            # Check if the token is actually encoded as a single token
            if len(tokenizer.encode(token, add_special_tokens=False)) > 1:
                continue

            # Filter out tokens that are not words
            if (not token.isnumeric() and len(token) > 1) and wordnet.synsets(token):
                candidate_tokens.add(token)
        print("len(candidate_tokens for tokenizer {}): {}".format(i + 1, len(candidate_tokens)))
        candidate_tokens_list.append(candidate_tokens)

    return set.intersection(*candidate_tokens_list)


# Load SD3 tokenizers
tokenizer_one = CLIPTokenizer.from_pretrained(
    'stable-diffusion-3-medium-diffusers-b1148b4/',
    subfolder="tokenizer",
)
tokenizer_two = CLIPTokenizer.from_pretrained(
    'stable-diffusion-3-medium-diffusers-b1148b4/',
    subfolder="tokenizer_2",
)
tokenizer_three = T5TokenizerFast.from_pretrained(
    'stable-diffusion-3-medium-diffusers-b1148b4/',
    subfolder="tokenizer_3",
)

tokenizers = [(tokenizer_one, 'clip'), (tokenizer_two, 'clip'), (tokenizer_three, 't5')]

# Retrieve tokens that are common across SD3 tokenizers
qa_model = SentenceTransformer('multi-qa-mpnet-base-cos-v1', device='cuda:0')
candidate_tokens = list(get_candidate_tokens(tokenizers))
print("len(candidate_tokens): {}".format(len(candidate_tokens)))


# Construct the training prompt for each image in the dataset
image_dir = './proposed_dataset/images/'
filenames = sorted([filename for filename in os.listdir(image_dir) if ('.jpg' in filename or '.jpeg' in filename)])
for filename in filenames:
    # Path to your image
    image_path = os.path.join(image_dir, filename)

    # Getting the base64 string
    base64_image = encode_image(image_path)
    pil_image = Image.open(image_path)

    print("<{}>".format(filename))
    pil_image.resize((512, 512)).show()

    print("----------------- 1. Captioning ------------------")
    while True:
        message_text = 'Describe the image in one detailed sentence, including the phrase "{}." "<new1>" is a special token that already describes the {} in the image. Do *not* describe the {} in duplicate with "<new1>."'
        phrase = input("Write the phrase you want to include in a caption: ")
        meaning = input('What does "<new1>" describe?: ')
        message_text = message_text.format(phrase, meaning, meaning)
        print("Input text: {}".format(message_text))
        print()

        payload = {
          "model": "gpt-4o",
          "messages": [
            {
              "role": "user",
              "content": [
                {
                  "type": "text",
                  "text": message_text
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                  }
                }
              ]
            }
          ],
          "max_tokens": 300
        }
        responses = []
        for i in range(5):
            response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
            try:
                print("Output {}: {}".format(i + 1, response.json()['choices'][0]['message']['content']))
                responses.append(response.json())
            except:
                print(response.json())
                raise Exception()

        # Repeat until the user is satisfied with any of the responses
        is_satisfied = input("Are you satisfied with any of the captions? [y/n]:")
        if is_satisfied == 'y':
            idx = input("Which one do you prefer the most?: ")
            print()
            caption = responses[int(idx) - 1]['choices'][0]['message']['content'].lower()[:-1]
            payload['messages'].append(responses[int(idx) - 1]['choices'][0]['message'])
            del payload['max_tokens']
            break
        print()

    print("--------- 2. Selecting initializer token ---------")
    message_text = 'Infer the {} contained in "<new1>" in one detailed noun phrase. Do *not* mention any elements other than the {}.'.format(meaning, meaning)
    print("Input text: {}".format(message_text))
    print()

    payload['messages'].append({
        "role": "user",
        "content": message_text
    })
    while True:
        responses = []
        for i in range(5):
            response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
            try:
                print("Output {}: {}".format(i + 1, response.json()['choices'][0]['message']['content']))
                responses.append(response.json())
            except:
                print(response.json())
                raise Exception()

        # Repeat until the user is satisfied with any of the responses
        is_satisfied = input("Are you satisfied with any of the phrases? [y/n]:")
        if is_satisfied == 'y':
            idx = input("Which one do you prefer the most?: ")
            print()
            noun = responses[int(idx) - 1]['choices'][0]['message']['content'].lower()[:-1]
            payload['messages'].append(responses[int(idx) - 1]['choices'][0]['message'])
            break
        print()

    # Encode query and tokens
    query = "What is the word that has the meaning of {}?".format(noun)
    print("Query: {}".format(query))
    query_embedding = qa_model.encode(query)
    token_embeddings = qa_model.encode(candidate_tokens)

    # Calculate similarities between query and tokens, and select tokens with the highest similarities
    similarities = qa_model.similarity(query_embedding, token_embeddings)[0].tolist()
    pairs = list(zip(candidate_tokens, similarities))
    pairs = sorted(pairs, key=lambda x: x[1], reverse=True)
    print("Answer:")
    top_tokens = []
    for token, similarity in pairs[:10]:
        print("{} {:.3f}".format(token, similarity))
        top_tokens.append(token)
    print()

    # Input the selected tokens back into GPT-4o
    message_text = 'Considering your previous answer, choose the best token to replace "<new1>" from the following tokens. Output the token as is: "{}."'.format('", "'.join(top_tokens))
    print("Input text: {}".format(message_text))
    print()

    payload['messages'].append({
        "role": "user",
        "content": message_text
    })
    while True:
        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)

        try:
            print("Output: {}".format(response.json()['choices'][0]['message']['content']))
            initializer_token = response.json()['choices'][0]['message']['content']
        except:
            print(response.json())
            raise Exception()

        # Repeat until the user is satisfied with the response
        is_satisfied = input("Are you satisfied with the output? [y/n]:")
        print()
        if is_satisfied == 'y':
            break

    # Save inference prompts for validation to a txt file
    print("--------- 3. Creating an inference file ----------")
    initialized_caption = caption.replace('<new1>', ' '.join([initializer_token] * 4))

    inference_prompts = []
    print("Write inference prompts:")
    inference_prompts.append(input("{}\n".format(caption)))
    inference_prompts.append(input("{}\n".format(initialized_caption)) or initialized_caption)
    inference_prompts.append(input("{}\n".format(caption)) or caption)
    inference_prompts.append(input(' '.join([initializer_token] * 4)) or ' '.join([initializer_token] * 4))
    inference_prompts.append(input("<new1>") or '<new1>')
    print()
    print()

    save_path = os.path.join('proposed_dataset', filename.replace('.jpg', '.txt').replace('.jpeg', '.txt'))
    with open(save_path, 'w') as f:
        f.write("{}\n".format('\n'.join(inference_prompts)))
