
import os
import os
import glob
import json
import time
import logging
import warnings
import pandas as pd
from tqdm import tqdm
from datetime import datetime
import google.generativeai as genai
from google.ai.generativelanguage_v1beta.types import content


WORKING_DIRECTORY_BASE_NAME = f"base_gemini_fsc147_{datetime.now().strftime('%d.%m.%Y-%H:%M:%S')}"
os.mkdir(WORKING_DIRECTORY_BASE_NAME)


def upload_to_gemini(path, mime_type=None):
	"""Uploads the given file to Gemini.

	See https://ai.google.dev/gemini-api/docs/prompting_with_media
	"""
	file = genai.upload_file(path, mime_type=mime_type)
	print(f"Uploaded file '{file.display_name}' as: {file.uri}")
	return file


df = pd.read_csv("path_to_results.csv")
# Create the model
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
generation_config = {
	"temperature": 1,
	"top_p": 0.95,
	"top_k": 64,
	"max_output_tokens": 8192,
	"response_schema": content.Schema(
		type = content.Type.OBJECT,
		enum = [],
		required = ["count"],
		properties = {
			"count": content.Schema(
				type = content.Type.INTEGER,
			),
		},
	),
	"response_mime_type": "application/json",
}
model_name = "gemini-1.5-pro"
model = genai.GenerativeModel(
	model_name=model_name,
	generation_config=generation_config,
)


for i in range(3):
    logger = logging.getLogger(f'logger{i}')
    logger.setLevel(logging.DEBUG)

    formatter = logging.Formatter("%(asctime)s %(name)s %(msecs)d %(levelname)s fn:%(funcName)s -- %(message)s")
    # Adding file handler
    fh = logging.FileHandler(f"{WORKING_DIRECTORY_BASE_NAME}/log{i}.log", encoding="utf-8")
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    # Adding stream handler
    sh = logging.StreamHandler()
    sh.setLevel(logging.DEBUG)
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    for log_name, log_obj in logging.Logger.manager.loggerDict.items():
        if log_name != f"logger{i}":
            log_obj.disabled = True  # type: ignore
    warnings.filterwarnings("ignore")

    logger.info("model_name: %s", model_name)
    logger.info("generation_config %s", str(generation_config))
    out_df = list()
    for _, row in tqdm(df.iterrows(), total=len(df)):
        index = row['id']
        image_name = row['image']
        image_path = f"data/fsc-147/images_384_VarV2/{image_name}"

        prompt = f"How many {row['object_of_interest']} are visibile in the image?"
        logger.info("LLM input text prompt %s", prompt)
        logger.info("LLM input image prompt %s", image_path)
        while True:
            try:
                files = [
                    upload_to_gemini(image_path, mime_type="image/png"),
                ]

                chat_session = model.start_chat(history=[])
                response = json.loads(
                    chat_session.send_message(
                        [
                        files[0],
                        prompt
                        ]
                    ).text
                )
                break
            except Exception:
                logger.exception("An exception Occurred")
                time.sleep(5)
        logger.info(f"LLM output %s", str(response))

        total_count = response["count"]
        logger.info(f"total count for %s was %d", image_name, total_count)
        out_df.append((row['id'], row['image'], row['object_of_interest'], row['answer'], total_count))
        pd.DataFrame(out_df, columns=["id", "image", "object_of_interest", "answer", "llm_count"]).to_csv(f"{WORKING_DIRECTORY_BASE_NAME}/output{i}.csv", index=False)
