from __future__ import annotations

"""Equivalent task"""

import csv
import os
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple

from .base import Task, register_task

from tqdm import tqdm


@register_task("equivalent")
class EquivalentTask(Task):
    def __init__(
        self,
        base_dir: str = "/path/to/dataset",
        out_root: str = "/path/to/output",
        combos: Optional[List[Tuple[bool, bool, bool]]] = None,
    ) -> None:
        self.type = "equivalent"
        self.base_dir = base_dir
        self.quest_path = os.path.join(base_dir, "reasoning_meta/reasoning_equivalent_dataset.csv")
        self.asset_path = os.path.join(base_dir, "assets/multimodal_datasets_equivalent.csv")
        self.out_root = out_root
        self.combos = (
            combos
            if combos is not None
            else [
                (True, False, False),
                (False, True, False),
                (False, False, True),
                (True, True, True),
            ]
        )

    # Data loading -------------------------------------------------------
    def _load_asset_dict(self, asset_csv: str) -> Dict[str, Dict[str, str]]:
        asset_dict: Dict[str, Dict[str, str]] = {}
        with open(asset_csv, newline="", encoding="utf-8") as f:
            for row in csv.DictReader(f):
                sg = row["subgraph_id"]
                asset_dict[sg] = {
                    "img": row["modality1_img"],
                    "wav": row["modality1_wav"],
                    "txt": row["modality1_txt"],
                }
        return asset_dict

    def _iter_rows(self) -> Iterable[Dict[str, Any]]:
        assets = self._load_asset_dict(self.asset_path)
        with open(self.quest_path, newline="", encoding="utf-8") as f:
            data = list(csv.DictReader(f))

        for row in data:
            sg_id = row["id"]
            asset = assets.get(sg_id)
            if asset is None:
                continue
            row["info_img"] = asset["img"]
            row["info_wav"] = asset["wav"]
            row["info_text"] = asset["txt"]
            yield row

    # Runner -------------------------------------------------------------
    def run(self, model_runner: Any, *, max_samples: Optional[int] = None) -> Dict[str, Any]:
        model_name_for_path = getattr(model_runner, "name", "model").replace("/", "_")
        out_dir = os.path.join(self.out_root, model_name_for_path, self.type)
        os.makedirs(out_dir, exist_ok=True)

        model_runner.load_model()

        for use_image, use_audio, use_text in self.combos:
            combo_label = "_".join(
                l for l, flag in zip(["Image", "Audio", "Text"], (use_image, use_audio, use_text)) if flag
            ) or "None"
            out_csv = os.path.join(out_dir, f"{combo_label}_result.csv")

            rows_out: List[Dict[str, Any]] = []

            rows_iter = list(self._iter_rows())
            if max_samples is not None:
                rows_iter = rows_iter[:max_samples]
            for idx, row in enumerate(tqdm(rows_iter, desc="Running")):
                user_content: List[Dict[str, Any]] = []
                if use_image:
                    user_content.append({"type": "image", "image": row["info_img"]})
                if use_audio:
                    user_content.append({"type": "audio", "audio": row["info_wav"]})
                if use_text:
                    user_content.append({"type": "text", "text": row["info_text"]})

                random.shuffle(user_content)
                rules = row["rules"]
                user_content.append({"type": "text", "text": f"\nRules are as follows: {rules}\n"})
                user_content.append({"type": "text", "text": row["question_text"]})

                conversation = model_runner.build_conversation(user_content)
                pred = str(model_runner.run_model(conversation)).strip()

                rows_out.append(
                    {
                        "id": row["id"],
                        "rules": row["rules"],
                        "question": row["questions"],
                        "options": row["options"],
                        "gt_answer": row.get("correct_answer", ""),
                        "model_answer": pred,
                    }
                )

            if rows_out:
                with open(out_csv, "w", newline="", encoding="utf-8") as f:
                    fieldnames = list(rows_out[0].keys())
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    writer.writeheader()
                    writer.writerows(rows_out)

        return {"status": "ok", "written_to": out_dir}
