from simple_rl.agents import RandomAgent, QLearningAgent
from TestEnv import GymMDP
from simple_rl.run_experiments import run_agents_on_mdp
from agents import QLearningAgentSimple, DynaQLearningAgent, RLangDynaQLearningAgent, SkillsAndObjWrapper
from utils import hash_state
from dynamic_grounding import get_knowledge_from_file, SmartStateFeaturizer, get_stable_knowledge

def get_agents_param_sweep(gym_mdp, gamma_vals, epsilon_vals, alpha_vals, anneal_vals, default_q_vals, actions):
    agents = []
    for gamma in gamma_vals:
        for epsilon in epsilon_vals:
            for alpha in alpha_vals:
                for anneal in anneal_vals:
                    for default_q in default_q_vals:
                        agents.append(DynaQLearningAgent(actions=actions, state_hash_func=hash_state,
                                    gamma=gamma, epsilon=epsilon, alpha=alpha, anneal=anneal, default_q=default_q, name=f"Dyna-Q-Agent-g:{gamma}-e:{epsilon}-a:{alpha}-q:{default_q}"))
    return agents

def get_rlang_agents_param_sweep(gym_mdp, gamma_vals, epsilon_vals, alpha_vals, anneal_vals, default_q_vals, actions):
    get_knowledge = lambda state, filename="rlang_advice/midmazelava.rlang": get_knowledge_from_file(state, filename, gym_mdp)
    agents = []
    for gamma in gamma_vals:
        for epsilon in epsilon_vals:
            for alpha in alpha_vals:
                for anneal in anneal_vals:
                    for default_q in default_q_vals:
                        agents.append(RLangDynaQLearningAgent(actions=actions, get_knowledge=get_knowledge, state_hash_func=hash_state,
                                                              use_policy=True, use_plan=True, use_effects=True, policy_epsilon=1.0,
                                    gamma=gamma, epsilon=epsilon, alpha=alpha, anneal=anneal, default_q=default_q, name=f"RLang-Dyna-Q-Agent-Combined-g:{gamma}-e:{epsilon}-a:{alpha}-q:{default_q}"))
    return agents


# This code performs a parameter sweep over different values of alpha, epsilon, and gamma for Dyna-Q-learning agents.
# The agents are run for 50 episodes and the cumulative rewards are plotted. 

def perform_param_sweep(env_name, action_ids, env_nickname=None, seed=None, *args, **kwargs):
    """
    Perform a parameter sweep of agents on a given MDP.
    """
    gym_mdp = GymMDP(env_name=env_name, render=False, wrapper=SkillsAndObjWrapper, seed=seed, max_steps=2000, *args, **kwargs)

    actions = action_ids
    if env_nickname in ("HardMaze", "MidMaze", "DoorKey", "MultiRoom", "HardMazeLight", "LavaCrossing", "MidMazeLava", "LockedRoom"):
        gym_mdp.reset()
        state = gym_mdp.get_init_state()
        knowledge = get_stable_knowledge()
        statefeaturizer = SmartStateFeaturizer(knowledge=knowledge)
        gym_mdp.env.skill_dict, skill_names = statefeaturizer.generate_skill_dict(state)
        skill_idxs = list(skill_names.keys())
        actions += skill_idxs
        print(actions)

    gamma_vals = [0.99]
    epsilon_vals = [0.01, 0.1]
    alpha_vals = [0.01, 0.1]
    # gamma_vals = [0.99]
    # epsilon_vals = [0.01]
    # alpha_vals = [0.01]
    anneal_vals = [False]
    default_q_vals = [0.0]

    agents = get_agents_param_sweep(gym_mdp, gamma_vals, epsilon_vals, alpha_vals, anneal_vals, default_q_vals, action_ids)
    # agents = get_rlang_agents_param_sweep(gym_mdp, gamma_vals, epsilon_vals, alpha_vals, anneal_vals, default_q_vals, action_ids)
    
    rand_agent = RandomAgent(actions)

    run_agents_on_mdp([*agents],
                        gym_mdp, instances=3, episodes=25, steps=500, open_plot=False, verbose=False,
                        cumulative_plot=True, track_disc_reward=False)
    