

import torch
import lightgbm as lgb
import pandas as pd
import numpy as np
from sklearn.preprocessing import OrdinalEncoder
from sklearn.mixture import GaussianMixture
from sklearn.mixture import BayesianGaussianMixture
from .disttree import DistTree


class Discretizer():
    """
    
    Variants: 
    - GBM based (gbm)
    - GMM based (gmm)
    """
    
    def __init__(self, X_num_trn, variant='gbm', k_max=20, perc_obs=0.03, seed=42, adjust_means=False, max_depth=3):
        
        self.seed = seed
        self.variant = variant
        self.adjust_means = adjust_means
        self.max_depth = max_depth
        self.k_max = k_max
        self.perc_obs = perc_obs
        self.fit_gmm_ord_enc = True
        
        # check for missings in train data (assuming the same holds for validation / synthetic data)
        self.has_missing = torch.isnan(X_num_trn).any(0).numpy()
        
        # get mean groups (after using imputation)
        # used to assign mean and variance to mean imputed values later on
        d_mean = X_num_trn.nanmean(0, keepdim=True).repeat((2,1))
        
        if self.variant == 'gbm':
            self.gbms = self._train_gbms(X_num_trn, k_max=k_max, perc_obs=perc_obs)
            groups, self.ord_encs = self._get_gbm_groups(X_num_trn, init=True)
            mean_groups = self._get_gbm_groups(d_mean, init=False)[0]
        elif self.variant == 'gmm':
            # self.gmms = self._train_gmms(X_num_trn, k_max=k_max)
            self.gmms = self._train_bgmms(X_num_trn, k_max=k_max)
            groups = self._get_gmm_groups(torch.vstack((X_num_trn, d_mean[0])))[:-1]
            mean_groups = self._get_gmm_groups(d_mean)[0]
        elif self.variant == 'dt':
            self.disttree = DistTree(max_depth, seed=seed)
            self.disttree.fit(X_num_trn)
            groups = self.disttree.get_groups(X_num_trn)
            mean_groups = self.disttree.get_groups(d_mean)[0]
            
        # get group-specific means and stds
        if self.adjust_means or self.variant == 'gbm':
            means = []
            stds = []
            for i in range(X_num_trn.shape[1]):
                df = pd.DataFrame({'x': X_num_trn[:,i].clone(), 'group': groups[:,i]})
                df_stats = df.groupby('group').agg(['mean', 'std']).droplevel(0, axis=1)
                means.append(torch.tensor(df_stats['mean'].to_numpy(), dtype=torch.float32))
                stds.append(torch.tensor(df_stats['std'].to_numpy(), dtype=torch.float32))
        elif self.variant == 'gmm':
            means = []
            stds = []
            for i in range(X_num_trn.shape[1]):
                means.append(torch.tensor(self.gmms[i].means_.squeeze(), dtype=torch.float32))
                stds.append(torch.tensor(np.sqrt(self.gmms[i].covariances_.squeeze()), dtype=torch.float32))
        elif self.variant == 'dt':
            means = [torch.tensor(m, dtype=torch.float32) for m in self.disttree.means]
            stds = [torch.tensor(s, dtype=torch.float32) for s in self.disttree.stds]
            
        # check for inflated values (empirical var = 0)
        self.has_inflated = []
        self.infl_groups = []
        for i in range(X_num_trn.shape[1]):
            df = pd.DataFrame({'x': X_num_trn[:,i].clone(), 'group': groups[:,i]})
            df_stats = df.groupby('group').agg(['mean', 'std']).droplevel(0, axis=1)
            infl_idx = df_stats.loc[df_stats['std'] == 0].index.to_list()
            self.has_inflated.append(len(infl_idx) > 0)
            self.infl_groups.append(infl_idx)
            
            # adjust std to zero
            stds[i][infl_idx] = 0
        
        # adjust means for missings (assign mean, std of group to which average X belongs)
        if self.has_missing.any():
            for i in range(X_num_trn.shape[1]):
                if self.has_missing[i]:
                    # update means with mean of missing group (= mean of group that we get for average x), similar for std dev.
                    miss_mu = means[i][mean_groups[i].astype(int)].unsqueeze(0)
                    miss_std = stds[i][mean_groups[i].astype(int)].unsqueeze(0)
                    means[i] = torch.cat((miss_mu, means[i]))
                    stds[i] = torch.cat((miss_std, stds[i]))
        self.means = means
        self.stds = stds
    
    def _get_gbm_groups(self, X: torch.Tensor, init=False):
        groups = []
        
        if init:
            ord_encs = []
        for i in range(X.shape[1]):
            d = X[:,i].clone()
            miss_mask = d.isnan()
            d[miss_mask] = d.nanmean()
            out = self.gbms[i].predict(pd.DataFrame(d))
            if init:
                enc = OrdinalEncoder()
                group = enc.fit_transform(out.reshape(-1,1))
                ord_encs.append(enc)
            else:
                group = self.ord_encs[i].transform(out.reshape(-1,1))
            group[miss_mask] = np.nan
            groups.append(group) 

        groups = np.column_stack(groups)
              
        if init:
            return groups, ord_encs
        
        return groups
    
    
    def _get_gmm_groups(self, X: torch.Tensor):
        groups = []
        for i in range(X.shape[1]):
            d = X[:,i].clone()
            miss_mask = d.isnan()
            d[miss_mask] = d.nanmean()
            # assign class with highest probability (argmax)
            group = self.gmms[i].predict(d.reshape(-1,1)).astype(float)
            group[miss_mask] = np.nan
            groups.append(group) 
        groups = np.column_stack(groups)
        
        if self.fit_gmm_ord_enc:
            self.fit_gmm_ord_enc = False
            self.gmm_ord_enc = OrdinalEncoder()
            self.gmm_ord_enc.fit(groups)
        groups = self.gmm_ord_enc.transform(groups)       
            
        return groups
    
    
    def encode(self, X: torch.Tensor):
        
        if self.variant == 'gbm':
            groups = self._get_gbm_groups(X, init=False)
        elif self.variant == 'gmm':
            groups = self._get_gmm_groups(X)
        elif self.variant == 'dt':
            groups = self.disttree.get_groups(X)
            
        groups, mask = self.postprocess_groups(groups)
            
        return groups, mask
      
    
    def postprocess_groups(self, groups):

        # get inflated mask
        infl_mask = []
        for i in range(groups.shape[1]):
            mask = np.isin(groups[:, i], self.infl_groups[i])
            infl_mask.append(torch.tensor(mask, dtype=torch.bool))
        infl_mask = torch.column_stack(infl_mask)
            
        # shift other groups by 1, so that group 0 is reserved for missings
        # construct missingness mask
        miss_mask = []
        for i in range(groups.shape[1]):
            g_i = groups[:,i]
            miss_mask.append(np.isnan(g_i))
            
            # update group IDs, missing = 0
            if self.has_missing[i]:
                new_g_i = g_i.copy() + 1
                new_g_i = np.nan_to_num(new_g_i, nan=0, copy=True)
                groups[:,i] = new_g_i
        miss_mask = np.column_stack(miss_mask) if len(miss_mask) > 0 else None
        miss_mask = torch.tensor(miss_mask, dtype=torch.bool) if self.has_missing.any() else None
        
        # combine masks
        if miss_mask is not None:
            mask = miss_mask | infl_mask
        else:
            mask = infl_mask
                
        # get into correct formats
        groups = torch.tensor(groups, dtype=torch.long)

        return groups, mask
    
    
    def get_masks(self, groups: torch.Tensor):
        # gets masks for generated Z_num (from low res model)
        
        # get inflated mask
        infl_mask = []
        for i in range(groups.shape[1]):
            # account for shift in groups if there are missings (then missing group = 0)
            infl_groups = torch.tensor(self.infl_groups[i]) + 1 if self.has_missing[i] else torch.tensor(self.infl_groups[i])
            mask = torch.isin(groups[:, i], infl_groups)
            infl_mask.append(mask)
        infl_mask = torch.column_stack(infl_mask)
        
        # get missingness mask
        miss_mask = []
        for i in range(groups.shape[1]):
            if self.has_missing[i]:
                miss_mask.append(groups[:,i] == 0)
            else:
                miss_mask.append(torch.zeros_like(groups[:,i]).bool())
        miss_mask = torch.column_stack(miss_mask) if self.has_missing.any() else None
        
        return infl_mask, miss_mask
        

    def _train_gbms(self, X: torch.Tensor, k_max=20, perc_obs=0.05):
        n = X.shape[0]
        gbms = []
        for i in range(X.shape[1]):
            df = pd.DataFrame(X[:,i].clone()).dropna()
            data_trn = lgb.Dataset(df, label=df)
            params = {'objective': 'regression', 'deterministic': True, 'verbosity':-1, 
                      'seed': 42, 'max_depth': 5, 'num_leaves': k_max, 
                      'min_data_in_leaf': int(n*perc_obs)}
            gbm = lgb.train(params, data_trn, num_boost_round=1)
            gbms.append(gbm)
        return gbms
    
    
    def _train_gmms(self, X: torch.Tensor, k_max=20):
        
        gmms = []
        for i in range(X.shape[1]):
            d = X[:,i].clone()
            d = d[~d.isnan()]
            
            bic_results = []
            for k in range(2, k_max+1):
                gmm = GaussianMixture(n_components=k, random_state=self.seed)
                gmm.fit(d.reshape(-1,1))
                bic_results.append(gmm.bic(d.reshape(-1,1)).item())
            best_k_idx = np.argmin(bic_results).item()
            best_k = list(range(2, k_max+1))[best_k_idx]
            
            # fit best model
            gmm = GaussianMixture(n_components=best_k, random_state=self.seed)
            gmms.append(gmm.fit(d.reshape(-1,1)))

        return gmms
    
    
    def _train_bgmms(self, X: torch.Tensor, k_max=20):
        
        bgmms = []
        for i in range(X.shape[1]):
            d = X[:,i].clone()
            d = d[~d.isnan()]
                        
            bgmm = BayesianGaussianMixture(n_components=k_max, random_state=self.seed,
                                            weight_concentration_prior_type='dirichlet_process',
                                            weight_concentration_prior=0.001, n_init=1)
            bgmms.append(bgmm.fit(d.reshape(-1,1)))

        return bgmms
    
