import dataclasses
import numpy as np
import torch
import torch.nn.functional as F
import typing
if typing.TYPE_CHECKING:
    from .fq_models import ScoreFunction, InterviewNode
import scipy.optimize

def get_corrected_score(
    candidate_score: float,
    class_weights: dict[int, float],
    multipliers: dict[int, float],
) -> float:
    eps = 1e-15
    candidate_score = np.clip(candidate_score, eps, 1 - eps)
    inverse_score = np.log(candidate_score / (1 - candidate_score))
    class_weights_sum = 0.0
    for class_name, weight in multipliers.items():
        class_weights_sum += weight * class_weights[class_name]
    combined_score = class_weights_sum + inverse_score
    return 1 / (1 + np.exp(-combined_score))

@dataclasses.dataclass
class FairnessSample:
    class_weights: dict[
        int, float
    ]  # c(x_i^j) = {c_1(x_i^j), c_2(x_i^j), ..., c_k(x_i^j)}
    candidate_index: int  # j
    candidate_score: float  # f_\phi(x_i^j)
    candidate_ground_truth: int  # y^j

    def get_corrected_score(
        self,
        multipliers: dict[int, float],
    ) -> float:
        return get_corrected_score(
            self.candidate_score,
            self.class_weights,
            multipliers,
        )

def get_fairness_loss(
    fairness_samples: list[FairnessSample],
    multipliers: dict[int, float] | None,
) -> float:
    # maximum inaccuracy across all classes

    vals: dict[int, float] = {}

    for class_name in fairness_samples[0].class_weights.keys():
        entries: list[float] = []
        for sample in fairness_samples:
            if multipliers is not None:
                entries.append(
                    sample.class_weights[class_name]
                    * (sample.get_corrected_score(multipliers) - sample.candidate_ground_truth)
                )
            else:
                entries.append(
                    sample.class_weights[class_name]
                    * (sample.candidate_score - sample.candidate_ground_truth)
                )
        vals[class_name] = float(np.mean(entries))

    return max([abs(v) for v in vals.values()])


def make_fairness_samples(
    scores: list[tuple[float, float, "InterviewNode"]],
    target_skill: str,
    torch_device: str,
) -> list[FairnessSample]:
    fairness_samples = []

    for i, (logit_pos, logit_neg, node) in enumerate(scores):
        normed_score = torch.tensor(
            [logit_pos, logit_neg],
            device=torch_device,
        )
        normed_score = F.softmax(normed_score, dim=0)
        normed_score = normed_score.cpu().numpy()

        candidate_score = 1 if target_skill in node.candidate.skills else 0

        fairness_samples.append(
            FairnessSample(
                class_weights=node.candidate.five_factors.to_int_dict(),
                candidate_index=i,
                candidate_score=normed_score[0],
                candidate_ground_truth=candidate_score,
            )
        )
    return fairness_samples


@dataclasses.dataclass
class FairnessCorrectionResult:
    initial_fairness_loss: float
    fairness_loss: float | None
    distillation_loss: float | None
    distill_fairness_loss: float | None
    l_c: dict[int, float] | None
    initial_accuracy : float
    corrected_accuracy : float

def fairness_corrected_f(
        sample: FairnessSample,
        multipliers: dict[int, float] | None,
) -> float:
    # f^*(x_i^j) = \mathrm{softmax}\bigg(\sum_{c\in\mathcal C} l_c c(x_i^j) + \mathrm{softmax}^{-1}(f(x_i^j))\bigg)
    # where \mathrm{softmax} is the sigmoid function since this is 1D

    if multipliers is None:
        return sample.candidate_score

    inv_f = torch.sigmoid(
        torch.tensor(sample.candidate_score, dtype=torch.float32)
    ).item()
    class_weights = 0.0
    for class_name, weight in multipliers.items():
        class_weights += weight * sample.class_weights[class_name]

    combined_score = class_weights + inv_f
    return torch.sigmoid(torch.tensor(combined_score, dtype=torch.float32)).item()

def sigmoid(
        x: np.ndarray,
) -> np.ndarray:
    return 1 / (1 + np.exp(-x))

def inverse_sigmoid(
        x: np.ndarray,
) -> np.ndarray:
    # Clip x to avoid division by zero or log of zero
    eps = 1e-15
    x = np.clip(x, eps, 1 - eps)
    return np.log(x / (1 - x))

def minimized_function(
        l_c: np.ndarray,
        y_j: np.ndarray,
        f_x_ij: np.ndarray,
        weights: np.ndarray,
) -> float:
    # \frac{1}{N} \sum_{j=1}^N -\bigg(y^j \log \Big(\mathrm{softmax}\Big(\sum_{c\in\mathcal C} l_c c(x_i^j) + \mathrm{softmax}^{(-1)}(f(x_i^j))\Big)\Big) \\&\quad+ (1 - y^j) \log \Big(1 - \mathrm{softmax}\Big(\sum_{c\in\mathcal C} l_c c(x_i^j) + \mathrm{softmax}^{(-1)}(f(x_i^j))\Big)\Big)\bigg)

    # Add numerical stability by clipping probabilities away from 0 and 1
    eps = 1e-15
    probs = sigmoid(
        np.sum(l_c * weights, axis=1) + inverse_sigmoid(f_x_ij)
    )
    probs = np.clip(probs, eps, 1 - eps)

    first_sec = y_j * np.log2(probs)
    second_sec = (1 - y_j) * np.log2(1 - probs)
    return - float(np.mean(first_sec + second_sec))



def compute_multipliers(
        samples: list[FairnessSample],

) -> dict[int, float]:

    intial_a = np.array(
        [0.001 for _ in range(len(samples[0].class_weights))]
    )

    y_j = np.array(
        [sample.candidate_ground_truth for sample in samples]
    )
    f_x_ij = np.array(
        [sample.candidate_score for sample in samples]
    )
    weights = np.array(
        [[sample.class_weights[class_name] for class_name in sorted(sample.class_weights.keys())] for sample in samples]
    )
    # We need to minimize the function
    a_min = scipy.optimize.minimize(
        minimized_function,
        intial_a,
        args=(y_j, f_x_ij, weights),
        method="L-BFGS-B",
        bounds=[(0.0, 10) for _ in range(len(intial_a))],
    )
    multipliers = {
        class_name: float(a_min.x[i])
        for i, class_name in enumerate(sorted(samples[0].class_weights.keys()))
    }
    return multipliers




def fairness_correction(
    score_model: "ScoreFunction",
    scores: list[tuple[float, float, "InterviewNode"]],
    target_skill: str,
    epsilon: float,
    torch_device: str,
) -> FairnessCorrectionResult:
    samples = make_fairness_samples(
        scores=scores,
        target_skill=target_skill,
        torch_device=torch_device,
    )

    initial_fairness_loss = get_fairness_loss(samples, None)

    #if intial_fairness_loss < epsilon:
    #    return FairnessCorrectionResult(
    #        intial_fairness_loss=intial_fairness_loss,
    #        fairness_loss=None,
    #        distillation_loss=None,
    #        distill_fairness_loss=None,
    #        l_c=None,
    #    )

    ## Compute multipliers
    multipliers = compute_multipliers(samples,
    )

    fairness_loss = get_fairness_loss(samples, multipliers)
    if fairness_loss > epsilon:
        raise ValueError(
            f"Fairness loss {fairness_loss} is greater than epsilon {epsilon}. "
            "This indicates that the fairness correction did not succeed."
        )

    initial_accuracy = get_accuracy(
        samples,
        None)
    corrected_accuracy = get_accuracy(
        samples,
        multipliers)


    return FairnessCorrectionResult(
            initial_fairness_loss=initial_fairness_loss,
            fairness_loss=fairness_loss,
            distillation_loss=None,
            distill_fairness_loss=None,
            l_c={},
            initial_accuracy=initial_accuracy,
            corrected_accuracy=corrected_accuracy,
        )
    raise NotImplementedError("Distillation loss is not implemented yet.")


def get_accuracy(
    samples: list[FairnessSample],
    multipliers: dict[int, float] | None,
) -> float:
    correct = 0
    for sample in samples:
        score = sample.candidate_score
        if multipliers is not None:
            score = sample.get_corrected_score(multipliers)
        if (score >= 0.5 and sample.candidate_ground_truth == 1) or (
            score < 0.5 and sample.candidate_ground_truth == 0
        ):
            correct += 1
    return correct / len(samples)
