#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Python codes for 'Causal Discovery with Reinforcement Learning', ICLR 2020 (oral)
Authors: Shengyu Zhu, Huawei Noah's Ark Lab,
         Ignavier Ng, University of Toronto (work was done during an internship at Huawei Noah's Ark Lab)
         Zhitang Chen, Huawei Noah's Ark Lab
"""

import os
import logging
import platform
import random
import numpy as np
import pandas as pd
from pytz import timezone
from datetime import datetime
import matplotlib.pyplot as plt
import tensorflow as tf

from data_loader import DataGenerator_read_data
from models import Actor
from rewards import get_Reward_BIC, get_Reward_MDS
from helpers.config_graph import get_config, print_config
from helpers.dir_utils import create_dir
from helpers.log_helper import LogHelper
from helpers.tf_utils import set_seed
from helpers.analyze_utils import convert_graph_int_to_adj_mat, graph_prunned_by_coef, \
                                  count_accuracy, graph_prunned_by_coef_2nd
# from helpers.cam_with_pruning_cam import pruning_cam
from helpers.lambda_utils import BIC_lambdas, BIC_MDS_lambdas
from helpers.other_utils import keep_skeleton
# Configure matplotlib for plotting
import matplotlib
matplotlib.use('Agg')


def main():
    # Get running configuration
    config, _ = get_config()
    # Reproducibility
    set_seed(config.seed)
    if config.read_data:
        data_index_path = '{}/data_index.npy'.format(config.data_path)
        data_path = '{}/data.npy'.format(config.data_path)
        dag_index_path = '{}/dag_index.npy'.format(config.data_path)
        dag_path = '{}/dag.npy'.format(config.data_path)
        training_set = DataGenerator_read_data(data_index_path, data_path, dag_index_path, dag_path, config.normalize, config.transpose)
        config.max_length = training_set.inputdata_index.shape[1]
    else:
        raise ValueError("Only support importing data from existing files")
    score_type = config.score_type
    reg_type = config.reg_type
    if config.lambda_flag_default:
        sl_index, su_index, _ = BIC_lambdas(training_set.inputdata_index, None, None, training_set.dag_index_true.T, reg_type, score_type)
        sl, su, _ = BIC_MDS_lambdas(training_set.inputdata, None, None, training_set.dag_true.T, reg_type, score_type, config.n_domains)
        lambda1 = 0
        lambda1_upper = 5
        lambda1_update_add = 1
        lambda2 = 1/(10**(np.round(config.max_length/3)))
        lambda2_upper = 0.01
        lambda2_update_mul = 10
        lambda_iter_num = config.lambda_iter_num
        
    # actor
    actor = Actor(config)
    callreward_BIC = get_Reward_BIC(actor.batch_size, config.max_length, actor.input_dimension, training_set.inputdata_index,
                            sl_index, su_index, lambda1_upper, score_type, reg_type, config.l1_graph_reg, False)
    callreward_MDS = get_Reward_MDS(actor.batch_size, config.max_length-1, actor.input_dimension,
                                    training_set.inputdata_index,
                                    sl, su, lambda1_upper, score_type, reg_type, config.l1_graph_reg, False, config.n_domains)
    sess_config = tf.ConfigProto(log_device_placement=False)
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        # Run initialize op
        sess.run(tf.global_variables_initializer())
        for i in (range(1, config.stage1_epoch + 1)):
            input_batch = training_set.train_batch(actor.batch_size, actor.input_dimension, with_index=True)
            graphs_feed = sess.run(actor.graphs, feed_dict={actor.input_: input_batch})
            reward_feed = callreward_BIC.cal_rewards(graphs_feed, lambda1, lambda2)
            # Get feed dict
            feed = {actor.input_: input_batch, actor.reward_: -reward_feed[:,0], actor.graphs_:graphs_feed}
            summary, base_op, score_test, probs, graph_batch, \
                reward_batch, reward_avg_baseline, train_step1, train_step2 = sess.run([actor.merged, actor.base_op,
                actor.test_scores, actor.log_softmax, actor.graph_batch, actor.reward_batch, actor.avg_baseline, actor.train_step1,
                actor.train_step2], feed_dict=feed)
            # update lambda1, lamda2
            if (i+1) % lambda_iter_num == 0:
                ls_kv = callreward_BIC.update_all_scores(lambda1, lambda2)
                graph_int, score_min, cyc_min = np.int32(ls_kv[0][0]), ls_kv[0][1][1], ls_kv[0][1][-1]
                if cyc_min < 1e-5:
                    lambda1_upper = score_min
                lambda1 = min(lambda1+lambda1_update_add, lambda1_upper)
                lambda2 = min(lambda2*lambda2_update_mul, lambda2_upper)
                graph_batch = convert_graph_int_to_adj_mat(graph_int)
                if reg_type == 'LR':
                    graph_batch_pruned = np.array(graph_prunned_by_coef(graph_batch, training_set.inputdata_index, th=0.7))
                elif reg_type == 'QR':
                    graph_batch_pruned = np.array(graph_prunned_by_coef_2nd(graph_batch, training_set.inputdata_index))
                elif reg_type == 'GPR':
                    # The R codes of CAM pruning operates the graph form that (i,j)=1 indicates i-th node-> j-th node
                    # so we need to do a tranpose on the input graph and another tranpose on the output graph
                    graph_batch_pruned = np.transpose(pruning_cam(training_set.inputdata_index, np.array(graph_batch).T))
                # estimate accuracy
                acc_est = count_accuracy(training_set.dag_index_true, graph_batch.T)
                acc_est2 = count_accuracy(training_set.dag_index_true, graph_batch_pruned.T)
                print("epoch: {}, acc:".format(i), acc_est2)
                print("graph: ", graph_batch_pruned)

        aug_dag = graph_batch_pruned

        # reset acyclicity constraints
        # lambda1 = 0
        # lambda1_upper = 100
        # lambda1_update_add = 20
        # lambda2 = 1/(10**(np.round((config.max_length-1)/3)))
        # lambda2_upper = 0.01
        # lambda2_update_mul = 10

        for i in (range(1, config.stage2_epoch + 1)):
            input_batch = training_set.train_batch(actor.batch_size, actor.input_dimension, with_index=True)
            graphs_feed = sess.run(actor.graphs, feed_dict={actor.input_: input_batch})
            graphs_feed_masked, nons_vars = keep_skeleton(graphs_feed, aug_dag)
            reward_feed = callreward_MDS.cal_rewards(graphs_feed_masked, nons_vars, lambda1, lambda2, config.lambda3)
            # Get feed dict
            feed = {actor.input_: input_batch, actor.reward_: -reward_feed[:,0], actor.graphs_:graphs_feed}
            summary, base_op, score_test, probs, graph_batch, \
                reward_batch, reward_avg_baseline, train_step1, train_step2 = sess.run([actor.merged, actor.base_op,
                actor.test_scores, actor.log_softmax, actor.graph_batch, actor.reward_batch, actor.avg_baseline, actor.train_step1,
                actor.train_step2], feed_dict=feed)
            # update lambda1, lamda2
            if (i+1) % lambda_iter_num == 0:
                ls_kv = callreward_MDS.update_all_scores(lambda1, lambda2)
                graph_int, score_min, cyc_min = np.int32(ls_kv[0][0]), ls_kv[0][1][1], ls_kv[0][1][-1]
                if cyc_min < 1e-5:
                    lambda1_upper = score_min
                lambda1 = min(lambda1+lambda1_update_add, lambda1_upper)
                lambda2 = min(lambda2*lambda2_update_mul, lambda2_upper)
                graph_batch = convert_graph_int_to_adj_mat(graph_int)
                if reg_type == 'LR':
                    graph_batch_pruned = np.array(graph_prunned_by_coef(graph_batch, training_set.inputdata, th=0.7))
                elif reg_type == 'QR':
                    graph_batch_pruned = np.array(graph_prunned_by_coef_2nd(graph_batch, training_set.inputdata, th=0.1))
                elif reg_type == 'GPR':
                    # The R codes of CAM pruning operates the graph form that (i,j)=1 indicates i-th node-> j-th node
                    # so we need to do a tranpose on the input graph and another tranpose on the output graph
                    graph_batch_pruned = np.transpose(pruning_cam(training_set.inputdata, np.array(graph_batch).T))
                # estimate accuracy
                acc_est = count_accuracy(training_set.dag_true, graph_batch.T)
                acc_est2 = count_accuracy(training_set.dag_true, graph_batch_pruned.T)
                print("epoch: {}, acc:".format(i), acc_est2)
                print("graph: ", graph_batch_pruned)

        np.save(config.out_path, graph_batch_pruned)


if __name__ == '__main__':
    main()
