#!/usr/bin/env python3
"""
Generate skip_pairing datasets for fixed n, k (optionally refined).

Partitions are noncrossing partitions of [1..n] with n-k+1 blocks,
enumerated in a fixed deterministic order.
The q,t distributions are read from qt_narayana_n{N}_k{k}.json:
- unrefined: coefficients for n.
- refined: coefficients for n=m minus n=m-1, where m is the max
  non-singleton element in the family.
"""

from __future__ import annotations

import argparse
import json
from dataclasses import dataclass
from itertools import combinations
from pathlib import Path
from typing import Iterable


@dataclass(frozen=True)
class SingleLine:
    value: list


class SingleLineJSONEncoder(json.JSONEncoder):
    """JSON encoder that keeps SingleLine lists on one line."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._single_line: dict[str, str] = {}

    def default(self, obj):
        if isinstance(obj, SingleLine):
            marker = f"__SL_{id(obj)}__"
            self._single_line[marker] = json.dumps(obj.value)
            return marker
        return super().default(obj)

    def encode(self, obj):
        result = super().encode(obj)
        for marker, payload in self._single_line.items():
            result = result.replace(f"\"{marker}\"", payload)
        return result


def partition_from_blocks(n: int, blocks: Iterable[Iterable[int]]) -> list[list[int]]:
    used = set()
    parts = []
    for block in blocks:
        block_list = sorted(block)
        used.update(block_list)
        parts.append(block_list)
    for i in range(1, n + 1):
        if i not in used:
            parts.append([i])
    parts.sort(key=lambda b: b[0])
    return parts


def is_noncrossing(partition: list[list[int]]) -> bool:
    blocks = [b for b in partition if b]
    for i, b1 in enumerate(blocks):
        for j, b2 in enumerate(blocks):
            if i >= j:
                continue
            for a in b1:
                for c in b1:
                    if a >= c:
                        continue
                    for b in b2:
                        for d in b2:
                            if b >= d:
                                continue
                            if a < b < c < d:
                                return False
    return True


def skip_stat(partition: list[list[int]]) -> int:
    total = 0
    for block in partition:
        if len(block) > 1:
            total += max(block) - min(block) - len(block) + 1
    return total


def max_non_singleton(partition: list[list[int]]) -> int:
    return max(max(block) for block in partition if len(block) > 1)


def size_partitions(total: int, max_part: int | None = None) -> Iterable[list[int]]:
    if max_part is None:
        max_part = total
    if total == 0:
        yield []
        return
    for part in range(min(max_part, total), 0, -1):
        for rest in size_partitions(total - part, part):
            yield [part] + rest


def blocks_for_sizes(elems: list[int], sizes: list[int]) -> Iterable[list[list[int]]]:
    if not sizes:
        yield []
        return
    size = sizes[0]
    require_smallest = size in sizes[1:]
    for combo in combinations(elems, size):
        if require_smallest and combo[0] != elems[0]:
            continue
        chosen = set(combo)
        remaining = [x for x in elems if x not in chosen]
        for rest in blocks_for_sizes(remaining, sizes[1:]):
            yield [list(combo)] + rest


def partitions_general(n: int, k: int) -> list[list[list[int]]]:
    parts: list[list[list[int]]] = []
    for part_sizes in size_partitions(k - 1):
        sizes = [p + 1 for p in part_sizes]
        subset_size = sum(sizes)
        for subset in combinations(range(1, n + 1), subset_size):
            elems = list(subset)
            for blocks in blocks_for_sizes(elems, sizes):
                partition = partition_from_blocks(n, blocks)
                if is_noncrossing(partition):
                    parts.append(partition)
    return parts


def load_qt_map(path: Path) -> tuple[dict[int, list[list[int]]], dict[int, dict[int, dict[int, int]]]]:
    data = json.loads(path.read_text())
    qt_terms: dict[int, list[list[int]]] = {}
    qt_map: dict[int, dict[int, dict[int, int]]] = {}
    for entry in data["data"]:
        n_val = entry["n"]
        qt_terms[n_val] = entry["terms"]
        by_q: dict[int, dict[int, int]] = {}
        for q_exp, t_exp, coeff in entry["terms"]:
            by_q.setdefault(q_exp, {})[t_exp] = coeff
        qt_map[n_val] = by_q
    return qt_terms, qt_map


def find_qt_file(n: int, k: int) -> Path:
    candidates = []
    for path in Path(".").glob(f"qt_narayana_n*_k{k}.json"):
        name = path.stem
        try:
            n_part = name.split("_k")[0].split("qt_narayana_n")[1]
            max_n = int(n_part)
        except (IndexError, ValueError):
            continue
        if max_n >= n:
            candidates.append((max_n, path))
    if not candidates:
        raise FileNotFoundError(f"No qt_narayana_n*_k{k}.json with max n >= {n}.")
    return min(candidates, key=lambda item: item[0])[1]


def distribution(qt_terms: dict[int, list[list[int]]], n_val: int, skip: int) -> dict[int, int]:
    dist: dict[int, int] = {}
    for q_exp, t_exp, coeff in qt_terms.get(n_val, []):
        if q_exp == skip:
            dist[t_exp] = coeff
    return dist


def refined_distribution(
    qt_terms: dict[int, list[list[int]]],
    qt_map: dict[int, dict[int, dict[int, int]]],
    n_val: int,
    skip: int,
    k: int,
) -> dict[int, int]:
    dist: dict[int, int] = {}
    prev = n_val - 1
    for q_exp, t_exp, coeff in qt_terms.get(n_val, []):
        if q_exp != skip:
            continue
        prev_coeff = 0
        if prev >= k:
            prev_coeff = qt_map.get(prev, {}).get(skip, {}).get(t_exp, 0)
        value = coeff - prev_coeff
        if value:
            dist[t_exp] = value
    return dist


def write_json(path: Path, families: list[dict[str, object]]) -> None:
    payload = {
        "families": [
            {
                "count": fam["count"],
                "distribution": fam["distribution"],
                "key": fam["key"],
                "partitions": [SingleLine(p) for p in fam["partitions"]],
            }
            for fam in families
        ]
    }
    rendered = json.dumps(payload, indent=2, cls=SingleLineJSONEncoder)
    path.write_text(rendered + "\n", encoding="utf-8")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate skip_pairing datasets for fixed n,k."
    )
    parser.add_argument("n", type=int, help="n (must satisfy n >= k >= 1).")
    parser.add_argument("k", type=int, help="k (must satisfy n >= k >= 1).")
    parser.add_argument("--refined", action="store_true", help="Refine by max non-singleton.")
    args = parser.parse_args()

    if args.k < 1 or args.n < args.k:
        parser.error("Require n >= k >= 1.")

    qt_path = find_qt_file(args.n, args.k)
    qt_terms, qt_map = load_qt_map(qt_path)

    partitions = partitions_general(args.n, args.k)

    families: list[dict[str, object]] = []
    if args.refined:
        grouped: dict[tuple[int, int], list[list[list[int]]]] = {}
        for partition in partitions:
            key = (max_non_singleton(partition), skip_stat(partition))
            grouped.setdefault(key, []).append(partition)
        for max_val in sorted({k for k, _ in grouped}):
            for skip_val in sorted({s for (m, s) in grouped if m == max_val}):
                parts = grouped[(max_val, skip_val)]
                dist = refined_distribution(qt_terms, qt_map, max_val, skip_val, args.k)
                families.append(
                    {
                        "key": [args.k - 1, max_val, skip_val],
                        "count": len(parts),
                        "partitions": parts,
                        "distribution": {str(k): dist[k] for k in sorted(dist, key=lambda x: str(x))},
                    }
                )
        suffix = "_refined"
    else:
        grouped: dict[int, list[list[list[int]]]] = {}
        for partition in partitions:
            skip_val = skip_stat(partition)
            grouped.setdefault(skip_val, []).append(partition)
        for skip_val in sorted(grouped):
            parts = grouped[skip_val]
            dist = distribution(qt_terms, args.n, skip_val)
            families.append(
                {
                    "key": [args.k - 1, skip_val],
                    "count": len(parts),
                    "partitions": parts,
                    "distribution": {str(k): dist[k] for k in sorted(dist, key=lambda x: str(x))},
                }
            )
        suffix = ""

    output = Path(f"skip_pairing_n{args.n}_k{args.k}{suffix}.json")
    write_json(output, families)
    print(f"Wrote {output}")


if __name__ == "__main__":
    main()
