import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Optional, Tuple
from efficientnet_pytorch import EfficientNet
from .base_model import BaseModel
from .self_attention import MultiLayerDecoder
import pdb
from .ssl_encoder import SSL_Encoder
from .vit import VisionTransformer

class VINT_V9_SA_SSL(BaseModel):
    def __init__(
        self,
        ssl_path: str,
        ssl_model: str,
        context_size: int = 5,
        len_traj_pred: Optional[int] = 5,
        learn_angle: Optional[bool] = True,
        # goal_encoder: Optional[str] = None,
        goal_embedding_size: Optional[int] = None,
        obs_encoding_size: Optional[int] = 512,
        # freeze_obs_encoder: Optional[bool] = True,
        # freeze_goal_encoder: Optional[bool] = True,
        mha_num_attention_heads: Optional[int] = 2,
        mha_num_attention_layers: Optional[int] = 2,
        mha_ff_dim_factor: Optional[int] = 4,
    ) -> None:
        """
        VINTv9 class: Uses the architecture from v4, but uses a multi-headed self-attention block for the context.
        Goal-conditioning is standard: goal is concatenated to the context.
        """
        super(VINT_V9_SA_SSL, self).__init__(context_size, len_traj_pred, learn_angle)
        self.obs_encoding_size = obs_encoding_size
        self.goal_encoding_size = obs_encoding_size
        # self.goal_encoder = goal_encoder

        self.obs_encoder = SSL_Encoder(
            ssl_model, obs_encoding_size,
            1 + context_size, ssl_path
        ) # context

        self.goal_encoder = SSL_Encoder(
            ssl_model, obs_encoding_size,
            2, ssl_path
        ) # obs+goal

        self.decoder = MultiLayerDecoder(
            embed_dim=self.obs_encoding_size,
            seq_len=self.context_size+2,
            output_layers=[256, 128, 64, 32],
            nhead=mha_num_attention_heads,
            num_layers=mha_num_attention_layers,
            ff_dim_factor=mha_ff_dim_factor,
        )
        self.dist_predictor = nn.Sequential(
            nn.Linear(32, 1),
        )
        self.action_predictor = nn.Sequential(
            nn.Linear(32, self.len_trajectory_pred * self.num_action_params),
        )

    def forward(
        self, obs_img: torch.tensor, goal_img: torch.tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        ### OBS_GOAL ENCODING STUFF ###
        obs_encoding = self.obs_encoder(obs_img)
        obsgoal_img = torch.cat([obs_img[:, 3*self.context_size:, :, :], goal_img], dim=1)
        obsgoal_embedding = self.goal_encoder(obsgoal_img)
        assert obsgoal_embedding.shape[2] == self.goal_encoding_size
        ###############################
        
        obs_encoding = torch.cat((obs_encoding, obsgoal_embedding), dim=1)
        final_repr = self.decoder(obs_encoding)
        dist_pred = self.dist_predictor(final_repr)
        action_pred = self.action_predictor(final_repr)

        # augment outputs to match labels size-wise
        action_pred = action_pred.reshape(
            (action_pred.shape[0], self.len_trajectory_pred, self.num_action_params)
        )
        action_pred[:, :, :2] = torch.cumsum(
            action_pred[:, :, :2], dim=1
        )  # convert position deltas into waypoints
        if self.learn_angle:
            action_pred[:, :, 2:] = F.normalize(
                action_pred[:, :, 2:].clone(), dim=-1
            )  # normalize the angle prediction
        return dist_pred, action_pred