import abc
import os
from typing import Optional, List, Any, Dict, cast, Sequence, Callable, Union

import gin
import gym
import torch
from gym_minigrid.minigrid import Lava, WorldObj, Wall
from torch import nn, optim
from torch.optim.lr_scheduler import LambdaLR

from extensions.rl_babyai.babyai_offpolicy import create_babyai_offpolicy_data_iterator
from extensions.rl_minigrid.minigrid_constants import MINIGRID_EXPERT_TRAJECTORIES_DIR
from extensions.rl_minigrid.minigrid_environments import (
    FastCrossing,
    AskForHelpSimpleCrossing,
)
from extensions.rl_minigrid.minigrid_models import MiniGridSimpleConvRNN
from extensions.rl_minigrid.minigrid_sensors import EgocentricMiniGridSensor
from extensions.rl_minigrid.minigrid_tasks import (
    MiniGridTaskSampler,
    MiniGridTask,
    AskForHelpSimpleCrossingTask,
)
from onpolicy_sync.losses import PPO, A2C
from onpolicy_sync.losses.a2cacktr import A2CConfig
from onpolicy_sync.losses.imitation import Imitation
from onpolicy_sync.losses.ppo import PPOConfig
from rl_base.common import Loss
from rl_base.experiment_config import ExperimentConfig
from rl_base.sensor import SensorSuite, Sensor, ExpertActionSensor
from utils.experiment_utils import Builder, LinearDecay, TrainingPipeline, PipelineStage


class MiniGridBaseExperimentConfig(ExperimentConfig):
    """Base experiment."""

    # Default MiniGrid values
    AGENT_VIEW_SIZE: int = 7
    AGENT_VIEW_CHANNELS: int = 3

    # Training params
    NUM_TRAIN_SAMPLERS: int = 20  # if torch.cuda.is_available() else 1
    ROLLOUT_STEPS: int = 100
    TOTAL_TRAIN_STEPS = int(1e6)
    NUM_TRAIN_TASKS: int = None
    NUM_TEST_TASKS: int = 1000
    GPU_ID: Optional[int] = 1 if torch.cuda.is_available() else None
    USE_EXPERT = False
    RNN_TYPE = "LSTM"
    CACHE_GRAPHS: bool = False
    SHOULD_LOG = True
    TEST_SEED_OFFSET = 0

    # Gin configurable default parameters
    DEFAULT_LR = 1e-3
    DEFAULT_TF_RATIO = 0.5
    DEFAULT_FIXED_ALPHA = 20
    DEFAULT_ALPHA_START = 0
    DEFAULT_ALPHA_STOP = 20

    @classmethod
    @gin.configurable
    def hyperparams(
        cls,
        lr: Optional[Union[str, float]] = None,
        tf_ratio: Optional[Union[str, float]] = None,
        fixed_alpha: Optional[Union[str, float]] = None,
        anneal_alpha_start: Optional[Union[str, float]] = None,
        anneal_alpha_stop: Optional[Union[str, float]] = None,
    ) -> Dict[str, Optional[float]]:
        """Hyperparameters to use during training.

        These values should be specified using `gin-params` from the command line.
        For instance, you can set the learning rate to be `1e-2` by adding the command
        line argument

        ```
        --gin_param="hyperparams.lr = 1e-2"
        ```

        # Parameters

        lr : Learning rate to use during training.
        tf_ratio : The proportion of training for which teacher forcing is used (if applicable).
        fixed_alpha : The (fixed) alpha value used in ADVISOR (if applicable).
        anneal_alpha_start : The starting alpha value used in ADVISOR when allowing alpha to be annealed
            from `anneal_alpha_start` to `anneal_alpha_stop` during training.
        anneal_alpha_stop : The value of alpha annealed to during training (if applicable).
        """

        def to_value(value: Optional[Union[str, float]], default: float):
            if value is None:
                return None
            elif value == "default":
                return default
            else:
                return value

        return {
            "lr": to_value(lr, cls.DEFAULT_LR),
            "tf_ratio": to_value(tf_ratio, cls.DEFAULT_TF_RATIO),
            "fixed_alpha": to_value(fixed_alpha, cls.DEFAULT_FIXED_ALPHA),
            "anneal_alpha_start": to_value(anneal_alpha_start, cls.DEFAULT_ALPHA_START),
            "anneal_alpha_stop": to_value(anneal_alpha_stop, cls.DEFAULT_ALPHA_STOP),
        }

    @classmethod
    def _get_hyperparameter_by_key(cls, key: str) -> float:
        hp = cls.hyperparams()[key]
        assert (
            hp is not None
        ), "`{}` must be set using gin-params. See the `hyperparams` method.".format(
            key
        )
        return hp

    @classmethod
    def alpha(cls) -> float:
        """Fixed alpha value used by ADVISOR.

        Set via command line options, e.g. use:

        `--gin_param="hyperparams.fixed_alpha = YOUR_VALUE"`
        """
        return cls._get_hyperparameter_by_key("fixed_alpha")

    @classmethod
    def anneal_alpha_start(cls) -> float:
        """Anneal alpha experiments need a user specified start value.

        Set via command line options, e.g. use:

        `--gin_param="hyperparams.anneal_alpha_start = YOUR_VALUE"`
        """

        return cls._get_hyperparameter_by_key("anneal_alpha_start")

    @classmethod
    def anneal_alpha_stop(cls) -> float:
        """Anneal alpha experiments need a user specified stop value.

        Set via command line options, e.g. use:

        `--gin_param="hyperparams.anneal_alpha_stop = YOUR_VALUE"`
        """
        return cls._get_hyperparameter_by_key("anneal_alpha_stop")

    @classmethod
    def lr(cls):
        """Learning rate used during training.

        Set via command line options, e.g. use:

        `--gin_param="hyperparams.lr = YOUR_VALUE"`
        """
        return cls._get_hyperparameter_by_key("lr")

    @classmethod
    @gin.configurable
    def tf_ratio(cls):
        """Portion of training for which teacher forcing is enabled.

        Set via command line options, e.g. use:

        `--gin_param="hyperparams.tf_ratio = YOUR_VALUE"`
        """
        return cls._get_hyperparameter_by_key("tf_ratio")

    @classmethod
    @gin.configurable
    def task_name(cls, name: Optional[str] = None):
        """User-specified name string for the task to run.

         Set via command line options, e.g. use:

        `--gin_param="task_name.name = 'THE_TASK_NAME'"`

        where `THE_TASK_NAME` might be CrossingS25N10.
        """
        if name is None:
            raise NotImplementedError(
                """Must specify a task name. E.g., use: --gin_param="task_name.name = \"CrossingS25N10\""""
            )
        return name

    @classmethod
    def task_info(cls):
        """All information needed about the underlying task.

        # Returns

         Dictionary of useful information:
            - env_info: used to initialize the environment
            - tag: string to use for logging
            - env_class: callable of the underlying mini-grid environment class
            - task_class: callable of the corresponding task class
        """
        name = cls.task_name()
        output_data = dict()

        if name == "CrossingS25N10":
            # Specific base parameters
            grid_size = 25
            num_crossings = 10
            obstacle_type: Callable[[], WorldObj] = Lava
            # Parameters needed for other functions
            output_data["env_info"] = {
                "size": grid_size,
                "num_crossings": num_crossings,
                "obstacle_type": obstacle_type,
            }
            output_data["tag"] = "Crossing{}S{}N{}{}".format(
                obstacle_type().__class__.__name__,
                grid_size,
                num_crossings,
                cls.extra_tag(),
            )
            output_data["task_sampler_args"] = {
                "repeat_failed_task_for_min_steps": 1000
            }
            output_data["env_class"] = FastCrossing
            output_data["task_class"] = MiniGridTask
            output_data["task_sampler_class"] = MiniGridTaskSampler

        elif name == "WallCrossingS25N10":
            # Specific base parameters
            grid_size = 25
            num_crossings = 10
            obstacle_type: Callable[[], WorldObj] = Wall
            # Parameters needed for other functions
            output_data["env_info"] = {
                "size": grid_size,
                "num_crossings": num_crossings,
                "obstacle_type": obstacle_type,
            }
            output_data["tag"] = "Crossing{}S{}N{}{}".format(
                obstacle_type().__class__.__name__,
                grid_size,
                num_crossings,
                cls.extra_tag(),
            )
            # # Each episode takes 4 * 25 * 25 = 2500 steps already, so no need to set
            # # repeat_failed_task_for_min_steps
            # output_data["task_sampler_args"] = {
            #     "repeat_failed_task_for_min_steps": 1000
            # }
            output_data["env_class"] = FastCrossing
            output_data["task_class"] = MiniGridTask
            output_data["task_sampler_class"] = MiniGridTaskSampler

        elif name == "WallCrossingCorruptExpertS25N10":
            # Specific base parameters
            grid_size = 25
            num_crossings = 10
            corrupt_expert_within_actions_of_goal = 15

            obstacle_type: Callable[[], WorldObj] = Wall
            # Parameters needed for other functions
            output_data["env_info"] = {
                "size": grid_size,
                "num_crossings": num_crossings,
                "obstacle_type": obstacle_type,
            }
            output_data["tag"] = "WallCrossingCorruptExpert{}S{}N{}C{}{}".format(
                obstacle_type().__class__.__name__,
                grid_size,
                num_crossings,
                corrupt_expert_within_actions_of_goal,
                cls.extra_tag(),
            )
            # # Each episode takes 4 * 25 * 25 = 2500 steps already, so no need to set
            # # repeat_failed_task_for_min_steps
            output_data["task_sampler_args"] = {
                "extra_task_kwargs": {
                    "corrupt_expert_within_actions_of_goal": corrupt_expert_within_actions_of_goal
                }
            }
            # output_data["task_sampler_args"] = {
            #     "repeat_failed_task_for_min_steps": 1000
            # }
            output_data["env_class"] = FastCrossing
            output_data["task_class"] = MiniGridTask
            output_data["task_sampler_class"] = MiniGridTaskSampler

        elif name == "LavaCrossingCorruptExpertS15N7":
            # Specific base parameters
            grid_size = 15
            num_crossings = 7
            corrupt_expert_within_actions_of_goal = 10
            obstacle_type: Callable[[], WorldObj] = Lava

            # Parameters needed for other functions
            output_data["env_info"] = {
                "size": grid_size,
                "num_crossings": num_crossings,
                "obstacle_type": obstacle_type,
            }
            output_data["tag"] = "LavaCrossingCorruptExpert{}S{}N{}C{}{}".format(
                obstacle_type().__class__.__name__,
                grid_size,
                num_crossings,
                corrupt_expert_within_actions_of_goal,
                cls.extra_tag(),
            )
            # # Each episode takes 4 * 25 * 25 = 2500 steps already, so no need to set
            # # repeat_failed_task_for_min_steps
            output_data["task_sampler_args"] = {
                "extra_task_kwargs": {
                    "corrupt_expert_within_actions_of_goal": corrupt_expert_within_actions_of_goal
                },
                "repeat_failed_task_for_min_steps": 1000,
            }
            output_data["env_class"] = FastCrossing
            output_data["task_class"] = MiniGridTask
            output_data["task_sampler_class"] = MiniGridTaskSampler

        elif name == "AskForHelpSimpleCrossing":
            # Specific base parameters
            grid_size = 15
            num_crossings = 7
            obstacle_type: Callable[[], WorldObj] = Wall
            # Parameters needed for other functions
            output_data["env_info"] = {
                "size": grid_size,
                "num_crossings": num_crossings,
                "obstacle_type": obstacle_type,
            }
            output_data["tag"] = "AskForHelpSimpleCrossing{}S{}N{}{}".format(
                obstacle_type().__class__.__name__,
                grid_size,
                num_crossings,
                cls.extra_tag(),
            )
            # output_data["task_sampler_args"] = {
            #     "repeat_failed_task_for_min_steps": 1000
            # }
            output_data["env_class"] = AskForHelpSimpleCrossing
            output_data["task_class"] = AskForHelpSimpleCrossingTask
            output_data["task_sampler_class"] = MiniGridTaskSampler

        elif name == "AskForHelpSimpleCrossingOnce":
            # Specific base parameters
            grid_size = 25
            num_crossings = 10
            toggle_is_permenant = True
            obstacle_type: Callable[[], WorldObj] = Wall
            # Parameters needed for other functions
            output_data["env_info"] = {
                "size": grid_size,
                "num_crossings": num_crossings,
                "obstacle_type": obstacle_type,
                "toggle_is_permenant": toggle_is_permenant,
            }
            output_data["tag"] = "AskForHelpSimpleCrossingOnce{}S{}N{}{}".format(
                obstacle_type().__class__.__name__,
                grid_size,
                num_crossings,
                cls.extra_tag(),
            )
            output_data["task_sampler_args"] = {
                "repeat_failed_task_for_min_steps": 1000
            }
            output_data["env_class"] = AskForHelpSimpleCrossing
            output_data["task_class"] = AskForHelpSimpleCrossingTask
            output_data["task_sampler_class"] = MiniGridTaskSampler

        elif name == "AskForHelpLavaCrossingOnce":
            # Specific base parameters
            grid_size = 15
            num_crossings = 7
            toggle_is_permenant = True
            obstacle_type: Callable[[], WorldObj] = Lava
            # Parameters needed for other functions
            output_data["env_info"] = {
                "size": grid_size,
                "num_crossings": num_crossings,
                "obstacle_type": obstacle_type,
                "toggle_is_permenant": toggle_is_permenant,
            }
            output_data["tag"] = "AskForHelpLavaCrossingOnce{}S{}N{}{}".format(
                obstacle_type().__class__.__name__,
                grid_size,
                num_crossings,
                cls.extra_tag(),
            )
            output_data["task_sampler_args"] = {
                "repeat_failed_task_for_min_steps": 1000
            }
            output_data["env_class"] = AskForHelpSimpleCrossing
            output_data["task_class"] = AskForHelpSimpleCrossingTask
            output_data["task_sampler_class"] = MiniGridTaskSampler

        elif name == "AskForHelpLavaCrossingSmall":
            # Specific base parameters
            grid_size = 9
            num_crossings = 4
            obstacle_type: Callable[[], WorldObj] = Lava
            # Parameters needed for other functions
            output_data["env_info"] = {
                "size": grid_size,
                "num_crossings": num_crossings,
                "obstacle_type": obstacle_type,
            }
            output_data["tag"] = "AskForHelpLavaCrossingSmall{}S{}N{}{}".format(
                obstacle_type().__class__.__name__,
                grid_size,
                num_crossings,
                cls.extra_tag(),
            )
            output_data["task_sampler_args"] = {
                "repeat_failed_task_for_min_steps": 1000
            }
            output_data["env_class"] = AskForHelpSimpleCrossing
            output_data["task_class"] = AskForHelpSimpleCrossingTask
            output_data["task_sampler_class"] = MiniGridTaskSampler

        else:
            raise NotImplementedError("Haven't implemented {}".format(name))

        output_data["name"] = name
        return output_data

    @classmethod
    def tag(cls):
        return cls.task_info()["tag"]

    @classmethod
    @abc.abstractmethod
    def extra_tag(cls):
        raise NotImplementedError

    @classmethod
    def get_sensors(cls) -> Sequence[Sensor]:
        return [
            EgocentricMiniGridSensor(
                config={
                    "agent_view_size": cls.AGENT_VIEW_SIZE,
                    "view_channels": cls.AGENT_VIEW_CHANNELS,
                }
            )
        ] + (
            [
                ExpertActionSensor(
                    {
                        "nactions": len(
                            cls.task_info()["task_class"].class_action_names()
                        )
                    }
                )
            ]
            if cls.USE_EXPERT
            else []
        )

    @classmethod
    @gin.configurable
    def machine_params(
        cls, mode="train", gpu_id="default", n_train_processes="default", **kwargs
    ):
        if mode == "train":
            if n_train_processes == "default":
                nprocesses = cls.NUM_TRAIN_SAMPLERS
            else:
                nprocesses = n_train_processes
        elif mode == "valid":
            nprocesses = 0
        elif mode == "test":
            nprocesses = min(
                cls.NUM_TEST_TASKS, 500 if torch.cuda.is_available() else 50
            )
        else:
            raise NotImplementedError("mode must be 'train', 'valid', or 'test'.")

        if gpu_id == "default":
            gpu_ids = [] if cls.GPU_ID is None else [cls.GPU_ID]
        else:
            gpu_ids = [gpu_id]

        return {"nprocesses": nprocesses, "gpu_ids": gpu_ids}

    @classmethod
    def create_model(cls, **kwargs) -> nn.Module:
        sensors = cls.get_sensors()
        return MiniGridSimpleConvRNN(
            action_space=gym.spaces.Discrete(
                len(cls.task_info()["task_class"].class_action_names())
            ),
            num_objects=cast(EgocentricMiniGridSensor, sensors[0]).num_objects,
            num_colors=cast(EgocentricMiniGridSensor, sensors[0]).num_colors,
            num_states=cast(EgocentricMiniGridSensor, sensors[0]).num_states,
            observation_space=SensorSuite(sensors).observation_spaces,
            hidden_size=128,
            rnn_type=cls.RNN_TYPE,
        )

    @classmethod
    def make_sampler_fn(cls, **kwargs) -> MiniGridTaskSampler:
        return cls.task_info()["task_sampler_class"](**kwargs)

    def train_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]] = None,
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        info = self.task_info()
        args_dict = {
            "sensors": self.get_sensors(),
            "env_class": info.get("env_class"),
            "env_info": info.get("env_info"),
            "cache_graphs": self.CACHE_GRAPHS,
            "task_class": info["task_class"],
        }
        if "task_sampler_args" in info:
            args_dict.update(info["task_sampler_args"])

        if self.NUM_TRAIN_TASKS:
            args_dict["max_tasks"] = self.NUM_TRAIN_TASKS
        return args_dict

    def valid_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]],
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        raise NotImplementedError

    def test_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]],
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        max_tasks = self.NUM_TEST_TASKS // total_processes + (
            process_ind < (self.NUM_TEST_TASKS % total_processes)
        )
        task_seeds_list = [
            2 ** 31 - 1 + self.TEST_SEED_OFFSET + process_ind + total_processes * i
            for i in range(max_tasks)
        ]

        assert min(task_seeds_list) >= 0 and max(task_seeds_list) <= 2 ** 32 - 1

        train_sampler_args = self.train_task_sampler_args(
            process_ind=process_ind,
            total_processes=total_processes,
            devices=devices,
            seeds=seeds,
            deterministic_cudnn=deterministic_cudnn,
        )
        if "repeat_failed_task_for_min_steps" in train_sampler_args:
            del train_sampler_args["repeat_failed_task_for_min_steps"]
        return {
            **train_sampler_args,
            "task_seeds_list": task_seeds_list,
            "max_tasks": max_tasks,
            "deterministic_sampling": True,
            "sensors": [
                s for s in train_sampler_args["sensors"] if "Expert" not in str(type(s))
            ],
        }

    @classmethod
    def offpolicy_demo_defaults(cls, also_using_ppo: bool):
        ppo_defaults = cls.rl_loss_default("ppo", 1)
        assert ppo_defaults["update_repeats"] % 2 == 0

        return {
            "data_iterator_builder": lambda: create_babyai_offpolicy_data_iterator(
                path=os.path.join(
                    MINIGRID_EXPERT_TRAJECTORIES_DIR,
                    "MiniGrid-{}-v0{}.pkl".format(cls.task_name(), "",),
                ),
                nrollouts=cls.NUM_TRAIN_SAMPLERS // ppo_defaults["num_mini_batch"],
                rollout_len=cls.ROLLOUT_STEPS,
                instr_len=None,
                restrict_max_steps_in_dataset=cls.TOTAL_TRAIN_STEPS,
            ),
            "ppo_update_repeats": ppo_defaults["update_repeats"] // 2
            if also_using_ppo
            else 0,
            "ppo_num_mini_batch": ppo_defaults["num_mini_batch"]
            if also_using_ppo
            else 0,
            "offpolicy_updates": ppo_defaults["num_mini_batch"]
            * (
                ppo_defaults["update_repeats"] // 2
                if also_using_ppo
                else ppo_defaults["update_repeats"]
            ),
        }

    @classmethod
    def rl_loss_default(cls, alg: str, steps: Optional[int] = None):
        if alg == "ppo":
            assert steps != None
            return {
                "loss": Builder(
                    PPO, kwargs={"clip_decay": LinearDecay(steps)}, default=PPOConfig,
                ),
                "num_mini_batch": 2,
                "update_repeats": 4,
            }
        elif alg == "a2c":
            return {
                "loss": Builder(A2C, default=A2CConfig,),
                "num_mini_batch": 1,
                "update_repeats": 1,
            }
        elif alg == "imitation":
            return {
                "loss": Builder(Imitation),
                "num_mini_batch": 2,  # if torch.cuda.is_available() else 1,
                "update_repeats": 4,
            }
        else:
            raise NotImplementedError

    @classmethod
    def _training_pipeline(
        cls,
        named_losses: Dict[str, Union[Loss, Builder]],
        pipeline_stages: List[PipelineStage],
        num_mini_batch: int,
        update_repeats: int,
    ):
        # When using many mini-batches or update repeats, decrease the learning
        # rate so that the approximate size of the gradient update is similar.
        lr = cls.lr()
        num_steps = cls.ROLLOUT_STEPS
        metric_accumulate_interval = 1000  # Log every 1000 steps
        save_interval = None
        gamma = 0.99

        use_gae = "reinforce_loss" not in named_losses
        gae_lambda = 1.0
        max_grad_norm = 0.5

        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=metric_accumulate_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses=named_losses,
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=None,
            should_log=cls.SHOULD_LOG,
            pipeline_stages=pipeline_stages,
            lr_scheduler_builder=Builder(
                LambdaLR, {"lr_lambda": LinearDecay(steps=cls.TOTAL_TRAIN_STEPS)}  # type: ignore
            ),
        )
