import pandas as pd
import numpy as np
# from torch_geometric.data import Data
from logger import *

O1 = ['sqrt', 'square', 'sin', 'cos', 'tanh', 'stand_scaler',
      'minmax_scaler', 'quan_trans', 'sigmoid', 'log', 'reciprocal', 'cube']
O2 = ['+', '-', '*', '/']
O3 = ['stand_scaler', 'minmax_scaler', 'quan_trans']

operation_set = O1 + O2

def add_unary(op, f_name):
    op = op[0:3]
    return f'{op}({f_name})'




def add_binary(op, pos_1, pos_2):
    return f'({pos_1}{op}{pos_2})'

def operate_two_features_new(f_cluster1, f_cluster2, op, op_func, f_names1, f_names2):
    feas, feas_names = [], []
    for i in range(f_cluster1.shape[1]):
        feas.append(op_func(f_cluster1[:, i], f_cluster2))
        feas_names.append(add_binary(op, str(f_names1[i]), str(f_names2)))
    feas = np.array(feas)
    feas_names = np.array(feas_names)
    return feas.T, feas_names


def unary_transform(D_train:pd.DataFrame, op, op_sign, f_cluster, f_names1):
    f_new, f_new_name, parent = [], [], []
    trans_train_data = D_train.values[: , :-1]
    f_cluster_train = trans_train_data[: , f_cluster]
    if op == 'sqrt':
        for i in range(f_cluster_train.shape[1]):
            if np.sum(f_cluster_train[:, i] < 0) == 0:
                f_new.append(op_sign(f_cluster_train[:, i]))
                f_new_name.append(add_unary(op, f_names1[i]))
                parent.append(f_cluster[i])
        f_generate = np.array(f_new).T
    elif op == 'reciprocal':
        for i in range(f_cluster_train.shape[1]):
            if np.sum(f_cluster_train[:, i] == 0) == 0:
                f_new.append(op_sign(f_cluster_train[:, i]))
                f_new_name.append(add_unary(op, f_names1[i]))
                parent.append(f_cluster[i])
        f_generate = np.array(f_new).T
    elif op == 'log':
        for i in range(f_cluster_train.shape[1]):
            if np.sum(f_cluster_train[:, i] <= 0) == 0:
                f_new.append(op_sign(f_cluster_train[:, i]))
                f_new_name.append(add_unary(op, f_names1[i]))
                parent.append(f_cluster[i])
        f_generate = np.array(f_new).T
    elif op in O3:
        f_generate = op_sign.fit_transform(f_cluster_train)
        f_new_name = [add_unary(op, f_n) for f_n in f_names1]
        parent = f_cluster
    else:
        f_generate = op_sign(f_cluster_train)
        f_new_name = [add_unary(op, f_n) for f_n in f_names1]
        parent = f_cluster

    return f_generate, f_new_name, parent


def binary_transform(D_train:pd.DataFrame, op, op_func, f_cluster, act_ind, f_names1, f_names2):
    
    trans_train_data = D_train.values[: , :-1]
    head_cluster_train = trans_train_data[: , f_cluster]
    tail_node_train = trans_train_data[: , act_ind]
    
    if op == '/' and np.sum(tail_node_train == 0) > 0:
        f_generate, f_new_name, parent = [],[],[]
        return f_generate, f_new_name, parent

    f_generate, f_new_name = operate_two_features_new(head_cluster_train, tail_node_train, op, op_func, f_names1, f_names2)
    parent = f_cluster
    return f_generate, f_new_name, parent

