#!/usr/bin/env python3
"""
Create CI Control Slice from FO Dataset

Creates a diagnostic dataset where:
- YES worlds are taken from FO instances as-is (no resampling)
- NO worlds are added post-hoc using trap logic

This allows comparing CI v1 vs "CI-from-FO" to measure selection bias.

Usage:
    python -m concept_synth.analysis.make_ci_control_from_fo \
        --fo_dataset results/ad_benchmark/ad_benchmark_v1.yaml \
        --n 50 \
        --no_worlds 2 \
        --seed 0 \
        --out data/ci_control_from_fo_v1.yaml
"""

import argparse
import copy
import os
import random
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

from concept_synth.fol.model import FiniteModel
from concept_synth.formula_filters import formula_matches_target
from concept_synth.io_utils import load_from_yaml, save_to_yaml
from concept_synth.sexpr_parser import parse_sexpr_formula


def build_model_from_world(world_dict: Dict[str, Any]) -> Tuple[FiniteModel, Dict[str, int]]:
    """Build a FiniteModel from world dict."""
    domain = world_dict.get("domain", [])
    n = len(domain)
    model = FiniteModel(n)
    const_to_idx = {c: i for i, c in enumerate(domain)}

    predicates = world_dict.get("predicates", {})

    # Set unary predicates
    for pred in ["P", "Q"]:
        if pred in predicates:
            ext = predicates[pred]
            if isinstance(ext, list):
                for elem in ext:
                    if elem in const_to_idx:
                        model.set_unary(pred, const_to_idx[elem], True)

    # Set binary predicates
    for pred in ["R", "S"]:
        if pred in predicates:
            ext = predicates[pred]
            if isinstance(ext, list):
                for pair in ext:
                    if isinstance(pair, str) and pair.startswith("("):
                        pair = pair.strip("()").replace(" ", "")
                        parts = pair.split(",")
                        if len(parts) == 2:
                            a, b = parts
                            if a in const_to_idx and b in const_to_idx:
                                model.set_binary(pred, const_to_idx[a], const_to_idx[b], True)

    return model, const_to_idx


def get_target_set(world_dict: Dict[str, Any]) -> Set[str]:
    """Get target extension from world dict."""
    target_ext = world_dict.get("targetExtension", {})
    return set(target_ext.get("T_true", []))


def generate_random_world(domain_size: int, seed: int, edge_density: float = 0.3) -> Dict[str, Any]:
    """Generate a random world."""
    rng = random.Random(seed)

    domain = [f"a{i}" for i in range(domain_size)]

    # Random unary predicates
    P = [d for d in domain if rng.random() < 0.5]
    Q = [d for d in domain if rng.random() < 0.5]

    # Random binary predicates
    R = []
    S = []
    for a in domain:
        for b in domain:
            if rng.random() < edge_density:
                R.append(f"({a}, {b})")
            if rng.random() < edge_density:
                S.append(f"({a}, {b})")

    # Random target (will be computed based on formula)
    return {
        "domain": domain,
        "domainSize": domain_size,
        "predicates": {
            "P": P,
            "Q": Q,
            "R": R,
            "S": S,
        },
    }


def compute_target_for_world(world_dict: Dict[str, Any], formula) -> Set[str]:
    """Compute target extension for a world given a formula."""
    model, const_to_idx = build_model_from_world(world_dict)
    domain = world_dict["domain"]

    target = set()
    for elem in domain:
        try:
            if formula.evaluate(model, {formula.free_var: const_to_idx[elem]}):
                target.add(elem)
        except Exception:
            pass

    return target


def find_no_world_for_instance(
    gold_formula,
    yes_worlds: List[Dict[str, Any]],
    trap_formulas: List,
    max_attempts: int = 500,
    seed_base: int = 0,
) -> Optional[Dict[str, Any]]:
    """
    Find a NO world where:
    - Gold formula does NOT match the target
    - At least one trap formula DOES match the target

    This creates a world that eliminates shortcut hypotheses.
    """
    rng = random.Random(seed_base)

    for attempt in range(max_attempts):
        # Generate random world
        domain_size = rng.randint(6, 10)
        world = generate_random_world(domain_size, seed_base + attempt)

        # Compute target under gold formula
        target = compute_target_for_world(world, gold_formula)

        if len(target) == 0 or len(target) == domain_size:
            # Trivial target, skip
            continue

        world["targetExtension"] = {
            "T_true": list(target),
            "T_false": [d for d in world["domain"] if d not in target],
        }

        model, const_to_idx = build_model_from_world(world)

        # Check if gold matches (it should, by construction)
        gold_matches = formula_matches_target(gold_formula, model, target, const_to_idx)

        if not gold_matches:
            # This is a valid NO world for gold
            # Check if any trap matches
            trap_matches = False
            for trap in trap_formulas:
                try:
                    if formula_matches_target(trap, model, target, const_to_idx):
                        trap_matches = True
                        break
                except Exception:
                    pass

            if trap_matches:
                # Found a good NO world
                world["worldId"] = f"no_{attempt}"
                world["splitLabel"] = "NO"
                world["observationMode"] = "full"
                return world

    return None


def generate_simple_traps(gold_formula_str: str) -> List:
    """
    Generate simple trap formulas as alternatives to gold.

    These are simplified versions that might match the data accidentally.
    """
    traps = []

    # Simple atomic traps
    trap_strs = [
        "(P x)",
        "(Q x)",
        "(not (P x))",
        "(not (Q x))",
        "(and (P x) (Q x))",
        "(or (P x) (Q x))",
        "(exists y (R x y))",
        "(exists y (S x y))",
        "(forall y (or (not (R x y)) (P y)))",
        "(forall y (or (not (S x y)) (Q y)))",
    ]

    for s in trap_strs:
        try:
            traps.append(parse_sexpr_formula(s))
        except Exception:
            pass

    return traps


def convert_fo_to_ci(
    fo_problem: Dict[str, Any], n_no_worlds: int = 2, seed: int = 0, max_no_attempts: int = 2000
) -> Optional[Dict[str, Any]]:
    """
    Convert an FO problem to CI format.

    - Keep FO worlds as YES worlds
    - Generate NO worlds using trap logic

    Returns None if NO world generation fails.
    """
    problem = fo_problem.get("problem", {})
    desc = fo_problem.get("problemDescription", {})

    # Get gold formula
    hidden_target = desc.get("hiddenTarget", {})
    gold_str = hidden_target.get("formula", "")

    if not gold_str:
        return None

    try:
        gold_formula = parse_sexpr_formula(gold_str)
    except Exception:
        return None

    # Get YES worlds (all FO worlds)
    yes_worlds = []
    for world in problem.get("worlds", []):
        world_copy = copy.deepcopy(world)
        world_copy["splitLabel"] = "YES"
        world_copy["worldId"] = world_copy.get("worldId", "").replace("train_", "yes_")
        yes_worlds.append(world_copy)

    if len(yes_worlds) == 0:
        return None

    # Generate trap formulas
    traps = generate_simple_traps(gold_str)

    # Generate NO worlds
    no_worlds = []
    for i in range(n_no_worlds):
        no_world = find_no_world_for_instance(
            gold_formula,
            yes_worlds,
            traps,
            max_attempts=max_no_attempts // n_no_worlds,
            seed_base=seed + i * 1000,
        )

        if no_world is None:
            # Failed to find NO world
            return None

        no_world["worldId"] = f"no_{i}"
        no_worlds.append(no_world)

    # Verify gold satisfies CI criterion
    for world in yes_worlds:
        model, const_to_idx = build_model_from_world(world)
        target = get_target_set(world)
        if not formula_matches_target(gold_formula, model, target, const_to_idx):
            return None  # Gold doesn't match YES world

    for world in no_worlds:
        model, const_to_idx = build_model_from_world(world)
        target = get_target_set(world)
        if formula_matches_target(gold_formula, model, target, const_to_idx):
            return None  # Gold matches NO world (shouldn't happen)

    # Build CI problem
    all_worlds = yes_worlds + no_worlds

    ci_problem = {
        "instanceId": f"CI_from_FO_{problem.get('instanceId', 'unknown')}",
        "schemaVersion": "fol-concept-synth-v1",
        "scenario": "C",
        "signature": problem.get("signature", {}),
        "backgroundAxioms": [],
        "worlds": all_worlds,
        "task": {
            "yesWorldIds": [w["worldId"] for w in yes_worlds],
            "noWorldIds": [w["worldId"] for w in no_worlds],
        },
    }

    ci_desc = {
        "scenario": "C",
        "seed": seed,
        "hiddenTarget": hidden_target,
        "source_fo_instance": problem.get("instanceId", ""),
        "source_fo_band": desc.get("ad_band", ""),
        "c_band": "control_from_fo",
        "c_gold_family_id": desc.get("gold_family_id"),
        "c_gold_subfamily_key": desc.get("gold_subfamily_key"),
        "c_gold_is_lift_hard": desc.get("gold_is_lift_hard", False),
        "numYesWorlds": len(yes_worlds),
        "numNoWorlds": len(no_worlds),
        "control_generation": True,
        "control_version": "1.0",
    }

    return {
        "problem": ci_problem,
        "problemDescription": ci_desc,
        "problemType": "foInduction",
        "llmResults": [],
    }


def generate_ci_control_dataset(
    fo_dataset_path: Path,
    output_path: Path,
    n_instances: int = 50,
    n_no_worlds: int = 2,
    seed: int = 0,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Generate CI control dataset from FO dataset.

    Returns summary stats.
    """
    if verbose:
        print(f"Loading FO dataset from {fo_dataset_path}...")

    fo_problems = load_from_yaml(str(fo_dataset_path))

    if verbose:
        print(f"Loaded {len(fo_problems)} FO problems")

    # Filter to AD scenario
    fo_problems = [p for p in fo_problems if p.get("problem", {}).get("scenario") == "AD"]

    if verbose:
        print(f"Found {len(fo_problems)} AD problems")

    # Shuffle and select
    rng = random.Random(seed)
    rng.shuffle(fo_problems)

    ci_problems = []
    failures = 0
    attempted = 0

    for i, fo_problem in enumerate(fo_problems):
        if len(ci_problems) >= n_instances:
            break

        attempted += 1
        ci_problem = convert_fo_to_ci(fo_problem, n_no_worlds=n_no_worlds, seed=seed + i * 10000)

        if ci_problem is not None:
            ci_problems.append(ci_problem)
            if verbose and len(ci_problems) % 10 == 0:
                print(f"  Generated {len(ci_problems)}/{n_instances} instances...")
        else:
            failures += 1

    if verbose:
        print(f"\nGenerated {len(ci_problems)} CI control instances")
        print(f"Failed: {failures} ({failures/attempted*100:.1f}% failure rate)")

    # Save dataset
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    save_to_yaml(ci_problems, str(output_path))

    if verbose:
        print(f"Saved to {output_path}")

    summary = {
        "source_fo_dataset": str(fo_dataset_path),
        "output_path": str(output_path),
        "n_requested": n_instances,
        "n_generated": len(ci_problems),
        "n_attempted": attempted,
        "n_failed": failures,
        "failure_rate": failures / attempted if attempted > 0 else 0,
        "n_no_worlds": n_no_worlds,
        "seed": seed,
        "generated_at": datetime.now().isoformat(),
    }

    return summary


def main():
    parser = argparse.ArgumentParser(
        description="Generate CI control dataset from FO",
    )
    parser.add_argument("--fo_dataset", required=True, help="Path to FO (AD) dataset YAML")
    parser.add_argument("--n", type=int, default=50, help="Number of instances to generate")
    parser.add_argument("--no_worlds", type=int, default=2, help="Number of NO worlds per instance")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--out", required=True, help="Output YAML path")
    parser.add_argument("--quiet", "-q", action="store_true")

    args = parser.parse_args()

    summary = generate_ci_control_dataset(
        fo_dataset_path=Path(args.fo_dataset),
        output_path=Path(args.out),
        n_instances=args.n,
        n_no_worlds=args.no_worlds,
        seed=args.seed,
        verbose=not args.quiet,
    )

    # Print summary
    if not args.quiet:
        print("\n" + "=" * 60)
        print("CI Control Dataset Summary")
        print("=" * 60)
        for k, v in summary.items():
            print(f"  {k}: {v}")


if __name__ == "__main__":
    main()
