import numpy as np
from typing import Tuple, Optional

from causally.scm.causal_mechanism import CustomMechanism
from causally.scm.scm import AdditiveNoiseModel
from causally.graph.random_graph import GraphGenerator
from causally.scm.noise import Distribution, CustomNoise
from causally.scm.context import SCMContext
from causally.utils.graph import topological_order
from numpy.core.multiarray import array as array

class InvertibleAdditiveNoiseModel(AdditiveNoiseModel):

    def __init__(
        self,
        num_samples: int,
        graph_generator: GraphGenerator,
        noise_scale: float,
        scm_context: Optional[SCMContext] = None,
        seed: Optional[int] = None
    ):
        # Mechanisms for invertibility
        causal_mechanism = CustomMechanism(lambda x: np.log(1 + np.exp(-x/noise_scale)))
        super().__init__(
            num_samples, graph_generator, None, causal_mechanism, scm_context, seed
        )

        # Noise generator for the cause (Logistic as difference of Gumbel)
        self.cause_noise_generator = Logistic(0, noise_scale)

        # Noise generator for the effect
        pdf = lambda n: 1/noise_scale * np.exp(-2*n/noise_scale -  np.exp(-n/noise_scale))

        #cdf = lambda x: (np.exp(x/noise_scale) + 1) * np.exp(-(x/noise_scale + np.exp(-x/noise_scale)))
        self.effect_noise_generator = CustomNoise(pdf=pdf, a=-5, b=+5)


    def sample(self) -> Tuple[np.ndarray, np.ndarray]:
        adjacency = self.adjacency.copy()

        # Find order and sample noise terms
        order = topological_order(adjacency)
        cause_noise = self.cause_noise_generator.sample(size=(self.num_samples,))
        effect_noise = self.effect_noise_generator.sample(size=(self.num_samples,)).squeeze(-1)

        # Sample the effect
        effect = self._sample_mechanism(cause_noise, effect_noise)

        # Build and sort X
        assert order == [0,1] or order == [1,0]
        if order == [0,1]:
            X = np.asarray([cause_noise, effect])
        else:
            X = np.asarray([effect, cause_noise])
        
        return X.transpose(), self.adjacency


class Logistic(Distribution):
    def __init__(self, loc:float, scale:float) -> None:
        super().__init__()
        self.loc = loc
        self.scale = scale

    def sample(self, size: tuple[int]) -> np.ndarray:
        if len(size) != 2:
            ValueError(
                f"Expected number of input dimensions is 2, but were given {len(size)}."
            )
        return np.random.logistic(self.loc, self.scale, size)