import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import LazyBatchNorm1d

from networks.network_seq import Transformer as TransformerSeq

import gym

from ALFRED.networks.networks_base_alfred import FiLM, ImageConv, StateActionEncoder, LayerNorm, GoalEncoder, GoalEncoderLSTM

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

class DecisionMaker(nn.Module):
    def __init__(self, feature_size=512, action_size=15) -> None:
        super().__init__()
        self.feature_size = feature_size
        self.action_size = action_size
        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()
        
        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()
        
        self.ff3 = nn.Linear(feature_size, 1)
        
        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)
        
    def forward(self, state_action_emb, goal, h_c=None):        
        state_action_emb = state_action_emb.view(*(state_action_emb.shape[:2]), self.action_size, self.feature_size)
        goal = goal.unsqueeze(2).expand(*(goal.shape[:2]), self.action_size, goal.shape[-1])
        
        x = self.film1(state_action_emb, goal) + state_action_emb
        x = self.film2(x, goal) + x
            
        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)
                
        x = x.squeeze(-1)
        return x, h_c

class DecisionMakerLSTM(nn.Module):
    def __init__(self, feature_size=512, action_size=15) -> None:
        super().__init__()
        self.feature_size = feature_size
        self.action_size = action_size
        self.ff1 = nn.Linear(2 * feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()
        
        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()
        
        self.ff3 = nn.Linear(feature_size, 1)
        
        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)
        
    def forward(self, state_action_emb, goal, ht):
        state_action_emb = state_action_emb.view(*(state_action_emb.shape[:2]), self.action_size, self.feature_size)
        goal = goal.unsqueeze(2).expand(*(goal.shape[:2]), self.action_size, goal.shape[-1])
        
        x = self.film1(state_action_emb, goal) + state_action_emb
        x = self.film2(x, goal) + x
            
        x = self.ff1(torch.cat([ht.unsqueeze(2).expand(*(ht.shape[:2]), self.action_size, ht.shape[-1]), x], dim=-1))
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)
                
        x = x.squeeze(-1)
        return x

class DecisionMakerC51(nn.Module):
    def __init__(self, feature_size=512, action_size=15) -> None:
        super().__init__()
        self.n_atoms = 101
        self.feature_size = feature_size
        self.action_size = action_size
        
        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()
        
        self.ff2 = nn.Linear(feature_size, int(feature_size/2))
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()
        
        self.ff3 = nn.Linear(int(feature_size/2), self.n_atoms)
        self.norm3 = nn.LayerNorm(feature_size)
        self.relu3 = nn.ReLU()
        
        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)
        
    def forward(self, state_action_emb, goal, h_c=None):
        state_action_emb = state_action_emb.view(*(state_action_emb.shape[:2]), self.action_size, self.feature_size)
        goal = goal.unsqueeze(2).expand(*(goal.shape[:2]), self.action_size, goal.shape[-1])
        
        x = self.film1(state_action_emb, goal) + state_action_emb
        x = self.film2(x, goal) + x
        
        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)
        
        out = F.softmax(x.view(*(state_action_emb.shape[:2]), self.action_size, self.n_atoms), dim=-1)
        return out, h_c

class DecisionMakerC51LSTM(nn.Module):
    def __init__(self, feature_size=512, action_size=15) -> None:
        super().__init__()
        self.n_atoms = 51
        self.feature_size = feature_size
        self.action_size = action_size
        
        self.ff1 = nn.Linear(2 * feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()
        
        self.ff2 = nn.Linear(feature_size, int(feature_size/2))
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()
        
        self.ff3 = nn.Linear(int(feature_size/2), self.n_atoms)
        self.norm3 = nn.LayerNorm(feature_size)
        self.relu3 = nn.ReLU()
        
        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)
        
    def forward(self, state_action_emb, goal, ht):
        state_action_emb = state_action_emb.view(*(state_action_emb.shape[:2]), self.action_size, self.feature_size)
        goal = goal.unsqueeze(2).expand(*(goal.shape[:2]), self.action_size, goal.shape[-1])
        
        # batch_size x sequence_len x hidden_state
        x = self.film1(state_action_emb, goal) + state_action_emb
        x = self.film2(x, goal) + x
        
        x = self.ff1(torch.cat([ht.unsqueeze(2).expand(*(ht.shape[:2]), self.action_size, ht.shape[-1]), x], dim=-1))
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)
        
        out = F.softmax(x.view(*(state_action_emb.shape[:2]), self.action_size, self.n_atoms), dim=-1)
        return out


class Network(nn.Module):
    def __init__(self, atoms, action_size, config, hidden_size=64):
        super().__init__()
        self.config = config
        self.hidden_size = hidden_size
        self.atoms = atoms
        self.action_size = action_size
        self.device = torch.device(config.device)
        
        self.goal_encoder = GoalEncoder(hidden_size)
        self.visual_encoder = ImageConv(history_frame=config.history_frame, features_dim=hidden_size)
        
        self.align_net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
        )
        
        if 'C51' in config.model_type and config.LSTM:
            Q_Net = DecisionMakerC51LSTM
        elif 'C51' not in config.model_type and config.LSTM:
            Q_Net = DecisionMakerLSTM
        elif 'C51' in config.model_type and not config.LSTM:
            Q_Net = DecisionMakerC51
        else:
            Q_Net = DecisionMaker
        
        self.q_net_1 = Q_Net(hidden_size)
        self.q_net_2 = Q_Net(hidden_size)
        
        self.q_target_net_1 = Q_Net(hidden_size)
        self.q_target_net_2 = Q_Net(hidden_size)
        
        self.q_target_net_1.load_state_dict(self.q_net_1.state_dict())
        self.q_target_net_2.load_state_dict(self.q_net_2.state_dict())
        
        self.q_target_net_1.eval()
        self.q_target_net_2.eval()
            
        embed_dim = hidden_size
        transformer_width = hidden_size
        transformer_heads = transformer_width // 64
        transformer_layers = 12
        
        self.transformerClip = TransformerSeq(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
        )
        
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=2, batch_first=True)
        
        scale = transformer_width ** -0.5
        self.proj = nn.Parameter(scale * torch.randn(transformer_width, embed_dim))
        self.frame_position_embeddings = nn.Embedding(300, hidden_size)
        self.logit_scale = nn.Parameter(torch.ones([]))
        self.ln_post = LayerNorm(transformer_width)
        
        self.init_out = nn.Parameter(torch.randn(1, 1, embed_dim))
    
    def update(self):
        self.q_target_net_1.load_state_dict(self.q_net_1.state_dict())
        self.q_target_net_2.load_state_dict(self.q_net_2.state_dict())
    
    def get_policy(self, states, goals):
        lang_emb = self.goal_encoder(goals)
        vis_emb = self.visual_encoder(states)
        action = self.actor_network(vis_emb, lang_emb)
        return action
    
    def get_init_out(self):
        return self.init_out
    
    def get_next_ht(self, states_seq, action, ht):
        states_seq_ = states_seq.view(*(states_seq.shape[:2]), self.action_size, self.hidden_size)
        
        action = torch.tensor(action, device=self.device).unsqueeze(0).unsqueeze(0)
        action = action[..., None].expand(*(action.shape[:2]), 1, self.hidden_size)
        
        states_actions_seq = states_seq_.gather(dim=2, index=action).squeeze(2)
        
        if not hasattr(self, '_flattened'):
            self.lstm.flatten_parameters()
        setattr(self, '_flattened', True)

        out, ht = self.lstm(states_actions_seq, ht)
        return out, ht
    
    def get_q_values(self, states_seq, goals_emb, ht, target=False):
        # batch_size x seq_len x frame_size x pic_dims (3x7x7)
        
        if target:
            q_values_1 = self.q_target_net_1(states_seq, goals_emb, ht)
            q_values_2 = self.q_target_net_2(states_seq, goals_emb, ht)
        else:
            q_values_1 = self.q_net_1(states_seq, goals_emb, ht)
            q_values_2 = self.q_net_2(states_seq, goals_emb, ht)
        
        if 'C51' in self.config.model_type:
            probs_1, probs_2 = q_values_1, q_values_2
            q_values_1 = (self.atoms * probs_1).sum(dim=-1)
            q_values_2 = (self.atoms * probs_2).sum(dim=-1)
            
            return q_values_1, q_values_2, probs_1, probs_2
        else:  
            return q_values_1, q_values_2
    
    def get_state_values(self, states, goals):
        goal_emb = self.goal_encoder(goals)
        states_emb = self.visual_encoder(states)
        values = self.value_network(states_emb, goal_emb)
        
        return values
    
    def get_seq_emb(self, states_seq, actions_seq, attn_mask):
        actions_seq = torch.argmax(actions_seq, dim=2, keepdim=True)
        actions_seq = actions_seq[..., None].expand(*(actions_seq.shape[:2]), 1, self.hidden_size)
        
        states_seq_ = states_seq.view(*(states_seq.shape[:2]), self.action_size, self.hidden_size)
        states_actions_seq = states_seq_.gather(dim=2, index=actions_seq).squeeze(2)
        
        packed_input = pack_padded_sequence(states_actions_seq, attn_mask.sum(1).to('cpu'), batch_first=True, enforce_sorted=False)
        output, _ = self.lstm(packed_input)
        
        output, _ = pad_packed_sequence(output, batch_first=True)
        
        return output
    
    def get_Q_value(self, states, action, goals):
        state_action_emb, action_emb, action_para = self.state_action_encoder(states, action)
        
        goals = goals.repeat(1, self.action_size, 1).view(goals.shape[0] * self.action_size, 1024)
        
        goals = self.goal_encoder(goals)
        
        if not self.config.train_state_q:
            state_action_emb = state_action_emb.detach()
        if not self.config.train_goal_q:
            goals = goals.detach()
        
        if 'C51' in self.config.model_type:
            probs = self.decision_maker(state_action_emb, goals, action_emb, action_para)
            q_values = (self.atoms * probs).sum(dim=-1)
            return q_values, probs
        else:
            q_values = self.decision_maker(state_action_emb, goals, action_emb, action_para)
            return q_values, None
    
    def forward(self, states, goals):
        if len(goals.shape) == 2:
            goals = goals.unsqueeze(1)
        if len(goals.shape) == 1:
            goals = goals.unsqueeze(0)
            
        bs = states.shape[0]
        k = self.action_size
        states = states.unsqueeze(1).repeat(1, k, 1, 1, 1, 1).view(bs * k, *states.shape[1:])
        
        actions = torch.tensor(range(k)).to(self.device)
        duplicated_actions = actions.unsqueeze(0).repeat(bs, 1).view(bs * k, 1)
        
        q_values, probs = self.get_Q_value(states, duplicated_actions, goals)
        if 'C51' not in self.config.model_type:
            q_values = q_values.view(bs, k)
        
        return q_values, probs

class ActorNetwork(nn.Module):
    def __init__(self, feature_size=512, action_size=7) -> None:
        super().__init__()
        self.action_size = action_size
        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()
        
        self.ff2 = nn.Linear(feature_size, feature_size)
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()
        
        self.ff3 = nn.Linear(feature_size, 7)
        
        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)
        
    def forward(self, state_action_emb, goal):
        x = self.film1(state_action_emb, goal) + state_action_emb
        x = self.film2(x, goal) + x
        
        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)
        
        return x
    
class ValueNetwork(nn.Module):
    def __init__(self, feature_size) -> None:
        super().__init__()
        self.ff1 = nn.Linear(feature_size, feature_size)
        self.norm1 = nn.LayerNorm(feature_size)
        self.relu1 = nn.ReLU()
        
        self.ff2 = nn.Linear(feature_size, int(feature_size/2))
        self.norm2 = nn.LayerNorm(feature_size)
        self.relu2 = nn.ReLU()
        
        self.ff3 = nn.Linear(int(feature_size/2), 1)
        
        self.film1 = FiLM(feature_size, feature_size)
        self.film2 = FiLM(feature_size, feature_size)

    def forward(self, state_emb, goal):
        x = self.film1(state_emb, goal) + state_emb
        x = self.film2(x, goal) + x
        
        x = self.ff1(x)
        x = self.relu1(x)
        x = self.ff2(x)
        x = self.relu2(x)
        x = self.ff3(x)
        
        return x