import torch
import torch.nn as nn
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

import gym

from networks.networks_base_alfred import FiLM, ImageConv, StateActionEncoder, LayerNorm, GoalEncoder, GoalEncoderLSTM

class DecisionMaker(nn.Module):
    def __init__(self, feature_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2) -> None:
        super().__init__()

        self.action_size = action_size
        self.feature_size = feature_size

        self.lstm = nn.LSTM(
            input_size=feature_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=0.2
        )

        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()

        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()

        self.ff3 = nn.Linear(feature_size, action_size)

        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)

        self.fc_lstm_to_action = nn.Linear(lstm_hidden_size, action_size)

    def forward(self, vis_emb, goal_emb, batch_size, seq_len, hidden_states=None, cell_states=None, single_step=False):
        
        goal_emb = goal_emb.unsqueeze(1).repeat(1, seq_len, 1)


        x = self.film1(vis_emb, goal_emb) + vis_emb
        x = self.film2(x, goal_emb) + x
        
        if hidden_states is None:
            x, (hidden_states, cell_states) = self.lstm(x)
        else:
            x, (hidden_states, cell_states) = self.lstm(x, (hidden_states, cell_states))


        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)

        return x, hidden_states, cell_states


class BCNetwork(nn.Module):
    def __init__(self, config, hidden_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2):
        super().__init__()

        self.config = config

        self.hidden_size = hidden_size
        self.action_size = action_size

        self.goal_encoder = GoalEncoder(hidden_size)

        self.visual_encoder = ImageConv(history_frame=config.history_frame, features_dim=hidden_size, if_bc=True)

        self.decision_maker = DecisionMaker(
            feature_size=hidden_size,
            action_size=action_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers,
        )

    def forward(self, states, goals, hidden_states=None, cell_states=None, single_step=False):
        batch_size = goals.shape[0]
        goal_emb = self.goal_encoder(goals)
            
        k = self.action_size

        batch_size = states.shape[0]
        seq_len = states.shape[1]

        states = states.unsqueeze(2).repeat(1, 1, self.config.history_frame, *[1] * (states.dim() - 2))

        states_emb = self.visual_encoder(states)  

        logits, hidden_states, cell_states = self.decision_maker(states_emb, goal_emb, batch_size, seq_len, hidden_states, cell_states, single_step)
        return logits, hidden_states, cell_states

