import numpy as np

sigmoid = lambda x : 1/(1+np.exp(-x))

class ATPNN:
    def __init__(self, var, data, b: float, gamma: float, bin_function=sigmoid):
        self.variable = var
        self.b = b              
        self.gamma = gamma      
        # self.bin_function = bin_function

        x = data[:, self.variable]
        # new_x = self.bin_function((x - self.b)/self.gamma)
        # new_x1 = self.bin_function((self.b - x)/self.gamma)
        new_x = sigmoid((x-self.b)/self.gamma)
        new_x1 = sigmoid((self.b-x)/self.gamma)
        self.eta = np.mean(new_x, axis = 0).item()
        self.c = -np.sum(new_x1, axis = 0)/np.sum(new_x, axis = 0)
        self.c = self.c.item()

    def forward(self, data: np.ndarray) -> np.ndarray:
        x = data[:, self.variable]
        # new_x = self.bin_function((x - self.b)/self.gamma)
        # new_x1 = self.bin_function((self.b - x)/self.gamma)
        new_x = sigmoid((x-self.b)/self.gamma)
        new_x1 = sigmoid((self.b-x)/self.gamma)
        
        result = self.c * new_x + new_x1
        # assert result.ndim == 1
        return result                               # (n,)

    def b_score_function(self, data: np.ndarray) -> np.ndarray :
        x = data[:, self.variable]
        sigma = sigmoid((x-self.b)/self.gamma)      # (n,)
        integrated = sigma * (1-sigma)
        res = - (sigma * np.mean(integrated)) / (self.eta**2 * self.gamma) - ((self.c - 1) * integrated)/(self.gamma)     # (n,)

        return res

    def gamma_score_function(self, data: np.ndarray) -> np.ndarray:
        x = data[:, self.variable]
        sigma = sigmoid((x-self.b)/self.gamma)      # (n,)
        integrated = -(x-self.b)/(self.gamma ** 2) * sigma * (1-sigma)      # (n,)
        res = (np.mean(integrated)/self.eta**2) * sigma + (self.c -1) * integrated      # (n,)

        return res
    
    def component_forward(self, var_array: np.ndarray) -> np.ndarray:
        var_array = var_array.squeeze()     # (n,)
        assert var_array.ndim == 1
        new_x = sigmoid((var_array - self.b)/self.gamma)
        new_x1 = sigmoid((self.b - var_array)/self.gamma)
        result = self.c * new_x + new_x1            # (n,)

        return result


class TPNN:
    def __init__(self, p: int, y_dist: str, structure: list[ATPNN]=None, height = 0. , z: int=1):
        self.variables = np.arange(p)      
        self.y_dist = y_dist
        self.structure = structure if structure is not None else []      # list[ATPNN]
        self.height = height
        self.z = z

        if not self.structure:
            root_height = 0.0
            self.height = root_height

    def prod_structure(self, data: np.ndarray) -> np.ndarray:
        result = np.ones((data.shape[0],))
        if self.structure:
            for atpnn in self.structure:
                tmp_result = atpnn.forward(data)            # (n,)
                result = np.column_stack((result, tmp_result))
            result = np.prod(result, axis = 1)
        else :
            result = np.zeros((data.shape[0],))
        return result

    def log_likelihood(self, lambda_t: np.ndarray, data: np.ndarray, y:np.ndarray, nui = None) -> float:
        # lambda_t \in \mathbb{R}^n
        forward_t = lambda_t + self.height * self.z * self.prod_structure(data)
        if self.y_dist == 'normal':
            assert nui['sigma2']
            var_error = nui['sigma2']
            # y ~ N(forward_t, var_error)
            log_likelihood = -np.sum((y - forward_t)**2).item()/(2*var_error)
        elif self.y_dist == 'poisson':
            # y ~ Poisson(exp(forward_t))
            # forward_t += nui
            log_likelihood = np.sum(y * forward_t - np.exp(forward_t)).item()
        elif self.y_dist == 'ber':
            # y ~ Ber(exp(forward_t)/(1+exp(forward_t)))
            # forward_t += nui
            log_likelihood = np.sum(y*forward_t - np.log(1+np.exp(forward_t))).item()

        return log_likelihood

    def height_score_function(self, lambda_t: np.ndarray, data: np.ndarray, y: np.ndarray, nui = None) -> float:
        spline = self.prod_structure(data)                      # (n,)
        forward_t = lambda_t + self.height * self.z * spline    # (n,)
        # print(f'y shape : {y.shape}, spline shape : {spline.shape}')
        if self.y_dist == 'normal':
            assert nui['sigma2']
            var_error = nui['sigma2']
            height_score = np.sum((y - forward_t) * spline).item()/var_error    # scalar
            height_score *= self.z
        elif self.y_dist == 'poisson':
            # forward_t += nui
            height_score = np.sum(y * spline - np.exp(forward_t) * spline).item()       # scalar
            height_score *= self.z
        elif self.y_dist == 'ber':
            # forward_t += nui
            exp_forward_t = np.exp(forward_t)               # (n,)
            height_score = np.sum(y * spline - spline * exp_forward_t/(1+exp_forward_t))
            height_score *= self.z                                              # scalar

        return height_score
    
    def b_gamma_score_function(self, lambda_t: np.ndarray, data: np.ndarray, y: np.ndarray, nui = None):
        cur_bs = np.array([atpnn.b for atpnn in self.structure])
        cur_gammas = np.array([atpnn.gamma for atpnn in self.structure])
        assert cur_bs.shape[0] == cur_gammas.shape[0]

        d = cur_bs.shape[0]
        assert d > 0

        if not self.z :
            bs_score = np.zeros((d,))
            gammas_score = np.zeros((d,))
            return bs_score, gammas_score
        
        basis = np.ones((data.shape[0], d))
        for idx, atpnn in enumerate(self.structure):
            basis[:, idx] = atpnn.forward(data)
        
        m = lambda_t + self.forward(data)
        if self.y_dist == 'normal':
            first_term = (y - m)/nui['sigma2']
        elif self.y_dist == 'poisson':
            first_term = y - np.exp(m)
        elif self.y_dist == 'ber' :
            first_term = y - sigmoid(m)
        else :
            NotImplementedError
        
        for idx, atpnn in enumerate(self.structure):
            b_score_idx = atpnn.b_score_function(data)                  # (n,) 
            gamma_score_idx = atpnn.gamma_score_function(data)          # (n,) 
            
            if d > 1 :
                idx_mask = np.arange(d) != idx                          # (d,)
                basis_idx = np.prod(basis[:, idx_mask], axis = 1)       # (n,)
            else :
                basis_idx = np.ones((data.shape[0],))                   # (n,)

            if idx == 0:
                b_second_term = b_score_idx * basis_idx
                b_second_term = b_second_term.reshape((-1, 1))          # (n, 1)

                gamma_second_term = gamma_score_idx * basis_idx
                gamma_second_term = gamma_second_term.reshape((-1, 1))      # (n, 1)
            else :
                b_second_term = np.column_stack((b_second_term, b_score_idx * basis_idx))               # (n, d)
                gamma_second_term = np.column_stack((gamma_second_term, gamma_score_idx * basis_idx))   # (n, d)     

        bs_score = first_term.reshape((1, -1)) @ b_second_term        # (1, n) @ (n, d) -> (1, d)
        bs_score = bs_score.squeeze() * self.height * self.z          # (d,)

        gammas_score = first_term.reshape((1, -1)) @ gamma_second_term      # (1, n) @ (n, d) -> (1, d)
        gammas_score = gammas_score.squeeze() * self.height * self.z        # (d,)


        return bs_score, gammas_score


    def forward(self, data: np.ndarray) -> np.ndarray:
        result = self.prod_structure(data) * self.height * self.z    # (n,)
        return result

