import numpy as np
import os
import lib
import pandas as pd
from tab_ddpm.modules import MLPDiffusion, ResNetDiffusion
from sklearn.model_selection import train_test_split

def get_model(
    model_name,
    model_params,
    n_num_features,
    category_sizes
): 
    print(model_name)
    if model_name == 'mlp':
        model = MLPDiffusion(**model_params)
    elif model_name == 'resnet':
        model = ResNetDiffusion(**model_params)
    else:
        raise "Unknown model!"
    return model

def update_ema(target_params, source_params, rate=0.999):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.
    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate)

def concat_y_to_X(X, y):
    if X is None:
        return y.reshape(-1, 1)
    return np.concatenate([y.reshape(-1, 1), X], axis=1)

def make_dataset_from_df(
        df, 
        T,
        is_y_cond,
        ratios=[0.7, 0.2, 0.1], 
        df_info=None
    ):
    train_val_df, test_df = train_test_split(df, test_size=ratios[2], random_state=42)
    train_df, val_df = train_test_split(
        train_val_df, 
        test_size=ratios[1] / (ratios[0] + ratios[1]), random_state=42
    )

    if df_info['n_classes'] > 0:
        X_cat = {} if df_info['cat_cols'] is not None or not is_y_cond else None
        X_num = {} if df_info['num_cols'] is not None else None
        y = {}

        cat_cols_with_y = []
        if df_info['cat_cols'] is not None:
            cat_cols_with_y += df_info['cat_cols']
        if not is_y_cond:
            cat_cols_with_y.append(df_info['y_col'])

        if len(cat_cols_with_y) > 0:
            X_cat['train'] = train_df[cat_cols_with_y].to_numpy(dtype=np.str_)
            X_cat['val'] = val_df[cat_cols_with_y].to_numpy(dtype=np.str_)
            X_cat['test'] = test_df[cat_cols_with_y].to_numpy(dtype=np.str_)
        
        y['train'] = train_df[df_info['y_col']].values.astype(np.float32)
        y['val'] = val_df[df_info['y_col']].values.astype(np.float32)
        y['test'] = test_df[df_info['y_col']].values.astype(np.float32)

        if df_info['num_cols'] is not None:
            X_num['train'] = train_df[df_info['num_cols']].values.astype(np.float32)
            X_num['val'] = val_df[df_info['num_cols']].values.astype(np.float32)
            X_num['test'] = test_df[df_info['num_cols']].values.astype(np.float32)

    else:
        X_cat = {} if df_info['cat_cols'] is not None else None
        X_num = {} if df_info['num_cols'] is not None or not is_y_cond else None
        y = {}

        num_cols_with_y = []
        if df_info['num_cols'] is not None:
            num_cols_with_y += df_info['num_cols']
        if not is_y_cond:
            num_cols_with_y.append(df_info['y_col'])

        if len(num_cols_with_y) > 0:
            X_num['train'] = train_df[num_cols_with_y].values.astype(np.float32)
            X_num['val'] = val_df[num_cols_with_y].values.astype(np.float32)
            X_num['test'] = test_df[num_cols_with_y].values.astype(np.float32)
        
        y['train'] = train_df[df_info['y_col']].values.astype(np.float32)
        y['val'] = val_df[df_info['y_col']].values.astype(np.float32)
        y['test'] = test_df[df_info['y_col']].values.astype(np.float32)

        if df_info['cat_cols'] is not None:
            X_cat['train'] = train_df[df_info['cat_cols']].to_numpy(dtype=np.str_)
            X_cat['val'] = val_df[df_info['cat_cols']].to_numpy(dtype=np.str_)
            X_cat['test'] = test_df[df_info['cat_cols']].to_numpy(dtype=np.str_)

    D = lib.Dataset(
        X_num,
        X_cat,
        y,
        y_info={},
        task_type=lib.TaskType(df_info['task_type']),
        n_classes=df_info['n_classes']
    )

    return lib.transform_dataset(D, T, None)

 
def make_dataset(
    data_path: str,
    T: lib.Transformations,
    num_classes: int,
    is_y_cond: bool,
    change_val: bool
):
    # classification
    if num_classes > 0:
        X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) or not is_y_cond else None
        X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) else None
        y = {} 

        for split in ['train', 'val', 'test']:
            X_num_t, X_cat_t, y_t = lib.read_pure_data(data_path, split)
            if X_num is not None:
                X_num[split] = X_num_t
            if not is_y_cond:
                X_cat_t = concat_y_to_X(X_cat_t, y_t)
            if X_cat is not None:
                X_cat[split] = X_cat_t
            y[split] = y_t
    else:
        # regression
        X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) else None
        X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) or not is_y_cond else None
        y = {}

        for split in ['train', 'val', 'test']:
            X_num_t, X_cat_t, y_t = lib.read_pure_data(data_path, split)
            if not is_y_cond:
                X_num_t = concat_y_to_X(X_num_t, y_t)
            if X_num is not None:
                X_num[split] = X_num_t
            if X_cat is not None:
                X_cat[split] = X_cat_t
            y[split] = y_t

    info = lib.load_json(os.path.join(data_path, 'info.json'))

    D = lib.Dataset(
        X_num,
        X_cat,
        y,
        y_info={},
        task_type=lib.TaskType(info['task_type']),
        n_classes=info.get('n_classes')
    )

    if change_val:
        D = lib.change_val(D)
    
    return lib.transform_dataset(D, T, None)