from __future__ import annotations

"""Contradictory task"""

import csv
import os
import random
from typing import Any, Dict, Iterable, List, Optional
from tqdm import tqdm

from .base import Task, register_task


@register_task("contradictory")
class ContradictoryTask(Task):
    def __init__(
        self,
        base_dir: str = "/path/to/dataset",
        out_root: str = "/path/to/output",
        order: str = "TAI",
    ) -> None:
        self.type = "contradictory"
        self.base_dir = base_dir
        self.quest_path = os.path.join(base_dir, f"reasoning_meta/reasoning_{self.type}_dataset.csv")
        self.asset_path = os.path.join(base_dir, f"assets/multimodal_datasets_{self.type}.csv")
        self.out_root = out_root
        self.order = order

    def _load_asset_dict(self, asset_csv: str) -> Dict[str, Dict[str, str]]:
        asset_dict: Dict[str, Dict[str, str]] = {}
        with open(asset_csv, newline="", encoding="utf-8") as f:
            for row in csv.DictReader(f):
                sg = row["subgraph_id"]
                asset_dict[sg] = {
                    "img_1": row["modality1_img"],
                    "wav_1": row["modality2_wav"],
                    "txt_1": row["modality3_txt"],
                }
        return asset_dict

    def _iter_rows(self) -> Iterable[Dict[str, Any]]:
        assets = self._load_asset_dict(self.asset_path)
        with open(self.quest_path, newline="", encoding="utf-8") as f:
            data = list(csv.DictReader(f))
        for row in data:
            sg_id = row["id"]
            asset = assets.get(sg_id)
            if asset is None:
                continue
            row.update(asset)
            yield row

    def run(self, model_runner: Any, *, max_samples: Optional[int] = None) -> Dict[str, Any]:
        model_name_for_path = getattr(model_runner, "name", "model").replace("/", "_")
        out_dir = os.path.join(self.out_root, model_name_for_path, self.type)
        os.makedirs(out_dir, exist_ok=True)

        model_runner.load_model()

        out_csv = os.path.join(out_dir, "contradictory_results.csv")
        rows_out: List[Dict[str, Any]] = []

        rows_iter = list(self._iter_rows())
        if max_samples is not None:
            rows_iter = rows_iter[:max_samples]
        for idx, row in enumerate(tqdm(rows_iter, desc="Running")):

            user_content: List[Dict[str, Any]] = []
            for ch in self.order:
                if ch == "T":
                    user_content.append({"type": "text", "text": row["txt_1"]})
                elif ch == "I":
                    user_content.append({"type": "image", "image": row["img_1"]})
                elif ch == "A":
                    user_content.append({"type": "audio", "audio": row["wav_1"]})

            random.shuffle(user_content)
            user_content.append({"type": "text", "text": f"\nRules are as follows: {row['rules']}\n"})
            user_content.append({"type": "text", "text": row["question_text"]})

            conversation = model_runner.build_conversation(user_content)
            pred = str(model_runner.run_model(conversation)).strip()

            rows_out.append(
                {
                    "id": row["id"],
                    "rules": row["rules"],
                    "question": row["questions"],
                    "options": row["options"],
                    "option_role_map": row.get("option_role_map", ""),
                    "model_answer": pred,
                }
            )

        if rows_out:
            with open(out_csv, "w", newline="", encoding="utf-8") as f:
                fieldnames = list(rows_out[0].keys())
                writer = csv.DictWriter(f, fieldnames=fieldnames)
                writer.writeheader()
                writer.writerows(rows_out)

        return {"status": "ok", "written_to": out_dir}
