import gym
import symbolic_behaviour_benchmark
import numpy as np 
import random

from symbolic_behaviour_benchmark.utils import wrappers
from symbolic_behaviour_benchmark.rule_based_agents import build_WrappedPositionallyDisentangledSpeakerAgent
from symbolic_behaviour_benchmark.rule_based_agents import build_WrappedPositionallyDisentangledListenerAgent


def test_env(
    nbr_communication_rounds=2,
    vocab_size=10,
    max_sentence_length=1,
    nbr_latents=2, 
    min_nbr_values_per_latent=4, 
    max_nbr_values_per_latent=5,
    nbr_object_centric_samples=1,
    nbr_distractors=3,
    allow_listener_query=False,
    use_communication_channel_permutations=False,
    ):
    
    """
    rg_config = {
        "observability":            "full",
        "max_sentence_length":      max_sentence_length,
        "vocab_size":               vocab_size,
        "nbr_communication_round":  1,
    
        "nbr_distractors":          {"train":nbr_latents, "test":nbr_distractors},
        "distractor_sampling":      'uniform',
        # Default: use 'uniform' or "similarity-0.5"
        # otherwise the emerging language 
        # will have very high ambiguity...
        # Speakers find the strategy of uttering
        # a word that is relevant to the class/label
        # of the target, seemingly.  

        "descriptive":              False,
        "descriptive_target_ratio": 0.5,

        "object_centric":           False,
        "nbr_stimulus":             1,

        "use_curriculum_nbr_distractors": False,
        "curriculum_distractors_window_size": 25, #100,
    }

    train_dataset = SymbolicContinuousStimulusDataset(
        train=True,
        transform=None,
        split_strategy='combinatorial2-40',
        nbr_latents=nbr_latents,
        min_nbr_values_per_latent=min_nbr_values_per_latent,
        max_nbr_values_per_latent=max_nbr_values_per_latent,
        nbr_object_centric_samples=1,
        prototype=None,
    )
    """

    env = gym.make(
        "SymbolicBehaviourBenchmark-ReceptiveConstructiveTestEnv-v0", 
        #rg_config=rg_config,
        #train_dataset=train_dataset,
        nbr_communication_rounds=nbr_communication_rounds,
        vocab_size=vocab_size,
        max_sentence_length=max_sentence_length,
        nbr_latents=nbr_latents,
        min_nbr_values_per_latent=min_nbr_values_per_latent,
        max_nbr_values_per_latent=max_nbr_values_per_latent,
        nbr_object_centric_samples=nbr_object_centric_samples,
        nbr_distractors=nbr_distractors,
        use_communication_channel_permutations=use_communication_channel_permutations,
        allow_listener_query=allow_listener_query,
    )
    
    dcaw_env = wrappers.DiscreteCombinedActionWrapper(env)

    env = wrappers.s2b_wrap(env, combined_actions=True)
    
    obs, infos = env.reset()
    

    speaker_agent = build_WrappedPositionallyDisentangledSpeakerAgent(
        player_idx=0,
        action_space_dim=env.action_space.n,
        vocab_size=vocab_size,
        max_sentence_length=max_sentence_length,
        nbr_communication_rounds=nbr_communication_rounds,
        nbr_latents=nbr_latents,
    )
    speaker_agent.set_nbr_actor(1)
    speaker_agent.reset_actors()

    listener_agent = build_WrappedPositionallyDisentangledListenerAgent(
        player_idx=1,
        action_space_dim=env.action_space.n,
        vocab_size=vocab_size,
        max_sentence_length=max_sentence_length,
        nbr_communication_rounds=nbr_communication_rounds,
        nbr_latents=nbr_latents,
    )
    listener_agent.set_nbr_actor(1)
    listener_agent.reset_actors()
    
    """
    nb_possible_sentences = vocab_size**max_sentence_length
    speaker_action = {'decision':0, 'communication_channel': np.ones(max_sentence_length)*3}
    #speaker_action = 0*nb_possible_sentences + 3**max_sentence_length
    speaker_action = dcaw_env._encode_action(speaker_action) 
    listener_action = {'decision':2, 'communication_channel': np.ones(max_sentence_length)*2}
    #listener_action = 2*nb_possible_sentences+2**max_sentence_length
    listener_action = dcaw_env._encode_action(listener_action)
    """

    speaker_action = speaker_agent.take_action(state=[obs[0]], infos=[infos[0]])
    listener_action = listener_agent.take_action(state=[obs[1]], infos=[infos[1]])

    import ipdb; ipdb.set_trace() 

    obs1, reward1, done1, infos1 = env.step(action=[speaker_action, listener_action])

    speaker_action1 = speaker_agent.take_action(state=[obs1[0]], infos=[infos1[0]])
    listener_action1 = listener_agent.take_action(state=[obs1[1]], infos=[infos1[1]])

    import ipdb; ipdb.set_trace()
    
    obs2, reward2, done2, infos2 = env.step(action=[speaker_action1, listener_action1])

    speaker_action2 = speaker_agent.take_action(state=[obs2[0]], infos=[infos2[0]])
    listener_action2 = listener_agent.take_action(state=[obs2[1]], infos=[infos2[1]])
     
    import ipdb; ipdb.set_trace()
    
    foutput1 = env.step(action=[speaker_action2, listener_action2])

    import ipdb; ipdb.set_trace()

    env.close()

if __name__ == "__main__":
    seed = 3 
    # Following: https://pytorch.org/docs/stable/notes/randomness.html
    """
    torch.manual_seed(seed)
    if hasattr(torch.backends, "cudnn") and not(args.fast):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    """

    np.random.seed(seed)
    random.seed(seed)

    #test_env()
    test_env(
        nbr_communication_rounds=1,
        max_sentence_length=2,
    )
