import numpy as np
import os
import lib
import pandas as pd
from tab_ddpm.modules import MLPDiffusion, ResNetDiffusion
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

def get_model(
    model_name,
    model_params,
    n_num_features,
    category_sizes
): 
    print(model_name)
    if model_name == 'mlp':
        model = MLPDiffusion(**model_params)
    elif model_name == 'resnet':
        model = ResNetDiffusion(**model_params)
    else:
        raise "Unknown model!"
    return model

def update_ema(target_params, source_params, rate=0.999):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.
    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate)

def concat_y_to_X(X, y):
    if X is None:
        return y.reshape(-1, 1)
    return np.concatenate([y.reshape(-1, 1), X], axis=1)

def make_dataset_from_df(
        df, 
        T,
        is_y_cond,
        ratios=[0.7, 0.2, 0.1], 
        df_info=None,
        std=0
    ):
    """
    The order of the generated dataset: (y, X_num, X_cat)

    is_y_cond:
        concat: y is concatenated to X, the model learn a joint distribution of (y, X)
        embedding: y is not concatenated to X. During computations, y is embedded
            and added to the latent vector of X
        none: y column is completely ignored

    How does is_y_cond affect the generation of y?
    is_y_cond:
        concat: the model synthesizes (y, X) directly, so y is just the first column
        embedding: y is first sampled using empirical distribution of y. The model only 
            synthesizes X. When returning the generated data, we return the generated X
            and the sampled y. (y is sampled from empirical distribution, instead of being
            generated by the model)
            Note that in this way, y is still not independent of X, because the model has been
            adding the embedding of y to the latent vector of X during computations.
        none: 
            y is synthesized using y's empirical distribution. X is generated by the model.
            In this case, y is completely independent of X.

    TODO: For now, n_classes has to be set to 0. This is because our matrix is the concatenation
    of (X_num, X_cat). In this case, if we have is_y_cond == 'concat', we can guarantee that y 
    is the first column of the matrix.
    However, if we have n_classes > 0, then y is not the first column of the matrix. This will 
    lead to problems and needs
    """
    train_val_df, test_df = train_test_split(df, test_size=ratios[2], random_state=42)
    train_df, val_df = train_test_split(
        train_val_df, 
        test_size=ratios[1] / (ratios[0] + ratios[1]), random_state=42
    )

    cat_column_orders = []
    num_column_orders = []
    index_to_column = list(df.columns)
    column_to_index = {col: i for i, col in enumerate(index_to_column)}

    if df_info['n_classes'] > 0:
        X_cat = {} if df_info['cat_cols'] is not None or is_y_cond == 'concat' else None
        X_num = {} if df_info['num_cols'] is not None else None
        y = {}

        cat_cols_with_y = []
        if df_info['cat_cols'] is not None:
            cat_cols_with_y += df_info['cat_cols']
        if is_y_cond == 'concat':
            cat_cols_with_y = [df_info['y_col']] + cat_cols_with_y

        if len(cat_cols_with_y) > 0:
            X_cat['train'] = train_df[cat_cols_with_y].to_numpy(dtype=np.str_)
            X_cat['val'] = val_df[cat_cols_with_y].to_numpy(dtype=np.str_)
            X_cat['test'] = test_df[cat_cols_with_y].to_numpy(dtype=np.str_)
        
        y['train'] = train_df[df_info['y_col']].values.astype(np.float32)
        y['val'] = val_df[df_info['y_col']].values.astype(np.float32)
        y['test'] = test_df[df_info['y_col']].values.astype(np.float32)

        if df_info['num_cols'] is not None:
            X_num['train'] = train_df[df_info['num_cols']].values.astype(np.float32)
            X_num['val'] = val_df[df_info['num_cols']].values.astype(np.float32)
            X_num['test'] = test_df[df_info['num_cols']].values.astype(np.float32)

        cat_column_orders = [column_to_index[col] for col in cat_cols_with_y]
        num_column_orders = [column_to_index[col] for col in df_info['num_cols']]

    else:
        X_cat = {} if df_info['cat_cols'] is not None else None
        X_num = {} if df_info['num_cols'] is not None or is_y_cond == 'concat' else None
        y = {}

        num_cols_with_y = []
        if df_info['num_cols'] is not None:
            num_cols_with_y += df_info['num_cols']
        if is_y_cond == 'concat':
            num_cols_with_y = [df_info['y_col']] + num_cols_with_y

        if len(num_cols_with_y) > 0:
            X_num['train'] = train_df[num_cols_with_y].values.astype(np.float32)
            X_num['val'] = val_df[num_cols_with_y].values.astype(np.float32)
            X_num['test'] = test_df[num_cols_with_y].values.astype(np.float32)
        
        y['train'] = train_df[df_info['y_col']].values.astype(np.float32)
        y['val'] = val_df[df_info['y_col']].values.astype(np.float32)
        y['test'] = test_df[df_info['y_col']].values.astype(np.float32)

        if df_info['cat_cols'] is not None:
            X_cat['train'] = train_df[df_info['cat_cols']].to_numpy(dtype=np.str_)
            X_cat['val'] = val_df[df_info['cat_cols']].to_numpy(dtype=np.str_)
            X_cat['test'] = test_df[df_info['cat_cols']].to_numpy(dtype=np.str_)

        cat_column_orders = [column_to_index[col] for col in df_info['cat_cols']]
        num_column_orders = [column_to_index[col] for col in num_cols_with_y]

    
    column_orders = num_column_orders + cat_column_orders
    column_orders = [index_to_column[index] for index in column_orders]
    
    label_encoders = {}
    if X_cat is not None and len(df_info['cat_cols']) > 0:
        X_cat_all = np.vstack((X_cat['train'], X_cat['val'], X_cat['test']))
        X_cat_converted = []
        for col_index in range(X_cat_all.shape[1]):
            label_encoder = LabelEncoder()
            X_cat_converted.append(label_encoder.fit_transform(X_cat_all[:, col_index]).astype(float))
            if std > 0:
                # add noise
                X_cat_converted[-1] += np.random.normal(0, std, X_cat_converted[-1].shape)
            label_encoders[col_index] = label_encoder

        X_cat_converted = np.vstack(X_cat_converted).T

        train_num = X_cat['train'].shape[0]
        val_num = X_cat['val'].shape[0]
        test_num = X_cat['test'].shape[0]

        X_cat['train'] = X_cat_converted[: train_num, :]
        X_cat['val'] = X_cat_converted[train_num: train_num + val_num, :]
        X_cat['test'] = X_cat_converted[train_num + val_num:, :]

        if len(X_num) > 0:
            X_num['train'] = np.concatenate((X_num['train'], X_cat['train']), axis=1)
            X_num['val'] = np.concatenate((X_num['val'], X_cat['val']), axis=1)
            X_num['test'] = np.concatenate((X_num['test'], X_cat['test']), axis=1)
        else:
            X_num = X_cat
            X_cat = None

    D = lib.Dataset(
        X_num,
        None,
        y,
        y_info={},
        task_type=lib.TaskType(df_info['task_type']),
        n_classes=df_info['n_classes']
    )

    return lib.transform_dataset(D, T, None), label_encoders, column_orders


def make_dataset_from_num_arrays(
        X_num_array,
        y_array, 
        T,
        is_y_cond,
        task_type,
        n_classes,
        ratios=[0.7, 0.2, 0.1], 
        num_cols_to_transform=0
    ):
    """
    If y is categorical, then y_array has been label encoded, and NEEDS to be
    transformed along with the columns to transform

    If y is numerical, then y_array is passed as is and also NEEDS to be transformed
    along with the columns to transform
    """
    train_ratio, val_ratio, test_ratio = ratios
    X_train, X_temp, y_train, y_temp = train_test_split(
        X_num_array, y_array, test_size=1 - train_ratio, random_state=42
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=test_ratio / (val_ratio + test_ratio), random_state=42
    )

    X_num = {}
    y = {}

    X_num['train'] = X_train
    X_num['val'] = X_val
    X_num['test'] = X_test

    if not is_y_cond:
        X_num['train'] = concat_y_to_X(X_num['train'], y_train)
        X_num['val'] = concat_y_to_X(X_num['val'], y_val)
        X_num['test'] = concat_y_to_X(X_num['test'], y_test)

    y['train'] = y_train
    y['val'] = y_val
    y['test'] = y_test

    D = lib.Dataset(
        X_num,
        None,
        y,
        y_info={},
        task_type=lib.TaskType(task_type),
        n_classes=n_classes
    )

    if is_y_cond == 'concat':
        num_cols_to_transform += 1

    return lib.transform_dataset(D, T, None, transform_cols_num=num_cols_to_transform)


 
def make_dataset(
    data_path: str,
    T: lib.Transformations,
    num_classes: int,
    is_y_cond: bool,
    change_val: bool
):
    # classification
    if num_classes > 0:
        X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) or is_y_cond == 'concat' else None
        X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) else None
        y = {} 

        for split in ['train', 'val', 'test']:
            X_num_t, X_cat_t, y_t = lib.read_pure_data(data_path, split)
            if X_num is not None:
                X_num[split] = X_num_t
            if is_y_cond == 'concat':
                X_cat_t = concat_y_to_X(X_cat_t, y_t)
            if X_cat is not None:
                X_cat[split] = X_cat_t
            y[split] = y_t
    else:
        # regression
        X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) else None
        X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) or is_y_cond == 'concat' else None
        y = {}

        for split in ['train', 'val', 'test']:
            X_num_t, X_cat_t, y_t = lib.read_pure_data(data_path, split)
            if is_y_cond == 'concat':
                X_num_t = concat_y_to_X(X_num_t, y_t)
            if X_num is not None:
                X_num[split] = X_num_t
            if X_cat is not None:
                X_cat[split] = X_cat_t
            y[split] = y_t

    info = lib.load_json(os.path.join(data_path, 'info.json'))

    D = lib.Dataset(
        X_num,
        X_cat,
        y,
        y_info={},
        task_type=lib.TaskType(info['task_type']),
        n_classes=info.get('n_classes')
    )

    if change_val:
        D = lib.change_val(D)
    
    return lib.transform_dataset(D, T, None)