from typing import Dict, Any, List, Optional, Tuple, Union

import gin
import gym
import torch
import torch.nn as nn
from torch import optim
from torch.optim.lr_scheduler import LambdaLR

from extensions.rl_lighthouse.lighthouse_environment import LightHouseEnvironment
from extensions.rl_lighthouse.lighthouse_sensors import FactorialDesignCornerSensor
from extensions.rl_lighthouse.lighthouse_tasks import FindGoalLightHouseTaskSampler
from extensions.rl_lighthouse.lighthouse_util import StopIfNearOptimal
from models.basic_models import LinearActorCritic, RNNActorCritic
from onpolicy_sync.losses import PPO, A2C
from onpolicy_sync.losses.a2cacktr import A2CConfig
from onpolicy_sync.losses.imitation import Imitation
from onpolicy_sync.losses.ppo import PPOConfig
from rl_base.common import Loss
from rl_base.experiment_config import ExperimentConfig
from rl_base.sensor import SensorSuite, ExpertPolicySensor, Sensor
from rl_base.task import TaskSampler
from utils.experiment_utils import Builder, LinearDecay, PipelineStage, TrainingPipeline


class BaseLightHouseExperimentConfig(ExperimentConfig):
    """Base experimental config."""

    WORLD_DIM = 2
    VIEW_RADIUS = 1
    EXPERT_VIEW_RADIUS = 15
    WORLD_RADIUS = 15
    DEGREE = -1
    MAX_STEPS = 1000
    GPU_ID: Optional[int] = None
    NUM_TRAIN_SAMPLERS: int = 20 if torch.cuda.is_available() else 2
    NUM_TEST_TASKS = 200
    RECURRENT_MODEL = False
    TOTAL_TRAIN_STEPS = int(3e5)
    SHOULD_LOG = False if torch.cuda.is_available() else True

    TEST_SEED_OFFSET: int = 0

    # `DEFAULT_LR` chosen by optimizing the performance of imitation
    # learning (see the `lighthouse_optimize_lr.py` and `summarize_lr_optimization.py` scripts).
    DEFAULT_LR = 0.0242

    _SENSOR_CACHE: Dict[Tuple[int, int, int], List[Sensor]] = {}

    @classmethod
    def lr(cls):
        return cls.DEFAULT_LR

    @classmethod
    def get_sensors(cls):
        key = (
            cls.VIEW_RADIUS,
            cls.WORLD_DIM,
            (None if cls.RECURRENT_MODEL else cls.DEGREE),
            cls.EXPERT_VIEW_RADIUS,
        )

        assert (not cls.RECURRENT_MODEL) or cls.DEGREE == 1

        if key not in cls._SENSOR_CACHE:
            sensors = [
                FactorialDesignCornerSensor(
                    config={
                        "view_radius": cls.VIEW_RADIUS,
                        "world_dim": cls.WORLD_DIM,
                        "degree": cls.DEGREE,
                    }
                )
            ]
            if cls.EXPERT_VIEW_RADIUS:
                sensors.append(
                    ExpertPolicySensor(
                        {
                            "expert_args": {
                                "expert_view_radius": cls.EXPERT_VIEW_RADIUS
                            },
                            "nactions": 2 * cls.WORLD_DIM,
                        }
                    )
                )
            cls._SENSOR_CACHE[key] = sensors

        return cls._SENSOR_CACHE[key]

    @classmethod
    def optimal_ave_ep_length(cls):
        return LightHouseEnvironment.optimal_ave_ep_length(
            world_dim=cls.WORLD_DIM,
            world_radius=cls.WORLD_RADIUS,
            view_radius=cls.VIEW_RADIUS,
        )

    @classmethod
    def get_early_stopping_criterion(cls):
        optimal_ave_ep_length = cls.optimal_ave_ep_length()

        return StopIfNearOptimal(
            optimal=optimal_ave_ep_length,
            deviation=optimal_ave_ep_length * 0.05,
            min_memory_size=50,
        )

    @classmethod
    def rl_loss_default(cls, alg: str, steps: Optional[int] = None):
        if alg == "ppo":
            assert steps != None
            return {
                "loss": Builder(
                    PPO, kwargs={"clip_decay": LinearDecay(steps)}, default=PPOConfig,
                ),
                "num_mini_batch": 2,
                "update_repeats": 4,
            }
        elif alg == "a2c":
            return {
                "loss": Builder(A2C, default=A2CConfig,),
                "num_mini_batch": 1,
                "update_repeats": 1,
            }
        elif alg == "imitation":
            return {
                "loss": Builder(Imitation),
                "num_mini_batch": 2,
                "update_repeats": 4,
            }
        else:
            raise NotImplementedError

    @classmethod
    def _training_pipeline(
        cls,
        named_losses: Dict[str, Union[Loss, Builder]],
        pipeline_stages: List[PipelineStage],
        num_mini_batch: int,
        update_repeats: int,
        lr: Optional[float] = None,
    ):
        # When using many mini-batches or update repeats, decrease the learning
        # rate so that the approximate size of the gradient update is similar.
        lr = cls.DEFAULT_LR if lr is None else lr
        num_steps = 100
        metric_accumulate_interval = cls.MAX_STEPS * 10  # Log every 10 max length tasks
        save_interval = 500000
        gamma = 0.99

        use_gae = "reinforce_loss" not in named_losses
        gae_lambda = 1.0
        max_grad_norm = 0.5

        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=metric_accumulate_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses=named_losses,
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=None,
            should_log=cls.SHOULD_LOG,
            pipeline_stages=pipeline_stages,
            lr_scheduler_builder=Builder(
                LambdaLR, {"lr_lambda": LinearDecay(steps=cls.TOTAL_TRAIN_STEPS)}  # type: ignore
            ),
        )

    @classmethod
    @gin.configurable
    def machine_params(
        cls, mode="train", gpu_id="default", n_train_processes="default", **kwargs
    ):
        if mode == "train":
            if n_train_processes == "default":
                nprocesses = cls.NUM_TRAIN_SAMPLERS
            else:
                nprocesses = n_train_processes
        elif mode == "valid":
            nprocesses = 0
        elif mode == "test":
            nprocesses = min(
                cls.NUM_TEST_TASKS, 500 if torch.cuda.is_available() else 50
            )
        else:
            raise NotImplementedError("mode must be 'train', 'valid', or 'test'.")

        if gpu_id == "default":
            gpu_ids = [] if cls.GPU_ID is None else [cls.GPU_ID]
        else:
            gpu_ids = [gpu_id]

        return {"nprocesses": nprocesses, "gpu_ids": gpu_ids}

    @classmethod
    def create_model(cls, **kwargs) -> nn.Module:
        sensors = cls.get_sensors()
        if cls.RECURRENT_MODEL:
            return RNNActorCritic(
                input_key=sensors[0].uuid,
                action_space=gym.spaces.Discrete(2 * cls.WORLD_DIM),
                observation_space=SensorSuite(sensors).observation_spaces,
                rnn_type="LSTM",
            )
        else:
            return LinearActorCritic(
                input_key=sensors[0].uuid,
                action_space=gym.spaces.Discrete(2 * cls.WORLD_DIM),
                observation_space=SensorSuite(sensors).observation_spaces,
            )

    @classmethod
    def make_sampler_fn(cls, **kwargs) -> TaskSampler:
        return FindGoalLightHouseTaskSampler(**kwargs)

    def train_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]] = None,
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        return {
            "world_dim": self.WORLD_DIM,
            "world_radius": self.WORLD_RADIUS,
            "max_steps": self.MAX_STEPS,
            "sensors": self.get_sensors(),
            "action_space": gym.spaces.Discrete(2 * self.WORLD_DIM),
            "seed": seeds[process_ind] if seeds is not None else None,
        }

    def valid_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]],
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        raise NotImplementedError

    def test_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]],
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        max_tasks = self.NUM_TEST_TASKS // total_processes + (
            process_ind < (self.NUM_TEST_TASKS % total_processes)
        )
        task_seeds_list = [
            2 ** 31 - 1 + self.TEST_SEED_OFFSET + process_ind + total_processes * i
            for i in range(max_tasks)
        ]

        assert min(task_seeds_list) >= 0 and max(task_seeds_list) <= 2 ** 32 - 1

        train_sampler_args = self.train_task_sampler_args(
            process_ind=process_ind,
            total_processes=total_processes,
            devices=devices,
            seeds=seeds,
            deterministic_cudnn=deterministic_cudnn,
        )
        return {
            **train_sampler_args,
            "task_seeds_list": task_seeds_list,
            "max_tasks": max_tasks,
            "deterministic_sampling": True,
        }
