import json
from typing import List, Dict, Any
from evaluation.icon_dataset_with_qa_parser import IconDQAParser
from evaluation.real_diagram_dataset_annotation_parser import RealDQAParser
from models.gpt_4 import GPT4Model
from models.gemini import GeminiModel
import os
import sys
from constants import DATASETS_INFO
import time

SEED = 1997  # 63


class DQAGenerate():

    def __init__(self, dataset_name: str, model_name: str, sample_path: str, log_path: str):
        if "gpt" in model_name:
            self.model = GPT4Model(seed=SEED, model_name=model_name)
        else:
            self.model = GeminiModel(model_name=model_name)
        self.dataset_name = dataset_name
        self.sample_path = sample_path
        self.dataset_info = self.get_dataset_info(dataset_name)

        self.dataset = self.get_dataset_parser(sample_path)
        self.log = []  # Log list to store question evaluations

        self.log_file_name = log_path

    def get_dataset_info(self, dataset_key):
        return DATASETS_INFO.get(dataset_key, None)

    def get_dataset_parser(self, sample_path):
        if "icon_dqa" in self.dataset_name:
            return IconDQAParser(self.dataset_info["question_file"], sample_path, self.dataset_info["image_folder"])
        elif "real_dqa" in self.dataset_name:
            return RealDQAParser(self.dataset_info["question_file"], sample_path, self.dataset_info["image_folder"])

    async def get_reply(self, image_type, use_cot, use_fewshot) -> str:
        system_message = self.dataset_info["system_message"]
        skip = True
        try:
            while True:
                # diagram_id, question, answer_choices, correct_answer, image_path = self.dataset.next()
                question_obj = self.dataset.next()
                """if int(question_obj["diagram_id"]) in list(range(88)):
                    continue"""
                if skip:
                    if question_obj["diagram_id"] == "fw_2047_Q2b":
                        skip = False
                    else:
                        continue
                # time.sleep(60/50)
                if use_cot:
                    question_obj["question"] = "Think step by step before answering the question and show your reasoning. " + \
                        question_obj["question"]
                if use_fewshot:
                    question_obj["question"] = "Following the in-context examples in the first image, answer the following question about the second image. " + \
                        question_obj["question"]
                full_question_text, answer_letter = self.dataset.format_qa(
                    question_obj["question"], question_obj["answer_choices"], question_obj["correct_answer_text"])
                if use_fewshot:
                    fewshot_image_path = self.dataset_info["fewshot_example_path"]
                    model_answer = await self.model.answer_question_fewshot(full_question_text, question_obj["image_path"], system_message, fewshot_image_path)
                elif image_type == "original_image":
                    model_answer = await self.model.answer_question(full_question_text, question_obj["image_path"], system_message)
                elif image_type == "no_image":
                    model_answer = await self.model.answer_question_without_image(full_question_text, system_message)
                self.log_answer(question_obj, answer_letter,
                                model_answer, system_message)
        except StopIteration:
            pass  # Dataset is exhausted

    def log_answer(self, question_obj, answer_letter, model_answer, system_message):
        log_element = {
            'q_id': question_obj["diagram_id"],
            'image_path': question_obj["image_path"],
            'question': question_obj["question"],
            'answer_choices': question_obj["answer_choices"],
            'correct_answer_letter': answer_letter,
            'correct_answer': question_obj["correct_answer_text"],
            'model_answer': model_answer,
            'system_message': system_message
        }

        if "icon_dqa" in self.dataset_name:
            # "many" in question
            question_type_bool = question_obj["question_type"] == "count"
            log_element['count_question'] = question_type_bool
        elif "real_dqa" in self.dataset_name:
            log_element['question_component'] = question_obj["question_component"]
            # log_element['question_type'] = question_obj["question_type"]
        self._log_and_save(self.log_file_name, log_element)

    def _log_and_save(self, logfile, log_entry: Dict[str, Any]) -> None:
        self.log.append(log_entry)
        # Save after each question to ensure data is not lost
        directory = os.path.dirname(logfile)
        os.makedirs(directory, exist_ok=True)

        with open(logfile, 'w') as f:
            json.dump(self.log, f, indent=4)


async def main(dataset_name, model_name, sample_file, log_path, image_type, use_cot, use_fewshot):

    answer_generator = DQAGenerate(dataset_name, model_name,
                                   sample_file, log_path)
    await answer_generator.get_reply(image_type, use_cot, use_fewshot)
