import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

import numpy as np
import torch as th

class MLPNaviEncoder(nn.Module):
    """
        dynamic-dim inputs -->> fixed-dim outputs
    """
    def __init__(self, decomposer, args):
        super(MLPNaviEncoder, self).__init__()
        self.args = args
        self.task_repre_dim = args.task_repre_dim
        self.state_latent_dim = self.args.state_latent_dim
        self.action_latent_dim = self.args.action_latent_dim
        self.input_latent_dim = self.args.input_latent_dim
        self.embed_dim = self.args.embed_dim
        # state dimension is just 4
        self.state_nf = 4   # NOTE: we hand-code it
        self.n_agents = 2
        self.action_nf = 5  # NOTE: hand-code
        self.reward_dim = 1 # 
        hypernet_embed = self.args.explainer_hypernet_embed
        # define hypernetwork  
        self.state_embed = nn.Linear(self.state_nf, self.state_latent_dim)
        self.action_embed = nn.Linear(self.action_nf * self.n_agents, self.action_latent_dim)
        self.input_embed = nn.Linear(self.state_latent_dim + self.action_latent_dim, self.input_latent_dim)
        self.hyper_w_1 = nn.Sequential(
            nn.Linear(self.task_repre_dim, hypernet_embed),
            nn.ReLU(),
            nn.Linear(hypernet_embed, self.embed_dim * self.input_latent_dim),
        )
        self.hyper_b_1 = nn.Linear(self.task_repre_dim, self.embed_dim)
        self.hyper_w_final = nn.Sequential(
            nn.Linear(self.task_repre_dim, hypernet_embed),
            nn.ReLU(),
            nn.Linear(hypernet_embed, self.embed_dim * self.embed_dim),
        )
        # define reward prediction network
        self.reward_decoder = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.ReLU(),
            nn.Linear(self.embed_dim, self.reward_dim),
        )

    def forward(self, obs, state, actions, task_repre):
        assert state.shape[-1] == 6 # we ensure two agents and one landmark exist in the environment
        task_repre = task_repre[0:1]  # we only need one task_repre
        
        # check actions input
        if len(actions.shape) != 2:
            assert len(actions.shape) == 3, f"Invalid actions shape {actions.shape}"
            bs = actions.shape[0]
            actions = actions.reshape(bs, self.n_agents*actions.shape[2])
        else:
            bs = actions.shape[0]//self.n_agents
            actions = actions.reshape(bs, self.n_agents*actions.shape[-1])    

        ## !!! suppose state.shape = [bs, state_dim]
        state_latent = self.state_embed(state[:, :self.state_nf])  # [bs, state_latent_dim]
        action_latent = self.action_embed(actions)  # [bs, action_latent_dim]
        input_latent = th.cat([state_latent, action_latent], dim=-1)
        input_latent = self.input_embed(input_latent)
        
        ## First layer
        w1 = self.hyper_w_1(task_repre)
        b1 = self.hyper_b_1(task_repre)
        w1 = w1.view(-1, self.input_latent_dim, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        hidden = F.relu(th.matmul(input_latent.unsqueeze(1), w1) + b1)
        # Second layer
        w_final = th.abs(self.hyper_w_final(task_repre))
        w_final = w_final.view(-1, self.embed_dim, self.embed_dim)
        # Compute final output
        y = th.matmul(hidden, w_final).squeeze(1) # shape: [bs, embed_dim]
        
        ## Through reward decoder network
        pred_reward = self.reward_decoder(y) 
    
        return pred_reward, bs