# ########################################
# #
# #  Custom DAG SEM
# #
# ########################################


# import numpy as np
# from typing import Tuple
# from numpy.typing import NDArray

# from src.sem.abstract import StructuralEquationModel as SEM


# # specify default parameters
# TREATMENT_DIMENSION: int=32
# CONFOUNDER_DIMENSION: int=TREATMENT_DIMENSION
# OUTCOME_DIMENSION: int=1
# NOISE_STD: float=0.1


# class LinearSimulationSEM(SEM):
#     def __init__(
#             self,
#             treatment_dimension: int=TREATMENT_DIMENSION,
#             confounder_dimension: int=CONFOUNDER_DIMENSION,
#             outcome_dimension: int=OUTCOME_DIMENSION,
#         ):
#         self.treatment_dimension = treatment_dimension
#         self.confounder_dimension = confounder_dimension
#         self.outcome_dimension = outcome_dimension

#         self.W_UX = np.random.randn(
#             confounder_dimension,
#             treatment_dimension
#         )
#         self.W_UY = np.random.randn(
#             confounder_dimension,
#             outcome_dimension
#         )
#         self.W_XY = np.random.randn(
#             treatment_dimension,
#             outcome_dimension
#         )
        
#         super(LinearSimulationSEM, self).__init__()
    
#     def sample(
#             self, N: int= 1, kappa: float=1.0, intervention: bool=False, **kwargs
#         ) -> Tuple[NDArray, NDArray]:

#         U = np.random.randn(
#             N, self.confounder_dimension
#         )
#         N_X = np.random.randn(
#             N, self.treatment_dimension
#         )
#         N_Y = np.random.randn(
#             N, self.outcome_dimension
#         )

#         if intervention:
#             X = N_X
#         else:
#             X = (
#                 U @ self.W_UX
#                 + NOISE_STD * N_X
#             )
        
#         Y = (
#             X @ self.W_XY               # f(X)
#             + kappa * U @ self.W_UY     # \xi
#             + NOISE_STD * N_Y           # noise
#         )

#         self.varXi = np.var(Y - X @ self.W_XY)
#         self.varEXiX = np.var(
#             X @ np.linalg.pinv(X) @ (
#                 Y - X @ self.W_XY
#             )
#         )
#         # for debugging
#         # print('linear Var(xi): ', self.varXi)
#         # print('linear Var(E[xi|X]): ', self.varEXiX)

#         return X, Y


########################################
##
##  Cyclic SEM from Akbar et al. (2025)
##
########################################


import numpy as np
from typing import Tuple
from numpy.typing import NDArray

from src.sem.abstract import StructuralEquationModel as SEM


# specify default parameters
TREATMENT_DIMENSION: int=32
CONFOUNDER_DIMENSION: int=TREATMENT_DIMENSION
OUTCOME_DIMENSION: int=1
NOISE_STD: float=0.1
SIMULTANEITY: bool=False


class LinearSimulationSEM(SEM):
    def __init__(
            self,
            treatment_dimension: int=TREATMENT_DIMENSION,
            confounder_dimension: int=CONFOUNDER_DIMENSION,
            outcome_dimension: int=OUTCOME_DIMENSION,
            normalize: bool=SIMULTANEITY,
        ):
        self.treatment_dimension = treatment_dimension
        self.confounder_dimension = confounder_dimension
        self.outcome_dimension = outcome_dimension

        self.W_CXY = np.random.randn(
            confounder_dimension,
            treatment_dimension+outcome_dimension
        )
        self.W_XY = np.random.randn(
            treatment_dimension, outcome_dimension
        )
        self.W_YX = SIMULTANEITY * np.random.randn(
                outcome_dimension, treatment_dimension
        )

        if normalize:
            self.W_YX = self.W_YX / np.linalg.norm(self.W_YX)
            self.W_XY = self.W_XY / np.linalg.norm(self.W_XY)
        
        super(LinearSimulationSEM, self).__init__()
    
    def sample(
            self, N: int= 1, kappa: float=1.0, intervention: bool=False, **kwargs
        ) -> Tuple[NDArray, NDArray]:

        # check SEM solvability (unique stationary distribution):
        #   ( 1 - kappa * tau @ f ) != 0
        feedback_strength = kappa * (self.W_YX @ self.W_XY).item()
        assert not np.isclose(
            feedback_strength, 1.0
        ), f'SEM may not be solvable with kappa={kappa}. Condition κ*fᵀτ={feedback_strength:.4f} is too close to 1.'

        C = np.random.randn(N, self.confounder_dimension)
        N_XY = np.random.randn(
            N, self.treatment_dimension + self.outcome_dimension
        )
        if not intervention:
            N_XY[:, :self.treatment_dimension] *= NOISE_STD
        N_XY[:, -self.outcome_dimension:] *= NOISE_STD

        # make block matrix for structural mechanism XY -> XY
        zeros_1x1 = np.zeros(
            (self.treatment_dimension, self.treatment_dimension)
        )
        zeros_2x2 = np.zeros(
            (self.outcome_dimension, self.outcome_dimension)
        )
        if intervention:
            T = np.block([
                [zeros_1x1, self.W_XY],
                [0.0*self.W_YX, zeros_2x2],
            ])
        else:
            T = np.block([
                [zeros_1x1, self.W_XY],
                [kappa*self.W_YX, zeros_2x2],
            ])

        # solve cyclic SEM for X, Y
        exogenous = C @ self.W_CXY
        if intervention:
            exogenous[:, :self.treatment_dimension] *= 0.0
        exogenous[:, -self.outcome_dimension:] *= kappa

        exogenous += N_XY
        I = np.eye(
            self.treatment_dimension + self.outcome_dimension
        )
        M = np.linalg.inv(I - T)

        XY = exogenous @ M
        X, Y = (
            XY[:, :self.treatment_dimension],
            XY[:, -self.outcome_dimension:],
        )

        self.varXi = np.var(Y - X @ self.W_XY)
        self.varEXiX = np.var(
            X @ np.linalg.pinv(X) @ (
                Y - X @ self.W_XY
            )
        )
        # # for debugging
        # print('linear Var(xi): ', self.varXi)
        # print('linear Var(E[xi|X]): ', self.varEXiX)
        
        return X, Y

