from tabscm.diffusion_regressor import DiffusionRegressor
from sklearn.neighbors import KernelDensity
from tqdm import tqdm
import xgboost as xgb
import numpy as np
import networkx as nx



def is_categorical(node: int,INFO) -> bool:
    if INFO['task_type'] == 'binclass':
        return node in INFO["cat_col_idx"] or node in INFO["target_col_idx"]
    else:
        return node in INFO["cat_col_idx"]

def get_num_classes(node:int,INFO) -> int:
    return INFO['n_classes'][node]



class RootSamplerKDE:
    def __init__(self, y, bandwidth=0.05):
        self.kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth)
        self.kde.fit(np.asarray(y).reshape(-1, 1))

    def __call__(self, n):
        return self.kde.sample(n).flatten()



class RootSamplerCategorical:
    def __init__(self, y):
        # Fit a histogram or empirical distribution
        values, counts = np.unique(y, return_counts=True)
        self.values = values
        self.probs = counts / counts.sum()

    def __call__(self, n):
        return np.random.choice(self.values, size=n, p=self.probs)

def fit_scm_from_dag(data: np.ndarray, dag: nx.DiGraph, INFO, device: str,**params_regressor) -> dict:
    scm = {}
    num_iter = len(dag.nodes)

    for node in tqdm(nx.topological_sort(dag), total=num_iter, desc='Fitting nodes',position=0):
        parents = list(dag.predecessors(node))
        y = data[:, node]

        if not parents:
            if is_categorical(node,INFO):
                sampler = RootSamplerCategorical(y)
            else:
                sampler = RootSamplerKDE(y)
            scm[node] = sampler

        else:
            X = data[:, parents]

            if is_categorical(node,INFO):
                print(node)
                n_classes = get_num_classes(node,INFO)
                print(n_classes)
                print(np.unique(y))
                dtrain = xgb.DMatrix(data[:,parents], label=y)
                params = {
                    'num_class': n_classes,
                    'objective': 'multi:softprob',
                    'eval_metric': 'aucpr',#'mlogloss',
                    'tree_method': 'hist',
                    'eta': 0.2,
                    'max_depth': 30,
                    'alpha':1.5,
                    'lambda':1.5,
                }
                model = xgb.train(
                    params=params,
                    dtrain=dtrain,
                    num_boost_round=500,
                )
                scm[node] = (model, parents, None, n_classes) 
            else:
                model = DiffusionRegressor(device=device,**params_regressor)
                model.fit(X, y)
                scm[node] = (model, parents, None, None)
    return scm



def _postprocess_predictions(y_preds, min_val, max_val):
        y_int = np.round(y_preds).astype(int)
        y_clipped = np.clip(y_int, min_val, max_val)
        return y_clipped


def sample_from_scm(scm, dag, n_samples, INFO) -> np.ndarray:
    n_vars = len(scm)
    data = np.zeros((n_samples, n_vars))

    for node in tqdm(nx.topological_sort(dag),desc=f'Sampling each nodes'):
        model_info = scm[node]

        if callable(model_info):  # root sampler
            sampler_fn = model_info
            data[:, node] = sampler_fn.__call__(n_samples)
            if not isinstance(sampler_fn,RootSamplerCategorical):
                if INFO['col_dtype'][node] == 'int':
                    MIN = INFO['column_info'][str(node)]['min']
                    MAX = INFO['column_info'][str(node)]['max']
                    data[:,node] = _postprocess_predictions(data[:,node], min_val=int(np.floor(MIN)), max_val=int(np.ceil(MAX)))

        else:
            model, parents, noise_std, n_classes = model_info
            X_parents = data[:, parents]

            if is_categorical(node, INFO):
                dmatrix = xgb.DMatrix(X_parents)
                prob_preds = model.predict(dmatrix)
                sampled = np.array([np.random.choice(n_classes, p=probs) for probs in prob_preds])
                data[:, node] = sampled
                

            else:
                sampled = model.predict(X_parents,node=node,INFO=INFO)
                data[:, node] = sampled
    

    return data



def sample_interventions_from_scm(scm, dag, n_samples,intervention, INFO) -> np.ndarray:
    # interventions dict {node: value}
    n_vars = len(scm)
    data = np.zeros((n_samples, n_vars))

    for node in tqdm(nx.topological_sort(dag),desc=f'Sampling each nodes'):
        model_info = scm[node]

        if callable(model_info):  # root sampler
            sampler_fn = model_info
            if node in intervention.keys():
                data[:,node] = np.array([intervention[node]]*n_samples).flatten()
            else:
                data[:, node] = sampler_fn.__call__(n_samples)
                if not isinstance(sampler_fn,RootSamplerCategorical):
                    if INFO['col_dtype'][node] == 'int':
                        MIN = INFO['column_info'][str(node)]['min']
                        MAX = INFO['column_info'][str(node)]['max']
                        data[:,node] = _postprocess_predictions(data[:,node], min_val=int(np.floor(MIN)), max_val=int(np.ceil(MAX)))

        else:
            model, parents, noise_std, n_classes = model_info
            X_parents = data[:, parents]

            if node in intervention.keys():
                print(np.array([intervention[node]]*n_samples))
                data[:,node] = np.array([intervention[node]]*n_samples).flatten()

            else:

                if is_categorical(node, INFO):
                    dmatrix = xgb.DMatrix(X_parents)
                    prob_preds = model.predict(dmatrix)
                    sampled = np.array([np.random.choice(n_classes, p=probs) for probs in prob_preds])
                    data[:, node] = sampled
                    

                else:
                    sampled = model.predict(X_parents,node=node,INFO=INFO)
                    data[:, node] = sampled
    

    return data