import pandas as pd
from OpenFE import OpenFE, get_candidate_features
from utils import node_to_formula, formula_to_node, calculate_new_features
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import StratifiedKFold, KFold
import numpy as np
from sklearn.metrics import f1_score


def rf_classify(n_jobs):
    model = RandomForestClassifier(random_state=0, n_jobs=n_jobs)
    return model


def rf_regression(n_jobs):
    model = RandomForestRegressor(random_state=0, n_jobs=n_jobs)
    return model


def rae(y_true: np.ndarray, y_pred: np.ndarray):
    up = np.abs(y_pred - y_true).sum()
    down = np.abs(y_true.mean() - y_true).sum()
    score = 1 - up / down
    return score


def f1_metric(y_true, y_pred):
    if len(np.unique(y_true)) > 2:
        score = f1_score(y_true, y_pred, average="micro")
    else:
        score = f1_score(y_true, y_pred)
    return score


def process_cat(data, data_test, d_columns):
    for column in data.columns:
        if data[column].dtype == 'category' and column not in d_columns:
            d_columns.append(column)
    n = data.shape[0]
    df = pd.concat([data, data_test], axis=0)
    if d_columns:
        df_d = pd.get_dummies(df[d_columns])
        df = pd.concat([df, df_d], axis=1)
        for column in d_columns: del df[column]
    data, data_test = df[:n], df[n:]
    return data, data_test



def feature_generation(data, label, data_test, label_test,
                       d_columns, TASK, model, metric_function, n_jobs):
    cat_features = d_columns
    ord_features = []
    num_features = []
    for feature in data.columns:
        if feature in cat_features:
            continue
        if data[feature].nunique() <= 100:
            ord_features.append(feature)
        else:
            num_features.append(feature)
    candidate_features_list = get_candidate_features(numerical_features=num_features,
                                                     categorical_features=cat_features,
                                                     ordinal_features=ord_features)

    if TASK == 'classification':
        if label[label.columns[0]].nunique() > 2:
            metric = 'multi_logloss'
            # metric = 'multi_error'
        else:
            metric = 'binary_logloss'
    else:
        metric = 'rmse'

    if TASK == 'classification':
        X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.33, random_state=42, stratify=label)
    else:
        X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.33, random_state=42)
    if TASK == 'regression':
        init_scores = np.array([np.mean(label)] * len(label))
    elif label[label.columns[0]].nunique() > 2:
        prob = label[label.columns[0]].value_counts().sort_index().to_list()
        prob = prob / np.sum(prob)
        prob = [list(prob)]
        init_scores = np.array(prob * len(label))
    else:
        def logit(x):
            return np.log(x / (1 - x))
        init_scores = np.array([logit(np.mean(label))] * len(label))
    init_scores = pd.DataFrame(init_scores, index=data.index)
    ofe = OpenFE()
    features = ofe.fit(data=data, label=label,
                       candidate_features_list=candidate_features_list,
                       metric=metric,
                       init_scores=init_scores,
                       train_index=X_train.index, val_index=X_test.index,
                       categorical_features=cat_features,
                       remain_for_stage2=None,
                       remain=2000,
                       n_jobs=n_jobs, fold=1, task=TASK)

    new_features = [feature for feature, _ in ofe.new_features_list[:50]]
    data, data_test = calculate_new_features(data, data_test, new_features, n_jobs=n_jobs)
    data, data_test = process_cat(data, data_test, d_columns)
    model.fit(data, label)
    pred = model.predict(data_test)
    label_test = label_test[label_test.columns[0]]
    score = metric_function(label_test, pred)
    return score
