import torch
from torch import nn
from torch.nn.modules.module import Module

class Intrinsic_Reward(nn.Module):
    def __init__(self, args, state_size, action_size, reward_size=10):
        super(Intrinsic_Reward, self).__init__()

        self.sign_embedding = nn.Embedding(5, reward_size)
        self.args = args
        self.rnn_int_rew = args.rnn_int_rew
        if self.rnn_int_rew:
            self.rnn_encoder = nn.GRU(state_size + action_size + reward_size, 32, num_layers=1, batch_first=True)
            self.rew_encoder = nn.Linear(1, reward_size)
            self.int_rew_func = nn.Sequential(nn.ReLU(), 
                                    nn.Linear(32, 1), 
                                    nn.Tanh()
                                )
        else:
            if args.double_int_reward:
                self.pos_rew_func = nn.Sequential(
                                nn.Linear((state_size * 2 + action_size * 2 + reward_size), 32),
                                nn.ReLU(),
                                nn.Linear(32, 1),
                                nn.Tanh()
                                )
            else:
                self.pos_rew_func = nn.Sequential(
                                nn.Linear((state_size * 2 + action_size * 2 + reward_size), 32), 
                                nn.ReLU(),
                                nn.Linear(32, 1),
                                nn.Tanh()
                                )

            self.neg_rew_func = nn.Sequential(
                                nn.Linear((state_size * 2 + action_size * 2 + reward_size), 32), 
                                nn.ReLU(),
                                nn.Linear(32, 1),
                                nn.Tanh()
                                )
    
    def forward(self, state_emb, action_emb, next_state_emb, turn_sign, target_emb, user_emb=None):
        sign_emb = self.sign_embedding.weight[turn_sign]

        int_rew_input = torch.cat((state_emb.squeeze(), action_emb.squeeze(), next_state_emb.squeeze(), sign_emb, user_emb.squeeze()), 0)
        if self.args.double_int_reward:
            pos_rew_input = torch.cat((state_emb.squeeze(), action_emb.squeeze(), next_state_emb.squeeze(), target_emb.squeeze(), sign_emb), 0)
            int_rew = self.pos_rew_func(pos_rew_input)
        else:
            int_rew = self.pos_rew_func(int_rew_input)
        neg_int_rew = self.neg_rew_func(int_rew_input)
        return int_rew, neg_int_rew

class Reward_Ensemble(nn.Module):
    def __init__(self, args, state_size, action_size, reward_size=10, num_ensemble=3):
        super(Reward_Ensemble, self).__init__()

        self.sign_embedding = nn.Embedding(5, reward_size)
        self.num_ensemble = num_ensemble
        self.args = args
        self.rnn_int_rew = args.rnn_int_rew
        
        self.state_size = state_size
        self.action_size = action_size
        self.reward_size = reward_size

        self.construct_ensemble()

    def construct_ensemble(self):
        self.ensemble = nn.ModuleList()
        self.paramlst = []
        for i in range(self.num_ensemble):
            model = nn.Sequential(
                                nn.Linear((self.state_size * 2 + self.action_size * 2 + self.reward_size), 32),
                                nn.ReLU(),
                                nn.Linear(32, 1),
                                nn.Tanh()
                                )
            self.ensemble.append(model)
    
    def r_hat(self, x):
        r_hats = []
        for member in range(self.num_ensemble):
            r_hats.append(self.ensemble[member](x))
        return torch.mean(torch.stack(r_hats), dim=0)
            
    def forward(self, state_emb, action_emb, next_state_emb, turn_sign, target_emb, user_emb=None):
        sign_emb = self.sign_embedding.weight[turn_sign]

        int_rew_input = torch.cat((state_emb.squeeze(), action_emb.squeeze(), next_state_emb.squeeze(), sign_emb, user_emb.squeeze()), 0)
        if self.args.double_int_reward:
            pos_rew_input = torch.cat((state_emb.squeeze(), action_emb.squeeze(), next_state_emb.squeeze(), target_emb.squeeze(), sign_emb), 0)
            int_rew = self.pos_rew_func(pos_rew_input)
        else:
            int_rew = self.r_hat(int_rew_input)
        return int_rew, int_rew