from typing import Optional
from typing import Sequence
from typing import Sequence, Union, Optional, Dict, Tuple, Type, List

import attr
import gym
from gym.spaces.dict import Dict as SpaceDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

from allenact.algorithms.onpolicy_sync.losses import PPO
from allenact.algorithms.onpolicy_sync.losses.ppo import PPOConfig
from allenact.base_abstractions.sensor import Sensor
from allenact.embodiedai.sensors.vision_sensors import RGBSensor, DepthSensor
from allenact.utils.experiment_utils import (
    Builder,
    TrainingPipeline,
    PipelineStage,
    LinearDecay,
)
from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.embodiedai.models.visual_nav_models import VisualNavActorCritic
from allenact.algorithms.onpolicy_sync.policy import (
    ObservationType,
    DistributionType
)
from allenact_plugins.clip_plugin.clip_preprocessors import (
    NaivePreprocessor
)
from projects.object_navigation.navigation.model import IthorDisentangledVAE
from projects.object_navigation.baseline_configs.navigation_base import update_with_auxiliary_losses

# fmt: off
try:
    # Habitat may not be installed, just create a fake class here in that case
    from allenact_plugins.habitat_plugin.habitat_sensors import TargetCoordinatesSensorHabitat
except ImportError:
    class TargetCoordinatesSensorHabitat:  #type:ignore
        pass
# fmt: on

from projects.plugins.robothor_plugin.robothor_sensors import GPSCompassSensorRoboThor
from projects.plugins.robothor_plugin.robothor_tasks import PointNavTask
from projects.plugins.navigation_plugin.pointnav.models import PointNavActorCritic


@attr.s(kw_only=True)
class PointNavUnfrozenResNetWithGRUActorCriticMixin:
    backbone: str = attr.ib()
    sensors: Sequence[Sensor] = attr.ib()
    auxiliary_uuids: Sequence[str] = attr.ib()
    add_prev_actions: bool = attr.ib()
    multiple_beliefs: bool = attr.ib()
    belief_fusion: Optional[str] = attr.ib()

    def create_model(self, **kwargs) -> nn.Module:
        rgb_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, RGBSensor)), None
        )
        depth_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, DepthSensor)), None
        )
        goal_sensor_uuid = next(
            (
                s.uuid
                for s in self.sensors
                if isinstance(
                    s, (GPSCompassSensorRoboThor, TargetCoordinatesSensorHabitat)
                )
            )
        )

        return PointNavActorCritic(
            # Env and Tak
            action_space=gym.spaces.Discrete(len(PointNavTask.class_action_names())),
            observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces,
            rgb_uuid=rgb_uuid,
            depth_uuid=depth_uuid,
            goal_sensor_uuid=goal_sensor_uuid,
            # RNN
            hidden_size=228
            if self.multiple_beliefs and len(self.auxiliary_uuids) > 1
            else 512,
            num_rnn_layers=1,
            rnn_type="GRU",
            add_prev_actions=self.add_prev_actions,
            action_embed_size=4,
            # CNN
            backbone=self.backbone,
            resnet_baseplanes=32,
            embed_coordinates=False,
            coordinate_dims=2,
            # Aux
            auxiliary_uuids=self.auxiliary_uuids,
            multiple_beliefs=self.multiple_beliefs,
            beliefs_fusion=self.belief_fusion,
        )


class PointNavPPOMixin:
    @staticmethod
    def training_pipeline(
        auxiliary_uuids: Sequence[str],
        multiple_beliefs: bool,
        normalize_advantage: bool,
        advance_scene_rollout_period: Optional[int] = None,
    ) -> TrainingPipeline:
        ppo_steps = int(2000000)
        lr = 3e-4
        num_mini_batch = 1
        update_repeats = 4
        num_steps = 128
        save_interval = 50000
        log_interval = 10000 if torch.cuda.is_available() else 1
        gamma = 0.99
        use_gae = True
        gae_lambda = 0.95
        max_grad_norm = 0.5

        named_losses = {
            "ppo_loss": (PPO(**PPOConfig, normalize_advantage=normalize_advantage), 1.0)
        }
        named_losses = update_with_auxiliary_losses(
            named_losses=named_losses,
            auxiliary_uuids=auxiliary_uuids,
            multiple_beliefs=multiple_beliefs,
        )

        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=log_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses={key: val[0] for key, val in named_losses.items()},
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=advance_scene_rollout_period,
            pipeline_stages=[
                PipelineStage(
                    loss_names=list(named_losses.keys()),
                    max_stage_steps=ppo_steps,
                    loss_weights=[val[1] for val in named_losses.values()],
                )
            ],
            lr_scheduler_builder=Builder(
                LambdaLR, {"lr_lambda": LinearDecay(steps=75000000)}
            ),
        )

# policies
@attr.s(kw_only=True)
class VAEPreprocessGRUActorCriticMixin:
    sensors: Sequence[Sensor] = attr.ib()
    screen_size: int = attr.ib()
    goal_sensor_type: Type[Sensor] = attr.ib()
    target_types: List[str] = attr.ib()
    pool: bool = attr.ib(default=False)
    pooling_type: str = attr.ib()
    clip_model_type: str = attr.ib()
    source_model: str = attr.ib(default=None)

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        rgb_sensor = next((s for s in self.sensors if isinstance(s, RGBSensor)), None)
        goal_sensor = next((s for s in self.sensors if isinstance(s, GPSCompassSensorRoboThor)), None)
        self.goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, self.goal_sensor_type)),
            None,
        )
        preprocessor_model = NaivePreprocessor

        assert rgb_sensor is not None and goal_sensor is not None

        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_means)
                - np.array(preprocessor_model.CLIP_RGB_MEANS)
            )
            < 1e-5
        )
        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_sds)
                - np.array(preprocessor_model.CLIP_RGB_STDS)
            )
            < 1e-5
        )

        preprocessor = preprocessor_model(
                rgb_input_uuid=rgb_sensor.uuid,
                goal_sensor_uuid=self.goal_sensor_uuid,
                clip_model_type=self.clip_model_type,
                pool=self.pool,
                pooling_type=self.pooling_type,
                class_emb_only = False,
                output_uuid="rgb_clip_vit",
            )
        
        self.preprocessor_output_shape = preprocessor.output_shape

        preprocessors = [
            preprocessor
        ]

        return preprocessors

    def create_model(self, num_actions: int, **kwargs) -> nn.Module:
        goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, self.goal_sensor_type)),
            None,
        )
        return VAEObjectNavActorCritic(
            action_space=gym.spaces.Discrete(num_actions),
            observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces,
            goal_sensor_uuid=goal_sensor_uuid,
            hidden_size=1024,
            rgb_preprocessor_uuid='rgb_clip_vit',
            embedding_dim = 32 + 2,
            source_model=self.source_model
        )


class VAEObjectNavActorCritic(VisualNavActorCritic):
    def __init__(
        # base params
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str,
        rgb_preprocessor_uuid: str,

        # RNN
        hidden_size=1024,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        # custom params
        embedding_dim: int = 512,
        source_model=None
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size
        )

        assert rgb_preprocessor_uuid is not None

        self.rgb_preprocessor_uuid = rgb_preprocessor_uuid

        self.create_state_encoders(
            obs_embed_size=embedding_dim,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=embedding_dim,
            action_embed_size=action_embed_size,
        )

        self.train()

        self.embedder = IthorDisentangledVAE(class_latent_size=16, content_latent_size = 32)
        self.embedder.load_state_dict(torch.load(source_model))
        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.embedder.named_parameters():
            param.requires_grad_(False)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    @property
    def is_blind(self) -> bool:
        return False

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:
        # observaion shaping
        obs = observations[self.rgb_preprocessor_uuid].to(self.device)
        x = obs[:, :, :3*224*224].detach().clone()
        B, env_n, _ = x.shape
        x = x.view(B*env_n, 3, 224, 224)
        goal = obs[:, :, 3*224*224:].detach().clone()
        x = self.embedder.encoder.get_feature(x)
        x = x.view(B, env_n, x.size(-1))
        x = torch.cat([x, goal], dim=-1)
        return x