import csv
from typing import List
from evaluation.dataset_parser import DatasetParser
import random
import os
import json

random.seed(89)


class RealDQAParser(DatasetParser):

    def load_data(self, sample_file: str):
        self.sample_file = sample_file
        self.sample_ids = self.sample_questions() if sample_file else None
        self.json_file = self.generate_json_filename(self.dataset_file)
        self.answer_order_dict = self.load_json_dictionary()
        self.data_iterator = self.create_iterator(self.dataset_file)

    def generate_json_filename(self, dataset_file: str) -> str:
        dataset_path, dataset_filename = os.path.split(dataset_file)
        json_filename = f"{dataset_filename}_answer_order.json"
        return os.path.join(dataset_path, json_filename)

    def load_json_dictionary(self):
        if os.path.exists(self.json_file):
            with open(self.json_file, 'r') as file:
                return json.load(file)
        else:
            return {}

    def save_json_dictionary(self):
        with open(self.json_file, 'w') as file:
            json.dump(self.answer_order_dict, file, indent=4)

    def create_iterator(self, dataset_file):
        base_keys = ["Q1a", "Q1b", "Q2a", "Q2b",
                     "Q3a", "Q3b", "Q4a", "Q4b", "Q5"]
        descriptive_keys = ["entity_KF_NC", "entity_KF_NR",
                            "relation_KF_NC", "relation_KF_NR",
                            "entity_KR_NC", "entity_KR_NR",
                            "relation_KR_NC", "relation_KR_NR", "whole_KR"]

        """components = ["entity_abstract", "relation_abstract",
                      "entity_know", "relation_know", "whole_know"]"""

        question_base = dict()

        with open(dataset_file, 'r', encoding='utf-8') as file:
            reader = csv.DictReader(file)
            question_row = True
            for row in reader:
                #diagram_id = row["global_ID"]
                diagram_id = row["image_url"].split(".")[0]
                if not self.sample_file or diagram_id in self.sample_ids:
                    for i, descr in enumerate(descriptive_keys):
                        #q_base_key = f"Q{i}_"
                        q_base_key = base_keys[i]

                        if question_row:
                            question_base[q_base_key] = row[f"{q_base_key}_{descr}"]

                        else:
                            question_variable = row[f"{q_base_key}_{descr}"]
                            question = question_variable if ("foodweb" in dataset_file and q_base_key != "Q2a" and q_base_key != "Q4a") else question_base[q_base_key].replace(
                                "[]", question_variable)
                            if not question:
                                continue

                            correct_answer_text = row[f"{q_base_key}_correct_anw"]
                            answer_choices = [
                                correct_answer_text,
                                row[f"{q_base_key}_fake_anw1"],
                                row[f"{q_base_key}_fake_anw2"],
                                row[f"{q_base_key}_fake_anw3"]
                            ]

                            q_id = f"{diagram_id}_{q_base_key}"
                            if q_id in self.answer_order_dict:
                                answer_choices = [answer_choices[j]
                                                  for j in self.answer_order_dict[q_id]]
                            else:
                                original_order = list(
                                    range(len(answer_choices)))
                                random.shuffle(original_order)
                                self.answer_order_dict[q_id] = original_order
                                answer_choices = [answer_choices[j]
                                                  for j in original_order]
                                self.save_json_dictionary()

                            image_path = os.path.join(
                                self.image_folder, row["image_url"])

                            yield {
                                "diagram_id": q_id,
                                "question": question,
                                "answer_choices": answer_choices,
                                "correct_answer_text": correct_answer_text,
                                "image_path": image_path,
                                "question_component": descr
                            }
                    question_row = False

    def next(self):
        # Return the next item from the iterator, or raise StopIteration if there are no more items
        return next(self.data_iterator)

    def sample_questions(self):
        samples = []
        with open(self.sample_file, 'r') as file:
            for line in file:
                samples.append(line.strip())
        return samples

    @staticmethod
    def format_qa(question: str, answer_choices: List[str], answer: str) -> str:
        question_with_answer_choices = "{}\n{}".format(question,
                                                       "\n".join("{}) {}".format(chr(65 + i), choice)
                                                                 for i, choice in enumerate(answer_choices)))

        answer_letter = chr(65 + answer_choices.index(answer))
        return question_with_answer_choices, answer_letter
