import math

import networkx as nx
import numpy as np
import pandas as pd
from torch import  lgamma
from abc import ABC, abstractmethod
import torch
from scipy.linalg import block_diag
class BaseScore(ABC):
    """Base class for the scorer.

    Parameters
    ----------
    data : pd.DataFrame
        The dataset.

    prior : `BasePrior` instance
        The prior over graphs p(G).
    """

    def __init__(self, data, prior):
        self.data = data
        self.prior = prior
        self.column_names = list(data.columns)
        self.num_variables = len(self.column_names)
        self.prior.num_variables = self.num_variables


    def __call__(self, graphs):
        assert graphs.ndim==1 or graphs.ndim ==2
        if graphs.ndim<2:
            graphs=graphs[None,:]
        batch_size=graphs.shape[0]
        graphs = graphs.reshape(batch_size,self.num_variables, self.num_variables).float()
        local_score=torch.zeros(batch_size)
        # for i, graph in enumerate(graphs):
        #     local_score[i] += self.structure_prior(graph)
        with torch.no_grad():
            for node,_ in enumerate(self.column_names):
                edge_idx = torch.nonzero(graphs[:,:,node]) #batch_idx, edge_idx
                par_idx  = torch.arange(self.num_variables, 2 * self.num_variables).repeat(batch_size, 1)
                par_idx[edge_idx[:, 0], edge_idx[:, 1]] = edge_idx[:, 1]
                par_num=graphs[:,:,node].sum(-1).long()
                local_score= local_score + self.local_scores(node, par_idx,par_num)
        return local_score
    @abstractmethod
    def local_scores(self, target, indices,indices_num):
        pass
    def structure_prior(self, graph):
        """A (log) prior distribution over models. Currently unused (= uniform)."""
        return 0

def logdet(array):
    _, logdet = torch.slogdet(array)
    return logdet
def ix_(array):
    return array[...,:,None],array[...,None,:]

class BGeScore(BaseScore):
    r"""BGe score.

    Parameters
    ----------
    data : pd.DataFrame
        A DataFrame containing the (continuous) dataset D. Each column
        corresponds to one variable. The dataset D is assumed to only
        contain observational data (a `INT` column will be treated as
        a continuous variable like any other).

    prior : `BasePrior` instance
        The prior over graphs p(G).

    mean_obs : np.ndarray (optional)
        Mean parameter of the Normal prior over the mean $\mu$. This array must
        have size `(N,)`, where `N` is the number of variables. By default,
        the mean parameter is 0.

    alpha_mu : float (default: 1.)
        Parameter $\alpha_{\mu}$ corresponding to the precision parameter
        of the Normal prior over the mean $\mu$.

    alpha_w : float (optional)
        Parameter $\alpha_{w}$ corresponding to the number of degrees of
        freedom of the Wishart prior of the precision matrix $W$. This
        parameter must satisfy `alpha_w > N - 1`, where `N` is the number
        of varaibles. By default, `alpha_w = N + 2`.
    """
    def __init__(self,data,prior,mean_obs=None,alpha_mu=1.,alpha_w=None):
        num_variables = data.shape[1]
        if mean_obs is None:
            mean_obs = torch.zeros((num_variables,),dtype=torch.float)
        if alpha_w is None:
            alpha_w = num_variables + 2.

        super().__init__(data, prior)
        self.mean_obs = mean_obs
        self.alpha_mu = alpha_mu
        self.alpha_w = alpha_w

        self.num_samples = self.data.shape[0]
        self.t = self.alpha_mu/(self.alpha_mu+self.num_samples) #(self.alpha_mu * (self.alpha_w - self.num_variables - 1)) / (self.alpha_mu + 1)

        T = torch.eye(self.num_variables)#self.t * torch.eye(self.num_variables)  # assuem W^-1 of wishart prior is I
        data = torch.tensor(self.data.values,dtype=torch.float)
        data_mean = torch.mean(data, dim=0, keepdim=True)
        data_centered = data - data_mean

        self.R = (T + torch.matmul(data_centered.T, data_centered)  #                        T+S_N
                  + ((self.num_samples * self.alpha_mu) / (self.num_samples + self.alpha_mu))
                  * torch.matmul((data_mean - self.mean_obs).T, data_mean - self.mean_obs)   # (N*α_μ)/(N+α_μ)*(v-x_μ)(v-x_μ)T
                  )
        self.block_R_I = torch.block_diag(self.R, torch.eye(self.num_variables))
        all_parents = torch.arange(self.num_variables)
        self.log_gamma_term = (
            0.5 * (math.log(self.alpha_mu) - math.log(self.num_samples + self.alpha_mu))
            + torch.lgamma(0.5 * (self.num_samples + self.alpha_w - self.num_variables + all_parents + 1))  #log Γ((N+α_w-n+l)/2)  for l=0,...n      log  Γ_l+1((N+α_w-n+l+1)/2) - log Γ_l((N+α_w-n)/2) =log  Γ((N+α_w-n+l)/2)
            - torch.lgamma(0.5 * (self.alpha_w - self.num_variables + all_parents + 1))                     #-log Γ((α_w-n+l)/2)  for l=0,...n       -( log Γ_l+1((α_w-n+l+1)/2)- log Γ_l((α_w-n)/2))=   log
            - 0.5 * self.num_samples * math.log(math.pi)                                                    #π^(l*N- (l+1)N) =π^N
            +0.5*math.log(self.t)    #+ 0.5 * (self.alpha_w - self.num_variables + 2 * all_parents + 1) * math.log(self.t)
        )
        # batch-wsie compute score sum
        # idea compute the score of [R   0]
        #                           [0,  I]
        # parents = torch.arange(self.num_variables, 2 * self.num_variables).repeat(batch_size, 1)
        # edge_idx =graphs[:,:,node].nonzero()
        # parents[edge_idx[:, 0], edge_idx[:, 1]] = edge_idx[:, 1]
        #
        # block_R_I = torch.block_diag(torch.tensor(self.R), torch.eye(self.num_variables))
    def local_scores(self,target,indices,indices_num):
        num_parents = indices_num.clone()
        variables = torch.clone(indices)
        variables[torch.arange(len(indices_num)),target]=target
        log_term_r = (0.5 * (self.num_samples + self.alpha_w - self.num_variables + num_parents)
                      * logdet(self.block_R_I[ix_(indices)])
                      - 0.5 * (self.num_samples + self.alpha_w - self.num_variables + num_parents + 1)
                      * logdet(self.block_R_I[ix_(variables)])
                      )
        return self.log_gamma_term[num_parents] + log_term_r


# class BaseScore(ABC):
#     """Base class for the scorer.
#
#     Parameters
#     ----------
#     data : pd.DataFrame
#         The dataset.
#
#     prior : `BasePrior` instance
#         The prior over graphs p(G).
#     """
#
#     def __init__(self, data, prior):
#         self.data = data
#         self.prior = prior
#         self.column_names = list(data.columns)
#         self.num_variables = len(self.column_names)
#         self.prior.num_variables = self.num_variables
#
#     def __call__(self, graphs):
#         if not isinstance(graphs,np.ndarray):
#             graphs=graphs.numpy()
#         assert graphs.ndim==1 or graphs.ndim ==2
#         if graphs.ndim<2:
#             graphs=graphs[None,:]
#
#         batch_size=graphs.shape[0]
#         local_score=np.zeros(batch_size)
#         for i, graph in enumerate(graphs):
#             local_score[i] += self.structure_prior(graph)
#             for node,_ in enumerate(self.column_names):
#                 graph=graph.reshape(self.num_variables,self.num_variables)
#                 parents = np.nonzero(graph[:,node])[0]
#                 local_score[i]= local_score[i]+ self.local_scores(node, parents)
#         scores=torch.tensor(local_score,dtype=torch.float)
#         return scores
#     @abstractmethod
#     def local_scores(self, target, indices):
#         pass
#     def structure_prior(self, graph):
#         """A (log) prior distribution over models. Currently unused (= uniform)."""
#         return 0


# class BDeScore(BaseScore):
#     """BDe score.
#
#     Parameters
#     ----------
#     data : pd.DataFrame
#         A DataFrame containing the (discrete) dataset D. Each column
#         corresponds to one variable. If there is interventional data, the
#         interventional targets must be specified in the "INT" column (the
#         indices of interventional targets are assumed to be 1-based).
#
#     prior : `BasePrior` instance
#         The prior over graphs p(G).
#
#     equivalent_sample_size : float (default: 1.)
#         The equivalent sample size (of uniform pseudo samples) for the
#         Dirichlet hyperparameters. The score is sensitive to this value,
#         runs with different values might be useful.
#     """
#
#     def __init__(self, data, prior, equivalent_sample_size=10.):
#         if 'INT' in data.columns:  # Interventional data
#             # Indices should start at 0, instead of 1;
#             # observational data will have INT == -1.
#             self._interventions = data.INT.map(lambda x: int(x) - 1)
#             data = data.drop(['INT'], axis=1)
#         else:
#             self._interventions = np.full(data.shape[0], -1)
#
#         super().__init__(data, prior)
#         self.equivalent_sample_size = equivalent_sample_size
#         self.state_names = {column: sorted(self.data[column].cat.categories.tolist())
#                             for column in self.data.columns}
#
#
#     def state_counts(self, target, indices):
#         # Source: pgmpy.estimators.BaseEstimator.state_counts()
#         all_indices = indices
#         parents = [self.column_names[index] for index in all_indices]
#         variable = self.column_names[target]
#
#         data = self.data[self._interventions != target]
#         data = data[[variable] + parents].dropna()
#
#         if not parents:
#             # count how often each state of 'variable' occurred
#             state_count_data = data.loc[:, variable].value_counts()
#             state_counts = (state_count_data.reindex(self.state_names[variable]).fillna(0).to_frame())
#
#         else:
#             parents_states = [self.state_names[parent] for parent in parents]
#             # count how often each state of 'variable' occurred, conditional on parents' states
#             state_count_data = (data.groupby([variable] + parents).size().unstack(parents))
#
#             if not isinstance(state_count_data.columns, pd.MultiIndex):
#                 state_count_data.columns = pd.MultiIndex.from_arrays([state_count_data.columns])
#             # reindex rows & columns to sort them and to add missing ones
#             # missing row    = some state of 'variable' did not occur in data
#             # missing column = some state configuration of current 'variable's parents
#             #                  did not occur in data
#             row_index = self.state_names[variable]
#             column_index = pd.MultiIndex.from_product(parents_states, names=parents)
#             state_counts = state_count_data.reindex(index=row_index, columns=column_index).fillna(0)
#
#         return state_counts
#
#     def local_scores(self, target,indices):
#         counts = self.state_counts(target,indices)
#         counts = np.asarray(counts)
#         num_parents_states = counts.shape[1]
#
#         log_gamma_counts = np.zeros_like(counts, dtype=np.float_)
#         alpha = self.equivalent_sample_size / num_parents_states
#         beta = self.equivalent_sample_size / counts.size
#
#         # Compute log(gamma(counts + beta))
#         gammaln(counts + beta, out=log_gamma_counts)
#         # Compute the log-gamma conditional sample size
#         log_gamma_conds = np.sum(counts, axis=0, dtype=np.float_)
#         gammaln(log_gamma_conds + alpha, out=log_gamma_conds)
#
#         local_score = (
#                 np.sum(log_gamma_counts)
#                 - np.sum(log_gamma_conds)
#                 + num_parents_states * math.lgamma(alpha)
#                 - counts.size * math.lgamma(beta)
#         )
#         return local_score
#
#     def structure_prior(self, graph):
#         """
#         Implements the marginal uniform prior for the graph structure where each arc
#         is independent with the probability of an arc for any two nodes in either direction
#         is 1/4 and the probability of no arc between any two nodes is 1/2."""
#         nedges = float(graph.sum())
#         nnodes = float(self.num_variables)
#         possible_edges = nnodes * (nnodes - 1) / 2.0
#         score = -(nedges + possible_edges) * np.log(2.0)
#         return score
#
#
#
# class BicScore(BDeScore):
#
#     def local_scores(self, variable, parents):
#         'Computes a score that measures how much a \
#         given variable is "influenced" by a given list of potential parents.'
#
#         var_states = self.state_names[variable]
#         var_cardinality = len(var_states)
#         state_counts = self.state_counts(variable, parents)
#         sample_size = len(self.data)
#         num_parents_states = float(state_counts.shape[1])
#
#         counts = np.asarray(state_counts)
#         log_likelihoods = np.zeros_like(counts, dtype=float)
#
#         # Compute the log-counts
#         np.log(counts, out=log_likelihoods, where=counts > 0)
#
#         # Compute the log-conditional sample size
#         log_conditionals = np.sum(counts, axis=0, dtype=float)
#         np.log(log_conditionals, out=log_conditionals, where=log_conditionals > 0)
#
#         # Compute the log-likelihoods
#         log_likelihoods -= log_conditionals
#         log_likelihoods *= counts
#
#         score = np.sum(log_likelihoods)
#         score -= 0.5 * np.log(sample_size) * num_parents_states * (var_cardinality - 1)
#
#         return score


# batch-wsie compute score sum
#idea compute the score of [R   0]
#                           [0,  I]
# parents = torch.arange(self.num_variables, 2 * self.num_variables).repeat(batch_size, 1)
# edge_idx =graphs[:,:,node].nonzero()
# parents[edge_idx[:, 0], edge_idx[:, 1]] = edge_idx[:, 1]
#
# block_R_I = torch.block_diag(torch.tensor(self.R), torch.eye(self.num_variables))