# src/grouping/dynamic_grouper.py
from typing import Dict, List
import numpy as np
import random

class DynamicGrouper:
    """
    Compute per-client scores (e.g. Owen value) from the last round’s metrics
    and convert them into group assignments for the next round.
    """

    # --------------------------------------------------------------------- #
    # constructor
    # --------------------------------------------------------------------- #
    def __init__(
        self,
        n_clients: int,
        n_groups: int,
        score_fn=None,             # fn(list[metric_dict]) -> {cid: score}
        group_fn=None,             # fn({cid: score}, n_groups) -> {cid: grp}
        random_regroup=False,      # if True: ignore scores, re-shuffle evenly
        rng: random.Random = None, # allow deterministic shuffles via seed
    ):
        self.n_clients   = n_clients
        self.n_groups    = n_groups
        self.random_regroup = random_regroup
        self.rng = rng or random.Random()

        # plug-in callbacks (fallback to defaults defined below)
        self.score_fn = score_fn or self._dummy_score_fn
        self.group_fn = group_fn or self._quantile_grouping

        # internal state
        self.latest_scores: Dict[str, float] = {}
        self.latest_groups: Dict[str, int]   = {str(cid): 0 for cid in range(n_clients)}

    # --------------------------------------------------------------------- #
    # public API
    # --------------------------------------------------------------------- #
    def on_round_end(
        self,
        rnd: int,
        client_metrics: List[Dict[str, float]],
    ) -> None:
        """
        Call after `server.fit_round`.  Updates self.latest_groups so that the
        next round can query `get_group_map()`.
        """
        if self.random_regroup:
            self.latest_groups = self._balanced_random_groups()
            return

        # 1. compute scores (Owen value or whatever)
        self.latest_scores = self.score_fn(client_metrics)

        # 2. map scores → groups
        self.latest_groups = self.group_fn(self.latest_scores, self.n_groups)

    def get_group_map(self) -> Dict[str, int]:
        """Return {cid: group_id} for the next round."""
        return self.latest_groups

    # --------------------------------------------------------------------- #
    # default callbacks
    # --------------------------------------------------------------------- #
    @staticmethod
    def _dummy_score_fn(client_metrics: List[Dict[str, float]]) -> Dict[str, float]:
        """If no accuracy present, just give everyone 0.5."""
        return {m.get("cid"): m.get("accuracy", 0.5) for m in client_metrics}

    @staticmethod
    def _quantile_grouping(scores: Dict[str, float], n_groups: int) -> Dict[str, int]:
        """
        Higher score → lower group index (0 is best).  Groups are filled by
        quantiles, so sizes are ≈ balanced.
        """
        cids, vals = zip(*scores.items())
        ranks      = np.argsort(np.argsort(-np.array(vals)))          # desc
        quantiles  = np.floor(ranks * n_groups / len(scores)).astype(int)
        return {cid: int(q) for cid, q in zip(cids, quantiles)}

    # --------------------------------------------------------------------- #
    # helper: balanced random shuffle
    # --------------------------------------------------------------------- #
    def _balanced_random_groups(self) -> Dict[str, int]:
        """
        Return a uniformly random assignment where each group has either
        ⌊n_clients/n_groups⌋ or ⌈…⌉ members.
        """
        cids = list(map(str, range(self.n_clients)))
        self.rng.shuffle(cids)

        base   = self.n_clients // self.n_groups
        extra  = self.n_clients % self.n_groups  # first `extra` groups get +1
        groups = {}
        idx    = 0
        for g in range(self.n_groups):
            sz = base + (1 if g < extra else 0)
            for _ in range(sz):
                groups[cids[idx]] = g
                idx += 1
        return groups

# ---------------------------------------------------------------------------
# Example: a *dummy* Owen-value scorer that returns random scores
# (swap into DynamicGrouper via score_fn=owen_value_score_fn)
# ---------------------------------------------------------------------------
def owen_value_score_fn(client_metrics: List[Dict[str, float]]) -> Dict[str, float]:
    """
    Stand-in for a real Owen calculation: assigns a random score ∈ [0,1).
    """
    rng = random.Random(23)  # fixed seed → reproducible
    return {m.get("cid"): rng.random() for m in client_metrics}
