import lightgbm as lgb
import numpy as np
import pandas as pd
import torch
from kditransform import KDIDiscretizer
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.preprocessing import OrdinalEncoder

from disttree import DistTree


class Discretizer:
    """
    Encoder that encodes x_num into z_num (discrete groups) using either DT or GMM.

    Variants:
    - DT based (dt)
    - GMM based (gmm)
    """

    def __init__(self, X_num_trn, variant="dt", k_max=20, perc_obs=0.03, seed=42, adjust_means=False, max_depth=3):
        self.seed = seed
        self.variant = variant
        self.adjust_means = adjust_means
        self.max_depth = max_depth
        self.k_max = k_max
        self.perc_obs = perc_obs
        self.fit_gmm_ord_enc = True

        # check for missings in train data (assuming the same holds for validation / synthetic data)
        self.has_missing = torch.isnan(X_num_trn).any(0).numpy()

        # get means of non-missing values to infer mean for missing group (after mean imputation)
        miss_mean = torch.nanmean(X_num_trn, dim=0)

        if self.variant == "gbm":
            self.gbms = self._train_gbms(X_num_trn, k_max=k_max, perc_obs=perc_obs)
            groups, self.ord_encs = self._get_gbm_groups(X_num_trn, init=True)
        elif self.variant == "gmm":
            self.gmms = self._train_bgmms(X_num_trn, k_max=k_max)
            groups = self._get_gmm_groups(X_num_trn)
        elif self.variant == "dt":
            self.disttree = DistTree(max_depth, seed=seed)
            self.disttree.fit(X_num_trn)
            groups = self.disttree.get_groups(X_num_trn)
        elif self.variant == "kdi":
            self.kdi = self._train_kdi(X_num_trn)
            groups = self._get_kdi_groups(X_num_trn)

        # get group-specific means and stds
        if self.adjust_means or self.variant == "gbm" or self.variant == "kdi":
            means = []
            stds = []
            for i in range(X_num_trn.shape[1]):
                df = pd.DataFrame({"x": X_num_trn[:, i].clone(), "group": groups[:, i]})
                df_stats = df.groupby("group").agg(["mean", "std"]).droplevel(0, axis=1)
                means.append(torch.tensor(df_stats["mean"].to_numpy(), dtype=torch.float32))
                stds.append(torch.tensor(df_stats["std"].to_numpy(), dtype=torch.float32))
        elif self.variant == "gmm":
            means = []
            stds = []
            for i in range(X_num_trn.shape[1]):
                means.append(torch.tensor(self.gmms[i].means_.squeeze(), dtype=torch.float32))
                stds.append(torch.tensor(np.sqrt(self.gmms[i].covariances_.squeeze()), dtype=torch.float32))
        elif self.variant == "dt":
            means = [torch.tensor(m, dtype=torch.float32) for m in self.disttree.means]
            stds = [torch.tensor(s, dtype=torch.float32) for s in self.disttree.stds]

        # check for inflated values (empirical var = 0)
        self.has_inflated = []
        self.infl_groups = []
        for i in range(X_num_trn.shape[1]):
            df = pd.DataFrame({"x": X_num_trn[:, i].clone(), "group": groups[:, i]})
            df_stats = df.groupby("group").agg(["mean", "std"]).droplevel(0, axis=1)
            infl_idx = df_stats.loc[df_stats["std"] == 0].index.to_list()
            self.has_inflated.append(len(infl_idx) > 0)
            self.infl_groups.append(infl_idx)

            # adjust std to zero
            stds[i][infl_idx] = 0

        # adjust means for missings (assign mean, std of group to which average X belongs)
        if self.has_missing.any():
            for i in range(X_num_trn.shape[1]):
                if self.has_missing[i]:
                    # update means with mean of missing group (= mean of group that we get for average x), similar for std dev.
                    # miss_mu = means[i][mean_groups[i].astype(int)].unsqueeze(0)
                    # miss_std = stds[i][mean_groups[i].astype(int)].unsqueeze(0)
                    # means[i] = torch.cat((miss_mu, means[i]))
                    # stds[i] = torch.cat((miss_std, stds[i]))
                    means[i] = torch.cat((miss_mean[i].unsqueeze(0), means[i]))
                    stds[i] = torch.cat((torch.zeros(1), stds[i]))
        self.means = means
        self.stds = stds

    def _get_gbm_groups(self, X: torch.Tensor, init=False):
        groups = []

        if init:
            ord_encs = []
        for i in range(X.shape[1]):
            d = X[:, i].clone()
            miss_mask = d.isnan()
            d[miss_mask] = d.nanmean()
            out = self.gbms[i].predict(pd.DataFrame(d))
            if init:
                enc = OrdinalEncoder()
                group = enc.fit_transform(out.reshape(-1, 1))
                ord_encs.append(enc)
            else:
                group = self.ord_encs[i].transform(out.reshape(-1, 1))
            group[miss_mask] = np.nan
            groups.append(group)

        groups = np.column_stack(groups)

        if init:
            return groups, ord_encs

        return groups

    def _get_gmm_groups(self, X: torch.Tensor):
        groups = []
        for i in range(X.shape[1]):
            d = X[:, i].clone()
            miss_mask = d.isnan()
            d[miss_mask] = d.nanmean()
            # assign class with highest probability (argmax)
            group = self.gmms[i].predict(d.reshape(-1, 1)).astype(float)
            group[miss_mask] = np.nan
            groups.append(group)
        groups = np.column_stack(groups)

        if self.fit_gmm_ord_enc:
            self.fit_gmm_ord_enc = False
            self.gmm_ord_enc = OrdinalEncoder()
            self.gmm_ord_enc.fit(groups)
        groups = self.gmm_ord_enc.transform(groups)

        return groups

    def _get_kdi_groups(self, X: torch.Tensor):
        groups = []
        for i in range(X.shape[1]):
            d = X[:, i].clone()
            miss_mask = d.isnan()
            d[miss_mask] = d.nanmean()
            group = self.kdi[i].transform(d.reshape(-1, 1)).astype(float)
            group[miss_mask] = np.nan
            groups.append(group)
        groups = np.column_stack(groups)

        return groups

    def encode(self, X: torch.Tensor):
        if self.variant == "gbm":
            groups = self._get_gbm_groups(X, init=False)
        elif self.variant == "gmm":
            groups = self._get_gmm_groups(X)
        elif self.variant == "dt":
            groups = self.disttree.get_groups(X)
        elif self.variant == "kdi":
            groups = self._get_kdi_groups(X)

        groups, mask = self.postprocess_groups(groups)

        return groups, mask

    def postprocess_groups(self, groups):
        # get inflated mask
        infl_mask = []
        for i in range(groups.shape[1]):
            mask = np.isin(groups[:, i], self.infl_groups[i])
            infl_mask.append(torch.tensor(mask, dtype=torch.bool))
        infl_mask = torch.column_stack(infl_mask)

        # shift other groups by 1, so that group 0 is reserved for missings
        # construct missingness mask
        miss_mask = []
        for i in range(groups.shape[1]):
            g_i = groups[:, i]
            miss_mask.append(np.isnan(g_i))

            # update group IDs, missing = 0
            if self.has_missing[i]:
                new_g_i = g_i.copy() + 1
                new_g_i = np.nan_to_num(new_g_i, nan=0, copy=True)
                groups[:, i] = new_g_i
        miss_mask = np.column_stack(miss_mask) if len(miss_mask) > 0 else None
        miss_mask = torch.tensor(miss_mask, dtype=torch.bool) if self.has_missing.any() else None

        # combine masks
        mask = miss_mask | infl_mask if miss_mask is not None else infl_mask

        # get into correct formats
        groups = torch.tensor(groups, dtype=torch.long)

        return groups, mask

    def get_masks(self, groups: torch.Tensor):
        # gets masks for generated Z_num (from low res model)

        # get inflated mask
        infl_mask = []
        for i in range(groups.shape[1]):
            # account for shift in groups if there are missings (then missing group = 0)
            infl_groups = (
                torch.tensor(self.infl_groups[i]) + 1 if self.has_missing[i] else torch.tensor(self.infl_groups[i])
            )
            mask = torch.isin(groups[:, i], infl_groups)
            infl_mask.append(mask)
        infl_mask = torch.column_stack(infl_mask)

        # get missingness mask
        miss_mask = []
        for i in range(groups.shape[1]):
            if self.has_missing[i]:
                miss_mask.append(groups[:, i] == 0)
            else:
                miss_mask.append(torch.zeros_like(groups[:, i]).bool())
        miss_mask = torch.column_stack(miss_mask) if self.has_missing.any() else None

        return infl_mask, miss_mask

    def _train_gbms(self, X: torch.Tensor, k_max=20, perc_obs=0.05):
        n = X.shape[0]
        gbms = []
        for i in range(X.shape[1]):
            df = pd.DataFrame(X[:, i].clone()).dropna()
            data_trn = lgb.Dataset(df, label=df)
            params = {
                "objective": "regression",
                "deterministic": True,
                "verbosity": -1,
                "seed": 42,
                "max_depth": 5,
                "num_leaves": k_max,
                "min_data_in_leaf": int(n * perc_obs),
            }
            gbm = lgb.train(params, data_trn, num_boost_round=1)
            gbms.append(gbm)
        return gbms

    def _train_gmms(self, X: torch.Tensor, k_max=20):
        gmms = []
        for i in range(X.shape[1]):
            d = X[:, i].clone()
            d = d[~d.isnan()]

            bic_results = []
            for k in range(2, k_max + 1):
                gmm = GaussianMixture(n_components=k, random_state=self.seed)
                gmm.fit(d.reshape(-1, 1))
                bic_results.append(gmm.bic(d.reshape(-1, 1)).item())
            best_k_idx = np.argmin(bic_results).item()
            best_k = list(range(2, k_max + 1))[best_k_idx]

            # fit best model
            gmm = GaussianMixture(n_components=best_k, random_state=self.seed)
            gmms.append(gmm.fit(d.reshape(-1, 1)))

        return gmms

    def _train_bgmms(self, X: torch.Tensor, k_max=20):
        bgmms = []
        for i in range(X.shape[1]):
            d = X[:, i].clone()
            d = d[~d.isnan()]

            bgmm = BayesianGaussianMixture(
                n_components=k_max,
                random_state=self.seed,
                weight_concentration_prior_type="dirichlet_process",
                weight_concentration_prior=0.001,
                n_init=1,
            )
            bgmms.append(bgmm.fit(d.reshape(-1, 1)))

        return bgmms

    def _train_kdi(self, X: torch.Tensor):
        kdis = []
        for i in range(X.shape[1]):
            d = X[:, i].clone()
            d = d[~d.isnan()]
            kdi = KDIDiscretizer(random_state=42)
            kdis.append(kdi.fit(d.reshape(-1, 1)))
        return kdis
