from .grid.wrappers import (
    FullyObsWrapper,
    ActionMasking,
    RGBImgObsWrapper,
    PartialObsWrapper,
    GoalOffsetWrapper,
)
from .grid.simple_grid import MultiRoomGrid
from .grid.two_goal_grid import TwoGoalsGrid
from .grid.babyai import (
    BabyAIGoToRedBall,
    BabyAIGoToRedBallGrey,
)
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import (
    EnvCreator,
    ParallelEnv,
    TransformedEnv,
)
from torchrl.envs.transforms import Compose, StepCounter, ToTensorImage, RewardScaling
import metaworld
from torchrl.envs.gym_like import default_info_dict_reader
from environments.metaworld.wrapper import SawyerReachEnv
from .overcook.single_agent_overcooked import SingleAgentOvercooked
from .overcook.wrappers import FullyObservableOvercookedWrapper, EnvMethodWrapper, ActionMasking_cook


def env_maker(config, seed=None):
    env_config, env_name = config['env_config'], config['env']
    use_rgb = env_config.get('use_rgb', False)
    fully_obs = env_config.get("fully_observable", True)

    if env_name == 'multi_grid':
        env = MultiRoomGrid(**env_config)
        if seed is not None:
            env.reset(seed=seed)

        env = RGBImgObsWrapper(env) if use_rgb else FullyObsWrapper(env)
        env = ActionMasking(env)
        env = GymWrapper(env)
        env = TransformedEnv(env, Compose(
            StepCounter(env_config['max_steps']),
            ToTensorImage(from_int=use_rgb, in_keys='image')
        ))

    elif env_name == "two_goal_grid":
        env = TwoGoalsGrid(**env_config)
        if seed is not None:
            env.reset(seed=seed)

        env = RGBImgObsWrapper(env) if use_rgb else FullyObsWrapper(env)
        env = ActionMasking(env)
        env = GymWrapper(env)
        env = TransformedEnv(env, Compose(
            StepCounter(env_config['max_steps']),
            ToTensorImage(from_int=use_rgb, in_keys='image')
        ))

    elif env_name == "single_cook":
        env = SingleAgentOvercooked(
            layout_name=env_config["layout_name"],
            horizon=env_config["horizon"],
            reward_shaping_horizon=env_config["reward_shaping_horizon"],
            random_layout=env_config.get("random_layout", False),
            random_recipe=env_config.get("random_recipe", False),
            force_ingredients=env_config.get("force_ingredients", False),
        )
        env = EnvMethodWrapper(env)
        env = FullyObservableOvercookedWrapper(env)
        env = GymWrapper(env)
        env = TransformedEnv(env, Compose(
            StepCounter(env_config["max_steps"]),
            ToTensorImage(from_int=use_rgb, in_keys=["image"])
        ))

    elif env_name == 'metaworld':
        task_suit, task_name = env_config['config'].split('_')
        tasks = metaworld.ML10() if task_suit == 'ml10' else None
        if tasks is None:
            raise ValueError(f'Unknown task suit {task_suit}')
        task = [t for t in tasks.train_tasks if t.env_name == task_name][0]
        if task_name == 'reach-v2':
            env = SawyerReachEnv(
                base_penalty=env_config['base_penalty'],
                max_steps=env_config['max_steps']
            )
            env.set_task(task)
        else:
            raise ValueError(f'Unknown task {task_name}')
        reader = default_info_dict_reader(["success"])
        env = GymWrapper(env).set_info_dict_reader(info_dict_reader=reader)
        env = TransformedEnv(env, Compose(
            StepCounter(env_config['max_steps']),
            RewardScaling(
                scale=env_config.get('reward_scale', 1),
                loc=env_config.get('reward_loc', 0)
            )
        ))

    else:
        if env_name == 'babyai_gotoredball':
            env = BabyAIGoToRedBall(
                room_size=env_config['room_size'],
                num_dists=env_config['num_dists'],
                max_steps=env_config['max_steps']
            )
        elif env_name == 'babyai_gotoredballgrey':
            if fully_obs:
                env = BabyAIGoToRedBallGrey(
                    room_size=env_config['room_size'],
                    num_dists=env_config['num_dists'],
                    max_steps=env_config['max_steps']
                )
            else:
                env = BabyAIGoToRedBallGrey(
                    room_size=env_config['room_size'],
                    num_dists=env_config['num_dists'],
                    max_steps=env_config['max_steps'],
                    agent_view_size=3
                )
        else:
            raise ValueError(f"Env name not recognized: {env_name}")

        if seed is not None:
            env.reset(seed=seed)

        env = RGBImgObsWrapper(env) if use_rgb else env
        if fully_obs:
            env = FullyObsWrapper(env)
        else:
            env = PartialObsWrapper(env)
            env = GoalOffsetWrapper(env)

        env = ActionMasking(env)
        env = GymWrapper(env)
        env = TransformedEnv(env, Compose(
            StepCounter(env_config['max_steps']),
            ToTensorImage(from_int=use_rgb, in_keys='image')
        ))

    return env


def parallel_env_maker(config, num_envs, device='cpu', base_seed=None):
    def make_env_fn(index):
        seed = base_seed + index if base_seed is not None else None
        return EnvCreator(lambda: env_maker(config, seed=seed))

    return ParallelEnv(num_envs, make_env_fn(0), device=device)

