import os
import json
import shutil
import tensorflow._api.v2.compat.v1 as tf
import numpy as np 
import copy

from utils.pipeline import Pipeline
from utils.oneline import OneLine
from utils import config

tf.disable_v2_behavior()

def path_check(dic_path):
    for path in dic_path.values():
        if not os.path.exists(os.path.join(os.getcwd(), path)):
            os.makedirs(os.path.join(os.getcwd(), path))
            
def copy_conf_file(dic_path, dic_agent_conf, dic_traffic_env_conf):
    json.dump(dic_agent_conf, open(os.path.join(dic_path["PATH_TO_WORK_DIRECTORY"], "agent.conf"), "w"), indent=4)
    json.dump(dic_traffic_env_conf, open(os.path.join(dic_path["PATH_TO_WORK_DIRECTORY"], "traffic_env.conf"), "w"), indent=4)
    
def copy_cityflow_file(dic_path, dic_traffic_env_conf):
    shutil.copy(os.path.join(dic_path["PATH_TO_DATA"], dic_traffic_env_conf["TRAFFIC_FILE"]),
                os.path.join(dic_path["PATH_TO_WORK_DIRECTORY"], dic_traffic_env_conf["TRAFFIC_FILE"]))
    shutil.copy(os.path.join(dic_path["PATH_TO_DATA"], dic_traffic_env_conf["ROADNET_FILE"]),
                os.path.join(dic_path["PATH_TO_WORK_DIRECTORY"], dic_traffic_env_conf["ROADNET_FILE"]))

def spectral_norm(w, iteration=1):
   w_shape = w.shape.as_list()
   w = tf.reshape(w, [-1, w_shape[-1]])

   u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)

   u_hat = u
   v_hat = None
   for i in range(iteration):
       """
       power iteration
       Usually iteration = 1 will be enough
       """
       v_ = tf.matmul(u_hat, tf.transpose(w))
       v_hat = tf.nn.l2_normalize(v_)

       u_ = tf.matmul(v_hat, w)
       u_hat = tf.nn.l2_normalize(u_)

   u_hat = tf.stop_gradient(u_hat)
   v_hat = tf.stop_gradient(v_hat)

   sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))

   with tf.control_dependencies([u.assign(u_hat)]):
       w_norm = w / sigma
       w_norm = tf.reshape(w_norm, w_shape)


   return w_norm, u

def merge(dic_tmp, dic_to_change):
    dic_result = copy.deepcopy(dic_tmp)
    dic_result.update(dic_to_change)
    return dic_result


def pipeline_wrapper(dic_agent_conf, dic_traffic_env_conf, dic_path, offline=False):
    ppl = Pipeline(dic_agent_conf=dic_agent_conf,
                   dic_traffic_env_conf=dic_traffic_env_conf,
                   dic_path=dic_path,
                   offline = offline
                   )
    ppl.run(multi_process=False)

    print("pipeline_wrapper end")
    return


def oneline_wrapper(dic_agent_conf, dic_traffic_env_conf, dic_path):
    oneline = OneLine(dic_agent_conf=dic_agent_conf,
                      dic_traffic_env_conf=merge(config.dic_traffic_env_conf, dic_traffic_env_conf),
                      dic_path=merge(config.DIC_PATH, dic_path)
                      )
    oneline.train()
    return
