import numpy as np
from dython.nominal import associations
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import jensenshannon
from sdmetrics.reports.single_table import QualityReport
from sklearn.preprocessing import OrdinalEncoder, MinMaxScaler
import polars as pl
import pandas as pd
import matplotlib.pyplot as plt


class SimilarityScores():
    """ Computes similarity scores between real and generated data.
    
    Access to various different metrics:
    
    - Jensen Shannon divergence for categorical features
    - Wasserstein distance for continuous features
    - L2 norm of differences in correlation matrices
    - Absolute differences in correlation matrices (for visualization)
    - SDMetrics: columnwise density metrics (TVComplement, KSComplement)
    
    """
    
    def __init__(self, df_trn, df_test, cat_cols):
        
        self.cat_cols = cat_cols
        self.sim_test = self.compute_similarity(df_trn, df_test)
        
        # init data for correlation computation
        self.corr_train = self._compute_correlation(df_trn)
        self.corr_test_diffs = self.compute_diff_in_corr(df_test)
        
        
    def _compute_correlation(self, df):
        corr = associations(df.to_pandas(), nominal_columns=self.cat_cols, 
                            mark_columns=False, nom_nom_assoc='cramer', 
                            num_num_assoc='pearson', plot=False,
                            multiprocessing=True, max_cpu_cores=4)['corr']
        # close plot automatically generated by associations
        plt.close()
        return corr
        
        
    def compute_diff_in_corr(self, df_gen):
        
        corr_gen = self._compute_correlation(df_gen)
    
        # construct differences in correlations
        diff = corr_gen - self.corr_train
        abs_diff_corr = np.abs(diff)
        l2_norm_diff_corr = np.linalg.norm(diff).item()
        
        diff_cat_part = diff[self.cat_cols].loc[self.cat_cols]
        num_cols = diff.columns[~diff.columns.isin(self.cat_cols)]
        diff_num_part = diff[num_cols].loc[num_cols]
        l2_norm_diff_cat = np.linalg.norm(diff_cat_part).item()
        l2_norm_diff_num = np.linalg.norm(diff_num_part).item()
        
        return {'corr_abs_diff': abs_diff_corr, 
                'corr_l2_norm_diff': l2_norm_diff_corr,
                'corr_l2_norm_diff_cat': l2_norm_diff_cat,
                'corr_l2_norm_diff_num': l2_norm_diff_num,
                'corr_min_abs_diff': np.min(abs_diff_corr).item(), 
                'corr_max_abs_diff': np.max(abs_diff_corr).item(),
                'corr_avg_abs_diff': np.mean(abs_diff_corr).item()}


    def compute_similarity(self, df_trn, df_gen):
        
        jd_vals = []
        wd_vals = []
  
        # compute Jensen Shannon divergence for categorical features
        for d in df_trn.select(self.cat_cols).iter_columns():
            p_trn = d.value_counts(normalize=True, name='p')
            
            # compute proportions for generated data
            # note that some categories mighy not be present in the generated data
            # also ensures same order of categories for both datasets
            exprs = [
                (pl.col("val") == v).sum().alias(v) 
                for v in p_trn[d.name]
            ]
            p_gen = pl.DataFrame({"val": df_gen[d.name]}).select(exprs) 
            p_gen = p_gen.transpose(include_header=True, header_name=d.name, column_names=['p'])
            p_gen = p_gen.with_columns((pl.col('p') / (pl.col('p').sum() + 1e-8)))
        
            jd_vals.append(jensenshannon(p_trn['p'], p_gen['p'], base=2.0))


        # compute Wasserstein distance for continuous features, scaled to [0,1]
        scaler = MinMaxScaler()
        X_num_trn = scaler.fit_transform(df_trn.select(pl.all().exclude(self.cat_cols)).to_numpy())
        X_num_gen = scaler.transform(df_gen.select(pl.all().exclude(self.cat_cols)).to_numpy())
        
        for col_idx in range(X_num_trn.shape[1]):
            wd_vals.append(wasserstein_distance(X_num_trn[:, col_idx], X_num_gen[:, col_idx]))
            
        return {'JD_max': np.max(jd_vals).item(), 'WD_max': np.max(wd_vals).item(),
                'JD_min': np.min(jd_vals).item(), 'WD_min': np.min(wd_vals).item(),
                'JD_avg': np.mean(jd_vals).item(), 'WD_avg': np.mean(wd_vals).item()}


    def compute_colwise_density_metrics(self, df_trn, df_gen):
        
        ord_enc = OrdinalEncoder()
        ord_enc.fit(df_trn.vstack(df_gen).select(self.cat_cols))
        X_cat_trn = ord_enc.transform(df_trn.select(self.cat_cols))
        X_cat_gen = ord_enc.transform(df_gen.select(self.cat_cols))
        
        # construct updated dataframe
        df_trn_enc = pl.concat([pl.DataFrame(X_cat_trn, schema=self.cat_cols).cast(pl.Int64),
                             df_trn.select(pl.all().exclude(self.cat_cols))], how='horizontal')
        df_gen_enc = pl.concat([pl.DataFrame(X_cat_gen, schema=self.cat_cols).cast(pl.Int64),
                             df_gen.select(pl.all().exclude(self.cat_cols))], how='horizontal')
        
        metadata = {}
        metadata['columns'] = {}
        for i, lab in enumerate(df_trn.columns):
            if lab in self.cat_cols:
                metadata['columns'][lab] = {'sdtype': 'categorical'}
            else:
                metadata['columns'][lab] = {'sdtype': 'numerical'}

        # note that this automatically handles missings
        qual_report = QualityReport()
        qual_report.generate(df_trn_enc.to_pandas(), df_gen_enc.to_pandas(), metadata,
                             verbose=False)
        quality = qual_report.get_properties()

        #################################################
        # Extract Shape info
        
        density_scores = qual_report.get_details(property_name='Column Shapes')
        avg_density_score = quality['Score'][0]
        
        scores = {'min': density_scores['Score'].min().item(), 
                  'max': density_scores['Score'].max().item(), 
                  'avg': avg_density_score.item()}
        cat_scores = density_scores[density_scores['Metric'] == 'TVComplement']['Score']
        cat_scores = {'min': cat_scores.min().item(), 
                      'max': cat_scores.max().item(), 
                      'avg': cat_scores.mean().item()}
        num_scores = density_scores[density_scores['Metric'] == 'KSComplement']['Score']
        num_scores = {'min': num_scores.min().item(), 
                       'max': num_scores.max().item(), 
                       'avg': num_scores.mean().item()}
        
        #################################################
        # Extract Trend info
      
        trend_scores = qual_report.get_details(property_name='Column Pair Trends')
        avg_trend_score = quality['Score'][1].item()
        min_trend_score = trend_scores['Score'].min().item()
        max_trend_score = trend_scores['Score'].max().item()
        
        #################################################
        # Extract Trend info only for mixed-type pairs
        
        contingency_df = trend_scores[trend_scores['Metric'] == 'ContingencySimilarity']
        filtered_rows = []
        for i in range(len(contingency_df)):
            if contingency_df.iloc[i]['Column 1'] not in self.cat_cols and contingency_df.iloc[i]['Column 2'] in self.cat_cols:
                filtered_rows.append(contingency_df.iloc[i])
            elif contingency_df.iloc[i]['Column 2'] not in self.cat_cols and contingency_df.iloc[i]['Column 1'] in self.cat_cols:
                filtered_rows.append(contingency_df.iloc[i])
        trend_scores_mixed = pd.DataFrame(filtered_rows)
        avg_mixed_trend_score = trend_scores_mixed['Score'].mean().item()
                
        
        return {'shape': {'all': scores, 'cat': cat_scores, 'num': num_scores},
                'trend': {'min': min_trend_score, 'max': max_trend_score, 'avg': avg_trend_score,
                          'avg_mixed': avg_mixed_trend_score}}