from numpy.linalg import matrix_rank
from causallearn.utils.cit import CIT
import numpy as np
import pandas as pd
from scipy.stats.distributions import chi2

def split_and_diagonalize(arr,num):
    assert arr.shape[1]%num==0, 'Not integer divisible.'
    subarrays = np.split(arr, num, axis=1)
    result_shape = (arr.shape[0]*len(subarrays), arr.shape[1])
    result = np.zeros(result_shape)
    for i, subarray in enumerate(subarrays):
        start_row = i * arr.shape[0]
        start_col = i * subarray.shape[1]
        end_row = start_row + subarray.shape[0]
        end_col = start_col + subarray.shape[1]
        result[start_row:end_row, start_col:end_col] = subarray
    return result

class ProxyTest_mulx():
    def __init__(self, pdData, level_a=15, level_y=15, level_w=9, level_x=1, eps=1e-5):

        self.pdData = pdData
        self.n = self.pdData.shape[0]

        self.level_a = level_a
        self.level_y = level_y
        self.level_w = level_w
        self.level_x = level_x
        assert level_a == self.pdData['A'].unique().shape[0]
        assert level_y == self.pdData['Y'].unique().shape[0]
        assert level_w == self.pdData['W'].unique().shape[0]

        self.eps = eps


    def _est_qy_hat(self,):
        X = [var for var in self.pdData.columns if 'X' in var]
        if len(X) == 0:
            self.pa_hat = self.pdData["A"].value_counts(
                normalize=True).sort_index().values
        else:
            self.pa_hat = (pd.crosstab(index=[self.pdData['{}'.format(
                name)] for name in X]+[self.pdData["A"]],columns=['joint_prob'], normalize=True,dropna=False))
            self.pa_hat[self.pa_hat == 0] = self.eps
            self.pa_hat[self.pa_hat == 1] -= self.level_a*self.eps
            self.pa_hat = self.pa_hat.values.squeeze()

        pa_hat_copy = np.tile(self.pa_hat, self.level_y-1)
        self.conditon_pob_ay = pd.crosstab(self.pdData["Y"], [
                                           self.pdData['{}'.format(name)] for name in X]+[self.pdData["A"]], normalize='columns', dropna=False)
        self.conditon_pob_ay[self.conditon_pob_ay == 0] = self.eps
        self.conditon_pob_ay[self.conditon_pob_ay ==
                             1] -= self.level_y*self.eps
        self.q_hat = self.conditon_pob_ay.iloc[:-1, :].values.ravel(order='C')
        self.Sigmay_hat = np.diag((self.q_hat * (1-self.q_hat) / pa_hat_copy))

    def _est_Q_hat(self,):
        X = [var for var in self.pdData.columns if 'X' in var]
        if  len(X) == 0:
            W = [var for var in self.pdData.columns if 'W' in var]
            conditon_pob_wa = pd.crosstab([self.pdData['{}'.format(
                name)] for name in W], self.pdData["A"], normalize='columns', dropna=False)
            self.Q_hat = conditon_pob_wa.values
            self.Q0_hat = np.kron(np.eye(self.level_y - 1), self.Q_hat)
            assert matrix_rank(self.Q0_hat) == self.level_w*(self.level_y -
                                                            1), 'Warning, the matrix Q does not have full row rank.'   
        else:
            W = [var for var in self.pdData.columns if 'W' in var]
            X = [var for var in self.pdData.columns if 'X' in var]
            conditon_pob_wa = pd.crosstab([self.pdData['{}'.format(name)] for name in W], [self.pdData['{}'.format(name)] for name in X]+[
                                        self.pdData["A"]], normalize='columns', dropna=False)
            conditon_pob_wa[conditon_pob_wa == 0] = self.eps
            conditon_pob_wa[conditon_pob_wa == 1] -= self.level_w*self.eps
            self.Q_hat = split_and_diagonalize(conditon_pob_wa.values,self.level_x)
            self.Q0_hat = np.kron(np.eye(self.level_y - 1), self.Q_hat)
            assert matrix_rank(self.Q0_hat) == self.level_x*self.level_w*(self.level_y -
                                                            1), 'Warning, the matrix Q does not have full row rank.'

    def _proxytest(self, todiscrete=True):
        r"""
        Functs: - compute the chi-square statistic
                - the larger (than 0.05) p-value, the more independent (H0 holds)
        """
        self._est_qy_hat()
        self._est_Q_hat()

        I = np.eye(self.level_a*self.level_x*(self.level_y-1))

        Sigmay_half_inv = np.linalg.inv(np.sqrt(self.Sigmay_hat))
        Sigmay_inv = np.linalg.inv(self.Sigmay_hat)

        Omegay_hat = I - Sigmay_half_inv@self.Q0_hat.T@np.linalg.inv(
            self.Q0_hat@Sigmay_inv@self.Q0_hat.T)@self.Q0_hat@Sigmay_half_inv
        
        Xiy = Omegay_hat@Sigmay_half_inv@self.q_hat
        Ty = self.n * Xiy.T@Xiy

        p_value = chi2.sf(
            Ty.item(), (self.level_a-self.level_w)*(self.level_y-1)*self.level_x)

        return p_value
