import numpy as np
import pandas as pd
from collections import defaultdict
from DQN import *
from graph import *
from itertools import chain

from sklearn.feature_selection import SelectKBest, SelectFromModel
from sklearn.linear_model import LogisticRegression, Lasso, Ridge, RidgeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics.pairwise import euclidean_distances
from scipy.sparse.csgraph import laplacian
from scipy.linalg import eigh
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score, pairwise_distances
from sklearn.metrics import precision_score, recall_score, f1_score, mean_absolute_error, mean_squared_error
from sklearn.preprocessing import StandardScaler, MinMaxScaler, QuantileTransformer
from scipy.special import expit
from sklearn.metrics import make_scorer
from sklearn import linear_model
from sklearn.svm import LinearSVC
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split, cross_val_score
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.cluster import AgglomerativeClustering
from sklearn.feature_selection import mutual_info_regression
import torch
import torch.nn as nn
import torch.utils.data as Data
from torch_geometric.data import Data as gData
from logger import error, info

# data preprocess
def preprocess(data:pd.DataFrame):
    feature_names = list(data.columns)
    x = data.values[: , : -1]
    y = data.values[: , -1]
    scaler = MinMaxScaler(feature_range=(-1,1))
    x = scaler.fit_transform(x)
    data = pd.concat([pd.DataFrame(x), pd.DataFrame(y)], axis=1)
    data.columns = [str(i) for i in range(len(feature_names))]
    return data

# downstream task
def relative_absolute_error(y_test, y_predict):
    y_test = np.array(y_test)
    y_predict = np.array(y_predict)
    error = np.sum(np.abs(y_test - y_predict)) / np.sum(np.abs(np.mean(y_test) - y_test))
    return error

def downstream_task_new(data:pd.DataFrame, task_type:str, state_num=0):
    X = data.iloc[:, :-1]
    y = data.iloc[:, -1]
    if task_type == 'cls':
        clf = RandomForestClassifier(random_state=state_num)
        f1_list = []
        skf = StratifiedKFold(n_splits=5, random_state=state_num, shuffle=True)
        for train, test in skf.split(X, y):
            X_train, y_train, X_test, y_test = X.iloc[train, :], y.iloc[train], X.iloc[test, :], y.iloc[test]
            clf.fit(X_train, y_train)
            y_predict = clf.predict(X_test)
            f1_list.append(f1_score(y_test, y_predict, average='weighted'))
        return np.mean(f1_list)
    elif task_type == 'reg':
        kf = KFold(n_splits=5, random_state=state_num, shuffle=True)
        reg = RandomForestRegressor(random_state=state_num)
        rae_list = []
        for train, test in kf.split(X):
            X_train, y_train, X_test, y_test = X.iloc[train, :], y.iloc[train], X.iloc[test, :], y.iloc[test]
            reg.fit(X_train, y_train)
            y_predict = reg.predict(X_test)
            rae_list.append(1 - relative_absolute_error(y_test, y_predict))
        return np.mean(rae_list)
    elif task_type == 'rank':
        pass
    else:
        return -1
    

# head clusters
def select_meta_cluster1(g:gData, clusters:defaultdict, feature_names, epsilon, dqn_cluster:DQN1):
    q_vals, cluster_list, action_list = [], [], []
    for key, value in clusters.items():
        result = dqn_cluster.get_q_value(g, value)
        q_value = result[0].detach().numpy()[0]
        head_emb = result[1]
        q_vals.append(q_value)
        action_list.append(head_emb)
        cluster_list.append(key)
    graph_emb = result[2]
    if np.random.uniform() > epsilon:
        act_id = np.argmax(q_vals)
    else:
        act_id = np.random.randint(0, len(clusters))
    cluster_ind = cluster_list[act_id]  
    f_cluster =  list(clusters[cluster_ind])  
    action_emb = action_list[act_id]    
    f_names = np.array(feature_names)[list(clusters[cluster_ind])]
    info('Current select feature names : ' + str(f_names))
    return graph_emb, action_emb, f_cluster, f_names


def generate_next_state_of_meta_cluster1(g:gData, clusters, dqn_cluster:DQN1):
    q_vals, action_list = [], []
    for key, value in clusters.items():
        result = dqn_cluster.get_q_value_next_state(g, value)
        q_value = result[0].detach().numpy()[0]
        head_emb = result[1]
        q_vals.append(q_value)
        action_list.append(head_emb)

    graph_emb = result[2]
    action_emb = action_list[np.argmax(q_vals)]
    f_cluster = clusters[np.argmax(q_vals)]
    return  graph_emb, action_emb, f_cluster 



# operation
def select_operation(head_state, dqn_operation:DQN2, operation_set, eps_threshold):
    op_index = dqn_operation.choose_action(head_state, eps_threshold)
    op = operation_set[op_index]
    return op_index, op

def generate_next_state_of_meta_operation(head_state, dqn_operation:DQN2, operation_set): 
    op_index = dqn_operation.choose_next_action(head_state)
    op = operation_set[op_index]
    return op_index, op


def cube(x):
    return x ** 3

def justify_operation_type(o):
    if o == 'sqrt':
        o = np.sqrt
    elif o == 'square':
        o = np.square
    elif o == 'sin':
        o = np.sin
    elif o == 'cos':
        o = np.cos
    elif o == 'tanh':
        o = np.tanh
    elif o == 'reciprocal':
        o = np.reciprocal
    elif o == '+':
        o = np.add
    elif o == '-':
        o = np.subtract
    elif o == '/':
        o = np.divide
    elif o == '*':
        o = np.multiply
    elif o == 'stand_scaler':
        o = StandardScaler()
    elif o == 'minmax_scaler':
        o = MinMaxScaler(feature_range=(-1, 1))
    elif o == 'quan_trans':
        o = QuantileTransformer(random_state=0)
    elif o == 'exp':
        o = np.exp
    elif o == 'cube':
        o = cube
    elif o == 'sigmoid':
        o = expit
    elif o == 'log':
        o = np.log
    else:
        print('Please check your operation!')
    return o

# tail clusters
# def select_meta_cluster2(head_emb, graph_emb, op, feature_names, epsilon, dqn_cluster:DQN3):
#     q_vals, nodes_list,  = [], []
#     for ind in range(len(graph_emb)):
#         node_emb = graph_emb[ind]
#         q_value = dqn_cluster.get_q_value(head_emb, torch.tensor(op,dtype=torch.long), node_emb).detach().numpy()[0]
#         q_vals.append(q_value)
#         nodes_list.append(node_emb)
#     if np.random.uniform() > epsilon:
#         act_id = np.argmax(q_vals)
#     else:
#         act_id = np.random.randint(0, len(graph_emb))
#     action = nodes_list[act_id]
#     f_names = np.array(feature_names)[act_id] 
#     info('current select feature name : ' + str(f_names))
#     return act_id, action, f_names

def select_meta_cluster1(g:gData, clusters:defaultdict, feature_names, epsilon, dqn_cluster:DQN1):
    q_vals, cluster_list, action_list = [], [], []
    for key, value in clusters.items():
        result = dqn_cluster.get_q_value(g, value)
        q_value = result[0].detach().numpy()[0]
        head_emb = result[1]
        q_vals.append(q_value)
        action_list.append(head_emb)
        cluster_list.append(key)
    graph_emb = result[2]
    if np.random.uniform() > epsilon:
        act_id = np.argmax(q_vals)
    else:
        act_id = np.random.randint(0, len(clusters))
    cluster_ind = cluster_list[act_id]  
    f_cluster =  list(clusters[cluster_ind])  
    action_emb = action_list[act_id]    
    f_names = np.array(feature_names)[list(clusters[cluster_ind])]
    info('Current select feature names : ' + str(f_names))
    return graph_emb, action_emb, f_cluster, f_names

def select_meta_cluster2(head_emb, graph_emb, op, clusters:defaultdict, feature_names, epsilon, dqn_cluster:DQN3):
    q_vals, nodes_list  = [], []
    for key, value in clusters.items():
        cluster_emb = graph_emb[value]
        q_value = dqn_cluster.get_q_value(head_emb, torch.tensor(op,dtype=torch.long), torch.mean(cluster_emb,dim=0)).detach().numpy()[0]
        q_vals.append(q_value)
        nodes_list.append(key)
    if np.random.uniform() > epsilon:
        act_id = np.argmax(q_vals)
    else:
        act_id = np.random.randint(0, len(clusters))
    action = nodes_list[act_id] #cluster ind
    operand = clusters[action] # nodes ind
    cluster_emb = graph_emb[operand]
    action = torch.mean(cluster_emb,dim=0)
    f_names = np.array(feature_names)[operand] 
    info('current select feature name : ' + str(f_names))
    return operand, action, f_names
        

def generate_next_state_of_meta_cluster2(head_emb, graph_emb, op, clusters:defaultdict, dqn_cluster:DQN3):
    q_vals, action_list = [], []
    for key,value in clusters.items():
        cluster_emb = graph_emb[value]
        q_value = dqn_cluster.get_q_value(head_emb, torch.tensor(op,dtype=torch.long), torch.mean(cluster_emb,dim=0)).detach().numpy()[0]
        q_vals.append(q_value)
        action_list.append(torch.mean(cluster_emb,dim=0))
    action = action_list[np.argmax(q_vals)]
    return action
        
def select_meta_node2(head_emb, graph_emb, op, feature_names, epsilon, dqn_cluster:DQN3):
    q_vals, nodes_list,  = [], []
    for ind in range(len(graph_emb)):
        node_emb = graph_emb[ind]
        q_value = dqn_cluster.get_q_value(head_emb, torch.tensor(op,dtype=torch.long), node_emb).detach().numpy()[0]
        q_vals.append(q_value)
        nodes_list.append(node_emb)
    if np.random.uniform() > epsilon:
        act_id = np.argmax(q_vals)
    else:
        act_id = np.random.randint(0, len(graph_emb))
    
    action = nodes_list[act_id]
    f_names = np.array(feature_names)[act_id]
    info('current select feature name : ' + str(f_names))
    return [act_id], action, f_names

def generate_next_state_of_meta_node2(head_emb, graph_emb, op, dqn_cluster:DQN3):
    q_vals, action_list = [], []
    for ind in range(len(graph_emb)):
        node_emb = graph_emb[ind]
        q_value = dqn_cluster.get_q_value_next_state(head_emb, torch.tensor(op,dtype=torch.long), node_emb).detach().numpy()[0]
        q_vals.append(q_value)
        action_list.append(node_emb)
        
    action = action_list[np.argmax(q_vals)]
    return action

def column_exists(all_data, column):
    all_data = np.transpose(all_data)
    result = np.any(np.all(column == all_data, axis=1))
    return result
    

def test_task_new(Dg:pd.DataFrame, task='cls', state_num=0):
    X = Dg.iloc[:, :-1]
    y = Dg.iloc[:, -1]
    if task == 'cls':
        clf = RandomForestClassifier(random_state=state_num)
        acc_list, pre_list, rec_list, f1_list = [], [], [], []
        skf = StratifiedKFold(n_splits=5, random_state=state_num, shuffle=True)
        for train, test in skf.split(X, y):
            X_train, y_train, X_test, y_test = X.iloc[train, :], y.iloc[train], X.iloc[test, :], y.iloc[test]
            clf.fit(X_train, y_train)
            y_predict = clf.predict(X_test)
            acc_list.append(accuracy_score(y_test, y_predict))
            pre_list.append(precision_score(y_test, y_predict, average='weighted', zero_division=0))
            rec_list.append(recall_score(y_test, y_predict, average='weighted', zero_division=0))
            f1_list.append(f1_score(y_test, y_predict, average='weighted', zero_division=0))
        return np.mean(acc_list), np.mean(pre_list), np.mean(rec_list), np.mean(f1_list)
    elif task == 'reg':
        kf = KFold(n_splits=5, random_state=state_num, shuffle=True)
        reg = RandomForestRegressor(random_state=state_num)
        mae_list, mse_list, rae_list = [], [], []
        for train, test in kf.split(X):
            X_train, y_train, X_test, y_test = X.iloc[train, :], y.iloc[train], X.iloc[test, :], y.iloc[test]
            reg.fit(X_train, y_train)
            y_predict = reg.predict(X_test)
            mae_list.append(mean_absolute_error(y_test, y_predict))
            mse_list.append(mean_squared_error(y_test, y_predict))
            rae_list.append(relative_absolute_error(y_test, y_predict))
        return np.mean(mae_list), np.mean(mse_list), np.mean(rae_list)
    elif task == 'det':
        kf = KFold(n_splits=5, random_state=state_num, shuffle=True)
        knn_model = KNeighborsClassifier(n_neighbors=5)
        map_list = []
        f1_list = []
        ras = []
        for train, test in kf.split(X):
            X_train, y_train, X_test, y_test = X.iloc[train, :], y.iloc[train], X.iloc[test, :], y.iloc[test]
            knn_model.fit(X_train, y_train)
            y_predict = knn_model.predict(X_test)
            map_list.append(average_precision_score(y_test, y_predict))
            f1_list.append(f1_score(y_test, y_predict, average='macro'))
            ras.append(roc_auc_score(y_test, y_predict))
        return np.mean(map_list), np.mean(f1_list), np.mean(ras)
    elif task == 'rank':
        pass
    else:
        return -1
    

def feature_distance(features):
    r = torch.tensor(features)
    return torch.cdist(r.transpose(-1,0) , r.transpose(-1,0),
                       p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary').numpy()

def spectral_clustering(g_pyg:gData, mode='a', alpha=0.5):
    """
    num_clusters: number of clusters
    features: numpy array of node features, shape (n_nodes, n_features)
    alpha: weight factor between adjacency matrix and feature similarity
    """
    if(mode == 'a' or mode == 'f'):
        features = g_pyg.x
    else:
        features = None
    # Adjacency matrix
    edge_index = g_pyg.edge_index
    num_nodes = g_pyg.num_nodes
    A = np.zeros([num_nodes, num_nodes])
    A[edge_index[0], edge_index[1]] = 1
    A[edge_index[1], edge_index[0]] = 1
    num_clusters = int(np.sqrt(g_pyg.x.shape[0]))
    # Node feature similarity matrix
    if (mode == 'a'):
        S = euclidean_distances(features)
        S = np.exp(-S ** 2 / S.std()) 
        W = alpha * A + (1 - alpha) * S
    elif(mode == 's'):
        W = A
    else:
        S = euclidean_distances(features)
        W = np.exp(-S ** 2 / S.std()) 
    L = laplacian(W, normed=True)
    eigenvalues, eigenvectors = eigh(L)
    eigenvectors = eigenvectors / np.linalg.norm(eigenvectors, axis=1, keepdims=True)
    clustering = AgglomerativeClustering(n_clusters=num_clusters, metric='precomputed', linkage='single').fit(eigenvectors)
    labels = clustering.labels_
    clusters = defaultdict(list)
    for ind, item in enumerate(labels):
        clusters[item].append(ind)
    return clusters

def update_data(g_pyg:gData, D_train:pd.DataFrame, parent, op_index, f_generate, final_name):
    train_data = D_train.values[: , :-1]

    for i, head_node in enumerate(parent):
        new_train_node = f_generate.T[i]
        if(column_exists(train_data, new_train_node)):
            info("The transformed feature exists, continue to next feature.")
            continue
        new_train_node = new_train_node.reshape(-1,1)
        train_data = np.hstack((train_data, new_train_node))
        new_node_feature, w, u = nodes_feature(new_train_node)

        g_nx = pyg2nx(g_pyg)
        g_nx.add_node(train_data.shape[1]-1,node_feature=torch.tensor(new_node_feature[0],dtype=torch.float))
        g_nx.add_weighted_edges_from([(head_node, train_data.shape[1]-1, op_index+1)])
        g_pyg = nx2pyg(g_nx)

        df_new_node = pd.DataFrame(new_train_node)
        df_new_node.columns = [str(final_name[i])]
        x_ = D_train.iloc[: ,:-1]
        y_ = D_train.iloc[: , -1]
        x_ = pd.concat([x_, df_new_node], axis=1)
        D_train = pd.concat([x_, y_], axis=1)

    feature_names = list(D_train.columns)
    return g_pyg, D_train, feature_names


# node-wsise prune
def prune(g_pyg:gData, D_train:pd.DataFrame, train_original:pd.DataFrame, FEATURE_LIMIT):
    selector = SelectKBest(mutual_info_regression, k=FEATURE_LIMIT).fit(D_train.iloc[:, :-1], D_train.iloc[:, -1])
    cols = selector.get_support(indices=True)
    first_nodes = np.arange(train_original.shape[1]-1)
    cols = np.concatenate((cols, first_nodes))
    list_node = np.sort(cols)[::-1]

    begin_nodes = g_pyg.edge_index[0].numpy()
    end_nodes = g_pyg.edge_index[1].numpy()
    path_all = []
    for node in list_node:
        path = [node]
        while(node not in first_nodes):
            node_ind = np.argwhere(end_nodes == node)
            node = begin_nodes[node_ind[0][0]]
            path.insert(0, node)
        path_all.append(path)

    sel_node = set(list(chain.from_iterable(path_all)))
    all_node = set(list(np.arange(D_train.shape[1]-1)))
    del_node = all_node - sel_node
    g_nx = pyg2nx(g_pyg)
    for i in del_node:
        g_nx.remove_node(i)
    g_pyg = nx2pyg(g_nx)
    feature_names = D_train.columns
    sel_node = list(sel_node)
    sel_node.append(D_train.shape[1]-1)
    feature_new = [feature_names[i] for i in sel_node]
    X_new = D_train.values[:,sel_node]
    D_train = pd.DataFrame(X_new)
    D_train.columns = feature_new
    return g_pyg, D_train


def complex_reward(g_pyg:gData, cluster, train_original:pd.DataFrame):
    result = 0
    first_nodes = np.arange(train_original.shape[1]-1)
    begin_nodes = g_pyg.edge_index[0].numpy()
    end_nodes = g_pyg.edge_index[1].numpy()
    
    for node in cluster:
        path = [node]
        while(node not in first_nodes):
            node_ind = np.argwhere(end_nodes == node)
            node = begin_nodes[node_ind[0][0]]
            path.insert(0, node)
        result += (1 / np.exp(len(path)))
    
    return result / len(cluster)