import copy
from typing import Any, Dict

import carla
import gym

from car_dreamer.toolkit import EnvMonitorOpenCV, Observer, WorldManager

from .carla_left_turn_env import CarlaLeftTurnEnv
from .carla_message_env import CarlaMessageEnv
from .carla_right_turn_env import CarlaRightTurnEnv


class CarlaMultitaskEnv(gym.Env):
    def __init__(self, config):
        self._config = config
        self._tasks = config.tasks
        if len(self._tasks) == 0:
            raise ValueError("At least one task should be provided.")
        self._task_configs = self._prepare_task_configs()
        self._current_task_index = 0
        self._episode_counter = 0

        # Initialize shared resources
        self._envs = {}
        self._shared_resources = self._init_shared_resources()
        self._create_task_envs()

        # Set up observation and action spaces
        self.observation_space = self._get_observation_space()
        self.action_space = self._get_action_space()

    def _get_task_env(self, task):
        if task.startswith("carla_message"):
            env_class = CarlaMessageEnv
        elif task.startswith("carla_left_turn"):
            env_class = CarlaLeftTurnEnv
        elif task.startswith("carla_right_turn"):
            env_class = CarlaRightTurnEnv
        else:
            raise ValueError(f"Unknown task: {task}")
        return env_class

    def _init_shared_resources(self):
        monitor = EnvMonitorOpenCV(list(self._task_configs.values())[0])
        shared_resources = {}
        unique_towns = set(task_config.world.town for task_config in self._task_configs.values())
        for task, task_config in self._task_configs.items():
            town = task_config.world.town
            if town not in unique_towns:
                continue

            env = self._get_task_env(task)(task_config, monitor=monitor)
            self._envs[task] = env

            shared_resources[town] = {
                "client": env._world._client,
                "world": env._world._world,
                "vehicle_manager": env._world._vehicle_manager,
                "monitor": monitor,
            }
            unique_towns.remove(town)
        return shared_resources

    def _create_task_envs(self):
        for task, task_config in self._task_configs.items():
            if task in self._envs:
                continue
            town = task_config.world.town
            shared_res = self._shared_resources[town]
            world_manager = WorldManager(
                task_config, client=shared_res["client"], world=shared_res["world"], vehicle_manager=shared_res["vehicle_manager"]
            )
            observer = Observer(world_manager, task_config.observation)

            self._envs[task] = self._get_task_env(task)(task_config, world=world_manager, observer=observer, monitor=shared_res["monitor"])

    def _prepare_task_configs(self) -> Dict[str, Any]:
        task_configs = {}
        for task in self._tasks:
            task_config = copy.deepcopy(self._config)
            task_specific = getattr(self._config, task, {})
            task_config = task_config.update(task_specific.env)
            task_configs[task] = task_config
        return task_configs

    def _get_observation_space(self):
        return list(self._envs.values())[0].observation_space

    def _get_action_space(self):
        return list(self._envs.values())[0].action_space

    def reset(self):
        self._current_task_index = self._episode_counter % len(self._tasks)
        current_task = self._tasks[self._current_task_index]
        print(f"Resetting for task: {current_task}")
        obs = self._envs[current_task].reset()
        self._episode_counter += 1
        return obs

    def step(self, action):
        current_task = self._tasks[self._current_task_index]
        return self._envs[current_task].step(action)

    def render(self):
        current_task = self._tasks[self._current_task_index]
        return self._envs[current_task].render()

    def close(self):
        for env in self._envs.values():
            env.close()
        for resources in self._shared_resources.values():
            resources["client"].apply_batch([carla.command.DestroyActor(id) for id in resources["world"].get_actors()])
