import math
from typing import Tuple

import numpy as np

class SingleVariableFunction:
    def __init__(self):
        # Placeholder for local extrema, to be populated by child classes
        self.local_minima = []
        self.local_maxima = []

    def f(self, x):
        """Evaluate the function at x. Must be implemented by subclasses."""
        raise NotImplementedError("Subclasses must implement this method.")

    def range_mapping(self, a: float, b: float) -> Tuple[float, float]:
        """
        Map the range [a, b] on the input to the range of the output.
        Includes function values at a, b, and any local extrema within [a, b].
        """
        # Ensure that a <= b
        if a > b:
            a, b = b, a

        # Evaluate f at the endpoints
        values = [self.f(a), self.f(b)]

        # Add f(lm) for each local minimum/maximum within [a, b]
        values.extend([self.f(lm) for lm in self.local_minima if a <= lm <= b])
        values.extend([self.f(lm) for lm in self.local_maxima if a <= lm <= b])

        return min(values), max(values)


class SoftmaxFunction(SingleVariableFunction):
    def __init__(self):
        super().__init__()
        # There are no local minima or maxima for f(z) = softmax(z), so we leave the extrema lists empty.

    def f(self, z):
        """Compute softmax for a single variable (equivalent to sigmoid)."""
        return np.exp(z) / (1 + np.exp(z))


class BetaFunction(SingleVariableFunction):
    def __init__(self):
        super().__init__()
        # The derivative of softmax (beta) has a local maximum at z = 0.
        self.local_maxima = [0]  # Add the location of the local maximum.

    def f(self, z):
        """Compute the derivative of softmax for a single variable (beta function)."""
        softmax_value = np.exp(z) / (1 + np.exp(z))
        return softmax_value * (1 - softmax_value)


class GammaFunction(SingleVariableFunction):
    def __init__(self):
        super().__init__()
        # Local extrema of gamma occur at z = ln(2 - sqrt(3)) and z = ln(2 + sqrt(3))
        sqrt_3 = np.sqrt(3)
        self.local_minima = [np.log(2 - sqrt_3)]  # Local minimum
        self.local_maxima = [np.log(2 + sqrt_3)]  # Local maximum

    def f(self, z):
        """Compute the derivative of beta (gamma function)."""
        softmax_value = np.exp(z) / (1 + np.exp(z))
        beta_value = softmax_value * (1 - softmax_value)
        return beta_value * (1 - 2 * softmax_value)


class WtDeltaFunction(SingleVariableFunction):
    def __init__(self, beta_0, g_ii):
        """
        Initialize the WtDeltaFunction.

        Parameters:
        beta_0 (float): Initial value of beta.
        g_ii (float): Positive parameter from the Gram matrix (must be > 0).
        """
        super().__init__()
        self.beta_0 = beta_0
        self.g_ii = g_ii

        # Add local minimum at z = 0
        self.local_minima = [0]

    def f(self, z):
        """
        Compute the normalized delta (wt_delta) at a given z.

        Parameters:
        z (float): Input value.

        Returns:
        float: Normalized delta.
        """
        # Compute beta(z) = softmax(z) * (1 - softmax(z))
        softmax_value = np.exp(z) / (1 + np.exp(z))
        beta = softmax_value * (1 - softmax_value)

        # Compute delta and wt_delta
        delta = self.beta_0 - beta
        return delta / (1 - self.g_ii * delta)




def compute_z_roots_no_sympy(beta_0_value, g_ii_value):
    # Hardcoded discriminant formula
    discriminant = (
            36.0 * beta_0_value ** 2 * g_ii_value ** 2
            - 4.0 * beta_0_value * g_ii_value ** 2
            - 72.0 * beta_0_value * g_ii_value
            + 1.0 * g_ii_value ** 2
            + 4.0 * g_ii_value
            + 36.0
    )

    # If the discriminant is negative, return no roots
    if discriminant < 0:
        print("No real roots exist for the quadratic equation.")
        return []

    # Hardcoded coefficients of the quadratic equation: ax^2 + bx + c = 0
    a = -2.0 * g_ii_value
    b = 6.0 * beta_0_value * g_ii_value - 6.0
    c = -0.5 * beta_0_value * g_ii_value + 0.125 * g_ii_value + 0.5

    # Compute roots of the quadratic equation
    w1 = (-b + math.sqrt(discriminant)) / (2 * a)
    w2 = (-b - math.sqrt(discriminant)) / (2 * a)

    # Filter out negative roots
    valid_w_solutions = [w for w in [w1, w2] if w >= 0]

    # For each valid w, compute the corresponding y and z roots
    z_values = []
    for w_root in valid_w_solutions:
        # Compute y roots: y = 1/2 ± sqrt(w)
        y_roots = [0.5 + math.sqrt(w_root), 0.5 - math.sqrt(w_root)]
        for y in y_roots:
            if 0 < y < 1:  # y must be in the valid range for probabilities
                # Compute z = logit(y) = log(y / (1 - y))
                z = math.log(y / (1 - y))
                z_values.append(z)

    return z_values


class WtDeltaDerivativeFunction(SingleVariableFunction):
    def __init__(self, beta_0, g_ii):
        """
        Initialize the WtDeltaDerivativeFunction.

        Parameters:
        beta_0 (float): Initial value of beta.
        g_ii (float): Positive parameter from the Gram matrix (must be > 0).
        """
        super().__init__()
        self.beta_0 = beta_0
        self.g_ii = g_ii

        # Compute local extrema (minima and maxima) during initialization
        local_extrema = self._find_extrema()
        self.local_minima = [local_extrema[0]]
        self.local_maxima = [local_extrema[1]]

    def f(self, z):
        """
        Compute the derivative of wt_delta with respect to z (d/dz wt_delta).

        Parameters:
        z (float): Input value.

        Returns:
        float: Derivative of wt_delta at z.
        """
        # Compute softmax(z)
        softmax_value = np.exp(z) / (1 + np.exp(z))
        beta = softmax_value * (1 - softmax_value)
        delta = self.beta_0 - beta
        d_beta_dz = (1 - 2 * softmax_value) * softmax_value * (1 - softmax_value)

        # Compute d/dz of wt_delta
        numerator = (
                -(1 - 2 * softmax_value) * softmax_value * (1 - softmax_value) * (1 - self.g_ii * delta)
                - delta * self.g_ii * (1 - 2 * softmax_value) * softmax_value * (1 - softmax_value)
        )
        denominator = (1 - self.g_ii * delta) ** 2

        return numerator / denominator

    def _find_extrema(self):
        """
        Find the z values where d/dz wt_delta achieves local minima and maxima.

        Returns:
        list: [z_min, z_max] where z_min is the local minimum and z_max is the local maximum.
        """
        # Use the pre-defined compute_z_roots_no_sympy function to find extrema
        z_roots = compute_z_roots_no_sympy(self.beta_0, self.g_ii)

        # Ensure exactly two roots are returned (minima and maxima)
        if len(z_roots) != 2:
            raise ValueError("The algorithm did not return exactly two extrema.")

        # Return the roots as [z_min, z_max]
        return z_roots
