from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any

from src.a4s.llm_client import LLMClient
from src.a4s.schemas import OrchestratorConfig, ScenarioDefinition, ExpertInput, ExpertOutput
from src.a4s.agents import ProblemRefinerAgent, DomainExpertAgent, ConflictResolverAgent, ReportGeneratorAgent, DebateCritiqueAgent


@dataclass
class AblationConfig:
    rounds: int
    roles: List[str]
    enable_debate: bool = True
    enable_conflict_resolution: bool = True
    use_shared_frame: bool = True
    use_refinement: bool = True


class AblationOrchestrator:
    def __init__(self, model_id: Optional[str] = None, temperature: float = 0.0) -> None:
        cfg = OrchestratorConfig()
        self.client = LLMClient(default_model=model_id or cfg.model_id)
        self.temperature = temperature
        self.refiner = ProblemRefinerAgent(self.client)
        self.expert = DomainExpertAgent
        self.resolver = ConflictResolverAgent(self.client)
        self.reporter = ReportGeneratorAgent(self.client)
        self.debate = DebateCritiqueAgent(self.client)

    def _default_scenario(self, proposition: str, roles: List[str]) -> ScenarioDefinition:
        # Minimal scenario if refinement is disabled
        return ScenarioDefinition(
            proposition=proposition,
            premises=["As stated in proposition."],
            constraints=["Modern technology unless specified."],
            timescales=["Short-term", "Medium-term", "Long-term"],
            uncertainties=["Not specified."],
            expert_plan=roles,
            refinement_raw=f"Proposition: {proposition}",
        )

    def run(self, proposition: str, out_dir: Path, cfg: AblationConfig) -> Tuple[str, Dict[str, Any]]:
        out_dir.mkdir(parents=True, exist_ok=True)
        logs_root = out_dir / "logs"
        logs_root.mkdir(parents=True, exist_ok=True)

        # 1) Refinement or default scenario
        if cfg.use_refinement:
            scenario = self.refiner.refine(proposition)
            (logs_root / "scenario.md").write_text(scenario.refinement_raw or "", encoding="utf-8")
        else:
            scenario = self._default_scenario(proposition, cfg.roles)

        # Override roles/rounds per ablation
        roles = cfg.roles or scenario.expert_plan
        total_rounds = max(cfg.rounds, 1)

        rounds_summaries: List[str] = []
        shared_summary: Optional[str] = None

        for round_index in range(1, total_rounds + 1):
            experts = [self.expert(role, self.client) for role in roles]
            ei = ExpertInput(scenario=scenario, shared_frame_summary=(shared_summary if cfg.use_shared_frame else None), round_index=round_index)
            outputs: List[ExpertOutput] = []
            for exp in experts:
                outputs.append(exp.run(ei))

            # Log raw outputs
            round_dir = logs_root / f"round_{round_index}"
            experts_dir = round_dir / "experts"
            experts_dir.mkdir(parents=True, exist_ok=True)
            for out in outputs:
                (experts_dir / f"{out.role}.md").write_text(out.raw_text or "", encoding="utf-8")

            # Optional debate brief
            debate_brief = ""
            if cfg.enable_debate:
                try:
                    debate_brief = self.debate.synthesize(proposition, scenario, round_index)
                    (round_dir / "debate_brief.md").write_text(debate_brief, encoding="utf-8")
                except Exception:
                    pass

            # Conflict resolution or simple aggregation
            if cfg.enable_conflict_resolution:
                conflict = self.resolver.reconcile(outputs)
                shared_summary = (
                    "\n".join([
                        "Consensus Points:", *[f"- {c}" for c in conflict.consensus_points],
                        "Conditional Branches:", *[f"- {k} -> {v}" for k, v in conflict.conditional_branches.items()],
                        "Remaining Uncertainties:", *[f"- {u}" for u in conflict.remaining_uncertainties],
                        "Notes:", *[f"- {n}" for n in conflict.notes],
                    ])
                )
                (round_dir / "conflict.md").write_text(conflict.raw_text or "", encoding="utf-8")
            else:
                # Simple aggregation as a "shared" frame without reconciliation
                lines: List[str] = ["Aggregated Expert Highlights:"]
                for out in outputs:
                    if out.conclusions:
                        lines.append(f"- {out.role}: {out.conclusions[0]}")
                shared_summary = "\n".join(lines)

            if debate_brief:
                shared_summary = (shared_summary or "") + "\n\nDebate Brief:\n" + debate_brief
            (round_dir / "shared_frame.md").write_text(shared_summary or "", encoding="utf-8")

            # Round summary (concise)
            parts: List[str] = [f"Round {round_index} Summary", "Shared Frame:", shared_summary or "(empty)"]
            round_summary = "\n".join(parts)
            (round_dir / "summary.md").write_text(round_summary, encoding="utf-8")
            rounds_summaries.append(round_summary)

        # 4) Final report
        report = self.reporter.generate(proposition, rounds_summaries)
        (out_dir / "report.md").write_text(report, encoding="utf-8")
        structured = {
            "proposition": proposition,
            "scenario": {
                "premises": scenario.premises,
                "constraints": scenario.constraints,
                "timescales": scenario.timescales,
                "uncertainties": scenario.uncertainties,
                "expert_plan": roles,
            },
            "rounds": rounds_summaries,
            "ablation": cfg.__dict__,
        }
        (out_dir / "structured.json").write_text(__import__("json").dumps(structured, ensure_ascii=False, indent=2), encoding="utf-8")
        return report, structured


def preset_two_experts_one_round() -> AblationConfig:
    return AblationConfig(rounds=1, roles=["Physics", "Economics"], enable_debate=True, enable_conflict_resolution=True, use_shared_frame=True, use_refinement=True)


def preset_no_debate(roles: List[str], rounds: int) -> AblationConfig:
    return AblationConfig(rounds=rounds, roles=roles, enable_debate=False, enable_conflict_resolution=True, use_shared_frame=True, use_refinement=True)


def preset_no_conflict(roles: List[str], rounds: int) -> AblationConfig:
    return AblationConfig(rounds=rounds, roles=roles, enable_debate=True, enable_conflict_resolution=False, use_shared_frame=True, use_refinement=True)


def preset_no_shared_frame(roles: List[str], rounds: int) -> AblationConfig:
    return AblationConfig(rounds=rounds, roles=roles, enable_debate=True, enable_conflict_resolution=True, use_shared_frame=False, use_refinement=True)


def preset_no_refinement(roles: List[str], rounds: int) -> AblationConfig:
    return AblationConfig(rounds=rounds, roles=roles, enable_debate=True, enable_conflict_resolution=True, use_shared_frame=True, use_refinement=False)



