
import os
import os
import glob
import json
import time
import logging
import warnings
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from dotenv import load_dotenv
import google.generativeai as genai
from google.ai.generativelanguage_v1beta.types import content

load_dotenv()

WORKING_DIRECTORY_BASE_NAME = f"ours_gemini_penguin_sepsites_{datetime.now().strftime('%d.%m.%Y-%H:%M:%S')}"
os.mkdir(WORKING_DIRECTORY_BASE_NAME)


def upload_to_gemini(path, mime_type=None):
	"""Uploads the given file to Gemini.

	See https://ai.google.dev/gemini-api/docs/prompting_with_media
	"""
	file = genai.upload_file(path, mime_type=mime_type)
	print(f"Uploaded file '{file.display_name}' as: {file.uri}")
	return file


df = pd.read_csv("path_to_results.csv")
# Create the model
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
generation_config = {
	"temperature": 1,
	"top_p": 0.95,
	"top_k": 64,
	"max_output_tokens": 8192,
	"response_schema": content.Schema(
		type = content.Type.OBJECT,
		enum = [],
		required = ["count"],
		properties = {
			"count": content.Schema(
				type = content.Type.INTEGER,
			),
		},
	),
	"response_mime_type": "application/json",
}

model_name = "gemini-1.5-pro"
model = genai.GenerativeModel(
	model_name=model_name,
	generation_config=generation_config,
)


for i in range(3):
    logger = logging.getLogger(f'logger{i}')
    logger.setLevel(logging.DEBUG)

    formatter = logging.Formatter("%(asctime)s %(name)s %(msecs)d %(levelname)s fn:%(funcName)s -- %(message)s")
    # Adding file handler
    fh = logging.FileHandler(f"{WORKING_DIRECTORY_BASE_NAME}/log{i}.log", encoding="utf-8")
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    # Adding stream handler
    sh = logging.StreamHandler()
    sh.setLevel(logging.DEBUG)
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    for log_name, log_obj in logging.Logger.manager.loggerDict.items():
        if log_name != f"logger{i}":
            log_obj.disabled = True  # type: ignore
    warnings.filterwarnings("ignore")

    logger.info("model_name: %s", model_name)
    logger.info("generation_config %s", str(generation_config))

    out_df = list()
    obect_of_interest = "penguin"
    for index, row in tqdm(df.iterrows(), total=len(df)):
        image_name = row['image_name']
        folder_name = f"{index}.{image_name}"

        total_count = 0
        for image_path in sorted(glob.glob(f"results_folder/{folder_name}/**/**_subimages/*.png", recursive=True)):
            prompt = f"How many {obect_of_interest}s are visibile in the image?"
            logger.info("LLM input text prompt %s", prompt)
            logger.info("LLM input image prompt %s", image_path)
            while True:
                try:
                    files = [
                        upload_to_gemini(image_path, mime_type="image/png"),
                    ]

                    chat_session = model.start_chat(history=[])
                    response = json.loads(
                        chat_session.send_message(
                            [
                                files[0],
                                prompt
                            ]
                        ).text
                    )
                    break
                except Exception:
                    logger.exception("An exception Occurred")
                    time.sleep(5)

            total_count += response['count']
            logger.info(f"LLM output %s", str(response))

        logger.info(f"total count for %s was %d", image_name, total_count)
        out_df.append((row['image_name'], row['number_vector'], row['max'], row['weight'], total_count))
        pd.DataFrame(out_df, columns=["image_name", "number_vector", "max", "weight", "llm_count"]).to_csv(f"{WORKING_DIRECTORY_BASE_NAME}/output{i}.csv", index=False)
