from envs.colour_grid_world import ColourGridWorld
from envs.bridge_crossing import BridgeCrossing
from envs.media_streaming import MediaStreaming
from envs.colour_bomb_grid_world import ColourBombGridWorld
from envs.colour_bomb_grid_world_v2 import ColourBombGridWorldV2
from envs.pacman import Pacman
from algorithms.q_learning import Q_Learning, boltzmann_probs
from algorithms.value_iteration import Value_Iteration
import matplotlib.pyplot as plt 
from model_checking import pctl, dfa
from model_checking.helpers import construct_product_state_space, compute_product_state_action_transition_matrix, init_model_checker, init_dynamics_model
import numpy as np
import argparse
import sys
import importlib.util
import time
import jax 
import tensorflow as tf
from common.logger import Logger
from tqdm import tqdm
 
"""task Q learning arguments"""
task_learning_rate = 0.1
task_discount_factor = 0.95
task_exploration_type = "boltzmann"
task_exploration_parameter = 0.05

"""safe Q learning arguments"""
safe_learning_rate = 0.1
safe_discount_factor = 0.95
safe_exploration_type = "boltzmann"
safe_exploration_parameter = 0.01

"""value iteration for pretrained safe policy arguments"""
pretrained_backup = False
value_iteration_steps = 1000
value_iteration_stopping_condition = 0.0

"""training arguments"""
property_path = "./properties/colour_grid_world/property_1.py"
num_frames = 100000
counter_factual_experiences = True
cost_coefficient = 10.0
reward_shaping = "none"

"""shielding arguments"""
model_checking_type = "exact"
approximate_model = False
shielding_type = "task_prod"
num_samples = 512
satisfaction_probability = 0.9
safe_policy_mode = "explore"
prior_type = "none"

"""environment arguments"""
env_id = "colour_grid_world"
random_action_probability = 0.05
episode_length = 1000

"""logger arguments"""
log_every = 1000
logger_window_size = 100

"""misc"""
device_type = "cpu"
seed = 0
logdir = "./logdir/"


def get_string_args():
    if approximate_model:
        template = "Env {}, Property {}, Pretrained {}, Model Checking Type {}, Approximate Model {}, Satisfaction Probability {}, Prior Type {}, Shielding Type {}"
        return template.format(env_id,
                            property_path.split('/')[-1],
                            pretrained_backup,
                            model_checking_type,
                            approximate_model,
                            satisfaction_probability,
                            prior_type,
                            shielding_type)
    else:
        template = "Env {}, Property {}, Pretrained {}, Model Checking Type {}, Approximate Model {}, Satisfaction Probability {}, Shielding Type {}"
        return template.format(env_id,
                            property_path.split('/')[-1],
                            pretrained_backup,
                            model_checking_type,
                            approximate_model,
                            satisfaction_probability,
                            shielding_type)
    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    """task policy Q learning arguments"""
    parser.add_argument("--tp-lr", default=task_learning_rate)
    parser.add_argument("--tp-df", default=task_discount_factor)
    parser.add_argument("--tp-exploration", default=task_exploration_type)
    parser.add_argument("--tp-expl-parameter", default=task_exploration_parameter)

    """safe policy Q learning arguments"""
    parser.add_argument("--sp-lr", default=safe_learning_rate)
    parser.add_argument("--sp-df", default=safe_discount_factor)
    parser.add_argument("--sp-exploration", default=safe_exploration_type)
    parser.add_argument("--sp-expl-parameter", default=safe_exploration_parameter)

    """value iteration for pretrained safe policy arguments"""
    parser.add_argument("--pretrained-backup", action='store_true', default=pretrained_backup)
    parser.add_argument("--vi-steps", default=value_iteration_steps)
    parser.add_argument("--vi-stopping-condition", default=value_iteration_stopping_condition)

    """training arguments"""
    parser.add_argument("--property", default=property_path)
    parser.add_argument("--num-frames", default=num_frames)
    parser.add_argument("--no-cf", action='store_false', default=counter_factual_experiences)
    parser.add_argument("--cost-coeff", default=cost_coefficient)
    parser.add_argument("--reward-shaping", default=reward_shaping)

    """shielding arguments"""
    parser.add_argument("--model-checking-type", default=model_checking_type)
    parser.add_argument("--approximate-model", action='store_true', default=approximate_model)
    parser.add_argument("--shielding-type", default=shielding_type)
    parser.add_argument("--num-samples", default=num_samples)
    parser.add_argument("--sat-prob", default=satisfaction_probability)
    parser.add_argument("--safe-policy-mode", default=safe_policy_mode)
    parser.add_argument("--prior-type", default=prior_type)

    """environment arguments"""
    parser.add_argument("--env", default=env_id)
    parser.add_argument("--random-action-probability", default=random_action_probability)
    parser.add_argument("--episode-length", default=episode_length)

    """logger arguments"""
    parser.add_argument("--log-every", default=1000)
    parser.add_argument("--logger-window-size", default=100)

    """misc"""
    parser.add_argument("--device-type", default=device_type)
    parser.add_argument("--seed", default=seed)
    parser.add_argument("--logdir", default=logdir)
    args = parser.parse_args()

    try:
        """task Q learning arguments"""
        task_learning_rate = float(args.tp_lr)
        task_discount_factor = float(args.tp_df)
        task_exploration_type = str(args.tp_exploration)
        task_exploration_parameter = float(args.tp_expl_parameter)

        """safe Q learning arguments"""
        safe_learning_rate = float(args.sp_lr)
        safe_discount_factor = float(args.sp_df)
        safe_exploration_type = str(args.sp_exploration)
        safe_exploration_parameter = float(args.sp_expl_parameter)

        """value iteration for pretrained safe policy arguments"""
        pretrained_backup = bool(args.pretrained_backup)
        value_iteration_steps = int(args.vi_steps)
        value_iteration_stopping_condition = float(args.vi_stopping_condition)

        """training arguments"""
        property_path = str(args.property)
        num_frames = int(args.num_frames)
        counter_factual_experiences = bool(args.no_cf)
        cost_coefficient = float(args.cost_coeff)
        reward_shaping = str(args.reward_shaping)

        """shielding arguments"""
        model_checking_type = str(args.model_checking_type)
        approximate_model = bool(args.approximate_model)
        shielding_type = str(args.shielding_type)
        num_samples = int(args.num_samples)
        satisfaction_probability = float(args.sat_prob)
        safe_policy_mode = str(args.safe_policy_mode)
        prior_type = str(args.prior_type)

        """environment arguments"""
        env_id = str(args.env)
        random_action_probability = float(args.random_action_probability)
        episode_length = int(args.episode_length)

        """logger arguments"""
        log_every = int(args.log_every)
        logger_window_size = int(args.logger_window_size)

        """misc"""
        device_type = str(args.device_type)
        seed = int(args.seed)
        logdir = str(args.logdir)
    except:
        raise TypeError

    string_args = get_string_args()
    print(string_args)

    # setup JAX device
    device = jax.devices(device_type)[0]
    cpu_device = jax.devices('cpu')[0]

    key = jax.random.PRNGKey(seed)
    np.random.seed(seed)

    # setup tensorboard
    summary_writer = tf.summary.create_file_writer(logdir)
    logger = Logger(log_every, tensorboard=True, summary_writer=summary_writer, stats_window_size=logger_window_size, prefix='rollout')

    # setup environment
    if env_id == "colour_grid_world":
        env = ColourGridWorld(seed=seed, random_action_probability=random_action_probability, episode_length=episode_length)
    elif env_id == "bridge_crossing":
        env = BridgeCrossing(seed=seed, random_action_probability=random_action_probability, episode_length=episode_length)
    elif env_id == "colour_bomb_grid_world":
        env = ColourBombGridWorld(seed=seed, random_action_probability=random_action_probability, episode_length=episode_length)
    elif env_id == "colour_bomb_grid_world_v2":
        env = ColourBombGridWorldV2(seed=seed, random_action_probability=random_action_probability, episode_length=episode_length)
    elif env_id == "media_streaming":
        env = MediaStreaming(seed=seed, episode_length=episode_length)
    elif env_id == "pacman":
        env = Pacman(seed=seed, episode_length=episode_length)
    
    n_states = env.n_states
    n_actions = env.n_actions

    # load cost function, automaton and pctl and lt properties
    spec=importlib.util.spec_from_file_location("property", property_path)
 
    # creates a new module based on spec
    properties = importlib.util.module_from_spec(spec)
    
    # executes the module in its own namespace
    # when a module is imported or reloaded.
    spec.loader.exec_module(properties)

    # get the cost funtion and corresponding automaton
    automaton = properties.automaton
    cost_function = dfa.Cost_Function(automaton, reward_shaping=reward_shaping, discount=safe_discount_factor)

    # define the product state space
    n_automaton_states = len(automaton.states)
    prod_state_space = construct_product_state_space(n_states, len(automaton.states))
    n_prod_states = np.prod(prod_state_space.shape)
    accepting_states = prod_state_space[automaton.accepting, :].flatten()

    # setup q learning agents
    task_agent = Q_Learning(n_states, n_actions, alpha=task_learning_rate, discount=task_discount_factor, exploration=task_exploration_type, expl_parameter=task_exploration_parameter)
    if pretrained_backup:
        safe_agent = Value_Iteration(n_prod_states, n_actions, discount=safe_discount_factor, temperature=safe_exploration_parameter)
        product_transition_matrix = compute_product_state_action_transition_matrix(env.transition_matrix, automaton, env.labelling_fn)
        cost_map = np.zeros(n_prod_states, dtype=np.float32)
        cost_map[accepting_states] = -cost_coefficient
        safe_agent.train(product_transition_matrix, cost_map, steps=value_iteration_steps, stopping_condition=value_iteration_stopping_condition)
    else:
        safe_agent = Q_Learning(n_prod_states, n_actions, alpha=safe_learning_rate, discount=safe_discount_factor, exploration=safe_exploration_type, expl_parameter=safe_exploration_parameter)

    # returns a model checker with model_checking_type that checks the corresponding property specified in the properties module
    model_checker = init_model_checker(env, properties, device, model_checking_type=model_checking_type, satisfaction_probability=satisfaction_probability, shielding_type=shielding_type)

    # kwargs to pass the the model checker
    mc_kwargs = {'num_samples': num_samples}

    # returns a dynamics model for the state space being checked if approximate then the model is learned from experience
    dynamics_model = init_dynamics_model(env, automaton, model_checking_type=model_checking_type, approximate_model=approximate_model, shielding_type=shielding_type, prior_type=prior_type)

    # if the dynamics model is approximate and on the product state space then it is updated with counter factual experiences
    update_dynamics_with_cf = (approximate_model and (shielding_type in ['action_cond_safe', 'task_prod']))

    # if the bakcup policy is pretrained and fixed (and the transition matrix is given) we can precompute the action satisfaction set to save time
    if (model_checking_type == "exact") and pretrained_backup and (approximate_model == False):
        action_sat = np.zeros((n_prod_states, n_actions), dtype=np.float32)
        for prod_state in tqdm(range(n_prod_states)):
            for action in range(n_actions):
                key, state_sat = model_checker.check(key, safe_agent.get_policy(), dynamics_model.get_model(), prod_state, action, **mc_kwargs)
                action_sat[prod_state, action] = state_sat.item()
        '''print(prod_state_space.shape)
        grid_size = env.grid_size
        for x in range(prod_state_space.shape[0]):
            print(x)
            for a in range(n_actions):
                print(a)
                print(action_sat[x*n_states:(x+1)*n_states, a].reshape(grid_size,grid_size))
        assert False'''
        print("Computed action sat set ...")
    else:
        action_sat = None

    # training loop
    state, info = env.reset()
    labels = info["labels"]
        
    cost_function.reset()
    _, _, automaton_state = cost_function.step(labels)

    for frame_idx in range(num_frames):

        prod_state = prod_state_space[automaton_state, state]

        if shielding_type == "task_prod":
            task_action = task_agent.step(state)
            # task agent is defined on the state space we need to map it to the product state space
            if action_sat is not None:
                sat = action_sat[prod_state, task_action]
            else:
                key, sat = model_checker.check(key, np.repeat(task_agent.get_policy(), n_automaton_states, axis=0), dynamics_model.get_model(), prod_state, task_action, **mc_kwargs)
        elif shielding_type == "action_cond_safe":
            task_action = task_agent.step(state)
            if action_sat is not None:
                sat = action_sat[prod_state, task_action]
            else:
                key, sat = model_checker.check(key, safe_agent.get_policy(), dynamics_model.get_model(), prod_state, task_action, **mc_kwargs)
        else:
            assert NotImplementedError

        '''if not sat:
            safe_probs = safe_agent.policy(prod_state, mode=safe_policy_mode)
            unsafe_actions = np.arange(n_actions)[safe_probs <= (1/n_actions - 1e-2)]
            task_q_vals = task_agent.Q[state].copy()
            task_q_vals[unsafe_actions] = -100.0
            safe_action = np.random.choice(n_actions, p=boltzmann_probs(task_q_vals, temp=task_exploration_parameter))
        else:
            safe_action = None'''

        safe_action = safe_agent.step(prod_state, mode=safe_policy_mode) if not sat else None
        task_action = task_action if task_action is not None else task_agent.step(state)
        action = safe_action if not sat else task_action

        next_state, reward, terminated, info = env.step(action)
        labels = info["labels"]

        violation, cost, next_automaton_state = cost_function.step(labels)
    
        logger.step({
            'done' : terminated, 
            'ep_rew' : reward, 
            'ep_cost' : float(violation),
            'ep_len' : 1.0, 
            'ep_overrides' : float(not sat), 
            'is_success' : float(not bool(violation))
        })
        
        # update the task agent
        if not sat:
            assert safe_action is not None
            tup = (state, task_action, 0.0, next_state, True)
            task_agent.update(tup)

        tup = (state, action, reward, next_state, terminated)
        task_agent.update(tup)
            
        # update the safe agent
        if (counter_factual_experiences and not pretrained_backup) or update_dynamics_with_cf:
            # run counter factual experiences and update the safe agent and/or dynamics model
            # simulate state, action, next_state, from all automaton states
            for counter_fac_automaton_state in automaton.states:
                counter_fac_prod_state = prod_state_space[counter_fac_automaton_state, state]
                counter_fac_next_automaton_state = automaton.transition(counter_fac_automaton_state, labels)
                counter_fac_next_prod_state = prod_state_space[counter_fac_next_automaton_state, next_state]
                if (counter_factual_experiences and not pretrained_backup):
                    counter_fac_viol = bool(counter_fac_next_automaton_state in automaton.accepting)
                    counter_fac_cost = float(counter_fac_viol) + cost_function.potential(counter_fac_automaton_state, counter_fac_next_automaton_state)
                    penalty = counter_fac_cost * cost_coefficient
                    counter_fac_terminated = terminated #or bool(counter_fac_viol)
                    tup = (counter_fac_prod_state, action, -penalty, counter_fac_next_prod_state, counter_fac_terminated)
                    safe_agent.update(tup)
                if update_dynamics_with_cf:
                    dynamics_model.update(counter_fac_next_prod_state, counter_fac_prod_state, action)
        else:
            if (not pretrained_backup):
                next_prod_state = prod_state_space[next_automaton_state, next_state]
                penalty = cost * cost_coefficient
                simulated_terminated = terminated #or bool(violation)
                tup = (prod_state, action, -penalty, next_prod_state, simulated_terminated)
                safe_agent.update(tup)

        # update the dynamics model
        if not update_dynamics_with_cf:
            dynamics_model.update(next_state, state, action) 
        
        state = next_state
        automaton_state = next_automaton_state

        # reset the environment and cost function if done
        if terminated:
            state, info = env.reset()
            labels = info["labels"]
                
            cost_function.reset()
            _, _, automaton_state = cost_function.step(labels)

    

