from dataclasses import dataclass, field
from conformal_fairness.config import (
    BaseExptConfig,
    ConfExptConfig,
    ConfFairExptConfig,
    DiffusionConfig,
    RegularizedConfig,
)
from typing import List, Union, Optional
from fed_constants import FOLKTABLES_ALL, FOLKTABLES_CONTINENTAL_ALL, CommFormulations


@dataclass
class FixedRegularizedConfig(RegularizedConfig):
    raps_k: int = field(default=3)
    raps_lambda: float = field(default=0.1)


@dataclass
class FixedDiffusionConfig(DiffusionConfig):
    daps_lambda: float = field(default=0.2)


@dataclass
class FedBaseExptConfig(BaseExptConfig):
    num_server_rounds: int = field(default=3)
    fraction_fit: float = field(default=0.1)
    num_clients: int = field(default=10)
    folktables_partition_type: str = field(default="")  # Changed from int to str

    def __post_init__(self):
        super().__post_init__()
        if "small" in self.folktables_partition_type:
            self.num_clients = 4
        elif "large" in self.folktables_partition_type:
            self.num_clients = 8
        elif self.folktables_partition_type == FOLKTABLES_CONTINENTAL_ALL:
            self.num_clients = 48
        elif self.folktables_partition_type == FOLKTABLES_ALL:
            self.num_clients = 51


@dataclass
class FedConfExptConfig(ConfExptConfig):
    num_clients: int = field(default=10)
    folktables_partition_type: str = field(default="")

    quantile_method: str = field(default="ddsketch")
    debug_mode: bool = field(default=False)

    # Forcing it to have fixed parameters
    regularization_config: Optional[FixedRegularizedConfig] = field(
        default_factory=FixedRegularizedConfig
    )

    # Forcing it to have fixed parameters
    diffusion_config: Optional[FixedDiffusionConfig] = field(
        default_factory=FixedDiffusionConfig
    )

    def __post_init__(self):
        super().__post_init__()
        if "small" in self.folktables_partition_type:
            self.num_clients = 4
        elif "large" in self.folktables_partition_type:
            self.num_clients = 8
        elif self.folktables_partition_type == FOLKTABLES_CONTINENTAL_ALL:
            self.num_clients = 48
        elif self.folktables_partition_type == FOLKTABLES_ALL:
            self.num_clients = 51

    #     if self.conformal_method == ConformalMethod.RAPS:
    #         assert(self.regularization_config isinstance(FixedRegularizedConfig))


@dataclass
class CFOptConfig:
    """optimization parameters for the cf framework"""

    num_opt_rounds: int = field(default=25)
    lr: float = field(default=0.01)
    momentum: float = field(default=0.5)
    use_mle: bool = field(default=False)


@dataclass
class FedConfFairExptConfig(FedConfExptConfig, ConfFairExptConfig):
    client_formulations: Union[List[int], int] = field(
        default=CommFormulations.LOW_OVERHEAD.value
    )
    cf_opt: CFOptConfig = field(default_factory=CFOptConfig)

    def __post_init__(self):
        super().__post_init__()
        if isinstance(self.client_formulations, int):
            self.client_formulations = [self.client_formulations] * self.num_clients
        else:
            assert len(self.client_formulations) == self.num_clients

    pass
