import torch
import numpy as np
import pandas as pd
import os
import math
import time
import sys
import argparse
import datetime
from collections import defaultdict
from copy import deepcopy
from tools import *
from graph import *
from DQN import *
from logger import *
from operate import *
from task_mapping import task_dict, task_type, task_measure

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append('./')
torch.manual_seed(0)

def init_param():
    parser = argparse.ArgumentParser(description="PyTorch Experiment")
    parser.add_argument('--file_name', type=str, default='airfoil', help='data file name')
    parser.add_argument('--task', type=str, default='ng', help='ng/cls/reg/det/rank, if provided ng, the model will take the task type in config')
    parser.add_argument('--seed', type=int, default=0, help='random_state')
    parser.add_argument('--data_path', type=str, default='./data', help='the data base path')
    parser.add_argument('--hidden-dim', type=int, default=32, help='the hidden dimension of graph encode network')
    #default true
    parser.add_argument('--save_model', action="store_false", help='train or fine tune, save the model checkpoints')
    parser.add_argument('--log-level', type=str, default='info', help='log level')

    parser.add_argument('--episodes', type=int, default=3, help='episodes for training') # *
    parser.add_argument('--steps', type=int, default=10, help='steps for each episode') # *
    parser.add_argument('--memory', type=int, default=8, help='memory capacity')
    #default true
    parser.add_argument('--train_mode', action="store_false", help='if train_mode is False, ensure the models exist in model path')

    parser.add_argument('--eps_start', type=float, default=0.9, help='eps start')
    parser.add_argument('--eps_end', type=float, default=0.5, help='eps end')
    parser.add_argument('--eps_decay', type=int, default=100, help='eps decay')

    #default true
    parser.add_argument('--cluster', action="store_false", help='group wise')
    #deafult false
    parser.add_argument('--operand_cluster', action='store_true', help='operand cluster or not')
    parser.add_argument('--enlarge_num', type=int, default=4, help='feature space enlarge')
    args, _ = parser.parse_known_args()
    return args

def model_train(param):
    base_path = param['data_path']
    data_path = os.path.join(base_path, param['file_name']+'.hdf')
    info('Read the data from {}'.format(data_path))
    if(param['train_mode']):
        data_key = 'train'
    else:
        data_key = 'test'
    train_data = pd.read_hdf(data_path, key=data_key)
    # the file and correspond downstream task type
    if param['task'] == 'ng':
        task_name = task_dict[param['file_name']]
    else:
        assert param['task'] in task_type
        task_name = param['task']
    info('The task is performing ' + task_name + ' on dataset ' + param['file_name'])
    measure = task_measure[task_name]
    info('The related measurement is ' + measure)
    info("Initialize the data and features...")
    D_train = preprocess(train_data)
        
    if param['file_name'] == 'ap_omentum_ovary':
        k = 100
        selector = SelectKBest(mutual_info_regression, k=k).fit(D_train.iloc[:, :-1], D_train.iloc[:, -1])
        cols = selector.get_support()
        X_new = D_train.iloc[:, :-1].loc[:, cols]
        D_train = pd.concat([X_new, D_train.iloc[:, -1]], axis=1)
    
    train_original = D_train.copy()
    init_train_per = downstream_task_new(train_original, task_name, state_num=0)
    info('The base performance of train data is: {}'.format(init_train_per))

    # RL-related hyperparameters
    EPISODES = param['episodes']
    STEPS = param['steps']
    HIDDEN_DIM = param['hidden_dim']
    MEMORY_CAPACITY = param['memory']
    EPS_START = param['eps_start']
    EPS_END = param['eps_end']
    EPS_DECAY = param['eps_decay']
    FEATURE_LIMIT = train_original.shape[1] * param['enlarge_num']
    N_ACTIONS = len(operation_set)

    # initialize the feature-state transformation graph
    x = D_train.values[: , :-1]
    feature_names = D_train.columns
    g_pyg = get_graph(x,feature_names)
    FEA_DIM = g_pyg.x.shape[1]
    HIDDEN_DIM = param['hidden_dim']
    ACTION_DIM = g_pyg.x.shape[1]
    NODE_DIM = 64

    dqn_head = DQN1(feature_dim=FEA_DIM, hidden_dim=HIDDEN_DIM, output_dim=NODE_DIM, num_relation=N_ACTIONS, MEMORY_CAPACITY=MEMORY_CAPACITY)
    dqn_operation = DQN2(input_dim=NODE_DIM*2, num_rel=N_ACTIONS, MEMORY_CAPACITY=MEMORY_CAPACITY) 
    dqn_tail = DQN3(num_rel=N_ACTIONS, op_dim=ACTION_DIM, input_dim=NODE_DIM*2, node_dim=NODE_DIM, MEMORY_CAPACITY=MEMORY_CAPACITY)
    
    # test mode: load the state dict
    if(not param['train_mode']):
        if(not os.path.exists('./model')):
            error("The model path doesn't exist.")
            return
        MODEL_NAME = './model/' + params['file_name'] + '/'
        # ep = str(param['episodes'])
        # st = str(param['steps'])

        if(os.path.exists(MODEL_NAME+'head_eval_net.pth')):
            dqn_head.eval_net.load_state_dict(torch.load(MODEL_NAME+'head_eval_net.pth'))
            dqn_head.target_net.load_state_dict(torch.load(MODEL_NAME+'head_eval_net.pth'))
            info(f"Loading head agent: eval net and target net.")
        else:
            error(f"The RL model doesn't exist.")
            return
        if(os.path.exists(MODEL_NAME+'tail_eval_net.pth')):
            dqn_tail.eval_net.load_state_dict(torch.load(MODEL_NAME+'tail_eval_net.pth'))
            dqn_tail.target_net.load_state_dict(torch.load(MODEL_NAME+'tail_eval_net.pth'))
            info(f"Loading tail agent: eval net and target net.")
        else:
            error(f"The RL model doesn't exist.")
            return
        if(os.path.exists(MODEL_NAME+'operation_eval_net.pth')):
            dqn_operation.eval_net.load_state_dict(torch.load(MODEL_NAME+'operation_eval_net.pth'))
            dqn_operation.target_net.load_state_dict(torch.load(MODEL_NAME+'operation_eval_net.pth'))
            info(f"Loading operation agent: eval net and target net.")
        else:
            error(f"The RL model doesn't exist.")
            return

    episode = 0
    steps_done = 0
    best_train_per = init_train_per
    D_OPT = D_train.copy()

    info('Initialize the model hyperparameter configure.')
    info('Epsilon start with {}, end with {}, the decay is {}.'.format(EPS_START, EPS_END, EPS_DECAY))
    info('The training start...')
    training_start_time = time.time()
    
    
    while episode < EPISODES:
        local_train_per = init_train_per
        D_train = train_original.copy()
        feature_names = D_train.columns
        old_train_per = init_train_per
        eps_start_time = time.time()
        step = 0
        g_pyg = get_graph(D_train.values[: , :-1], feature_names)
        D_LOCAL_OPT = D_train.copy()
        graph_local_best = deepcopy(g_pyg)

        while step < STEPS:
            feature_names = D_train.columns
            x = D_train.values[: , :-1]
            info(f"Current features are: {list(feature_names)}")
            step_start_time = time.time()
            steps_done += 1
            eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1.0 * steps_done / EPS_DECAY)

            if(param['cluster']):
                clusters = spectral_clustering(g_pyg)
            else:
                clusters = defaultdict(list)
                for i in range(x.shape[1]):
                    clusters[i] = [i]

            graph_emb, head_emb, f_cluster, f_names1 = select_meta_cluster1(g_pyg, clusters, feature_names, eps_threshold, dqn_head)
            
            op_index, op = select_operation(head_emb, dqn_operation, operation_set, eps_threshold)
            info(f"Current operation is: {op}.")
            
            if(param['train_mode']):
                com_reward = complex_reward(g_pyg, f_cluster, train_original)
                dqn_head.store_graph(g_pyg)
                dqn_head.store_cluster(f_cluster)
                
            if op in O1:
                op_sign = justify_operation_type(op)
                f_generate, final_name, parent = unary_transform(D_train, op, op_sign, f_cluster, f_names1)
                g_pyg, D_train, feature_names = update_data(g_pyg, D_train, parent, op_index, f_generate, final_name)
                info(f"Current features are: {feature_names}")
            elif op in O2:
                op_func = justify_operation_type(op)
                if(param['operand_cluster']):
                    act_cluster_ind, tail_emb, f_names2 = select_meta_cluster2(head_emb, graph_emb, op_index, clusters, feature_names, eps_threshold, dqn_tail)
                else:
                    act_cluster_ind, tail_emb, f_names2 = select_meta_node2(head_emb, graph_emb, op_index, feature_names, eps_threshold, dqn_tail)
                for ind, act_ind in enumerate(act_cluster_ind):
                    if(param['operand_cluster']):
                        f_generate, final_name, parent = binary_transform(D_train, op, op_func, f_cluster, act_ind, f_names1, f_names2[ind])
                    else:
                        f_generate, final_name, parent = binary_transform(D_train, op, op_func, f_cluster, act_ind, f_names1, f_names2)
                    
                    if len(f_generate) == 0:
                        info('The transformation is invalid.')
                        continue
                    contains_inf = np.isinf(f_generate).any()
                    if contains_inf:
                        info('The transformation contains inf.')
                        continue
                    if np.max(f_generate) > 1000:
                        scaler = MinMaxScaler()
                        f_generate = scaler.fit_transform(f_generate)
                    g_pyg, D_train, feature_names = update_data(g_pyg, D_train, parent, op_index, f_generate, final_name)
                info(f"Current features are: {feature_names}")

            new_train_per = downstream_task_new(D_train, task_name, state_num=0)
            reward = new_train_per - old_train_per
            old_train_per = new_train_per

            if new_train_per > best_train_per:
                best_train_per = new_train_per
                D_OPT = D_train.copy()

            if new_train_per > local_train_per:
                local_train_per = new_train_per
                graph_local_best = deepcopy(g_pyg)
                D_LOCAL_OPT = D_train.copy()

            if(param['train_mode']):
                r_c1, r_op, r_c2 = reward+com_reward , reward+com_reward , reward+com_reward
                if(param['cluster']):
                    x = D_train.values[: , :-1]
                    clusters = spectral_clustering(g_pyg)
                else:
                    clusters = defaultdict(list)
                    for i in range(len(feature_names)-1):
                        clusters[i] = [i]

                next_graph_emb, next_head_emb, next_f_cluster = generate_next_state_of_meta_cluster1(g_pyg, clusters, dqn_head)
                dqn_head.store_graph(g_pyg)
                dqn_head.store_cluster(next_f_cluster)
                next_op_ind, new_op = generate_next_state_of_meta_operation(next_head_emb, dqn_operation, operation_set)
                if op in O2 and new_op in O2:
                    if(param['operand_cluster']):
                        next_tail_emb = generate_next_state_of_meta_cluster2(next_head_emb, next_graph_emb, next_op_ind, clusters, dqn_tail)
                    else:
                        next_tail_emb = generate_next_state_of_meta_node2(next_head_emb, next_graph_emb, next_op_ind, dqn_tail)
                    dqn_tail.store_transition(head_emb.detach().numpy(), op_index, tail_emb.detach().numpy(), r_c2, \
                                                next_head_emb.detach().numpy(), next_op_ind, next_tail_emb.detach().numpy())
    
                dqn_head.store_transition(r_c1)
                dqn_operation.store_transition(head_emb.detach().numpy(), op_index, r_op, next_head_emb.detach().numpy())
                if dqn_head.neg_memory_counter > dqn_head.MEMORY_CAPACITY and dqn_head.pos_memory_counter > dqn_head.MEMORY_CAPACITY:
                    dqn_head.learn()
                if dqn_operation.neg_memory_counter > dqn_operation.MEMORY_CAPACITY and dqn_operation.pos_memory_counter > dqn_operation.MEMORY_CAPACITY:
                    dqn_operation.learn()
                if dqn_tail.neg_memory_counter > dqn_tail.MEMORY_CAPACITY and dqn_tail.pos_memory_counter > dqn_tail.MEMORY_CAPACITY:
                    dqn_tail.learn()

            if D_train.shape[1] > FEATURE_LIMIT and new_train_per != local_train_per:
                if episode < param['episodes']*0.3:
                    g_pyg, D_train = prune(g_pyg, D_train, train_original, FEATURE_LIMIT)
                else:
                    D_train = D_LOCAL_OPT
                    g_pyg = graph_local_best

            info('New train performance is: {:.6f}  Best train performance is: {:.6f}  Base train performance is: {:.6f}'.\
                 format(new_train_per, best_train_per, init_train_per))
            info('Episode {}, Step {} ends!'.format(episode, step))
            info('Current spend time for step-{} is: {:.1f}s'.format(step, time.time() - step_start_time))
            step += 1
        episode += 1
        info('Current spend time for episode-{} is: {:.1f}s'.format(episode, time.time() - eps_start_time))
    
    info('Total spend time for is: {:.1f}s'.format(time.time() - training_start_time))
    info('Exploration ends!')
    info('Begin evaluation...')

    # train mode and save model
    if(param['train_mode'] and param['save_model']):
        if(not os.path.exists('./model')):
            os.mkdir('./model')

        MODEL_PATH = './model/' + params['file_name'] + '/'
        if(not os.path.exists(MODEL_PATH)):
            os.mkdir(MODEL_PATH[: -1])
        ep = str(param['episodes'])
        st = str(param['steps'])
        torch.save(dqn_head.eval_net.state_dict(), MODEL_PATH+'head_eval_net.pth')
        torch.save(dqn_tail.eval_net.state_dict(), MODEL_PATH+'tail_eval_net.pth')
        torch.save(dqn_operation.eval_net.state_dict(), MODEL_PATH+'operation_eval_net.pth')

    #store results
    if not os.path.exists('./results/'):
        os.mkdir('./results/')
    if not os.path.exists('./results/' + param['file_name'] + '/'):
        os.mkdir('./results/' + param['file_name'] + '/')
    OPT_PATH = './results/' + param['file_name'] + '/'
    now = datetime.datetime.now()
    formatted_date = now.strftime("%m%d_%H%M")
    D_OPT.to_csv(OPT_PATH  + f"episode_{params['episodes']}_step{params['steps']}_per{best_train_per:.5f}_{formatted_date}.csv")

    if task_name == 'reg':
        mae0, rmse0, rae0 = test_task_new(train_original, task=task_name, state_num=0)
        mae1, rmse1, rae1 = test_task_new(D_OPT, task=task_name, state_num=0)
        info('MAE on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(mae0, mae1))
        info('RMSE on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(rmse0, rmse1))
        info('1 - RAE on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(1-rae0, 1-rae1))
    elif task_name == 'cls':
        acc0, precision0, recall0, f1_0 = test_task_new(train_original, task=task_name, state_num=0)
        acc1, precision1, recall1, f1_1 = test_task_new(D_OPT, task=task_name, state_num=0)
        info('ACC on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(acc0, acc1))
        info('Precision on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(precision0, precision1))
        info('Recall on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(recall0, recall1))
        info('F1 Score on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(f1_0, f1_1))
    elif task_name == 'det':
        map0, f1_0, ras0 = test_task_new(train_original, task=task_name,state_num=0)
        map1, f1_1, ras1 = test_task_new(D_OPT, task=task_name, state_num=0)
        info('Average Precision on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(map0, map1))
        info('F1 Score on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(f1_0, f1_1))
        info('ROC AUC Score on:\n \
             original data is {:.6f}, best train data is {:.6f}.'.format(ras0, ras1))
    else:
        error('wrong task name!!!!!')
        assert False
    info('Total using time: {:.1f}s'.format(time.time() - training_start_time))


if __name__ == "__main__":
    try:
        params = init_param()
        params = vars(params)
        if not os.path.exists('./log'):
            os.mkdir('./log')
        now = datetime.datetime.now()
        formatted_date = now.strftime("%m%d_%H%M")
        log = Logger(params, f"./log/{params['file_name']}_episode{params['episodes']}_steps{params['steps']}_{formatted_date}")
        model_train(params)

    except Exception as exception:
        error(exception)
        raise