from envs.colour_grid_world import ColourGridWorld
from algorithms.q_learning import Q_Learning
import matplotlib.pyplot as plt 
from model_checking import dfa
from model_checking import lt
from model_checking import pctl
import numpy as np
from tools.plotting import draw_grid
import argparse
import sys
import importlib.util
import time
import tensorflow as tf
 
"""Q learning parameters"""
learning_rate = 0.1
discount_factor = 0.95
exploration_type = 'boltzmann'
exploration_parameter = 0.05

"""training parameters"""
property_path = "./properties/property_1.py"
steps = 100

"""environment parameters"""
random_action_probability = 0.05
episode_length = 1000

"""misc"""
seed = 0
logdir = "./logdir/"

def get_string_args():
    template = "Property {}, Learning rate {}, Discount {}, Exploration {}, Expl Param {}"
    return template.format(property_path.split('/')[-1],
                           learning_rate,
                           discount_factor,
                           exploration_type,
                           exploration_parameter)

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    """Q learning parameters"""
    parser.add_argument("--lr", default=learning_rate)
    parser.add_argument("--df", default=discount_factor)
    parser.add_argument("--exploration", default=exploration_type)
    parser.add_argument("--expl_parameter", default=exploration_parameter)

    """training parameters"""
    parser.add_argument("--property", default=property_path)
    parser.add_argument("--steps", default=steps)

    """environment parameters"""
    parser.add_argument("--random_action_probability", default=random_action_probability)
    parser.add_argument("--episode_length", default=episode_length)

    """misc"""
    parser.add_argument("--seed", default=seed)
    parser.add_argument("--logdir", default=logdir)
    args = parser.parse_args()

    try:
        """Q learning parameters"""
        learning_rate = float(args.lr)
        discount_factor = float(args.df)
        exploration_type = str(args.exploration)
        exploration_parameter = float(args.expl_parameter)

        """training parameters"""
        property_path = str(args.property)
        steps = int(args.steps)

        """environment parameters"""
        random_action_probability = float(args.random_action_probability)
        episode_length = int(args.episode_length)

        """misc"""
        seed = int(args.seed)
        logdir = str(args.logdir)
    except:
        raise TypeError

    string_args = get_string_args()
    print(string_args)

    np.random.seed(seed)

    # setup tensorboard
    summary_writer = tf.summary.create_file_writer(logdir)

    # setup environment
    env = ColourGridWorld(random_action_probability=random_action_probability, episode_length=episode_length)
    n_states = env.n_states
    n_actions = env.n_actions

    # load cost function
    spec=importlib.util.spec_from_file_location("property", property_path)
 
    # creates a new module based on spec
    properties = importlib.util.module_from_spec(spec)
    
    # executes the module in its own namespace
    # when a module is imported or reloaded.
    spec.loader.exec_module(properties)

    cost_function = properties.cost_function

    # setup q learning agent
    agent = Q_Learning(n_states, n_actions, alpha=learning_rate, discount=discount_factor, exploration=exploration_type, expl_parameter=exploration_parameter)

    # training loop
    with summary_writer.as_default():
        for episode in range(steps):

            episode_reward = 0.0
            episode_cost = 0.0

            terminated = False
            state, info = env.reset()
            labels = info["labels"]
            
            cost_function.reset()
            _, automaton_state = cost_function.step(labels)

            start_time = time.time()

            while not terminated:
                action = agent.step(state)
                next_state, reward, terminated, info = env.step(action)
                labels = info["labels"]

                cost, _ = cost_function.step(labels)
            
                episode_cost += cost
                episode_reward += reward
                
                tup = (state, action, reward, next_state, terminated)
                agent.update(tup)
                
                state = next_state

                if terminated:
                    tf.summary.scalar('reward', episode_reward, step=episode)
                    tf.summary.scalar('cost', episode_cost, step=episode)
                    template = "Episode {} completed in {} second. Reward {}, Cost {},  FPS {}"
                    current_time = time.time()
                    print(template.format(episode+1,
                                          current_time-start_time,
                                          episode_reward,
                                          episode_cost,
                                          episode_length/(current_time-start_time)))