from envs.colour_grid_world import ColourGridWorld
from envs.bridge_crossing import BridgeCrossing
from envs.media_streaming import MediaStreaming
from envs.colour_bomb_grid_world import ColourBombGridWorld
from envs.colour_bomb_grid_world_v2 import ColourBombGridWorldV2
from envs.pacman import Pacman
from algorithms.q_learning import Q_Learning
import matplotlib.pyplot as plt 
from model_checking import dfa
from model_checking import pctl
from model_checking.helpers import construct_product_state_space
import numpy as np
import argparse
import sys
import importlib.util
import time
import tensorflow as tf
from common.logger import Logger
 
"""Q learning arguments"""
learning_rate = 0.1
discount_factor = 0.95
exploration_type = 'boltzmann'
exploration_parameter = 0.05

"""training arguments"""
property_path = "./properties/colour_grid_world/property_1.py"
num_frames = 100000
counter_factual_experiences = True
cost_coefficient = 10.0
reward_shaping = "none"

"""environment arguments"""
env_id = "colour_grid_world"
random_action_probability = 0.05
episode_length = 1000

"""logger arguments"""
log_every = 1000
logger_window_size = 100

"""misc"""
seed = 0
logdir = "./logdir/"

def get_string_args():
    template = "Env {}, Property {}, Counter Factual Experiences {}, Cost Coefficient {}, Reward Shaping {}"
    return template.format(env_id, property_path.split('/')[-1],
                           counter_factual_experiences,
                           cost_coefficient,
                           reward_shaping)

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    """Q learning arguments"""
    parser.add_argument("--lr", default=learning_rate)
    parser.add_argument("--df", default=discount_factor)
    parser.add_argument("--exploration", default=exploration_type)
    parser.add_argument("--expl-parameter", default=exploration_parameter)

    """training arguments"""
    parser.add_argument("--property", default=property_path)
    parser.add_argument("--num-frames", default=num_frames)
    parser.add_argument("--no-cf", action='store_false', default=counter_factual_experiences)
    parser.add_argument("--cost-coeff", default=cost_coefficient)
    parser.add_argument("--reward-shaping", default=reward_shaping)

    """environment arguments"""
    parser.add_argument("--env", default=env_id)
    parser.add_argument("--random-action-probability", default=random_action_probability)
    parser.add_argument("--episode-length", default=episode_length)

    """logger arguments"""
    parser.add_argument("--log-every", default=log_every)
    parser.add_argument("--logger-window-size", default=logger_window_size)

    """misc"""
    parser.add_argument("--seed", default=seed)
    parser.add_argument("--logdir", default=logdir)
    args = parser.parse_args()

    try:
        """Q learning arguments"""
        learning_rate = float(args.lr)
        discount_factor = float(args.df)
        exploration_type = str(args.exploration)
        exploration_parameter = float(args.expl_parameter)

        """training arguments"""
        property_path = str(args.property)
        num_frames = int(args.num_frames)
        counter_factual_experiences = bool(args.no_cf)
        cost_coefficient = float(args.cost_coeff)
        reward_shaping = str(args.reward_shaping)

        """environment arguments"""
        env_id = str(args.env)
        random_action_probability = float(args.random_action_probability)
        episode_length = int(args.episode_length)

        """logger arguments"""
        log_every = int(args.log_every)
        logger_window_size = int(args.logger_window_size)

        """misc"""
        seed = int(args.seed)
        logdir = str(args.logdir)
    except:
        raise TypeError

    string_args = get_string_args()
    print(string_args)

    np.random.seed(seed)

    # setup tensorboard
    summary_writer = tf.summary.create_file_writer(logdir)
    logger = Logger(log_every, tensorboard=True, summary_writer=summary_writer, stats_window_size=logger_window_size, prefix='rollout')

    # setup environment
    if env_id == "colour_grid_world":
        env = ColourGridWorld(seed=seed, random_action_probability=random_action_probability, episode_length=episode_length)
    elif env_id == "bridge_crossing":
        env = BridgeCrossing(seed=seed, random_action_probability=random_action_probability, episode_length=episode_length)
    elif env_id == "colour_bomb_grid_world":
        env = ColourBombGridWorld(seed=seed, random_action_probability=random_action_probability, episode_length=episode_length)
    elif env_id == "colour_bomb_grid_world_v2":
        env = ColourBombGridWorldV2(seed=seed, random_action_probability=random_action_probability, episode_length=episode_length)
    elif env_id == "media_streaming":
        env = MediaStreaming(seed=seed, episode_length=episode_length)
    elif env_id == "pacman":
        env = Pacman(seed=seed, episode_length=episode_length)
    n_states = env.n_states
    n_actions = env.n_actions

    # load cost function and atomaton
    spec=importlib.util.spec_from_file_location("property", property_path)
 
    # creates a new module based on spec
    properties = importlib.util.module_from_spec(spec)
    
    # executes the module in its own namespace
    # when a module is imported or reloaded.
    spec.loader.exec_module(properties)

    automaton = properties.automaton

    cost_function = dfa.Cost_Function(automaton, reward_shaping=reward_shaping, discount=discount_factor)

    # define the product state space
    prod_state_space = construct_product_state_space(n_states, len(automaton.states))
    n_prod_states = np.prod(prod_state_space.shape)

    # setup q learning agent
    agent = Q_Learning(n_prod_states, n_actions, alpha=learning_rate, discount=discount_factor, exploration=exploration_type, expl_parameter=exploration_parameter)

    # training loop
    state, info = env.reset()
    labels = info["labels"]
    
    cost_function.reset()
    _, _, automaton_state = cost_function.step(labels)

    for frame_idx in range(num_frames):

        prod_state = prod_state_space[automaton_state, state]
        action = agent.step(prod_state)

        next_state, reward, terminated, info = env.step(action)
        labels = info["labels"]

        violation, cost, next_automaton_state = cost_function.step(labels)
    
        logger.step({
            'done' : terminated, 
            'ep_rew' : reward, 
            'ep_cost' : float(violation),
            'ep_len' : 1.0, 
            'ep_overrides' : 0.0, 
            'is_success' : float(not bool(violation))
        })

        # update the agent with experience
        if counter_factual_experiences:
            # simulate state, action, next_state, from all automaton states
            for counter_fac_automaton_state in automaton.states:
                counter_fac_prod_state = prod_state_space[counter_fac_automaton_state, state]
                counter_fac_next_automaton_state = automaton.transition(counter_fac_automaton_state, labels)
                counter_fac_next_prod_state = prod_state_space[counter_fac_next_automaton_state, next_state]
                counter_fac_viol = bool(counter_fac_next_automaton_state in automaton.accepting)
                counter_fac_cost = float(counter_fac_viol) + cost_function.potential(counter_fac_automaton_state, counter_fac_next_automaton_state)
                penalty = counter_fac_cost * cost_coefficient
                # simulate episode termination - this is crucial!
                counter_fac_terminated = terminated or bool(counter_fac_viol)
                tup = (counter_fac_prod_state, action, reward-penalty, counter_fac_next_prod_state, counter_fac_terminated)
                agent.update(tup)
        else:
            next_prod_state = prod_state_space[next_automaton_state, next_state]
            penalty = cost * cost_coefficient
            # simulate episode termination - this is crucial!
            simulated_terminated = terminated or bool(violation)
            tup = (prod_state, action, reward-penalty, next_prod_state, simulated_terminated)
            agent.update(tup)
        
        state = next_state
        automaton_state = next_automaton_state

        if terminated:
            state, info = env.reset()
            labels = info["labels"]
                
            cost_function.reset()
            _, _, automaton_state = cost_function.step(labels)


    

    

