import numpy as np
import os
import pandas as pd

from tabsyn import src
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split


class TabularDataset(Dataset):
    def __init__(self, X_num, X_cat):
        self.X_num = X_num
        self.X_cat = X_cat

    def __getitem__(self, index):
        this_num = self.X_num[index]
        this_cat = self.X_cat[index]

        sample = (this_num, this_cat)

        return sample

    def __len__(self):
        return self.X_num.shape[0]


def preprocess_from_numpy(
    data, 
    task_type='None', 
    inverse=False, 
    cat_encoding=None, 
    concat=True,
    n_classes=0,
    has_test=False
):
    T_dict = {}

    T_dict['normalization'] = "quantile"
    T_dict['num_nan_policy'] = 'mean'
    T_dict['cat_nan_policy'] =  None
    T_dict['cat_min_frequency'] = None
    T_dict['cat_encoding'] = cat_encoding
    T_dict['y_policy'] = "default"

    T = src.Transformations(**T_dict)

    dataset = make_dataset_from_numpy(
        data=data,
        T=T,
        task_type=task_type,
        change_val=False,
        concat=concat,
        n_classes=n_classes
    )

    if cat_encoding is None:
        X_num = dataset.X_num
        X_cat = dataset.X_cat

        if has_test:
            X_train_num, X_test_num = X_num['train'], X_num['test']
            X_train_cat, X_test_cat = X_cat['train'], X_cat['test']
            
            categories = src.get_categories(X_train_cat)

            d_numerical = X_train_num.shape[1]

            X_num = (X_train_num, X_test_num)
            X_cat = (X_train_cat, X_test_cat)
        else:
            if X_num is not None:
                X_train_num = X_num['train']
            X_train_cat = X_cat['train']

            categories = src.get_categories(X_train_cat)

            if X_num is not None:
                d_numerical = X_train_num.shape[1]
                X_num = X_train_num
            else:
                d_numerical = 0

            X_cat = X_train_cat


        if inverse:
            if X_num is not None:
                num_inverse = dataset.num_transform.inverse_transform
            else:
                num_inverse = None
            
            if X_cat is not None:
                cat_inverse = dataset.cat_transform.inverse_transform
            else:
                cat_inverse = None

            return X_num, X_cat, categories, d_numerical, num_inverse, cat_inverse
        else:
            return X_num, X_cat, categories, d_numerical
    else:
        return dataset



def preprocess(dataset_path, task_type = 'binclass', inverse = False, cat_encoding = None, concat = True):
    
    T_dict = {}

    T_dict['normalization'] = "quantile"
    T_dict['num_nan_policy'] = 'mean'
    T_dict['cat_nan_policy'] =  None
    T_dict['cat_min_frequency'] = None
    T_dict['cat_encoding'] = cat_encoding
    T_dict['y_policy'] = "default"

    T = src.Transformations(**T_dict)

    dataset = make_dataset(
        data_path = dataset_path,
        T = T,
        task_type = task_type,
        change_val = False,
        concat = concat
    )

    if cat_encoding is None:
        X_num = dataset.X_num
        X_cat = dataset.X_cat

        X_train_num, X_test_num = X_num['train'], X_num['test']
        X_train_cat, X_test_cat = X_cat['train'], X_cat['test']
        
        categories = src.get_categories(X_train_cat)
        d_numerical = X_train_num.shape[1]

        X_num = (X_train_num, X_test_num)
        X_cat = (X_train_cat, X_test_cat)


        if inverse:
            num_inverse = dataset.num_transform.inverse_transform
            cat_inverse = dataset.cat_transform.inverse_transform

            return X_num, X_cat, categories, d_numerical, num_inverse, cat_inverse
        else:
            return X_num, X_cat, categories, d_numerical
    else:
        return dataset


def update_ema(target_params, source_params, rate=0.999):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.
    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for target, source in zip(target_params, source_params):
        target.detach().mul_(rate).add_(source.detach(), alpha=1 - rate)



def concat_y_to_X(X, y):
    if X is None:
        return y.reshape(-1, 1)
    return np.concatenate([y.reshape(-1, 1), X], axis=1)


def preprocess_group_data(
    group_num,
    group_cat,
):
    T_dict = {}

    T_dict['normalization'] = "quantile"
    T_dict['num_nan_policy'] = 'mean'
    T_dict['cat_nan_policy'] =  None
    T_dict['cat_min_frequency'] = None
    T_dict['cat_encoding'] = None
    T_dict['y_policy'] = "default"

    T = src.Transformations(**T_dict)
    if len(group_num) > 0 and group_num[0].shape[1] > 0:
        num_arr = np.concatenate(group_num, axis=0)
        X_num = {
            'train': num_arr.astype(float)
        }
        y_arr = np.zeros((num_arr.shape[0], 1))
    else:
        X_num = None

    if len(group_cat) > 0 and group_cat[0].shape[1] > 0:
        cat_arr = np.concatenate(group_cat, axis=0)
        X_cat = {
            'train': cat_arr
        }
        y_arr = np.zeros((cat_arr.shape[0], 1))
    else:
        X_cat = None

    y = {
        'train': y_arr
    }

    D = src.Dataset(
        X_num,
        X_cat,
        y,
        y_info={},
        task_type=src.TaskType('None'),
        n_classes=0
    )

    transformed = src.transform_dataset(D, T, None)
    if transformed.X_num is not None:
        transformed_num = transformed.X_num['train']
    else:
        transformed_num = None

    if transformed.X_cat is not None:
        transformed_cat = transformed.X_cat['train'].astype(int)
    else:
        transformed_cat = None

    categories = src.get_categories(transformed_cat)

    if transformed_num is not None:
        d_numerical = transformed_num.shape[1]
    else:
        d_numerical = 0

    # recover to groups
    groups = group_num if group_num is not None else group_cat
    start = 0
    transformed_group_num = []
    transformed_group_cat = []
    for i, group in enumerate(groups):
        end = start + group.shape[0]
        batch_size = end - start
        if group_num[0].shape[1] > 0:
            transformed_group_num.append(transformed_num[start:end])
        else:
            transformed_group_num.append(np.empty((batch_size, 0)))
        if group_cat[0].shape[1] > 0:
            transformed_group_cat.append(transformed_cat[start:end])
        else:
            transformed_group_cat.append(np.empty((batch_size, 0)))

    return transformed_group_num, transformed_group_cat, d_numerical, categories


def make_dataset_from_numpy(
    data,
    T,
    task_type,
    change_val,
    concat=True,
    n_classes=0
):
    # classification
    if task_type == 'binclass' or task_type == 'multiclass':
        X_cat = {} if data.get('X_cat') is not None else None
        X_num = {} if data.get('X_num') is not None else None
        y = {} if data.get('y') is not None else None

        for split in ['train', 'test']:
            if X_num is not None:
                if split in data['X_num']:
                    X_num_t = data['X_num'][split]
                    X_num[split] = X_num_t
            if X_cat is not None:
                if split in data['X_cat']:
                    X_cat_t = data['X_cat'][split]
                    if concat:
                        X_cat_t = concat_y_to_X(X_cat_t, y_t)
                    X_cat[split] = X_cat_t  
            if y is not None:
                if split in data['y']:
                    y_t = data['y'][split]
                    y[split] = y_t
    else:
        # regression or None
        X_cat = {} if data.get('X_cat') is not None else None
        X_num = {} if data.get('X_num') is not None else None
        y = {} if data.get('y') is not None else None

        for split in ['train', 'test']:
            if X_num is not None:
                if split in data['X_num']:
                    X_num_t = data['X_num'][split]
                    if concat:
                        X_num_t = concat_y_to_X(X_num_t, y_t)
                    X_num[split] = X_num_t
            if X_cat is not None:
                if split in data['X_cat']:
                    X_cat_t = data['X_cat'][split]
                    X_cat[split] = X_cat_t
            if y is not None:
                if split in data['y']:
                    y_t = data['y'][split]
                    y[split] = y_t

    if X_num is not None and X_num['train'].shape[1] == 0:
        X_num = None

    if X_cat is not None and X_cat['train'].shape[1] == 0:
        X_cat = None
    
    D = src.Dataset(
        X_num,
        X_cat,
        y,
        y_info={},
        task_type=src.TaskType(task_type),
        n_classes=n_classes
    )

    if change_val:
        D = src.change_val(D)

    return src.transform_dataset(D, T, None)


def make_dataset(
    data_path: str,
    T: src.Transformations,
    task_type,
    change_val: bool,
    concat = True,
):

    # classification
    if task_type == 'binclass' or task_type == 'multiclass':
        X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy'))  else None
        X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) else None
        y = {} if os.path.exists(os.path.join(data_path, 'y_train.npy')) else None

        for split in ['train', 'test']:
            X_num_t, X_cat_t, y_t = src.read_pure_data(data_path, split)
            if X_num is not None:
                X_num[split] = X_num_t
            if X_cat is not None:
                if concat:
                    X_cat_t = concat_y_to_X(X_cat_t, y_t)
                X_cat[split] = X_cat_t  
            if y is not None:
                y[split] = y_t
    else:
        # regression or None
        X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) else None
        X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) else None
        y = {} if os.path.exists(os.path.join(data_path, 'y_train.npy')) else None

        for split in ['train', 'test']:
            X_num_t, X_cat_t, y_t = src.read_pure_data(data_path, split)

            if X_num is not None:
                if concat:
                    X_num_t = concat_y_to_X(X_num_t, y_t)
                X_num[split] = X_num_t
            if X_cat is not None:
                X_cat[split] = X_cat_t
            if y is not None:
                y[split] = y_t

    info = src.load_json(os.path.join(data_path, 'info.json'))

    D = src.Dataset(
        X_num,
        X_cat,
        y,
        y_info={},
        task_type=src.TaskType(info['task_type']),
        n_classes=info.get('n_classes')
    )

    if change_val:
        D = src.change_val(D)

    # def categorical_to_idx(feature):
    #     unique_categories = np.unique(feature)
    #     idx_mapping = {category: index for index, category in enumerate(unique_categories)}
    #     idx_feature = np.array([idx_mapping[category] for category in feature])
    #     return idx_feature

    # for split in ['train', 'val', 'test']:
    # D.y[split] = categorical_to_idx(D.y[split].squeeze(1))

    return src.transform_dataset(D, T, None)