import math
import csv
from pathlib import Path
from typing import Dict, List, Tuple

def bytes_fedavg_round(param_count: int,
                       dtype_bytes: int = 4,
                       include_down: bool = True) -> Tuple[int, int, int]:
    """Per-round bytes for FedAvg, per *client*.
    Returns (uplink_bytes_per_client, downlink_bytes_per_client, total_per_client)."""
    dense_bytes = param_count * dtype_bytes
    up = dense_bytes
    down = dense_bytes if include_down else 0
    return up, down, up + down


def bytes_topk_round(param_count: int,
                     k_frac: float,
                     dtype_bytes: int = 4,
                     index_bytes: int = 4,
                     include_down: bool = True,
                     downlink_dense: bool = True) -> Tuple[int, int, int]:
    """Per-round bytes for Top-k methods, per *client*.
    Uplink: k values + k indices (int32). Downlink: typically dense new model.
    Returns (uplink, downlink, total)."""
    k = int(math.ceil(param_count * k_frac))
    up = k * (dtype_bytes + index_bytes)
    down = param_count * dtype_bytes if (include_down and downlink_dense) else 0
    return up, down, up + down


def bytes_cser_round(param_count: int,
                     base_k_frac: float,
                     round_idx: int,
                     H: int,
                     avg_frac: float,
                     dtype_bytes: int = 4,
                     index_bytes: int = 4,
                     include_down: bool = True,
                     downlink_dense: bool = True) -> Tuple[int, int, int]:
    """CSER: Top-k every round + every H rounds send avg_frac extra residual (with indices)."""
    up_base = int(math.ceil(param_count * base_k_frac)) * (dtype_bytes + index_bytes)
    up_extra = 0
    if H > 0 and ((round_idx % H) == 0):
        k_extra = int(math.ceil(param_count * avg_frac))
        up_extra = k_extra * (dtype_bytes + index_bytes)
    up = up_base + up_extra
    down = param_count * dtype_bytes if (include_down and downlink_dense) else 0
    return up, down, up + down


def accumulate_comm(acc_list: List[float]) -> List[float]:
    """Cumulative sum of a list."""
    out, s = [], 0.0
    for a in acc_list:
        s += a
        out.append(s)
    return out


def load_accuracy_vs_commucated_GBytes(
    acc_by_round: Dict[str, List[float]],
    model_name: str,
    m: int,
    topk_fracs: Dict[str, float],
    CSER_H: int,
    CSER_avg_frac: float,
    dtype_bytes: int = 4,
    index_bytes: int = 4,
    include_downlink: bool = True,
    downlink_dense: bool = True,
) -> List[Dict]:

    MODEL_PARAMS = {
        "ResNet9":  6_575_370,
        "ResNet18": 12_556_426, 
        "ResNet34": 21_797_672,
    }
    if model_name not in MODEL_PARAMS:
        raise ValueError(f"Unknown model name: {model_name}")
    param_count = MODEL_PARAMS[model_name]

    def base_name(method: str) -> str:
        # 'EF_top1' -> 'EF', 'SAEF' -> 'SAEF'
        return method.split('_', 1)[0].upper()

    def get_k(method: str) -> float:
        # exact match first, then by base name
        if method in topk_fracs:
            return float(topk_fracs[method])
        b = base_name(method)
        if b in topk_fracs:
            return float(topk_fracs[b])
        raise ValueError(f"No top-k fraction provided for method '{method}'")

    rows = []
    for method, acc_vec in acc_by_round.items():
        kind = base_name(method)  # FEDAVG, EF, SAEF, SAPEF, CSER
        total_bytes_per_round_per_client = []

        for r in range(1, len(acc_vec) + 1):
            if kind == "FEDAVG":
                up, down, tot = bytes_fedavg_round(
                    param_count, dtype_bytes, include_downlink
                )
            elif kind in {"EF", "SAEF", "SAPEF"}:
                k_frac = get_k(method)
                up, down, tot = bytes_topk_round(
                    param_count, k_frac, dtype_bytes, index_bytes,
                    include_downlink, downlink_dense
                )
            elif kind == "CSER":
                k_frac = get_k(method)
                up, down, tot = bytes_cser_round(
                    param_count, k_frac, r, CSER_H, CSER_avg_frac,
                    dtype_bytes, index_bytes, include_downlink, downlink_dense
                )
            else:
                raise ValueError(f"Unknown method: {method}")

            total_bytes_per_round_per_client.append(tot)

        # Cumulative per client and system-wide (= per client × m)
        cum_per_client = accumulate_comm(total_bytes_per_round_per_client)
        cum_selected   = [b * m for b in cum_per_client]

        for r, acc in enumerate(acc_vec, start=1):
            rows.append({
                "model": model_name,
                "method": method,
                "round": r,
                "acc": acc,
                "cum_bytes_per_client": cum_per_client[r-1],
                "cum_GB_per_client":    cum_per_client[r-1] / 1e9,
                "cum_GiB_per_client":   cum_per_client[r-1] / (1024**3),
                "cum_bytes_selected":   cum_selected[r-1],
                "cum_GB_selected":      cum_selected[r-1] / 1e9,
                "cum_GiB_selected":     cum_selected[r-1] / (1024**3),
            })
    return rows
