from sdmetrics.single_column import MissingValueSimilarity
import numpy as np
from dython.nominal import associations
import polars as pl
import pandas as pd
import lightgbm as lgb
from sklearn.preprocessing import OrdinalEncoder
from sklearn.metrics import roc_auc_score
from scipy.stats import kendalltau
import matplotlib.pyplot as plt


class MissingEvaluator():
    """ Evaluates the missingness patterns in the data using various metrics.
    
    - univariate missingness distributions
    - correlations between missingness indicators and categorical features
    - conditional distributions via training lightgbm models to predict missingness
    
    """
    
    
    def __init__(self, df_trn, num_cols, cat_cols, max_obs=100_000, boost_rounds=200):
        self.num_cols = num_cols
        self.cat_cols = cat_cols
        self.num_cols_w_miss = [col for col in num_cols if df_trn[col].is_null().any()]
        self.max_obs = max_obs
        self.boost_rounds = boost_rounds
    
    
    def eval_cond_dist(self, df_trn, df_tst, df_gen, seed=42):
        
        results_trn = []
        results_gen = []
        results = {'trn': [], 'gen': [], 'feat_imp_rank_dist': []}
        for col in self.num_cols_w_miss:
            
            # prepare data (target is missingness indicator for column)
            data_trn, data_gen, X_tst, y_tst = self._prep_gbm_data(df_trn, df_tst, df_gen, col, seed=seed)
        
            # train model
            auc_trn, feat_imp_trn = self._train_gbm(data_trn, X_tst, y_tst, seed=seed)
            results['trn'].append(auc_trn)
            
            auc_gen, feat_imp_gen = self._train_gbm(data_gen, X_tst, y_tst, seed=seed)
            results['gen'].append(auc_gen)
            
            stat, _ = kendalltau(feat_imp_trn, feat_imp_gen, nan_policy='raise')
            results['feat_imp_rank_dist'].append(stat.item())
            
        results_trn = np.array(results['trn'])
        results_gen = np.array(results['gen'])
        results_feat_imp = np.array(results['feat_imp_rank_dist'])
        results = {
            'miss_cond_trn_avg': results_trn.mean().item(),
            'miss_cond_gen_avg': results_gen.mean().item(),
            'miss_cond_abs_diff': np.abs(results_trn - results_gen).mean().item(),
            'miss_cond_raw_trn': results_trn,
            'miss_cond_raw_gen': results_gen,
            'miss_feat_imp_raw': results_feat_imp,
            'miss_feat_imp_dist': results_feat_imp.mean().item(),
        }
            
        return results
    
    
    def _prep_gbm_data(self, df_trn, df_test, df_gen, target_col, seed=42):
        
        X_num_trn = df_trn.select(self.num_cols).to_pandas()
        X_num_gen = df_gen.select(self.num_cols).to_pandas()
        X_num_test = df_test.select(self.num_cols).to_pandas()
        
        # if target is 'miss_{col}', then remove col from num_cols
        X_num_trn = X_num_trn.drop(target_col, axis=1)
        X_num_gen = X_num_gen.drop(target_col, axis=1)
        X_num_test = X_num_test.drop(target_col, axis=1)
        
        X_cat_fit = pl.concat([df_trn.vstack(df_gen).select(self.cat_cols), df_test.select(self.cat_cols)])
        cat_enc = OrdinalEncoder().fit(X_cat_fit)
        X_cat_trn = cat_enc.transform(df_trn.select(self.cat_cols))
        X_cat_gen = cat_enc.transform(df_gen.select(self.cat_cols))
        X_cat_test = cat_enc.transform(df_test.select(self.cat_cols))
        X_cat_trn = pd.DataFrame(X_cat_trn.astype(int), columns=self.cat_cols, dtype='category')
        X_cat_gen = pd.DataFrame(X_cat_gen.astype(int), columns=self.cat_cols, dtype='category')
        X_cat_test = pd.DataFrame(X_cat_test.astype(int), columns=self.cat_cols, dtype='category')
        
        # construact missingness indicators
        miss_ind_trn = []
        miss_ind_tst = []
        miss_ind_gen = []
        for col in self.num_cols_w_miss:
            miss_ind_trn.append((df_trn[col].is_null()).cast(int).to_numpy())
            miss_ind_tst.append((df_test[col].is_null()).cast(int).to_numpy())
            miss_ind_gen.append((df_gen[col].is_null()).cast(int).to_numpy())
        miss_ind_trn = np.column_stack(miss_ind_trn)
        miss_ind_tst = np.column_stack(miss_ind_tst)
        miss_ind_gen = np.column_stack(miss_ind_gen)
        
        miss_cols = [f"miss_{col}" for col in self.num_cols_w_miss]
        miss_trn = pd.DataFrame(miss_ind_trn, columns=miss_cols, dtype='category')
        miss_tst = pd.DataFrame(miss_ind_tst, columns=miss_cols, dtype='category')
        miss_gen = pd.DataFrame(miss_ind_gen, columns=miss_cols, dtype='category')
        
        # construct target and remove target from indicators
        target_col = f"miss_{target_col}"
        
        y_trn = miss_trn[target_col].to_numpy().ravel()
        y_tst = miss_tst[target_col].to_numpy().ravel()
        y_gen = miss_gen[target_col].to_numpy().ravel()
        
        # subsample if necessary to limit needed resources
        if df_trn.height > self.max_obs:
            rng = np.random.default_rng(seed)
            idx = rng.choice(X_cat_trn.shape[0], self.max_obs, replace=False)
            X_num_trn = X_num_trn.iloc[idx]
            X_cat_trn = X_cat_trn.iloc[idx]
            # miss_trn = miss_trn.iloc[idx]
            y_trn = y_trn.iloc[idx]
            
        X_trn = pd.concat((X_cat_trn, X_num_trn), axis=1)
        X_tst = pd.concat((X_cat_test, X_num_test), axis=1)
        X_gen = pd.concat((X_cat_gen, X_num_gen), axis=1)
        
        data_trn = lgb.Dataset(X_trn, label=y_trn,
                               categorical_feature=self.cat_cols)
        data_gen = lgb.Dataset(X_gen, label=y_gen,
                               categorical_feature=self.cat_cols)

        return data_trn, data_gen, X_tst, y_tst
    
    
    def _train_gbm(self, data_trn, X_tst, y_tst, seed=42):
    
        objective = 'binary'
        params = {'objective': objective, 'deterministic': True, 'verbosity': -1, 'seed': seed,
                  'max_depth': 5, 'num_leaves': 2**5-1}
        gbm = lgb.train(params, data_trn, num_boost_round=self.boost_rounds)
        
        # retrieve probabiliies for classification and predictions for regression
        y_pred = gbm.predict(X_tst)
        auc = roc_auc_score(y_tst, y_pred)
        
        # get feature importances
        feat_imp = gbm.feature_importance(importance_type='gain')
        
        return auc, feat_imp
        
    
    def eval_correlation(self, df_trn, df_gen):
        
        miss_ind_trn = []
        miss_ind_gen = []
        for col in self.num_cols_w_miss:
            miss_ind_trn.append((df_trn[col].is_null()).cast(int).to_numpy())
            miss_ind_gen.append((df_gen[col].is_null()).cast(int).to_numpy())
        miss_ind_trn = np.column_stack(miss_ind_trn)
        miss_ind_gen = np.column_stack(miss_ind_gen)
        
        # compute correlations only between cat features and missingness indicators (this avoids deadling with missing values in numerical features)
        df_cat_trn = pl.concat((df_trn.select(self.cat_cols),
                   pl.DataFrame(miss_ind_trn, schema=self.num_cols_w_miss)), how='horizontal')
        df_cat_gen = pl.concat((df_gen.select(self.cat_cols),
                   pl.DataFrame(miss_ind_gen, schema=self.num_cols_w_miss)), how='horizontal')
        
        corr_trn = associations(df_cat_trn.to_pandas(), 
                            nominal_columns=df_cat_trn.columns,
                            mark_columns=False, nom_nom_assoc='cramer', 
                            num_num_assoc='pearson', plot=False,
                            multiprocessing=True, max_cpu_cores=4)['corr']
        plt.close()
        
        corr_gen = associations(df_cat_gen.to_pandas(), 
                            nominal_columns=df_cat_trn.columns,
                            mark_columns=False, nom_nom_assoc='cramer', 
                            num_num_assoc='pearson', plot=False,
                            multiprocessing=True, max_cpu_cores=4)['corr']
        plt.close()
        
        l2_diff = np.linalg.norm(corr_gen - corr_trn)
        abs_diff = np.abs(corr_gen - corr_trn)
        
        results = {
            'miss_corr_l2_diff': l2_diff.item(),
            'miss_corr_abs_diff': abs_diff,
        }
        
        return results
    
    
    def eval_similarity(self, df_trn, df_gen):
        
        miss_sim = MissingValueSimilarity()
        miss_results = []
        for col in self.num_cols_w_miss:
            miss_results.append(miss_sim.compute(df_trn[col].to_pandas(), df_gen[col].to_pandas()).item())
        miss_results = np.array(miss_results)
        
        results = {'miss_sim_mean': miss_results.mean().item(),
                   'miss_sim_min': miss_results.min().item(),
                   'miss_sim_max': miss_results.max().item()}
        
        return results