from continual_rl.experiments.experiment import Experiment
from continual_rl.experiments.tasks.image_task import ImageTask
from continual_rl.experiments.tasks.minigrid_task import MiniGridTask
from continual_rl.utils.env_wrappers import wrap_deepmind, make_atari
from continual_rl.envs.incremental_classification_env import IncrementalClassificationEnv, DatasetIds


def get_mnist_task(ids, dataset_id):
    dataset_location = "tmp/mnist"
    # Need to wrap it in a scope so the id doesn't change to be the last-executed id (i.e. 9)
    return lambda: IncrementalClassificationEnv(data_dir=dataset_location,
                                                num_steps_per_episode=100,
                                                allowed_class_ids=ids,
                                                dataset_id=dataset_id)


def create_mnist_full(use_early_stopping):
    # Train on each ID individually
    early_stopping_lambda = (lambda mean_val, _: mean_val > 97) if use_early_stopping else None
    recall_mnist_sequential_full_tasks = []
    for id in range(10):
        recall_mnist_sequential_full_tasks.append(ImageTask(action_space_id=0,
                  env_spec=get_mnist_task([id], dataset_id=DatasetIds.MNIST_TRAIN),
                  num_timesteps=300000, time_batch_size=1, eval_mode=False, image_size=[28, 28],
                  grayscale=True, early_stopping_lambda=early_stopping_lambda))

        # Test on the full set up to this id
        recall_mnist_sequential_full_tasks.append(
            ImageTask(action_space_id=0,
                      env_spec=get_mnist_task(list(range(id+1)), dataset_id=DatasetIds.MNIST_TEST),
                      num_timesteps=10000, time_batch_size=1, eval_mode=True, image_size=[28, 28], grayscale=True))

    return Experiment(tasks=recall_mnist_sequential_full_tasks, continual_testing_freq=10)


def get_available_experiments():
    experiments = {

        "atari_simple_space_invaders": Experiment(
            tasks=[ImageTask(action_space_id=0,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('SpaceInvadersNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=50000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True)
                   ], continual_testing_freq=50),

        "atari_simple_breakout": Experiment(
            tasks=[ImageTask(action_space_id=0,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('BreakoutNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=50000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True)
                   ], continual_testing_freq=50),

        "atari_cycle": Experiment(
            tasks=[ImageTask(action_space_id=0,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('SpaceInvadersNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=50000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=1,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('HeroNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=50000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=2,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('KrullNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=50000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=3,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('StarGunnerNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=50000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=4,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('BeamRiderNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=50000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=5,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('MsPacmanNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=50000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True)
                   ], continual_testing_freq=50),  # TODO: only one "cycle" so continual testing doesn't take forever

        "atari_mini_cycle_full": Experiment(
            tasks=[ImageTask(action_space_id=0,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('SpaceInvadersNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=10000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=1,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('HeroNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=10000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=2,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('KrullNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=10000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=3,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('StarGunnerNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=10000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=4,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('BeamRiderNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=10000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True),
                   ImageTask(action_space_id=5,
                             env_spec=lambda: wrap_deepmind(
                                 make_atari('MsPacmanNoFrameskip-v4'),
                                 clip_rewards=False,
                                 frame_stack=False,  # Handled separately
                                 scale=False,
                             ), num_timesteps=10000000, time_batch_size=4, eval_mode=False,
                             image_size=[84, 84], grayscale=True)
                   ], continual_testing_freq=50, cycle_count=5),

        "minigrid_2room_random": Experiment(
            tasks=[MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=750000, time_batch_size=1,
                                eval_mode=False)
                   ], continual_testing_freq=10),

        "minigrid_4room_random": Experiment(
            tasks=[MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N4-S5-v0',
                                num_timesteps=750000, time_batch_size=1,
                                eval_mode=False)
                   ], continual_testing_freq=10),

        "minigrid_keycorridorS3R1": Experiment(
            tasks=[MiniGridTask(action_space_id=0, env_spec='MiniGrid-KeyCorridorS3R1-v0',
                                num_timesteps=1500000, time_batch_size=1,
                                eval_mode=False)
                   ], continual_testing_freq=10),

        "minigrid_empty8x8_unlock": Experiment(
            tasks=[MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=1000000,  # 1337 is Minigrid's multiroom default
                                time_batch_size=1, eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=1500000,
                                time_batch_size=1, eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=10000,
                                time_batch_size=1, eval_mode=True)
                   ], continual_testing_freq=10),

        "minigrid_empty8x8_2room_unlock": Experiment(
            tasks=[MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=750000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=10000,
                                time_batch_size=1,
                                eval_mode=True),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=2000000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=10000,
                                time_batch_size=1,
                                eval_mode=True),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=10000, time_batch_size=1,
                                eval_mode=True)
                   ], continual_testing_freq=10),

        "minigrid_empty8x8_2room_unlock_pseudo_stop_early": Experiment(
            tasks=[MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=200000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=300000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=10000,
                                time_batch_size=1,
                                eval_mode=True),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=10000,
                                time_batch_size=1,
                                eval_mode=True),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=10000, time_batch_size=1,
                                eval_mode=True)
                   ], continual_testing_freq=10),

        "minigrid_empty8x8_2room_4room_unlock": Experiment(
            tasks=[MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=750000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N4-S5-v0',
                                num_timesteps=1000000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=1000000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=10000,
                                time_batch_size=1,
                                eval_mode=True),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=10000, time_batch_size=1,
                                eval_mode=True),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N4-S5-v0',
                                num_timesteps=10000, time_batch_size=1,
                                eval_mode=True)
                   ], continual_testing_freq=10),

        "minigrid_empty8x8_2room_4room_unlock_longer": Experiment(
            tasks=[MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=750000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N4-S5-v0',
                                num_timesteps=1500000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=1500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Empty-8x8-v0', num_timesteps=10000,
                                time_batch_size=1,
                                eval_mode=True),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=10000, time_batch_size=1,
                                eval_mode=True),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N4-S5-v0',
                                num_timesteps=10000, time_batch_size=1,
                                eval_mode=True)
                   ], continual_testing_freq=10),

        "minigrid_2room_unlock_4room": Experiment(
            tasks=[
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=750000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=1500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N4-S5-v0',
                                num_timesteps=1500000, time_batch_size=1,
                                eval_mode=False)
                   ], continual_testing_freq=10),

        "minigrid_2room_unlock_keycorridor": Experiment(
            tasks=[
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=750000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=1500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-KeyCorridorS3R2-v0',
                                num_timesteps=1500000, time_batch_size=1,
                                eval_mode=False)
                   ], continual_testing_freq=10),

        "minigrid_2room_unlock_keycorridorS3R1": Experiment(
            tasks=[
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                                num_timesteps=750000, time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=1500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-KeyCorridorS3R1-v0',
                                num_timesteps=1500000, time_batch_size=1,
                                eval_mode=False)
                   ], continual_testing_freq=10),

        "minigrid_unlock_keycorridorS3R1": Experiment(
            tasks=[
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=1500000,
                                time_batch_size=1,
                                eval_mode=False),
                   MiniGridTask(action_space_id=0, env_spec='MiniGrid-KeyCorridorS3R1-v0',
                                num_timesteps=1500000, time_batch_size=1,
                                eval_mode=False)
                   ], continual_testing_freq=10),

        "minigrid_2room_unlock_keycorridorS3R1_stop_early": Experiment(
            tasks=[
                MiniGridTask(action_space_id=0, env_spec='MiniGrid-MultiRoom-N2-S4-v0',
                             num_timesteps=750000, time_batch_size=1,
                             eval_mode=False, early_stopping_lambda=(lambda mean_val, task_timesteps: mean_val > 0.78)),
                MiniGridTask(action_space_id=0, env_spec='MiniGrid-Unlock-v0', num_timesteps=1500000,
                             time_batch_size=1,
                             eval_mode=False, early_stopping_lambda=(lambda mean_val, task_timesteps: mean_val > 0.95)),
                MiniGridTask(action_space_id=0, env_spec='MiniGrid-KeyCorridorS3R1-v0',
                             num_timesteps=1500000, time_batch_size=1,
                             eval_mode=False, early_stopping_lambda=(lambda mean_val, task_timesteps: mean_val > 0.92))
            ], continual_testing_freq=3),

        "mnist_sequential_full": create_mnist_full(use_early_stopping=False),
        "mnist_sequential_full_stop_early": create_mnist_full(use_early_stopping=True),
    }

    return experiments
