import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import LazyBatchNorm1d
from networks.resnet import ResNetEncoder, ResNet18

import networks.CLIP.clip.clip as clip
from networks.network_seq import Transformer as TransformerSeq
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

import gym

from networks.networks_base_alfred import FiLM, ImageConv, StateActionEncoder, LayerNorm, GoalEncoder, GoalEncoderLSTM

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence


class PriorEncoder(nn.Module):
    def __init__(self, lang_emb_size, state_emb_size, hidden_size=1024, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.hidden_size = hidden_size
        self.fc_1 = nn.Linear(state_emb_size*2, hidden_size)
        self.fc_2 = nn.Linear(hidden_size, hidden_size)
        self.fc_3 = nn.Linear(hidden_size, lang_emb_size)

    def forward(self, state_emb_init, state_emb_goal):
        state_emb = torch.cat((state_emb_init, state_emb_goal), dim=1)
        x = F.relu(self.fc_1(state_emb))
        x = F.relu(self.fc_2(x))
        x = F.relu(self.fc_3(x))
        return x

class DecisionMaker(nn.Module):
    def __init__(self, feature_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2) -> None:
        super().__init__()

        self.action_size = action_size
        self.feature_size = feature_size

        self.lstm = nn.LSTM(
            input_size=feature_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=0.2
        )

        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()

        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()

        self.ff3 = nn.Linear(feature_size, action_size)

        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)

        self.fc_lstm_to_action = nn.Linear(lstm_hidden_size, action_size)

    def forward(self, vis_emb, goal_emb, seq_len, hidden_states=None, cell_states=None, single_step=False):
        # goal_emb = goal_emb.unsqueeze(1).repeat(1, seq_len, 1).view(batch_size * seq_len, -1) 
        goal_emb = goal_emb.unsqueeze(1).repeat(1, seq_len, 1)


        x = self.film1(vis_emb, goal_emb) + vis_emb
        x = self.film2(x, goal_emb) + x
        
        if hidden_states is None:
            x, (hidden_states, cell_states) = self.lstm(x)
        else:
            x, (hidden_states, cell_states) = self.lstm(x, (hidden_states, cell_states))

        # if single_step:
        #     if (hidden_states is not None) and (cell_states is not None):
        #     else:
        #         x, (hidden_states, cell_states) = self.lstm(x)

        # print(f"------x after lstm: {x.shape}")
        # print(f"------x after fc_lstm_to_action: {x.shape}")

        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)

        # print(f"------x after decisionMaker: {x.shape}")

        return x, hidden_states, cell_states


class GRIFNetwork(nn.Module):
    def __init__(self, config, hidden_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2):
        super().__init__()

        self.config = config

        self.hidden_size = hidden_size
        self.action_size = action_size

        self.goal_encoder = GoalEncoder(hidden_size)
        self.prior_encoder = PriorEncoder(state_emb_size=hidden_size, lang_emb_size=hidden_size)

        self.vis_encoder = ImageConv(history_frame=config.history_frame, features_dim=hidden_size, if_bc=True)

        self.decision_maker = DecisionMaker(
            feature_size=hidden_size,
            action_size=action_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers,
        )

        self.logit_scale = nn.Parameter(torch.ones([]))

    def forward(self, states, goals, episode_length, hidden_states=None, cell_states=None, single_step=False):
        bs = states.shape[0]
        seq_len = states.shape[1]
        lang_emb = self.goal_encoder(goals)

            
        k = self.action_size

        states = states.unsqueeze(2).repeat(1, 1, self.config.history_frame, *[1] * (states.dim() - 2))
        
        vis_emb = self.vis_encoder(states)
        
        # Compute the prior embedding
        episode_length = episode_length.to(torch.int64) - 1
        state_emb_init = vis_emb[:, 0, :]
        episode_length = episode_length.unsqueeze(-1).expand(-1, vis_emb.shape[-1]) # (bs, state_emb_size)
        state_emb_goal = torch.gather(vis_emb, 1, episode_length.unsqueeze(1)).squeeze(1)
        # state_goal = vis_emb[:, episode_length, :]
        goal_prior_emb = self.prior_encoder(state_emb_init, state_emb_goal)
        # goal_prior_emb = goal_prior_emb.unsqueeze(1).repeat(1, seq_len, 1).view(bs*seq_len, -1)
        lang_policy_output = self.decision_maker(vis_emb, lang_emb, seq_len, hidden_states, cell_states, single_step)
        prior_policy_output = self.decision_maker(vis_emb, goal_prior_emb, seq_len, hidden_states, cell_states, single_step)

        return lang_policy_output, prior_policy_output

    def get_prior_lang_emb(self, states_seq, episode_length, goal):
        lang_emb = self.goal_encoder(goal)
        bs = states_seq.shape[0]
        seq_len = states_seq.shape[1]
        states_seq = states_seq.unsqueeze(2).repeat(1, 1, self.config.history_frame, *[1] * (states_seq.dim() - 2))
        # states_seq = states_seq.view(bs*seq_len, *states_seq.shape[2:])
        states_emb = self.vis_encoder(states_seq)
        vis_emb = states_emb.view(bs, seq_len, -1)

        episode_length = episode_length.to(torch.int64) - 1
        state_emb_init = vis_emb[:, 0, :]
        episode_length = episode_length.unsqueeze(-1).expand(-1, vis_emb.shape[-1]) # (bs, state_emb_size)
        state_emb_goal = torch.gather(vis_emb, 1, episode_length.unsqueeze(1)).squeeze(1)
        goal_prior_emb = self.prior_encoder(state_emb_init, state_emb_goal)

        return goal_prior_emb, lang_emb, self.logit_scale
    
    def get_action(self, states_seq, goal, hidden_states=None, cell_states=None, single_step=True):
        lang_emb = self.goal_encoder(goal)

        bs = states_seq.shape[0]
        seq_len = states_seq.shape[1]

        states_seq = states_seq.view(bs, seq_len, 1, *states_seq.shape[2:])
        states_emb = self.vis_encoder(states_seq)
        vis_emb = states_emb.view(bs, seq_len, -1)

        output, h_n, c_n = self.decision_maker(vis_emb, lang_emb, seq_len, hidden_states, cell_states, single_step)
        return output, h_n, c_n