import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import LazyBatchNorm1d
from networks.resnet import ResNetEncoder, ResNet18

import networks.CLIP.clip.clip as clip
from networks.network_seq import Transformer as TransformerSeq

from networks.networks_base_alfred import FiLM, ImageConv, StateActionEncoder, LayerNorm, GoalEncoder, GoalEncoderLSTM

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)
    
class QNetwork(nn.Module):
    def __init__(self, feature_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2) -> None:
        super().__init__()

        self.feature_size = feature_size
        self.action_size = action_size

        # print(f"-----feature_size:{feature_size}")

        self.lstm = nn.LSTM(
            input_size=feature_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=0.2
        )

        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()

        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()

        self.ff3 = nn.Linear(feature_size, action_size)

        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)

        # self.fc_lstm_to_action = nn.Linear(lstm_hidden_size, action_size)

    def forward(self, vis_emb, goal_emb, seq_len, hidden_states=None, cell_states=None):
        goal_emb = goal_emb.unsqueeze(1).repeat(1, seq_len, 1)


        x = self.film1(vis_emb, goal_emb) + vis_emb
        # print(f"-----x after FiLM1: {x.shape}")
        x = self.film2(x, goal_emb) + x
        # print(f"-----x after FiLM2: {x.shape}")

        if hidden_states is None:
            x, (hidden_states, cell_states) = self.lstm(x)
        else:
            x, (hidden_states, cell_states) = self.lstm(x, (hidden_states, cell_states))
        # print(f"-----x after LSTM: {x.shape}")

        x = self.ff1(x)
        x = self.relu1(x)
        # print(f"-----x after ff1: {x.shape}")
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)

        return x, hidden_states, cell_states
    
class ValueNetwork(nn.Module):
    def __init__(self, feature_size=512, lstm_hidden_size=512, lstm_layers=2) -> None:
        super().__init__()

        self.feature_size = feature_size

        self.lstm = nn.LSTM(
            input_size=feature_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=0.2
        )

        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()

        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()

        self.ff3 = nn.Linear(feature_size, 1)

        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)


    def forward(self, vis_emb, goal_emb, seq_len, hidden_states=None, cell_states=None):
        goal_emb = goal_emb.unsqueeze(1).repeat(1, seq_len, 1)


        x = self.film1(vis_emb, goal_emb) + vis_emb
        x = self.film2(x, goal_emb) + x
        
        if hidden_states is None:
            x, (hidden_states, cell_states) = self.lstm(x)
        else:
            x, (hidden_states, cell_states) = self.lstm(x, (hidden_states, cell_states))

        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)

        return x, hidden_states, cell_states

class ActorNetwork(nn.Module):
    def __init__(self, feature_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2) -> None:
        super().__init__()

        self.action_size = action_size
        self.feature_size = feature_size

        self.lstm = nn.LSTM(
            input_size=feature_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=0.2
        )

        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()

        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()

        self.ff3 = nn.Linear(feature_size, action_size)

        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)

    def forward(self, vis_emb, goal_emb, seq_len, hidden_states=None, cell_states=None):
        goal_emb = goal_emb.unsqueeze(1).repeat(1, seq_len, 1)


        x = self.film1(vis_emb, goal_emb) + vis_emb
        x = self.film2(x, goal_emb) + x
        
        if hidden_states is None:
            x, (hidden_states, cell_states) = self.lstm(x)
        else:
            x, (hidden_states, cell_states) = self.lstm(x, (hidden_states, cell_states))

        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)

        return x, hidden_states, cell_states
    
class IQLNetwork(nn.Module):
    def __init__(self, config, hidden_size=512, action_size=7, lstm_hidden_size=512, lstm_layers=2) -> None:
        super().__init__()

        self.config = config

        self.hidden_size = hidden_size
        self.action_size = action_size

        self.goal_encoder = GoalEncoder(hidden_size)

        self.visual_encoder = ImageConv(history_frame=config.history_frame, features_dim=hidden_size, if_bc=True)

        self.q_network_1 = QNetwork(
            feature_size=hidden_size,
            action_size=action_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers)
            
        self.q_network_2 = QNetwork(
            feature_size=hidden_size,
            action_size=action_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers)

        self.q_target_network_1 = QNetwork(
            feature_size=hidden_size,
            action_size=action_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers)

        self.q_target_network_2 = QNetwork(
            feature_size=hidden_size,
            action_size=action_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers)

        self.q_target_network_1.load_state_dict(self.q_network_1.state_dict())
        self.q_target_network_2.load_state_dict(self.q_network_2.state_dict())

        self.q_target_network_1.eval()
        self.q_target_network_2.eval()

        self.value_network = ValueNetwork(
            feature_size=hidden_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers)

        self.actor_network = ActorNetwork(
            feature_size=hidden_size,
            action_size=action_size,
            lstm_hidden_size=lstm_hidden_size,
            lstm_layers=lstm_layers)

        # self.pure_vis_encoder = PureVisEncoder(config)

        
        # transformer_width = config.state_emb_size
        # transformer_heads = transformer_width // 64
        # transformer_layers = len(set(k.split(".")[2] for k in self.lang_goal_encoder.clip.state_dict() if k.startswith("transformer.resblocks")))
        
        # self.transformerClip = TransformerSeq(
        #     width=transformer_width,
        #     layers=transformer_layers,
        #     heads=transformer_heads,
        # )
        # scale = transformer_width ** -0.5
        # self.proj = nn.Parameter(scale * torch.randn(transformer_width, config.lang_emb_size))
        # self.frame_position_embeddings = nn.Embedding(50, config.state_emb_size)
        # self.ln_post = LayerNorm(transformer_width)
        # self.logit_scale = nn.Parameter(torch.ones([]))
   
    def get_state_values(self, state, goal, hidden_states=None, cell_states=None):
        seq_len = states.shape[1]

        goal_emb = self.goal_encoder(goal)
        state_emb = self.visual_encoder(state)

        values, _, _ = self.value_network(state_emb, goal_emb, seq_len, hidden_states, cell_states)

        return value

    def get_q_values(self, states, goal, hidden_states=None, cell_states=None, target=False):
        # batch_size x seq_len x frame_size x pic_dims (3x7x7)
        
        goal_emb = self.goal_encoder(goal)
        states = states.unsqueeze(2).repeat(1, 1, self.config.history_frame, *[1] * (states.dim() - 2))
        # print(f"----state shape:{states.shape}")
        state_emb = self.visual_encoder(states)
        # vis_emb = self.visual_encoder(state, lang_emb)
        seq_len = states.shape[1]
        if target:
            q_values_1, _, _ = self.q_target_network_1(state_emb, goal_emb, seq_len, hidden_states, cell_states)
            q_values_2, _, _ = self.q_target_network_2(state_emb, goal_emb, seq_len, hidden_states, cell_states)
        else:
            q_values_1, _, _ = self.q_network_1(state_emb, goal_emb, seq_len, hidden_states, cell_states)
            q_values_2, _, _ = self.q_network_1(state_emb, goal_emb, seq_len, hidden_states, cell_states)
        
        return q_values_1, q_values_2
     
    def get_policy(self, state, goal, hidden_states=None, cell_states=None):
        seq_len = states.shape[1]

        goal_emb = self.goal_encoder(goal)
        state_emb = self.visual_encoder(state)

        action, _, _ = self.actor_network(state_emb, goal_emb, seq_len, hidden_states, cell_states)

        return action
    
    def get_seq_emb(self, states_seq, actions_seq, attn_mask):
        bs = states_seq.shape[0]
        seq_len = states_seq.shape[1]
        
        actions_seq = torch.argmax(actions_seq, dim=2, keepdim=True)
        actions_seq = actions_seq[..., None].expand(*(actions_seq.shape[:2]), 1, self.hidden_size)
        
        states_seq_ = states_seq.view(*(states_seq.shape[:2]), self.action_size, self.hidden_size)
        states_actions_seq = states_seq_.gather(dim=2, index=actions_seq).squeeze(2)
        
        packed_input = pack_padded_sequence(states_actions_seq, attn_mask.sum(1).to('cpu'), batch_first=True, enforce_sorted=False)
        output, _ = self.lstm(packed_input)
        
        output, _ = pad_packed_sequence(output, batch_first=True)
        
        return output
    
    def forward(self, states, goals, hidden_states=None, cell_states=None, return_hidden=False):
        seq_len = states.shape[1]

        goal_emb = self.goal_encoder(goals)

        states = states.unsqueeze(2).repeat(1, 1, self.config.history_frame, *[1] * (states.dim() - 2))
        state_emb = self.visual_encoder(states)

        q_values_1, _, _ = self.q_network_1(state_emb, goal_emb, seq_len, hidden_states, cell_states)
        q_values_2, _, _ = self.q_network_2(state_emb, goal_emb, seq_len, hidden_states, cell_states)

        value, _, _ = self.value_network(state_emb, goal_emb, seq_len, hidden_states, cell_states)
        action, hidden_states, cell_states = self.actor_network(state_emb, goal_emb, seq_len, hidden_states, cell_states)
        
        if return_hidden:
            return (q_values_1, q_values_2), value, action, hidden_states, cell_states
        else:
            return (q_values_1, q_values_2), value, action
  