import os
import uuid
import json
import shutil
import pandas as pd
from typing import Any, List, Dict, Optional, Iterator


class DatasetManager:
    def __init__(self, dataset_dir: str):
        self.dataset_dir = dataset_dir
        self.images_dir = os.path.join(dataset_dir, "images")
        self.csv_path = os.path.join(dataset_dir, "dataset.csv")

        os.makedirs(self.images_dir, exist_ok=True)

        if not os.path.exists(self.csv_path):
            self._initialize_csv()

    def _initialize_csv(self):
        columns = ["ID", "SceneName", "Prompt", "Answers", "Correct", "ImagePath"]
        df = pd.DataFrame(columns=columns)
        df.to_csv(self.csv_path, index=False)

    def add_entry(
        self,
        scene_name: str,
        prompt: str,
        correct_answer: Any,
        temp_image_path: str,
        answer_variants: List[Any] = None,
        input_variables: Dict[str, Any] = {},
    ) -> str:
        instance_id = str(uuid.uuid4())

        image_filename = f"{instance_id}.png"
        image_path = os.path.join(self.images_dir, image_filename)
        shutil.copy(temp_image_path, image_path)

        if answer_variants is None:
            answer_variants = [correct_answer]

        answers_json = json.dumps(answer_variants)

        new_row = {
            "ID": instance_id,
            "SceneName": scene_name,
            "Prompt": prompt,
            "Answers": answers_json,
            "Correct": str(correct_answer),
            "ImagePath": f"images/{image_filename}",
            "InputVariables": json.dumps(input_variables),
        }

        try:
            df = pd.read_csv(self.csv_path)
        except (pd.errors.EmptyDataError, FileNotFoundError):
            df = pd.DataFrame(columns=new_row.keys())

        df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)

        df.to_csv(self.csv_path, index=False)

        return instance_id

    def get_dataset(self) -> pd.DataFrame:
        try:
            return pd.read_csv(self.csv_path)
        except (pd.errors.EmptyDataError, FileNotFoundError):
            return pd.DataFrame(
                columns=[
                    "ID",
                    "SceneName",
                    "Prompt",
                    "Answers",
                    "Correct",
                    "ImagePath",
                    "InputVariables",
                ]
            )

    def get_entry(self, instance_id: str) -> Optional[Dict[str, Any]]:
        df = self.get_dataset()
        entry = df[df["ID"] == instance_id]

        if entry.empty:
            return None

        result = entry.iloc[0].to_dict()

        if "Answers" in result:
            try:
                result["Answers"] = json.loads(result["Answers"])
            except json.JSONDecodeError:
                pass

        if "ImagePath" in result:
            result["ImagePath"] = os.path.join(self.dataset_dir, result["ImagePath"])

        return result

    def iterate_dataset(self, batch_size: int = 1) -> Iterator[List[Dict[str, Any]]]:
        df = self.get_dataset()

        for i in range(0, len(df), batch_size):
            batch = df.iloc[i : i + batch_size]

            result = []
            for _, row in batch.iterrows():
                entry = row.to_dict()

                if "Answers" in entry:
                    try:
                        entry["Answers"] = json.loads(entry["Answers"])
                    except json.JSONDecodeError:
                        pass

                if "ImagePath" in entry:
                    entry["ImagePath"] = os.path.join(
                        self.dataset_dir, entry["ImagePath"]
                    )

                if "InputVariables" in entry:
                    try:
                        entry["InputVariables"] = json.loads(entry["InputVariables"])
                    except json.JSONDecodeError:
                        pass

                result.append(entry)

            yield result

    def get_stats(self) -> Dict[str, Any]:
        df = self.get_dataset()

        return {
            "total_entries": len(df),
            "scenes": df["SceneName"].value_counts().to_dict(),
            "prompts": df["Prompt"].value_counts().to_dict(),
        }

    def filter_by_scene(self, scene_name: str) -> pd.DataFrame:
        df = self.get_dataset()
        return df[df["SceneName"] == scene_name]
