import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
from dython.nominal import associations
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
from scipy.stats import wasserstein_distance
from sdmetrics.reports.single_table import QualityReport
from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder


def no_miss(df):
    """Remove rows with missing values."""
    return df.drop_nulls()


class SimilarityScores:
    """
    Computes similarity scores between real and generated data.

    Access to various different metrics:

    - Jensen Shannon divergence for categorical features
    - Wasserstein distance for continuous features
    - L2 norm of differences in correlation matrices
    - Absolute differences in correlation matrices (for visualization)
    - SDMetrics: columnwise density metrics (TVComplement, KSComplement)

    """

    def __init__(self, df_trn, df_test, cat_cols):
        self.cat_cols = cat_cols
        self.sim_test = self.compute_similarity(df_trn, df_test)

        # init data for correlation computation
        # self.corr_train = self._compute_correlation(df_trn)
        # self.corr_test_diffs = self.compute_diff_in_corr(df_test)

    def _compute_correlation(self, df):
        corr = associations(
            df.to_pandas(),
            nominal_columns=self.cat_cols,
            mark_columns=False,
            nom_nom_assoc="cramer",
            num_num_assoc="pearson",
            plot=False,
            multiprocessing=True,
            max_cpu_cores=4,
        )["corr"]
        # close plot automatically generated by associations
        plt.close()
        return corr

    def compute_diff_in_corr(self, df_gen):
        df = no_miss(df_gen)

        if df.is_empty():
            # if generated data is empty, return NaNs
            num_cols = self.corr_train.shape[0]
            return {
                "corr_abs_diff": np.full((num_cols, num_cols), np.nan),
                "corr_l2_norm_diff": np.nan,
                "corr_l2_norm_diff_cat": np.nan,
                "corr_l2_norm_diff_num": np.nan,
                "corr_min_abs_diff": np.nan,
                "corr_max_abs_diff": np.nan,
                "corr_avg_abs_diff": np.nan,
            }

        corr_gen = self._compute_correlation(df)

        # construct differences in correlations
        diff = corr_gen - self.corr_train
        abs_diff_corr = np.abs(diff)
        l2_norm_diff_corr = np.linalg.norm(diff).item()

        diff_cat_part = diff[self.cat_cols].loc[self.cat_cols]
        num_cols = diff.columns[~diff.columns.isin(self.cat_cols)]
        diff_num_part = diff[num_cols].loc[num_cols]
        l2_norm_diff_cat = np.linalg.norm(diff_cat_part).item()
        l2_norm_diff_num = np.linalg.norm(diff_num_part).item()

        return {
            "corr_abs_diff": abs_diff_corr,
            "corr_l2_norm_diff": l2_norm_diff_corr,
            "corr_l2_norm_diff_cat": l2_norm_diff_cat,
            "corr_l2_norm_diff_num": l2_norm_diff_num,
            "corr_min_abs_diff": np.min(abs_diff_corr).item(),
            "corr_max_abs_diff": np.max(abs_diff_corr).item(),
            "corr_avg_abs_diff": np.mean(abs_diff_corr).item(),
        }

    def compute_similarity(self, df_trn, df_gen):
        # compute Jensen Shannon divergence for categorical features (missings encoded as separate category)
        def js(p, q, base=2.0, axis=0):
            p = np.asarray(p)
            q = np.asarray(q)
            p = p / np.sum(p, axis=axis, keepdims=True)
            q = q / np.sum(q, axis=axis, keepdims=True)
            m = (p + q) / 2.0
            left = rel_entr(p, m)
            right = rel_entr(q, m)
            left_sum = np.sum(left, axis=axis, keepdims=False)
            right_sum = np.sum(right, axis=axis, keepdims=False)

            js = left_sum + right_sum
            if base is not None:
                js /= np.log(base)
            # quick fix to ensure non-negative JSD, to be fixed upstream, see https://github.com/scipy/scipy/pull/20786
            js = np.clip(js, 0, None)
            return np.sqrt(js / 2.0)

        jd_vals = []
        for d in df_trn.select(self.cat_cols).iter_columns():
            p_trn = d.value_counts(normalize=True, name="p")

            # compute proportions for generated data
            # note that some categories might not be present in the generated data
            # also ensures same order of categories for both datasets
            exprs = [(pl.col("val") == v).sum().alias(v) for v in p_trn[d.name]]
            p_gen = pl.DataFrame({"val": df_gen[d.name]}).select(exprs)
            p_gen = p_gen.transpose(include_header=True, header_name=d.name, column_names=["p"])
            p_gen = p_gen.with_columns((pl.col("p") / (pl.col("p").sum() + 1e-8)))
            jd_vals.append(js(p_trn["p"], p_gen["p"]))

        # compute Wasserstein distance for continuous features, scaled to [0,1]
        # missings are kept as NaNs
        scaler = MinMaxScaler()
        X_num_trn = scaler.fit_transform(df_trn.select(pl.all().exclude(self.cat_cols)).to_numpy())
        X_num_gen = scaler.transform(df_gen.select(pl.all().exclude(self.cat_cols)).to_numpy())

        wd_vals = []
        for col_idx in range(X_num_trn.shape[1]):
            real = X_num_trn[:, col_idx]
            real = real[~np.isnan(real)]
            gen = X_num_gen[:, col_idx]
            gen = gen[~np.isnan(gen)]
            wd_vals.append(wasserstein_distance(real, gen))

        return {
            "JD_max": np.max(jd_vals).item(),
            "WD_max": np.max(wd_vals).item(),
            "JD_min": np.min(jd_vals).item(),
            "WD_min": np.min(wd_vals).item(),
            "JD_avg": np.mean(jd_vals).item(),
            "WD_avg": np.mean(wd_vals).item(),
        }

    def compute_colwise_density_metrics(self, df_trn, df_gen):
        ord_enc = OrdinalEncoder()
        ord_enc.fit(df_trn.vstack(df_gen).select(self.cat_cols))
        X_cat_trn = ord_enc.transform(df_trn.select(self.cat_cols))
        X_cat_gen = ord_enc.transform(df_gen.select(self.cat_cols))

        # construct updated dataframe
        df_trn_enc = pl.concat(
            [
                pl.DataFrame(X_cat_trn, schema=self.cat_cols).cast(pl.Int64),
                df_trn.select(pl.all().exclude(self.cat_cols)),
            ],
            how="horizontal",
        )
        df_gen_enc = pl.concat(
            [
                pl.DataFrame(X_cat_gen, schema=self.cat_cols).cast(pl.Int64),
                df_gen.select(pl.all().exclude(self.cat_cols)),
            ],
            how="horizontal",
        )

        metadata = {}
        metadata["columns"] = {}
        for i, lab in enumerate(df_trn.columns):
            if lab in self.cat_cols:
                metadata["columns"][lab] = {"sdtype": "categorical"}
            else:
                metadata["columns"][lab] = {"sdtype": "numerical"}

        # note that this automatically handles missings
        qual_report = QualityReport()
        qual_report.generate(df_trn_enc.to_pandas(), df_gen_enc.to_pandas(), metadata, verbose=False)
        quality = qual_report.get_properties()

        #################################################
        # Extract Shape info

        density_scores = qual_report.get_details(property_name="Column Shapes")
        avg_density_score = quality["Score"][0]

        scores = {
            "min": density_scores["Score"].min().item(),
            "max": density_scores["Score"].max().item(),
            "avg": avg_density_score.item(),
        }
        cat_scores = density_scores[density_scores["Metric"] == "TVComplement"]["Score"]
        cat_scores = {"min": cat_scores.min().item(), "max": cat_scores.max().item(), "avg": cat_scores.mean().item()}
        num_scores = density_scores[density_scores["Metric"] == "KSComplement"]["Score"]
        num_scores = {"min": num_scores.min().item(), "max": num_scores.max().item(), "avg": num_scores.mean().item()}

        #################################################
        # Extract Trend info

        trend_scores = qual_report.get_details(property_name="Column Pair Trends")
        avg_trend_score = quality["Score"][1].item()
        min_trend_score = trend_scores["Score"].min().item()
        max_trend_score = trend_scores["Score"].max().item()

        #################################################
        # Extract Trend info only for mixed-type pairs

        contingency_df = trend_scores[trend_scores["Metric"] == "ContingencySimilarity"]
        filtered_rows = []
        for i in range(len(contingency_df)):
            if (
                contingency_df.iloc[i]["Column 1"] not in self.cat_cols
                and contingency_df.iloc[i]["Column 2"] in self.cat_cols
            ) or (
                contingency_df.iloc[i]["Column 2"] not in self.cat_cols
                and contingency_df.iloc[i]["Column 1"] in self.cat_cols
            ):
                filtered_rows.append(contingency_df.iloc[i])
        trend_scores_mixed = pd.DataFrame(filtered_rows)
        avg_mixed_trend_score = trend_scores_mixed["Score"].mean().item()

        return {
            "shape": {"all": scores, "cat": cat_scores, "num": num_scores},
            "trend": {
                "min": min_trend_score,
                "max": max_trend_score,
                "avg": avg_trend_score,
                "avg_mixed": avg_mixed_trend_score,
            },
        }
