from typing import Dict, Union
from os.path import join
import torch
from copy import deepcopy
from conformal_fairness.data import BaseDataModule, BaseDataset
from conformal_fairness.constants import (
    ACS_INCOME,
    ACS_EDUC,
    Stage,
    FairnessMetric,
    ConformalMethod,
)
from conformal_fairness.config import (
    PrimitiveScoreConfig,
)
from conformal_fairness.conformal_predictors import APSScore, TPSScore
from conformal_fairness.cp_methods.transformations import (
    DiffusionTransformation,
    RegularizationTransformation,
)
from fed_config import (
    FedConfFairExptConfig,
    FixedDiffusionConfig,
    FixedRegularizedConfig,
)
from pokec_datamodule import PokecDataModule


def get_test_dm_probs(datamodule: BaseDataModule, probs: torch.Tensor):
    if isinstance(datamodule, PokecDataModule):
        return datamodule.load_test_dm_probs(probs)

    ds: BaseDataset = datamodule._base_dataset
    ds_test = deepcopy(ds)

    if isinstance(ds_test.X, list):
        ds_test.X = [datamodule.X[i] for i in datamodule.split_dict[Stage.TEST]]
    else:
        ds_test.X = datamodule.X[datamodule.split_dict[Stage.TEST]]
    ds_test.y = datamodule.y[datamodule.split_dict[Stage.TEST]]
    ds_test.sens = datamodule.sens[datamodule.split_dict[Stage.TEST]]
    ds_test.split_config.train = 0
    ds_test.split_config.valid = 0
    ds_test.split_config.calib = 0

    dm_test = deepcopy(datamodule)
    dm_test._base_dataset = ds_test
    dm_test._init_with_dataset(ds_test)
    probs = probs[datamodule.split_dict[Stage.TEST]]

    assert dm_test.num_points == probs.shape[0]
    return dm_test, probs


def get_filter_mask(
    fairness_metric, labels: torch.Tensor, groups: torch.Tensor, pos_label, group_id
):
    assert (
        labels.shape[0] == groups.shape[0]
    ), f"Got {labels.shape[0]} labels, but {groups.shape[0]} groups"

    match fairness_metric:
        case FairnessMetric.EQUAL_OPPORTUNITY.value:
            label_satisfied = labels == pos_label
            group_satisfied = groups == group_id
            return (label_satisfied & group_satisfied).reshape(-1)

        case FairnessMetric.PREDICTIVE_EQUALITY.value:
            label_not_satisfied = labels != pos_label
            group_satisfied = groups == group_id
            return (label_not_satisfied & group_satisfied).reshape(-1)

        case FairnessMetric.EQUALIZED_ODDS.value:
            label_satisfied = labels == pos_label
            label_not_satisfied = labels != pos_label
            group_satisfied = groups == group_id

            return (
                (label_satisfied & group_satisfied).reshape(-1),
                (label_not_satisfied & group_satisfied).reshape(-1),
            )
        case (
            FairnessMetric.DEMOGRAPHIC_PARITY.value
            | FairnessMetric.DISPARATE_IMPACT.value
            | FairnessMetric.OVERALL_ACC_EQUALITY.value
        ):
            return (groups == group_id).reshape(-1)
        case _:
            raise NotImplementedError(
                f"Filtering function not implemented for {fairness_metric}"
            )


def get_score_module(conformal_method, split_conf_input, alpha=None):
    if conformal_method in [ConformalMethod.TPS, ConformalMethod.APS]:
        assert isinstance(split_conf_input, PrimitiveScoreConfig)
    elif conformal_method == ConformalMethod.DAPS:
        assert isinstance(split_conf_input, FixedDiffusionConfig)
    elif conformal_method == ConformalMethod.RAPS:
        assert isinstance(split_conf_input, FixedRegularizedConfig)
    else:
        raise NotImplementedError

    if conformal_method == ConformalMethod.TPS:
        score_module = TPSScore(split_conf_input, alpha=alpha)
    elif conformal_method == ConformalMethod.APS:
        score_module = APSScore(split_conf_input, alpha=alpha)
    elif conformal_method == ConformalMethod.RAPS:
        score_module = FixedRAPSScore(split_conf_input, alpha=alpha)
    elif conformal_method == ConformalMethod.DAPS:
        score_module = FixedDAPSScore(split_conf_input, alpha=alpha)
    else:
        raise NotImplementedError

    return score_module


class FixedDAPSScore(APSScore):
    def __init__(self, config: PrimitiveScoreConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.config = config

    def compute(self, probs, **kwargs):
        scores = super().compute(probs, **kwargs)

        daps_params = {
            DiffusionTransformation.DAPS_LAMBDA: self.config.daps_lambda,
        } | kwargs

        transform_module = DiffusionTransformation(config=self.config)

        scores = transform_module.transform(
            scores,
            **daps_params,
        )

        return scores


class FixedRAPSScore(APSScore):
    def __init__(self, config: PrimitiveScoreConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.config = config

    def compute(self, probs, **kwargs):
        scores = super().compute(probs, **kwargs)

        raps_params = {
            RegularizationTransformation.RAPS_K: self.config.raps_k,
            RegularizationTransformation.RAPS_LAMBDA: self.config.raps_lambda,
        } | kwargs

        transform_module = RegularizationTransformation(config=self.config)

        if self.config.raps_mod:
            assert (
                self.config.use_aps_epsilon == False
            ), "Should not use randomized APS with modified RAPS"

        scores = transform_module.transform(
            scores,
            probs,
            raps_modified=self.config.raps_mod,
            **raps_params,
        )

        return scores


def get_fair_configs_res_path(conf: FedConfFairExptConfig):
    conf_dict: Dict[str, Union[int, str, float]] = dict()

    conf_dict["seed"] = conf.conformal_seed or conf.seed
    conf_dict["c"] = conf.closeness_measure

    out_file_name = ""
    if conf.conformal_method == ConformalMethod.APS.value:
        out_file_name = (
            "aps.csv" if conf.primitive_config.use_aps_epsilon else "aps_no_rand.csv"
        )
    elif conf.conformal_method == ConformalMethod.TPS.value:
        out_file_name = (
            "tps_classwise.csv"
            if conf.primitive_config.use_tps_classwise
            else "tps.csv"
        )
    elif conf.conformal_method == ConformalMethod.RAPS.value:
        conf_dict[RegularizationTransformation.RAPS_K] = (
            conf.regularization_config.raps_k
        )
        conf_dict[RegularizationTransformation.RAPS_LAMBDA] = (
            conf.regularization_config.raps_lambda
        )
        if conf.regularization_config.raps_mod:
            out_file_name = "raps_mod.csv"
        else:
            out_file_name = (
                "raps_rand_no_mod.csv"
                if conf.primitive_config.use_aps_epsilon
                else "raps_no_rand_no_mod.csv"
            )
    elif conf.conformal_method == ConformalMethod.DAPS.value:
        conf_dict[DiffusionTransformation.DAPS_LAMBDA] = (
            conf.diffusion_config.daps_lambda
        )
        out_file_name = "daps.csv"
    else:
        raise NotImplementedError

    ds_name = f"{conf.dataset.name}"

    if conf.dataset.sens_attrs:
        ds_name += "_".join([""] + conf.dataset.sens_attrs)

    out_dir = join(
        "analysis",
        "fairness_trials",
        ds_name,
        (
            f"{conf.num_clients}_clients"
            if ds_name not in (ACS_INCOME, ACS_EDUC)
            else conf.folktables_partition_type
        ),
        f"split_{conf.dataset_split_fractions.train}_{conf.dataset_split_fractions.valid}",
        conf.fairness_metric,
    )

    return conf_dict, out_dir, out_file_name
