#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Python codes for 'Causal Discovery with Reinforcement Learning', ICLR 2020 (oral)
Authors: Shengyu Zhu, Huawei Noah's Ark Lab,
         Ignavier Ng, University of Toronto (work was done during an internship at Huawei Noah's Ark Lab)
         Zhitang Chen, Huawei Noah's Ark Lab
"""

import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
import logging
import platform
import random
import numpy as np
import pandas as pd
from pytz import timezone
from datetime import datetime
import matplotlib.pyplot as plt
import tensorflow as tf

from data_loader import DataGenerator_read_data
from models import Actor
from rewards import get_Reward_GS, get_Reward_MDS
from helpers.config_graph import get_config, print_config
from helpers.dir_utils import create_dir
from helpers.log_helper import LogHelper
from helpers.tf_utils import set_seed
from helpers.analyze_utils import convert_graph_int_to_adj_mat, graph_prunned_by_coef, \
                                  count_accuracy, graph_prunned_by_coef_2nd
from helpers.cam_with_pruning_cam import pruning_cam
from helpers.lambda_utils import BIC_lambdas
from helpers.other_utils import keep_skeleton, keep_skeleton_with_index

# Configure matplotlib for plotting
import matplotlib
matplotlib.use('Agg')


def main():
    config, _ = get_config()
    # Reproducibility
    set_seed(config.seed)
    
    if config.read_data:
        data_index_path = '{}/data_index.npy'.format(config.data_path)
        data_path = '{}/data.npy'.format(config.data_path)
        dag_index_path = '{}/dag_index.npy'.format(config.data_path)
        dag_path = '{}/dag.npy'.format(config.data_path)
        training_set = DataGenerator_read_data(data_index_path, data_path, dag_index_path, dag_path, config.normalize, config.transpose)
        config.max_length = training_set.inputdata_index.shape[1]
    else:
        raise ValueError("Only support importing data from existing files")
        
    # set penalty weights
    score_type = config.score_type
    reg_type = config.reg_type
    
    if config.lambda_flag_default:
        
        sl, su, strue = BIC_lambdas(training_set.inputdata_index, config, None, None, training_set.dag_index_true.T, reg_type, score_type)
        
        lambda1 = 0
        lambda1_upper = 5
        lambda1_update_add = 1
        lambda2 = 1/(10**(np.round(config.max_length/3)))
        lambda2_upper = 0.01
        lambda2_update_mul = 10
        lambda_iter_num = config.lambda_iter_num
        
    # actor
    actor = Actor(config)

    callreward_GS = get_Reward_GS(config, actor.batch_size, actor.input_dimension, training_set.inputdata_index,
                            sl, su, lambda1_upper, score_type, reg_type, config.l1_graph_reg, False)
    callreward_MDS = get_Reward_MDS(config, actor.batch_size, actor.input_dimension, training_set.inputdata_index,
                            sl, su, lambda1_upper, score_type, reg_type, config.l1_graph_reg, False)
    sess_config = tf.ConfigProto(log_device_placement=False)
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        # Run initialize op
        sess.run(tf.global_variables_initializer())
        for i in (range(1, config.stage1_epoch + 1)):
            input_batch = training_set.train_batch(actor.batch_size, actor.input_dimension, with_index=True)
            graphs_feed = sess.run(actor.graphs, feed_dict={actor.input_: input_batch})
            reward_feed = callreward_GS.cal_rewards(graphs_feed, lambda1, lambda2)
            # Get feed dict
            feed = {actor.input_: input_batch, actor.reward_: -reward_feed[:,0], actor.graphs_:graphs_feed}

            adj_prob, summary, base_op, score_test, probs, graph_batch, \
                reward_batch, reward_avg_baseline, train_step1, train_step2 = sess.run([actor.decoder.adj_prob, actor.merged, actor.base_op,
                actor.test_scores, actor.log_softmax, actor.graph_batch, actor.reward_batch, actor.avg_baseline, actor.train_step1,
                actor.train_step2], feed_dict=feed)

            # update lambda1, lamda2
            if (i) % lambda_iter_num == 0:
                ls_kv = callreward_GS.update_all_scores(lambda1, lambda2)
                graph_int, score_min, cyc_min = np.int32(ls_kv[0][0]), ls_kv[0][1][1], ls_kv[0][1][-1]
                if cyc_min < 1e-5:
                    lambda1_upper = score_min
                lambda1 = min(lambda1+lambda1_update_add, lambda1_upper)
                lambda2 = min(lambda2*lambda2_update_mul, lambda2_upper)

                graph_batch = convert_graph_int_to_adj_mat(graph_int)

                # if reg_type == 'LR':
                #     graph_batch_pruned = np.array(graph_prunned_by_coef(graph_batch, training_set.inputdata_index, th=0.7))
                # elif reg_type == 'QR':
                #     graph_batch_pruned = np.array(graph_prunned_by_coef_2nd(graph_batch, training_set.inputdata))
                # elif reg_type == 'GPR':
                #     # The R codes of CAM pruning operates the graph form that (i,j)=1 indicates i-th node-> j-th node
                #     # so we need to do a tranpose on the input graph and another tranpose on the output graph
                #     graph_batch_pruned = np.transpose(pruning_cam(training_set.inputdata, np.array(graph_batch).T))
                # else:
                graph_batch_pruned = np.transpose(pruning_cam(training_set.inputdata, np.array(graph_batch[:-1, :-1]).T, config.stage1_prune, config.max_length-1))
                tmp = np.zeros(graph_batch.shape)
                tmp[:-1, :-1] = graph_batch_pruned
                tmp[:, -1] = graph_batch[:, -1]
                graph_batch_pruned = tmp
                # graph_batch_pruned = graph_batch
            
                # estimate accuracy
                acc_est2 = count_accuracy(training_set.dag_true, graph_batch_pruned[:-1, :-1].T)
                print("epoch: {}, acc:".format(i), acc_est2)
                print("graph_before_pruned: ", graph_batch)
                print("graph_pruned: ", graph_batch_pruned)
                np.save('{}/stage1_before_pruned.npy'.format(config.out_path), graph_batch)
                np.save('{}/stage1_pruned.npy'.format(config.out_path), graph_batch_pruned)
        aug_dag = graph_batch_pruned

        lambda1 = 0
        lambda1_upper = 6
        lambda1_update_add = 3
        lambda2 = 1 / (10 ** (np.round(config.max_length / 3)))
        lambda2_upper = 0.01
        lambda2_update_mul = 100

        for i in (range(1, config.stage2_epoch + 1)):
            input_batch = training_set.train_batch(actor.batch_size, actor.input_dimension, with_index=True)
            graphs_feed = sess.run(actor.graphs, feed_dict={actor.input_: input_batch})
            graphs_feed_masked, nons_vars = keep_skeleton_with_index(graphs_feed, aug_dag)
            reward_feed = callreward_MDS.cal_rewards(graphs_feed_masked, nons_vars, lambda1, lambda2, config.lambda3)
            # Get feed dict
            feed = {actor.input_: input_batch, actor.reward_: -reward_feed[:,0], actor.graphs_: graphs_feed}

            adj_prob, summary, base_op, score_test, probs, graph_batch, \
            reward_batch, reward_avg_baseline, train_step1, train_step2 = sess.run(
                [actor.decoder.adj_prob, actor.merged, actor.base_op,
                 actor.test_scores, actor.log_softmax, actor.graph_batch, actor.reward_batch, actor.avg_baseline,
                 actor.train_step1,
                 actor.train_step2], feed_dict=feed)

            # update lambda1, lamda2
            if (i + 1) % lambda_iter_num == 0:
                ls_kv = callreward_MDS.update_all_scores(lambda1, lambda2)
                graph_int, score_min, cyc_min = np.int32(ls_kv[0][0]), ls_kv[0][1][1], ls_kv[0][1][-1]
                if cyc_min < 1e-5:
                    lambda1_upper = score_min
                lambda1 = min(lambda1 + lambda1_update_add, lambda1_upper)
                lambda2 = min(lambda2 * lambda2_update_mul, lambda2_upper)

                graph_batch = convert_graph_int_to_adj_mat(graph_int)

                # if reg_type == 'LR':
                #     graph_batch_pruned = np.array(graph_prunned_by_coef(graph_batch, training_set.inputdata_index, th=0.7))
                # elif reg_type == 'QR':
                #     graph_batch_pruned = np.array(graph_prunned_by_coef_2nd(graph_batch, training_set.inputdata))
                # elif reg_type == 'GPR':
                #     # The R codes of CAM pruning operates the graph form that (i,j)=1 indicates i-th node-> j-th node
                #     # so we need to do a tranpose on the input graph and another tranpose on the output graph
                #     graph_batch_pruned = np.transpose(pruning_cam(training_set.inputdata, np.array(graph_batch).T))
                # else:
                graph_batch_pruned = np.transpose(pruning_cam(training_set.inputdata, np.array(graph_batch[:-1, :-1]).T, config.stage2_prune, config.max_length-1))
                tmp = np.zeros(graph_batch.shape)
                tmp[:-1, :-1] = graph_batch_pruned
                tmp[:, -1] = graph_batch[:, -1]
                # graph_batch_pruned = tmp
                
                # estimate accuracy
                acc_est2 = count_accuracy(training_set.dag_true, graph_batch_pruned[:-1, :-1].T)
                print("epoch: {}, acc:".format(i), acc_est2)
                print("graph: ", graph_batch_pruned)
        # np.save(os.path.join(config.data_path, 'AUG_DAG.npy'), graph_batch_pruned)
                np.save('{}/stage2_before_pruned.npy'.format(config.out_path), graph_batch[:-1, :-1])
                np.save('{}/stage2_pruned.npy'.format(config.out_path), graph_batch_pruned[:-1, :-1])
        # np.save('./before_pruned.npy', graph_batch)




if __name__ == '__main__':
    main()
