import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import LazyBatchNorm1d
from networks.resnet import ResNetEncoder, ResNet18

import networks.CLIP.clip.clip as clip
from networks.network_seq import Transformer as TransformerSeq
# device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

import gym

from networks.networks_base_alfred import FiLM, ImageConv, StateActionEncoder, LayerNorm, GoalEncoder, GoalEncoderLSTM


class VideoEncoder(nn.Module):
    def __init__(self, device, config, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.config = config
        self.device = device

        self.vis_encoder = ImageConv(history_frame=config.history_frame, features_dim=config.hidden_size, if_bc=True)
        transformer_width = config.hidden_size
        transformer_heads = transformer_width // 64
        transformer_layers = 12
        
        self.transformerClip = TransformerSeq(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
        )
        scale = transformer_width ** -0.5
        self.proj = nn.Parameter(scale * torch.randn(transformer_width, config.hidden_size))
        self.frame_position_embeddings = nn.Embedding(350, config.hidden_size)
        self.logit_scale = nn.Parameter(torch.ones([]))
        self.ln_post = LayerNorm(config.hidden_size)

    def _mean_pooling_for_similarity_visual(self, visual_output, video_mask):
        video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
        visual_output = visual_output * video_mask_un
        video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
        video_mask_un_sum[video_mask_un_sum == 0.] = 1.
        video_out = torch.sum(visual_output, dim=1) / video_mask_un_sum
        return video_out
    
    def forward(self, states, attn_mask):
        bs = states.shape[0]
        seq_len = states.shape[1]
        
        states = states.unsqueeze(2).repeat(1, 1, self.config.history_frame, *[1] * (states.dim() - 2))
        
        vis_emb = self.vis_encoder(states)
        
        vis_emb = vis_emb.view(bs, seq_len, -1)
        
        position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device)
        position_ids = position_ids.unsqueeze(0).expand(bs, -1)
        frame_position_embeddings = self.frame_position_embeddings(position_ids)
        visual_output = vis_emb + frame_position_embeddings
        
        extended_video_mask = (1.0 - attn_mask.unsqueeze(1)) * -1000000.0
        extended_video_mask = extended_video_mask.expand(-1, attn_mask.size(1), -1)
        
        visual_output = visual_output.permute(1, 0, 2)  # NLD -> LND
        visual_output = self.transformerClip(visual_output, extended_video_mask)
        visual_output = visual_output.permute(1, 0, 2)  # LND -> NLD
        
        visual_output = visual_output @ self.proj
        visual_output = self.ln_post(visual_output)

        visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True)        
        visual_output = self._mean_pooling_for_similarity_visual(visual_output, attn_mask)
        visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True)

        return visual_output

class DecisionMaker(nn.Module):
    def __init__(self, feature_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2) -> None:
        super().__init__()

        self.action_size = action_size
        self.feature_size = feature_size

        self.lstm = nn.LSTM(
            input_size=feature_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=0.2
        )

        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()

        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()

        self.ff3 = nn.Linear(feature_size, action_size)

        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)

        self.fc_lstm_to_action = nn.Linear(lstm_hidden_size, action_size)

    def forward(self, vis_emb, goal_emb, batch_size, seq_len, hidden_states=None, cell_states=None, single_step=False):
        goal_emb = goal_emb.unsqueeze(1).repeat(1, seq_len, 1)


        x = self.film1(vis_emb, goal_emb) + vis_emb
        x = self.film2(x, goal_emb) + x
        
        if hidden_states is None:
            x, (hidden_states, cell_states) = self.lstm(x)
        else:
            x, (hidden_states, cell_states) = self.lstm(x, (hidden_states, cell_states))

        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)

        return x, hidden_states, cell_states


class BCZNetwork(nn.Module):
    def __init__(self, config, hidden_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2):
        super().__init__()

        self.config = config

        self.hidden_size = hidden_size
        self.action_size = action_size

        self.goal_encoder = GoalEncoder(hidden_size)
        self.vis_encoder = ImageConv(history_frame=config.history_frame, features_dim=hidden_size, if_bc=True)

        self.decision_maker = DecisionMaker(
            feature_size=hidden_size,
            action_size=action_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers,
        )

        self.video_encoder = VideoEncoder(config.device, config)

    def forward(self, states, goals, hidden_states=None, cell_states=None, single_step=False):
        batch_size = goals.shape[0]
        goal_emb = self.goal_encoder(goals)
            
        batch_size = states.shape[0]
        seq_len = states.shape[1]

        states = states.unsqueeze(2).repeat(1, 1, self.config.history_frame, *[1] * (states.dim() - 2))
        states_emb = self.vis_encoder(states)  

        logits, hidden_states, cell_states = self.decision_maker(states_emb, goal_emb, batch_size, seq_len, hidden_states, cell_states, single_step)
        return logits, hidden_states, cell_states
    
    def get_video_lang_emb(self, states_seq, attn_mask, goal):
        video_emb = self.video_encoder(states_seq, attn_mask)
        lang_emb = self.goal_encoder(goal)
        return video_emb, lang_emb

