import argparse
import torch
import os
import json
import pandas as pd
from tqdm import tqdm
import shortuuid
import math
from PIL import Image
from io import BytesIO
import base64
import time
from abc import ABC, abstractmethod
import google.generativeai as genai
from utils.dataset_load import load_dataset


all_options = ['A', 'B', 'C', 'D', 'E']


def load_image_from_base64(image):
    return Image.open(BytesIO(base64.b64decode(image)))


def is_none(value):
    if value is None:
        return True
    if pd.isna(value):
        return True
    if type(value) is float and math.isnan(value):
        return True
    if type(value) is str and value.lower() == 'nan':
        return True
    if type(value) is str and value.lower() == 'none':
        return True
    return False


def get_options(row, options):
    parsed_options = []
    for option in options:
        try:
            option_value = row[option]
        except KeyError:
            break
        if is_none(option_value):
            break
        parsed_options.append(option_value)
    return parsed_options


def get_pil_image(raw_image_data) -> Image.Image:
    if isinstance(raw_image_data, Image.Image):
        return raw_image_data

    elif isinstance(raw_image_data, dict) and "bytes" in raw_image_data:
        return Image.open(io.BytesIO(raw_image_data["bytes"]))

    elif isinstance(raw_image_data, str):  # Assuming this is a base64 encoded string
        image_bytes = base64.b64decode(raw_image_data)
        return Image.open(io.BytesIO(image_bytes))

    else:
        raise ValueError("Unsupported image data format")


class BaseModel(ABC):
    def __init__(self, model_name: str, *, max_batch_size: int = 1):
        self.name = model_name
        self.max_batch_size = max_batch_size

    @abstractmethod
    def generate(self, **kwargs):
        pass

    @abstractmethod
    def eval_forward(self, **kwargs):
        pass


class GeminiProVision(BaseModel):
    def __init__(self, api_key: str, model_name="gemini-pro-vision"):
        super().__init__(model_name)
        self.api_key = api_key
        genai.configure(api_key=self.api_key)
        self.endpoint = "https://asia-northeast1-aiplatform.googleapis.com/v1/projects/UPD/locations/asia-northeast1/publishers/google/models/gemini-pro-vision:streamGenerateContent"
        self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
        self.safety_settings = [
            {
                "category": "HARM_CATEGORY_DANGEROUS",
                "threshold": "BLOCK_NONE",
            },
            {
                "category": "HARM_CATEGORY_HARASSMENT",
                "threshold": "BLOCK_NONE",
            },
            {
                "category": "HARM_CATEGORY_HATE_SPEECH",
                "threshold": "BLOCK_NONE",
            },
            {
                "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                "threshold": "BLOCK_NONE",
            },
            {
                "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                "threshold": "BLOCK_NONE",
            },
        ]

    def generate(self, text_prompt: str, raw_image_data):
        raw_image_data = get_pil_image(raw_image_data).convert("RGB")
        model = genai.GenerativeModel(model_name="gemini-pro-vision")
        retry = True
        retry_times = 0

        while retry and retry_times < 10:
            response = model.generate_content([text_prompt, raw_image_data], safety_settings=self.safety_settings, generation_config={"max_output_tokens": 256, "temperature":0.0})
            try:
                return response.text
            except ValueError:
                print(f"Failed: Retrying...")
                print(response._result)
                time.sleep(10)
                retry_times += 1
        return "I cannot answer."

    def eval_forward(self, **kwargs):
        return super().eval_forward(**kwargs)


def read_jsonl(file_path):
    with open(file_path, 'r') as json_file:
        return [json.loads(line) for line in json_file]


def eval_model(args):
    questions = load_dataset(dataset_name=args.data_name)
    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")
    api_key = args.gemini_api_key
    model = GeminiProVision(api_key)

    for row in tqdm(questions.iterrows(), total=len(questions)):
        row = row[1]
        options = get_options(row, all_options)
        cur_option_char = all_options[:len(options)]

        if args.all_rounds:
            num_rounds = len(options)
        else:
            num_rounds = 1

        for round_idx in range(num_rounds):
            idx = row['index']
            question = row['question']
            eval_type = row['type']
            hint = row['hint']
            image = load_image_from_base64(row['image'])
            if not is_none(hint):
                question = hint + '\n' + question
            for option_char, option in zip(all_options[:len(options)], options):
                question = question + '\n' + option_char + '. ' + option
            qs = cur_prompt = question

            if args.single_pred_prompt:
                if args.prompt_id == 0:
                    qs = qs + '\n'
                if args.prompt_id == 1:
                    qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
                elif args.prompt_id == 2:
                    qs = qs + '\n' + "If all the options are incorrect, answer \"F. None of the above\"."
                elif args.prompt_id == 3:
                    qs = qs + '\n' + "If the given image is irrelevant to the question, answer \"F. The image and question are irrelevant.\"."

            prompt = qs

            with torch.inference_mode():
                outputs = model.generate(prompt, image)

            outputs = outputs.strip()

            ans_id = shortuuid.uuid()
            ans_file.write(json.dumps({"question_id": idx,
                                       "eval_type": eval_type,
                                       "round_id": round_idx,
                                       "prompt": cur_prompt,
                                       "text": outputs,
                                       "options": options,
                                       "option_char": cur_option_char,
                                       "answer_id": ans_id,
                                       "model_id": "gemini-pro-vision",
                                       "prompt_detail": prompt,
                                       "metadata": {}}) + "\n")
            ans_file.flush()
    ans_file.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-name", type=str, default="mmaad_aad_base")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--gemini-api-key", type=str, default='')
    parser.add_argument("--all-rounds", action="store_true")
    parser.add_argument("--single-pred-prompt", action="store_true")
    parser.add_argument("--prompt_id", default=0, type=int)
    args = parser.parse_args()

    eval_model(args)
