import math
from dataclasses import dataclass
from typing import Optional, override

import numpy as np
import rustworkx as rx

from .. import DatasetConfig, register_dataset
from .simple_synthetic import SimpleSyntheticDataset, SimpleSyntheticDatasetConfig


@dataclass
class SyntheticGraphDatasetConfig(DatasetConfig):
    test_size: int
    predictor_size: int
    calibrator_size: int
    input_dim: int
    poly_degree: int
    n_nodes: Optional[int]
    output_dim: Optional[int]


class SyntheticGraphDataset(SimpleSyntheticDataset):
    def __init__(
        self, config: SyntheticGraphDatasetConfig, params: Optional[dict] = None
    ):
        self.config = config

        if config.n_nodes is None and config.output_dim is None:
            raise ValueError(
                "Either n_nodes or output_dim must be provided. Available options: n_nodes, output_dim"
            )

        if config.n_nodes is not None and config.output_dim is not None:
            # Check consistency
            expected_output_dim = config.n_nodes * (config.n_nodes - 1) // 2
            if expected_output_dim != config.output_dim:
                raise ValueError(
                    f"output_dim {config.output_dim} does not match n_nodes {config.n_nodes}; expected {expected_output_dim}"
                )
            n_nodes = config.n_nodes
            output_dim = config.output_dim
        elif config.n_nodes is not None:
            n_nodes = config.n_nodes
            output_dim = n_nodes * (n_nodes - 1) // 2
        else:  # config.output_dim is not None
            assert config.output_dim is not None
            output_dim = config.output_dim
            # Compute n_nodes from output_dim
            discriminant = 1 + 8 * output_dim
            sqrt_disc = math.sqrt(discriminant)
            n_nodes_val = (1 + sqrt_disc) / 2
            if not n_nodes_val.is_integer() or n_nodes_val <= 0:
                raise ValueError(
                    f"output_dim {output_dim} does not correspond to a valid number of nodes"
                )
            n_nodes = int(n_nodes_val)

        synthetic_config = SimpleSyntheticDatasetConfig(
            type=config.type,
            test_size=config.test_size,
            predictor_size=config.predictor_size,
            calibrator_size=config.calibrator_size,
            input_dim=config.input_dim,
            poly_degree=config.poly_degree,
            output_dim=output_dim,
        )
        super().__init__(synthetic_config)
        self.config = config
        self.n_nodes = n_nodes
        self.top_k = None  # Graph datasets use max-weight matching, not top-k

        # Precompute edge index to (u, v) mapping for faster graph construction
        self._edge_endpoints: list[tuple[int, int]] = []
        for u in range(self.n_nodes):
            for v in range(u + 1, self.n_nodes):
                self._edge_endpoints.append((u, v))

    def scale(self) -> float:
        return float(self.n_nodes / 2)

    @override
    def decision_function(self, y_pred: np.ndarray) -> np.ndarray:
        """Apply maximum weight matching to graph predictions.

        Args:
            y_pred: (..., n_nodes * (n_nodes - 1) / 2) edge predictions

        Returns:
            One-hot encoding of selected edges in matching, same shape as input
        """
        n_nodes = self.n_nodes
        n_edges = y_pred.shape[-1]
        leading_shape = y_pred.shape[:-1]
        edge_endpoints = self._edge_endpoints

        # Flatten leading dimensions for iteration
        y_flat = y_pred.reshape(-1, n_edges)
        batch_size = y_flat.shape[0]

        ohe = np.zeros_like(y_flat)

        for i in range(batch_size):
            scores = y_flat[i]

            # Build weighted edge list: [(u, v, weight), ...]
            weighted_edges = [
                (u, v, float(scores[idx])) for idx, (u, v) in enumerate(edge_endpoints)
            ]

            G = rx.PyGraph()
            G.add_nodes_from(range(n_nodes))
            G.extend_from_weighted_edge_list(weighted_edges)

            matching = rx.max_weight_matching(
                G, max_cardinality=False, weight_fn=lambda x: int(x * 1e6)
            )

            for u, v in matching:
                if u > v:
                    u, v = v, u

                edge_idx = u * (2 * n_nodes - u - 1) // 2 + (v - u - 1)

                ohe[i, edge_idx] = 1.0

        return ohe.reshape(*leading_shape, n_edges)


__all__ = ["SyntheticGraphDatasetConfig", "SyntheticGraphDataset"]

register_dataset("synthetic_graph", SyntheticGraphDatasetConfig, SyntheticGraphDataset)
