import logging
from typing import List, Dict, Optional


import torch
import pickle
from ddsketch import DDSketch
from tdigest import TDigest
from flwr.client import NumPyClient
from flwr.common import Context, ConfigRecord, ArrayRecord, Array, Scalar
from lightning import seed_everything


from conformal_fairness.data.base_datamodule import BaseDataModule
from conformal_fairness.utils import get_label_scores
from conformal_fairness.constants import Stage, FairnessMetric
from conformal_fairness.cp_methods.scores import CPScore

from fed_config import FedConfFairExptConfig
from base_fed_model.task import load_data, load_acs_data
from folktables_datamodule import FolktablesDataModule
from pokec_datamodule import PokecDataModule
from fed_constants import (
    CLIENT_ID,
    CommFormulations,
    QuantileMethod,
    SpecialParameters,
    FedCPFitConstants,
    FedCFConstants,
    FedConfFairStage,
    STATE_CONFIGS,
    STATE_PARAMS,
)
from fed_utils import FixedDAPSScore, get_filter_mask


class CPFedClient(NumPyClient):
    def __init__(
        self,
        context: Context,
        datamodule: BaseDataModule,
        calib_probs: torch.Tensor,
        partition_id: int,
        num_partitions: int,
        client_formulation: int = CommFormulations.MORE_PRIVATE.value,
        use_mle: bool = False,
        all_probs: Optional[torch.Tensor] = None,
    ):
        self.datamodule: BaseDataModule = datamodule
        self.partition_id: int = partition_id
        self.num_partitions: int = num_partitions
        self.client_formulation = client_formulation
        self.calib_probs = calib_probs
        self.use_mle = use_mle
        self.all_probs = all_probs

        # number of calibration points
        self.n_k: int = int(self.datamodule.split_dict[Stage.CALIBRATION].shape[0])

        # Ensuring probs is only the size of calibration points
        assert (
            self.n_k == self.calib_probs.shape[0]
        ), f"The number of client calibration points {self.n_k} is different than the number of calibration probs {self.calib_probs.shape[0]}"

        # This dictonary allows the clients to be stateful (i.e., retain information from previous rounds)
        self.client_state = context.state
        if STATE_CONFIGS not in self.client_state.config_records:
            self.client_state.config_records[STATE_CONFIGS] = ConfigRecord()
        if STATE_PARAMS not in self.client_state.parameters_records:
            self.client_state.parameters_records[STATE_PARAMS] = ArrayRecord()

    # The following getters and setters make it easier to store/save the stateful values
    @property
    def calib_labels(self):
        return self.datamodule.y[self.datamodule.split_dict[Stage.CALIBRATION]]

    @property
    def calib_groups(self):
        return self.datamodule.sens[self.datamodule.split_dict[Stage.CALIBRATION]]

    @property
    def fairness_metric(self):
        if "f_metric" not in self.client_state.config_records[STATE_CONFIGS]:
            return None
        return self.client_state.config_records[STATE_CONFIGS]["f_metric"]

    @fairness_metric.setter
    def fairness_metric(self, value):
        self.client_state.config_records[STATE_CONFIGS]["f_metric"] = value

    @property
    def _cached_scores(self):
        if "cached_scores" not in self.client_state.parameters_records[STATE_PARAMS]:
            return None
        return torch.from_numpy(
            self.client_state.parameters_records[STATE_PARAMS]["cached_scores"].numpy()
        )

    @_cached_scores.setter
    def _cached_scores(self, value):
        self.client_state.parameters_records[STATE_PARAMS]["cached_scores"] = Array(
            value.detach().cpu().numpy()
        )

    @property
    def positive_labels(self):
        if "pos_labels" not in self.client_state.parameters_records[STATE_PARAMS]:
            return None
        return torch.from_numpy(
            self.client_state.parameters_records[STATE_PARAMS]["pos_labels"].numpy()
        )

    @positive_labels.setter
    def positive_labels(self, value):
        self.client_state.parameters_records[STATE_PARAMS]["pos_labels"] = Array(
            value.detach().cpu().numpy()
        )

    @property
    def prior(self):
        if "prior" not in self.client_state.parameters_records[STATE_PARAMS]:
            return None
        return torch.from_numpy(
            self.client_state.parameters_records[STATE_PARAMS]["prior"].numpy()
        )

    @prior.setter
    def prior(self, value):
        self.client_state.parameters_records[STATE_PARAMS]["prior"] = Array(
            value.detach().cpu().numpy()
        )

    @property
    def U(self):
        if "U" not in self.client_state.parameters_records[STATE_PARAMS]:
            return None
        return torch.from_numpy(
            self.client_state.parameters_records[STATE_PARAMS]["U"].numpy()
        )

    @U.setter
    def U(self, value):
        self.client_state.parameters_records[STATE_PARAMS]["U"] = Array(
            value.detach().cpu().numpy()
        )

    @property
    def L(self):
        if "L" not in self.client_state.parameters_records[STATE_PARAMS]:
            return None
        return torch.from_numpy(
            self.client_state.parameters_records[STATE_PARAMS]["L"].numpy()
        )

    @L.setter
    def L(self, value):
        self.client_state.parameters_records[STATE_PARAMS]["L"] = Array(
            value.detach().cpu().numpy()
        )

    @property
    def prior_2(self):
        if "prior_2" not in self.client_state.parameters_records[STATE_PARAMS]:
            return None
        return torch.from_numpy(
            self.client_state.parameters_records[STATE_PARAMS]["prior_2"].numpy()
        )

    @prior_2.setter
    def prior_2(self, value):
        self.client_state.parameters_records[STATE_PARAMS]["prior_2"] = Array(
            value.detach().cpu().numpy()
        )

    @property
    def U_2(self):
        if "U_2" not in self.client_state.parameters_records[STATE_PARAMS]:
            return None
        return torch.from_numpy(
            self.client_state.parameters_records[STATE_PARAMS]["U_2"].numpy()
        )

    @U_2.setter
    def U_2(self, value):
        self.client_state.parameters_records[STATE_PARAMS]["U_2"] = Array(
            value.detach().cpu().numpy()
        )

    @property
    def L_2(self):
        if "L_2" not in self.client_state.parameters_records[STATE_PARAMS]:
            return None
        return torch.from_numpy(
            self.client_state.parameters_records[STATE_PARAMS]["L_2"].numpy()
        )

    @L_2.setter
    def L_2(self, value):
        self.client_state.parameters_records[STATE_PARAMS]["L_2"] = Array(
            value.detach().cpu().numpy()
        )

    def fit(self, parameters, config):
        """Train the model with data of this client."""

        assert FedConfFairStage.key in config.keys(), "Need to know the fed_cf stage"

        fed_cf_stage = config[FedConfFairStage.key]
        assert (
            fed_cf_stage in FedConfFairStage.values
        ), f"Stage, {fed_cf_stage}, is not a valid stage"
        # The stage value (sent by the server config) indicates which stage of FedCP/FedCF we are in
        # This allows us to differentiate the training methods
        if fed_cf_stage == FedConfFairStage.FED_CP:
            p, n, m = self.fed_cp_fit(parameters, config)

        elif fed_cf_stage == FedConfFairStage.POP_STAT:
            # Assume fed_cp has already been run
            p, n, m = self.fed_cf_pop_stat(parameters, config)

        elif fed_cf_stage == FedConfFairStage.CF_ITER:
            if SpecialParameters.key in config.keys():
                # check what addition param is there
                self.update_from_special_parameters(parameters, config)

            p, n, m = self.fed_cf_local_cg(parameters, config)

        m.update(
            {CLIENT_ID: self.partition_id}
        )  # Adding client ID to metrics for logging later
        return p, n, m

    def fed_cp_fit(self, parameters, config: Dict[str, Scalar]):
        """
        Computes the Scores and a the client sketch
        config - it should contain a 'quantile_method' in its key
        parameters - dummy variable to match function signature
        """
        assert (
            QuantileMethod.key in config.keys()
        ), "Expected quantile_method in configs"
        assert (
            FedCPFitConstants.SCORE_MODULE in config.keys()
        ), "Expected score_module in configs"

        score_module: CPScore = pickle.loads(config[FedCPFitConstants.SCORE_MODULE])
        if isinstance(score_module, FixedDAPSScore):
            self._cached_scores = score_module.compute(
                self.all_probs, datamodule=self.datamodule
            )[self.datamodule.split_dict[Stage.CALIBRATION]]
        else:
            self._cached_scores = score_module.pipe_compute(self.calib_probs)

        # Mask is all since we only have the probs for the calibration datamodule
        label_scores = get_label_scores(
            labels=self.calib_labels,
            scores=self._cached_scores,
            mask=torch.ones(self.n_k, dtype=torch.bool),
            dataset=self.datamodule.name,
        )

        quantile_method = config[QuantileMethod.key]

        """
        This code is based upon the github repo provided by FedCP.
        Modified from https://github.com/clu5/federated-conformal/blob/main/src/conformal.py
        """
        # Compute the specific skethc - pickling the sketch since we
        # can't directly send the sketch over
        if quantile_method == QuantileMethod.TDIGEST:
            client_digest = TDigest()
            client_digest.batch_update(label_scores.numpy())
            sketch = pickle.dumps(client_digest)

        elif quantile_method == QuantileMethod.DDSKETCH:
            client_sketch = DDSketch()
            for score in label_scores.tolist():
                client_sketch.add(score)
            sketch = pickle.dumps(client_sketch)

        elif quantile_method == QuantileMethod.MEAN:
            quantile = score_module.compute_quantile(label_scores)
            sketch = float(quantile)

        else:
            raise ValueError(f"{quantile_method} not supported")

        # Sending the sketch over as a metric (which is a dict)
        # since it is not a parameters
        metrics = {FedCPFitConstants.QUANT_SKETCH: sketch}
        num_examples = self.n_k

        return [], num_examples, metrics

    def fed_cf_pop_stat(self, parameters, config: Dict[str, Scalar]):
        """
        Computes the population statistics w.r.t the fairness metric
        It is used as a prior in our computations.
        Here we only compute the lower bound. The server has enough information to compute the upper bound from the lower
        config - needs to have the fairness metric
        parameters - should only have one tensor which is the positive labels (assuming 0 indexing)
        """
        assert (
            FedCFConstants.F_M in config.keys()
        ), "The fairness metric needs to be provided to the client to compute population stats."

        self.fairness_metric = config[FedCFConstants.F_M]
        self.positive_labels = torch.from_numpy(parameters[0])

        # Empty matrices to store pop stats computations
        prior = torch.zeros(
            (self.datamodule.num_sensitive_groups, self.positive_labels.shape[0])
        )
        if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
            prior_pe = torch.zeros(
                (self.datamodule.num_sensitive_groups, self.positive_labels.shape[0])
            )
        for i, pos in enumerate(self.positive_labels):
            for g_id in range(self.datamodule.num_sensitive_groups):
                f_m = self._get_filter_mask(pos, g_id)
                if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
                    # for equalized odds prior is equal oppportunity
                    # and prior_pe for predictive equality

                    if self.use_mle:
                        prior[g_id, i] = sum(f_m[0]) / (self.n_k)
                        prior_pe[g_id, i] = sum(f_m[1]) / (self.n_k)
                    else:
                        prior[g_id, i] = sum(f_m[0]) / (self.n_k + 1)
                        prior_pe[g_id, i] = sum(f_m[1]) / (self.n_k + 1)
                else:
                    if self.use_mle:
                        prior[g_id, i] = sum(f_m) / (self.n_k)
                    else:
                        prior[g_id, i] = sum(f_m) / (self.n_k + 1)

        nd_array = [prior.cpu().detach().numpy()]
        if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
            nd_array = nd_array + [prior_pe.cpu().detach().numpy()]

        # The tensor containing the lower_bound must be coverted to a parameter.
        # param = ndarrays_to_parameters(nd_array)
        num_examples = self.n_k
        metrics = {FedCFConstants.CLIENT_FORM: self.client_formulation}

        return nd_array, num_examples, metrics

    def fed_cf_local_cg(self, parameters, config):
        """
        The parameters includes the new set of lambdas
        """
        curr_lmbda = torch.from_numpy(parameters[0])
        if self.partition_id == 0:
            logging.info(
                f"[Client {self.partition_id}/{self.num_partitions}] Current lambda: {curr_lmbda}"
            )

        # Get the scores below the current threshold
        score_below = self._cached_scores[:, self.positive_labels]
        score_below = score_below <= curr_lmbda

        final_nd_arrays = None

        lower = torch.zeros(
            (self.datamodule.num_sensitive_groups, self.positive_labels.shape[0])
        )
        upper = torch.zeros(
            (self.datamodule.num_sensitive_groups, self.positive_labels.shape[0])
        )
        if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
            lower_pe = torch.zeros(
                (
                    self.datamodule.num_sensitive_groups,
                    self.positive_labels.shape[0],
                )
            )
            upper_pe = torch.zeros(
                (
                    self.datamodule.num_sensitive_groups,
                    self.positive_labels.shape[0],
                )
            )

        # Loop to compute coverages for each group - label pair
        # TODO: make more efficient for demographic parity
        for i, pos in enumerate(self.positive_labels):
            for g_id in range(self.datamodule.num_sensitive_groups):
                # The validity of these calculations can be seen in the overleaf
                if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
                    f_eo, f_pe = self._get_filter_mask(pos, g_id)  # eo and pe
                    alpha_feo = sum(score_below[f_eo, i])
                    alpha_fpe = sum(score_below[f_pe, i])

                    if self.use_mle:
                        # In this case upper and lower are the same since a point estimate is an interval with the same endpoints
                        upper[g_id, i] = alpha_feo / self.n_k
                        upper_pe[g_id, i] = alpha_fpe / self.n_k

                        lower[g_id, i] = alpha_feo / self.n_k
                        lower_pe[g_id, i] = alpha_fpe / self.n_k
                    else:
                        n_feo = sum(f_eo)
                        upper[g_id, i] = (alpha_feo + 1) / (self.n_k + 1)
                        lower[g_id, i] = (alpha_feo * n_feo) / (
                            (n_feo + 1) * (self.n_k + 1)
                        )

                        n_fpe = sum(f_pe)
                        upper_pe[g_id, i] = (alpha_fpe + 1) / (self.n_k + 1)
                        lower_pe[g_id, i] = (alpha_fpe * n_fpe) / (
                            (n_fpe + 1) * (self.n_k + 1)
                        )
                else:
                    f_m = self._get_filter_mask(pos, g_id)
                    alpha_fm = sum(score_below[f_m, i])
                    if self.use_mle:
                        upper[g_id, i] = alpha_fm / self.n_k
                        lower[g_id, i] = alpha_fm / self.n_k
                    else:
                        n_fm = sum(f_m)
                        upper[g_id, i] = (alpha_fm + 1) / (self.n_k + 1)
                        lower[g_id, i] = (alpha_fm * n_fm) / (
                            (n_fm + 1) * (self.n_k + 1)
                        )

        # Based on the formulation return the pos-label coverage for each group label pair
        # Or send the pairwise differences
        if self.client_formulation == CommFormulations.MORE_PRIVATE.value:
            lower = lower / self.U
            upper = upper / self.L
            pairwise_cg = [
                (upper[:, None, :] - lower[None, :, :]).cpu().detach().numpy()
            ]  # Creates a tensor of size (num_groups, num_groups, num_pos_labels)

            if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
                lower_pe = lower_pe / self.U_2
                upper_pe = upper_pe / self.L_2
                pairwise_cg = pairwise_cg + [
                    (upper_pe[:, None, :] - lower_pe[None, :, :]).cpu().detach().numpy()
                ]

            final_nd_arrays = pairwise_cg

        elif self.client_formulation == CommFormulations.LOW_OVERHEAD.value:
            nd_list = [lower.cpu().detach().numpy(), upper.cpu().detach().numpy()]
            if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
                nd_list = nd_list + [
                    lower_pe.cpu().detach().numpy(),
                    upper_pe.cpu().detach().numpy(),
                ]
            final_nd_arrays = nd_list

        metrics = dict()
        num_examples = self.n_k

        return final_nd_arrays, num_examples, metrics

    def _get_filter_mask(self, pos_label, group_id):
        labels = self.calib_labels
        groups = self.calib_groups

        return get_filter_mask(
            self.fairness_metric, labels, groups, pos_label, group_id
        )

    def update_from_special_parameters(self, parameters, config) -> None:
        assert (
            SpecialParameters.key in config.keys()
        ), "Need to have 'special_params' to use this function."

        special_param = config[SpecialParameters.key]
        assert (
            special_param in SpecialParameters.values
        ), f"{special_param} parameters is not listed as a Special Parameter."
        """Update U and L or positive label values"""
        if special_param == SpecialParameters.UPDATE_U_L:
            # getting the U and L values if needed
            self.L = torch.from_numpy(parameters[-1])
            self.U = torch.from_numpy(parameters[-2])
            if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
                self.L_2 = torch.from_numpy(parameters[-3])
                self.U_2 = torch.from_numpy(parameters[-4])

        elif special_param == SpecialParameters.UPDATE_POS:
            # updating the active positive labels here
            new_pos_labels = torch.tensor(parameters[-1])
            # values that are true are still pos labels
            # assert (False), f"The updated pos labels are f{new_pos_labels}"
            mask = torch.isin(self.positive_labels, new_pos_labels)
            self.positive_labels = new_pos_labels

            # reduce the values of L and U that are needed
            if self.client_formulation == CommFormulations.MORE_PRIVATE.value:
                self.L = self.L[:, mask]
                self.U = self.U[:, mask]

                if self.fairness_metric == FairnessMetric.EQUALIZED_ODDS.value:
                    self.L_2 = self.L_2[:, mask]
                    self.U_2 = self.U_2[:, mask]

        else:
            raise NotImplementedError

    def evaluate(self, parameters, config):
        """Evaluate the model on the data this client has."""
        """ Using eval to retreive the values of U and L"""

        return 0, 0, dict()
        # return loss, len(self.valloader.dataset), {"accuracy": accuracy}


def fair_client_fn(
    context: Context,
    args: FedConfFairExptConfig,
    datamodule: BaseDataModule,
    probs: torch.Tensor,
    client_formulations: List[int],
    client_mapping: Optional[torch.Tensor] = None,
    global_masks=None,
    global_client_mapping=None,
):
    """Construct a Client that will be run in a ClientApp.
    Assume Federated DataModule Is provided by the client
    """

    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    seed_everything(args.seed + partition_id)

    if isinstance(datamodule, FolktablesDataModule):
        assert (
            client_mapping is not None
        ), "Expected client_mapping for FolktablesDataModule"

        probs_dict = {}
        for part_id in client_mapping.unique():
            id_index = client_mapping == part_id
            probs_dict[part_id.item()] = probs[id_index]

        fed_datamodule, part_probs = load_acs_data(
            datamodule,
            partition_type=args.folktables_partition_type,
            partition_id=partition_id,
            global_masks=global_masks,
            global_client_mapping=global_client_mapping,
            probs=probs_dict,
        )
    elif isinstance(datamodule, PokecDataModule):
        fed_datamodule, part_probs = datamodule.load_partition(partition_id, probs)
    else:
        fed_datamodule, part_probs = load_data(
            datamodule,
            num_partitions=num_partitions,
            partition_id=partition_id,
            probs=probs,
        )

    cl_form = client_formulations[partition_id]

    # Parse out the calibration probabilities from all the probs
    calib_part_probs = part_probs[fed_datamodule.split_dict[Stage.CALIBRATION]]

    return CPFedClient(
        context=context,
        datamodule=fed_datamodule,
        calib_probs=calib_part_probs,
        partition_id=partition_id,
        num_partitions=num_partitions,
        client_formulation=cl_form,
        use_mle=args.cf_opt.use_mle,
        all_probs=part_probs if args.conformal_method.lower() == "daps" else None,
    ).to_client()
