import numpy as np
import pandas as pd
import polars as pl
import torch
from sklearn.preprocessing import OneHotEncoder, QuantileTransformer, StandardScaler, OrdinalEncoder
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline


from utils import set_seeds
from omegaconf import OmegaConf

from scipy import optimize
from .data_utils import CatEncoder
from .fast_dataloader import FastTensorDataLoader


DATA_DIR = Path('data')
    
def pick_coeffs(X, idxs_nas):
    
    d_na = len(idxs_nas)
    coeffs = torch.randn(X.shape[1], d_na, dtype=X.dtype)
    Wx = X.mm(coeffs)
    coeffs /= torch.std(Wx, 0, keepdim=True)

    return coeffs


def fit_intercepts(X, coeffs, p):

    d_na = coeffs.shape[1]
    intercepts = torch.zeros(d_na)
    for j in range(d_na):
        def f(x):
            return torch.sigmoid(X.mv(coeffs[:, j]) + x).mean().item() - p
        intercepts[j] = optimize.bisect(f, -500, 500)
        
    return intercepts


class DataProcessor:
    """
    Does not encode data based on validation and training set, only on training set!
    """
    def __init__(self, dataset, cat_encoding='ordinal', seed=42, val_prop=0.1, test_prop=0.1, 
                 train_batch_size=4096, val_batch_size=4096, cat_min_freq=5,
                 missing_mechanism=None, seed_missings=42, p_miss=0.1, p_obs=0.3):
        
        self.name = dataset
        self.cfg = OmegaConf.load(f'{DATA_DIR}/configs/{dataset}.yaml')
        self.file_path = DATA_DIR/'raw'/dataset/self.cfg.csv_file
        self.separator = (self.cfg.sep if 'sep' in self.cfg else ',')
        
        self.target = self.cfg.target
        self.cat_cols = self.cfg.cat_features
        if 'dequant_features' in self.cfg:
            self.num_cols = self.cfg.int_features + self.cfg.cont_features + self.cfg.dequant_features
        else:
            self.num_cols = self.cfg.int_features + self.cfg.cont_features
        self.is_regression = (self.cfg.task == 'regression')
        self.task = self.cfg.task
        self.seed = seed
        self.val_prop = val_prop
        self.test_prop = test_prop
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.cat_min_freq = cat_min_freq
        self.cat_encoding = cat_encoding # 'ordinal' or 'onehot'
        assert missing_mechanism in [None, 'mcar', 'mar', 'mnar']
        self.missing_mechanism = missing_mechanism # None, 'mcar', 'mar', 'mnar'
        self.seed_missings = seed_missings
        self.p_miss = p_miss
        self.p_obs = p_obs
        
        self.data_preprocessed = False
        self.preprocess()
        
        
    def create_splits(self, data):

            # compute proportion of data required to achieve val_prop
            prop = self.val_prop / (1 - self.test_prop)

            # train, validation, test split
            idx = np.arange(len(data))
            
            if self.is_regression:
                train_val_idx, test_idx = train_test_split(
                    idx, test_size=self.test_prop, random_state=self.seed
                )
                train_idx, val_idx = train_test_split(
                    train_val_idx, test_size=prop, random_state=self.seed
                )
            else:
                train_val_idx, test_idx = train_test_split(
                    idx, test_size=self.test_prop, random_state=self.seed, stratify=data.select(self.target)
                )
                train_idx, val_idx = train_test_split(
                    train_val_idx,
                    test_size=prop,
                    random_state=self.seed,
                    stratify=data.select(self.target)[train_val_idx],
                )
            return {"train": train_idx, "val": val_idx, "test": test_idx}   
        
        
    def preprocess(self):
        
        set_seeds(self.seed)
        assert not self.data_preprocessed, "Data has already been preprocessed."
        
        print(f"= Processing {self.name} with seed {self.seed} =")
        
        nan_vals = ['NA', 'NULL', 'null', 'nan', 'NaN', 'N/A', 'n/a', '', ' ', 'None', 'none', '?']
        if self.name == 'news':
            data  = pl.from_pandas(pd.read_csv(self.file_path))
        else:
            data = pl.read_csv(self.file_path, null_values=nan_vals, 
                        infer_schema_length=100000,
                        separator=self.separator)
            
        n_full_sample = data.shape[0]
        data = data.select(pl.col([self.target] + self.cat_cols + self.num_cols))
        self.orig_cols = data.columns
        # self.orig_schema = data.schema
        
        # remove rows with missings in target (only needed for ml efficiency tasks)
        data = data.filter(~pl.col(self.target).is_null())
        print(f"Rows with missings in target removed: {n_full_sample - data.shape[0]}")
        
        # determine missings in numerical features
        # aux = data.select(self.num_cols).is_null().sum().to_pandas()
        n_num_miss = data.select(self.num_cols).null_count().to_numpy().sum().item()
        assert n_num_miss == 0, "Data should not have missings in numerical features."

        # create binary classification tasks
        if self.name == 'diabetes':
            data = data.with_columns(pl.when(pl.col(self.target) == 'NO').then(pl.lit("no")).otherwise(pl.lit("yes")).alias(self.target))
        elif self.name == 'covertype':
            data = data.with_columns(pl.when(pl.col(self.target) == 2).then(1).otherwise(0).alias(self.target))
        
        # depending on task, add target to categorical or numerical features
        if self.cfg.task == 'regression':
            self.num_cols = [self.target] + self.num_cols
        else:
            self.cat_cols = [self.target] + self.cat_cols
        print(f"Cat cols: {len(self.cat_cols)}")
        print(f"Num cols: {len(self.num_cols)}")
        
        # record number of digits to round generated numerical values
        self.col_to_round_digits = {}
        for i, d in enumerate(data.select(self.num_cols).iter_columns()):
            if d.dtype == pl.Int64:
                self.col_to_round_digits[d.name] = 0
            else:
                self.col_to_round_digits[d.name] = d.filter(~d.is_null()).to_pandas().map(lambda x: str(x).split('.')[1]).apply(len).max().item()
            
        # for categorical features, replace missings to encode as separate category
        data = data.with_columns(pl.col(self.cat_cols).fill_null('_MISSING_'))
        
        # simulate missings according to selected mechanism
        if self.missing_mechanism is not None:
            assert data.null_count().to_numpy().sum().item() == 0, "Data should have no missings before simulating."
            
        if self.missing_mechanism == 'mcar':
            print("Simulating MCAR...")
            rng = np.random.default_rng(seed=self.seed_missings)
            miss_mask = rng.random((*data.select(self.num_cols).shape,)) < self.p_miss
            miss_mask = torch.tensor(miss_mask, dtype=torch.bool)
            
            # avoid introducing missings in numerical target
            if self.is_regression:
                miss_mask[:, self.num_cols.index(self.target)] = False

        elif self.missing_mechanism == 'mar':
            print("Simulating MAR...")
            miss_mask = self.get_MAR_mask(data)

        elif self.missing_mechanism == 'mnar':
            print("Simulating MNAR...")
            miss_mask, miss_mask_cat = self.get_MAR_mask(data, exclude_inputs=True)
            
        if self.missing_mechanism is not None:
            # apply mask to numerical features
            d = data.select(self.num_cols).to_pandas()
            d[miss_mask.numpy()] = np.nan
            data = data.with_columns([pl.Series(name=col, values=d[col].values) for col in self.num_cols])
            
            if self.missing_mechanism == 'mnar':
                # apply mask to categorical features
                d = data.select(self.cat_cols).to_pandas()
                d[miss_mask_cat.numpy()] = '_MISSING_'
                data = data.with_columns([pl.Series(name=col, values=d[col].values) for col in self.cat_cols])
                
        # create train/val/test splits
        self.splits = self.create_splits(data)

        # take snapshot of original data
        self.orig_schema = data.schema
        # self.orig_data = pl.from_pandas(data.clone().to_pandas(), schema_overrides=self.orig_schema)
        self.orig_data = data.fill_nan(None)
        
        # set flag to avoid re-doing preprocessing
        self.data_preprocessed = True
        
        
    def get_MAR_mask(self, data, exclude_inputs=False):
        """
        Simulates missings in numerical features based on subset of categorical and numerical features.
        Uses a random logistic model to determine the probability of missing values.
        Avoids introducing missings into the target avoid problems for ML efficiency evaluation.
        
        Based on https://github.com/BorisMuzellec/MissingDataOT/blob/master/utils.py#L144.
        """
        rng = np.random.default_rng(seed=self.seed_missings)
        
        # remove target in regression case to avoid introducing missings to it
        if self.is_regression:
            num_cols = [col for col in self.num_cols if col != self.target]
        else:
            num_cols = self.num_cols

        X = data.select(num_cols).to_numpy().astype(float)
        std_scaler = StandardScaler()
        X = std_scaler.fit_transform(X)
        X = torch.tensor(X).float()
        miss_mask = torch.zeros_like(X).bool()
        n, k = X.shape

        # number of numerical features without missing values (at least one variable)
        k_obs = max(int(self.p_obs * k), 1) 
        
        # number of numerical features that will have missing values
        k_na = k - k_obs
        assert k_na > 0, "At least one numerical feature should have missing values."

        # sample observed feature indices and those with missing values
        idxs_obs = rng.choice(k, k_obs, replace=False)
        idxs_nas = np.array([i for i in range(k) if i not in idxs_obs])
        
        # subsample categorical features that determine the missing values
        # if excluding inputs later on, cannot use binary target to avoid introducing missings to it
        if not self.is_regression and exclude_inputs:
            cat_cols = [col for col in self.cat_cols if col != self.target]
        else:
            cat_cols = self.cat_cols  
        
        cat_enc = OrdinalEncoder()
        X_cat = cat_enc.fit_transform(data.select(cat_cols).to_numpy())
        X_cat = torch.tensor(X_cat).float()
        
        cat_idx = rng.choice(X_cat.shape[1], max(int(self.p_obs * X_cat.shape[1]), 1), replace=False)
        
        # combine observed numerical and selected categorical features
        X_combined = torch.column_stack((X[:, idxs_obs], X_cat[:, cat_idx]))
        
        # use randomly initialized logistic model to determine missing value probability
        # pick coefficients so that W^Tx has unit variance (avoids shrinking)
        coeffs = pick_coeffs(X_combined, idxs_nas)
        
        # pick the intercepts to have a desired amount of missing values
        # so per feature, get self.p_miss proportion of missing values
        intercepts = fit_intercepts(X_combined, coeffs, self.p_miss)

        ps = torch.sigmoid(X_combined.mm(coeffs) + intercepts)
        bern = torch.rand(n, k_na)
        miss_mask[:, idxs_nas] = bern < ps 
        
        # miss_mask[:, idxs_nas].sum(0) / n
        
        if exclude_inputs:
            # mask part of the features that are used to determine the missing values
            miss_mask[:, idxs_obs] = torch.rand(n, k_obs) < self.p_miss
            miss_mask_cat = torch.zeros_like(X_cat).bool()
            miss_mask_cat[:, cat_idx] = torch.rand(n, len(cat_idx)) < self.p_miss
            
            if self.is_regression:
                miss_mask = np.insert(miss_mask, self.num_cols.index(self.target), False, axis=1).bool()
            else:
                miss_mask_cat = np.insert(miss_mask_cat, self.cat_cols.index(self.target), False, axis=1).bool()
                
            assert miss_mask.shape[1] == len(self.num_cols)
            assert miss_mask_cat.shape[1] == len(self.cat_cols)
            
            return miss_mask, miss_mask_cat
        
        if self.is_regression:
            miss_mask = np.insert(miss_mask, self.num_cols.index(self.target), False, axis=1).bool()
        assert miss_mask.shape[1] == len(self.num_cols)
        
        return miss_mask    
    
        
    def get_data_splits(self):
        """
        Returns the data splits as polars DataFrames.
        """
        train_data = self.orig_data[self.splits['train'], :].clone()
        val_data = self.orig_data[self.splits['val'], :].clone()
        test_data = self.orig_data[self.splits['test'], :].clone()
        
        return train_data, val_data, test_data
        

    def get_data_loaders(self, mean_impute=True, include_test=False):
        # include_test is only used for evaluation of LowRes model, to ensure that test data is preprocessed the same way as training data
        
        train_data, val_data, test_data = self.get_data_splits()
        data = pl.concat([train_data, val_data]) # to ensure consistent values for later evaluation as well
        
        if include_test:
            pl.concat([train_data, val_data, test_data])

        ####################################################3
        # Handle numerical features
        
        X_num_train = train_data.select(self.num_cols).to_numpy()
        X_num_val = val_data.select(self.num_cols).to_numpy() if val_data.height > 0 else None
        X_num_test = test_data.select(self.num_cols).to_numpy()
    
        # create missing indicator features for numerical features
        if self.missing_mechanism is not None:
            M_ind_train = []
            M_ind_val = []
            M_ind_test = []
            self.M_ind_cols = []
            for i, label in enumerate(self.num_cols):
                d_train = X_num_train[:, i]
                d_val = X_num_val[:, i] if X_num_val is not None else None
                d_test = X_num_test[:, i]
                if np.isnan(d_train).any():
                    M_ind_train.append(np.isnan(d_train))
                    M_ind_test.append(np.isnan(d_test))
                    if d_val is not None:
                        M_ind_val.append(np.isnan(d_val))
                    self.M_ind_cols.append(label)
                    
                    # mean-impute numerical missings per feature
                    if mean_impute:
                        X_num_train[:, i] = np.nan_to_num(d_train, nan=np.nanmean(d_train), copy=False)
                        X_num_val[:, i] = np.nan_to_num(d_val, nan=np.nanmean(d_train), copy=False) if d_val is not None else None
                        X_num_test[:, i] = np.nan_to_num(d_test, nan=np.nanmean(d_train), copy=False)
                
            M_ind_train = np.column_stack(M_ind_train)
            M_ind_val = np.column_stack(M_ind_val) if val_data.height > 0 else None
            M_ind_test = np.column_stack(M_ind_test)
            M_ind_train = torch.tensor(M_ind_train).bool()
            M_ind_val = torch.tensor(M_ind_val).bool() if val_data.height > 0 else None
            M_ind_test = torch.tensor(M_ind_test).bool()
            
        # quantile-transform + standardize
        self.num_enc = Pipeline([
            ("quantile_transformer", QuantileTransformer(
                output_distribution="normal",
                n_quantiles=max(min(train_data.height // 30, 1000), 10),
                subsample=None,
                random_state=42,
            )),
            ("standard_scaler", StandardScaler()),
        ])
        X_num_train = self.num_enc.fit_transform(X_num_train)
        X_num_val = self.num_enc.transform(X_num_val) if val_data.height > 0 else None
        X_num_test = self.num_enc.transform(X_num_test)
        X_num_train = torch.tensor(X_num_train).float()
        X_num_val = torch.tensor(X_num_val).float() if val_data.height > 0 else None
        X_num_test = torch.tensor(X_num_test).float()
        
        ####################################################3
        # Handle categorical features
        
        if self.cat_encoding == 'onehot':
            self.cat_enc = OneHotEncoder(
                handle_unknown="ignore", sparse_output=False, dtype=np.float32
            )
        else:
            self.cat_enc = CatEncoder(data.select(self.cat_cols), min_frequency=self.cat_min_freq)
            
        self.cat_enc.fit(data.select(self.cat_cols))
        X_cat_train = self.cat_enc.transform(train_data.select(self.cat_cols))
        X_cat_val = self.cat_enc.transform(val_data.select(self.cat_cols))
        X_cat_test = self.cat_enc.transform(test_data.select(self.cat_cols)) if include_test else None
        X_cat_train = torch.tensor(X_cat_train).long()
        X_cat_val = torch.tensor(X_cat_val).long() if val_data.height > 0 else None
        X_cat_test = torch.tensor(X_cat_test).long() if include_test else None
        
        if self.cat_encoding == 'onehot':
            self.X_cat_n_classes = [len(cats) for cats in self.cat_enc.categories_]
        else:
            self.X_cat_n_classes = [self.cat_enc.idx_to_stats[i]['n_classes'] for i in range(len(self.cat_cols))]
            
        # add missing indicator features for numerical features
        if self.missing_mechanism is not None:
            X_cat_train = torch.column_stack((X_cat_train, M_ind_train)).long()
            X_cat_val = torch.column_stack((X_cat_val, M_ind_val)).long() if val_data.height > 0 else None
            X_cat_test = torch.column_stack((X_cat_test, M_ind_test)).long() if include_test else None
            
            # update n_classes
            self.X_cat_n_classes = self.X_cat_n_classes + [2] * M_ind_train.shape[1] 
        
        ####################################################3
        # Prepare loaders

        train_loader = FastTensorDataLoader(
            X_cat_train, X_num_train,
            batch_size=min(self.train_batch_size, train_data.height),
            shuffle=True,
            drop_last=True,
        )
        
        if val_data.height > 0:
            val_loader = FastTensorDataLoader(
                X_cat_val, X_num_val,
                batch_size=min(self.val_batch_size, val_data.height),
                shuffle=False,
                drop_last=False,
            )
        else:
            val_loader = None
        
        if include_test:
            test_loader = FastTensorDataLoader(
                X_cat_test, X_num_test,
                batch_size=min(self.val_batch_size, test_data.height),
                shuffle=False,
                drop_last=False,
            )
                    
            return train_loader, val_loader, test_loader

        return train_loader, val_loader
    
    
    def postprocess(self, X_cat_gen, X_num_gen, includes_miss_ind=True):
        """
        Postprocess synthetic data samples.
        Returns full synthetic data frame with same structure as original data.
        """
        
        X_cat_gen = X_cat_gen.astype(np.int64)
        X_num_gen = X_num_gen.astype(np.float32)
        
        ####################################################
        # Handle categorical features
        
        if includes_miss_ind and self.missing_mechanism is not None:
            M_ind_gen = X_cat_gen[:, len(self.cat_cols):].astype(np.bool)
            assert M_ind_gen.shape[1] == len(self.M_ind_cols)
            X_cat_gen = X_cat_gen[:, :len(self.cat_cols)]
        X_cat_gen = self.cat_enc.inverse_transform(X_cat_gen)
        
        ####################################################
        # Handle nuermical features
        
        X_num_gen = self.num_enc.inverse_transform(X_num_gen)
        X_num_gen = pd.DataFrame(X_num_gen, columns=self.num_cols)
                
        # rounding numerical values
        for col_name, decimals in self.col_to_round_digits.items():
            X_num_gen[col_name] = np.round(X_num_gen[col_name], decimals)
        
        # if there are missings, replace them with NaN
        if includes_miss_ind and self.missing_mechanism is not None:
            miss_mask = np.zeros_like(X_num_gen, dtype=bool)
            for i, col in enumerate(self.num_cols):
                if col in self.M_ind_cols:
                    miss_mask[:, i] = M_ind_gen[:, self.M_ind_cols.index(col)]
        
            X_num_gen[miss_mask] = np.nan
            
        ####################################################
        # Combine data and construct final dataframe
            
        # combine and reorder columns, make sure types are correct
        X_columns = np.concatenate((self.cat_cols, self.num_cols))
        X = pl.from_pandas(pd.DataFrame(np.column_stack((X_cat_gen, X_num_gen))))
        X.columns = X_columns
        
        # bring in correct order and apply original schema
        X = X.select(self.orig_cols)
        df_gen = pl.DataFrame(X, schema=self.orig_schema)

        return df_gen
