from castle.algorithms.gradient.notears.linear import *
import torch
import os

class notears_prior(Notears):
    def __init__(self, lambda1=0.1,
                 sigma=1.0, 
                 loss_type='l2', 
                 max_iter=100, 
                 h_tol=1e-8, 
                 rho_max=1e+16, 
                 w_threshold=0.3,
                 prior_type=None,
                 regular_type='l1',
                 device_type='gpu',
                 device_ids=0,
                 W_initial=None,
                 sigma_MCP = 0.2):

        super().__init__()

        self.lambda1 = lambda1
        self.sigma = sigma
        self.loss_type = loss_type
        self.max_iter = max_iter
        self.h_tol = h_tol
        self.rho_max = rho_max
        self.w_threshold = w_threshold
        
        self.w_prior=None
        self.prob_prior=0
        self.ground_truth=None
        self.prior_type=prior_type
        self.regular_type = regular_type
        self.device_type = device_type
        self.device_ids = device_ids
        self.W_initial=W_initial
        self.sigma_MCP=sigma_MCP

        if torch.cuda.is_available():
            logging.info('GPU is available.')
        else:
            logging.info('GPU is unavailable.')
            if self.device_type == 'gpu':
                raise ValueError("GPU is unavailable, "
                                 "please set device_type = 'cpu'.")
        if self.device_type == 'gpu':
            if self.device_ids:
                os.environ['CUDA_VISIBLE_DEVICES'] = str(self.device_ids)
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        self.device = device

    def learn(self, data, columns=None, **kwargs):
        """
        Set up and run the Notears algorithm.

        Parameters
        ----------
        data: castle.Tensor or numpy.ndarray
            The castle.Tensor or numpy.ndarray format data you want to learn.
        columns : Index or array-like
            Column labels to use for resulting tensor. Will default to
            RangeIndex (0, 1, 2, ..., n) if no column labels are provided.
        """
        #X = Tensor(data, columns=columns)
    
        X=data
        n, d = X.shape
        W_est = self.notears_linear(X, lambda1=self.lambda1,
                                    sigma=self.sigma, 
                                    loss_type=self.loss_type,
                                    max_iter=self.max_iter, 
                                    h_tol=self.h_tol, 
                                    rho_max=self.rho_max,
                                    w_threshold=self.w_threshold,
                                    W_initial=self.W_initial,
                                    sigma_MCP=self.sigma_MCP)
        
        if self.loss_type == 'l2' or self.loss_type == 'pdf':
            X = X - np.mean(X, axis=0, keepdims=True)
        Omega_est = np.diag(np.diag((np.eye(d) - W_est).T @ ((1.0 / n) * (X.T @ X)) @ (np.eye(d) - W_est)))

        causal_matrix = (abs(W_est) >= self.w_threshold).astype(int)
        X = Tensor(data, columns=columns)
        self.weight_causal_matrix = Tensor(W_est,
                                           index=X.columns,
                                           columns=X.columns)
        self.causal_matrix = Tensor(causal_matrix, index=X.columns,
                                    columns=X.columns)
        
        return W_est, Omega_est
        
    def notears_linear(self,X, lambda1=0.1, loss_type='pdf', max_iter=100, h_tol=1e-8, rho_max=1e+16, w_threshold=0.3,sigma=1.0,W_initial=None,sigma_MCP=0.2):
        """Solve min_W L(W; X) + lambda1 ‖W‖_1 s.t. h(W) = 0 using augmented Lagrangian.

        Args:
            X (np.ndarray): [n, d] sample matrix
            lambda1 (float): l1 penalty parameter
            loss_type (str): l2, logistic, poisson
            max_iter (int): max num of dual ascent steps
            h_tol (float): exit if |h(w_est)| <= htol
            rho_max (float): exit if rho >= rho_max
            w_threshold (float): drop edge if |weight| < threshold

        Returns:
            W_est (np.ndarray): [d, d] estimated DAG
        """
        def _loss(W, sigma=sigma):
            """Evaluate value and gradient of loss."""
            M = X @ W
            if loss_type == 'l2':
                R = X - M
                loss = 0.5 / X.shape[0] * (R ** 2).sum()
                G_loss = - 1.0 / X.shape[0] * X.T @ R
            elif loss_type == 'logistic':
                loss = 1.0 / X.shape[0] * (np.logaddexp(0, M) - X * M).sum()
                G_loss = 1.0 / X.shape[0] * X.T @ (sigmoid(M) - X)
            elif loss_type == 'poisson':
                S = np.exp(M)
                loss = 1.0 / X.shape[0] * (S - X * M).sum()
                G_loss = 1.0 / X.shape[0] * X.T @ (S - X)
            elif loss_type == 'pdf':
                R = X - M
                loss= 0.5 / X.shape[0] * (R ** 2).sum() / sigma**2
                G_loss = - 1.0 / X.shape[0] * X.T @ R / sigma**2
            elif loss_type == 'nll':
                n, d = X.shape
                X_torch = torch.from_numpy(X).to(self.device)
                W_torch = torch.from_numpy(W).to(self.device)
                W_torch.requires_grad = True

                if self.regular_type == 'l1':
                    loss = 0.5 * torch.sum(torch.log(torch.sum(torch.square(X_torch - X_torch @ W_torch), dim=0)))
                elif self.regular_type == 'MCP':
                    loss = 0.5 * torch.sum(torch.log((torch.sum(torch.square(X_torch - X_torch @ W_torch), dim=0))/n)) - \
                        torch.linalg.slogdet(torch.eye(d).to(self.device) - W_torch)[1]
                    
                loss.backward()
                G_loss = W_torch.grad.cpu().detach().numpy()
                loss = loss.cpu().detach().numpy()
            else:
                raise ValueError('unknown loss type')
            return loss, G_loss

        def _h(W):
            """Evaluate value and gradient of acyclicity constraint."""
            # E = slin.expm(W * W)  # (Zheng et al. 2018)
            # h = np.trace(E) - d
            # A different formulation, slightly faster at the cost of numerical stability
            M = np.eye(d) + W * W / d  # (Yu et al. 2019)
            E = np.linalg.matrix_power(M, d - 1)
            h = (E.T * M).sum() - d
            G_h = E.T * W * 2
            
            return h, G_h

        def _prior_order(W, w_prior=None, device="cpu"):

            if w_prior is None:
                return 0, np.zeros_like(W)

            tclo = 0
            w_ = np.eye(d)
            for _ in range(d):
                w_ = np.matmul(w_, w_prior)
                tclo = tclo + w_
            tc = tclo != 0
            tc_torch = torch.from_numpy(tc).to(device)
            W_torch = torch.from_numpy(W).requires_grad_(True).to(device)
            W_torch.retain_grad()
            w_ = torch.eye(d, dtype=torch.float64).to(device)
            states_W = []
            for _ in range(d):
                w_ = torch.matmul(w_, W_torch * W_torch)
                states_W.append(w_)
            
            # |w|^k sum  Learned adjacency matrix
            sum_W = 0
            for w in states_W:
                sum_W += w
            # L'(W)
            loss = torch.sum(tc_torch.t() * sum_W)
            # to get gradient
            loss.backward()
            grad = W_torch.grad.cpu().numpy()
            return loss.cpu().detach().numpy(), grad

        def _regularization(w):
            if self.regular_type == 'l1':
                return w.sum(), 1, 1
            elif self.regular_type == 'MCP':
                MCP=abs(w)-w*w/(2*sigma_MCP)
                MCP[np.abs(w)>sigma_MCP]=sigma_MCP/2

                GMCP=np.sign(w)-w/sigma_MCP
                GMCP[np.abs(w)>sigma_MCP]=0

                return MCP.sum() - np.trace(w[:d*d].reshape([d, d])) - np.trace(w[d*d:].reshape([d, d])), \
                    GMCP[:d * d].reshape([d,d]), GMCP[d * d:].reshape([d,d])

        def _adj(w):
            """Convert doubled variables ([2 d^2] array) back to original variables ([d, d] matrix)."""
            return (w[:d * d] - w[d * d:]).reshape([d, d])

        def _func(w):
            """Evaluate value and gradient of augmented Lagrangian for doubled variables ([2 d^2] array)."""
            W = _adj(w)
            loss, G_loss = _loss(W)
            h, G_h = _h(W)

            regular, G_regular_pos, G_regular_neg = _regularization(w)
            # obj = loss + 0.5 * rho * h * h + alpha * h + lambda1 * regular
            # G_smooth = G_loss  + (rho * h + alpha) * G_h
            obj = loss + lambda1 * regular
            G_smooth = G_loss
            if self.prior_type=='soft':
                prior, G_prior = self._prior(W,self.w_prior, self.prob_prior)
                obj+=prior
                G_smooth+=G_prior
            elif self.prior_type=='gredient':
                self.lambda1 = np.where(self.w_prior != 0, -lambda1, lambda1)
            elif self.prior_type=='order':
                w_, G_w_ = _prior_order(W, self.w_prior, device=self.device)
                h = h + w_
                G_h = G_h + G_w_

            obj += 0.5 * rho * h * h + alpha * h
            G_smooth += (rho * h + alpha) * G_h

            g_obj = np.concatenate((G_smooth + self.lambda1 * G_regular_pos , - G_smooth + self.lambda1 * G_regular_neg), axis=None)
            return obj, g_obj

        n, d = X.shape
        w_in = np.zeros(2 * d * d)
        if W_initial is not None:
            w_pos = np.where(W_initial > 0, W_initial, 0)
            w_neg = np.where(W_initial < 0, -W_initial, 0)
            w_in = np.concatenate([w_pos.flatten(), w_neg.flatten()])
        w_est, rho, alpha, h = w_in, 1.0, 0.0, np.inf  # double w_est into (w_pos, w_neg)
        bnds = self.bound_prior(d)
        if loss_type == 'l2' or loss_type == 'pdf':
            X = X - np.mean(X, axis=0, keepdims=True)
        for i in range(max_iter):
            # print(f"iter: {i} rho: {rho:.4f}, alpha: {rho:.4f}, w_est: {np.abs(w_est).sum()}, h: {h:.4f}")
            w_new, h_new = None, None
            while rho < rho_max:
                # print(f"rho: {rho:.4f}, alpha: {rho:.4f}, w_est: {np.abs(w_est).sum()}, h: {h:.4f}")
                sol = sopt.minimize(lambda w: _func(w), w_est, method='L-BFGS-B', jac=True, bounds=bnds)
                w_new = sol.x
                h_new, _ = _h(_adj(w_new))
                if h_new > 0.25 * h:
                    rho *= 10
                else:
                    break
            w_est, h = w_new, h_new
            alpha += rho * h
            if h <= h_tol or rho >= rho_max:
                break
        W_est = _adj(w_est)
        W_est[np.abs(W_est) < w_threshold] = 0
        return W_est
    
    def load_param(self,lambda1=0.1, W_initial=None,sigma_MCP=0.2):
        self.lambda1 = lambda1
        self.W_initial = W_initial
        self.sigma_MCP = sigma_MCP
        
    def load_prior(self,w_prior=None, prob_prior=0,ground_truth=None,adaptive_degree=1):
        self.w_prior=w_prior
        self.prob_prior=prob_prior
        self.ground_truth=ground_truth
        self.adaptive_degree=adaptive_degree


    def bound_prior(self,d):
        if self.prior_type=='bound':
            con_prior=np.zeros((d,d),dtype=int)
            for i in range(d):
                for j in range(d):
                    if self.w_prior[i][j]==1:
                        con_prior[i][j] = np.random.choice([1, -1], p=[0.5, 0.5])
                    elif self.w_prior[i][j]==0:
                        con_prior[i][j]=0
            bnds = []  
            for i in range(d):
                for j in range(d):
                    if i == j:
                        bnds.append((0,0))
                    elif con_prior[i][j]>0:
                        bnds.append((0.3,None))
                    elif con_prior[i][j]<0:
                        bnds.append((0,0))
                    elif con_prior[i][j]==0:
                        bnds.append((0,None))
            for i in range(d):
                for j in range(d):
                    if i == j:
                        bnds.append((0,0))
                    elif con_prior[i][j]>0:
                        bnds.append((0,0))
                    elif con_prior[i][j]<0:
                        bnds.append((0.3,None))
                    elif con_prior[i][j]==0:
                        bnds.append((0,None))
        else:
            bnds = [(0, 0) if i == j else (0, None) for _ in range(2) for i in range(d) for j in range(d)]
        return bnds
    
    def _prior(self, W, w_prior=None, prob_prior=0):
        
        if w_prior is None:
            return 0, np.zeros_like(W)
        W = torch.from_numpy(W).to(self.device)
        w_prior = torch.from_numpy(w_prior).to(self.device)
        adaptive_degree=torch.tensor(self.adaptive_degree).to(self.device)
        W.requires_grad = True
        W_b = torch.abs(2*torch.sigmoid(W)-1)
        prob_exist = (W_b * prob_prior + (1-W_b) * (1-prob_prior))
        prob_forb  = ((1-W_b) * prob_prior + W_b * (1-prob_prior)) 
        log_prob_exist = torch.log(prob_exist + 1e-9) * adaptive_degree
        log_prob_forb = torch.log(prob_forb + 1e-9) * adaptive_degree
        prior = torch.sum(log_prob_exist[w_prior == 1]) + \
        torch.sum(log_prob_forb[w_prior == -1])
        prior = - prior
        prior.backward()
        G_prior = W.grad
        return prior.cpu().detach().numpy(), G_prior.cpu().detach().numpy()
    
