# Re-export existing decoders with new names for GMVAE compatibility
from models.decoder_mixture_ext import (
    StateTransitionDecoder_mixture_ext as StateTransitionDecoder_dme,
    RewardDecoder_mixture_ext as RewardDecoder_dme,
    TaskDecoder_mixture_ext as TaskDecoder_dme
)

# GMVAE decoders - simplified to only condition on z
import torch
import torch.nn as nn
from torch.nn import functional as F
from utils import helpers as utl
from utils.helpers import get_device


class StateTransitionDecoder_dme(nn.Module):
    def __init__(self,
                 args,
                 layers,
                 class_dim,
                 latent_dim,
                 action_dim,
                 action_embed_dim,
                 state_dim,
                 state_embed_dim,
                 pred_type='deterministic',
                 dropout_rate=0.0
                 ):
        super(StateTransitionDecoder_dme, self).__init__()

        self.args = args

        self.state_encoder = utl.FeatureExtractor(state_dim, state_embed_dim, F.relu)
        self.action_encoder = utl.FeatureExtractor(action_dim, action_embed_dim, F.relu)

        self.dropout = nn.Dropout(p=dropout_rate)
        self.drop_input = dropout_rate > 0.0

        curr_input_dim = latent_dim + state_embed_dim + action_embed_dim
        self.fc_layers = nn.ModuleList([])
        for i in range(len(layers)):
            self.fc_layers.append(nn.Linear(curr_input_dim, layers[i]))
            curr_input_dim = layers[i]

        # output layer
        if pred_type == 'gaussian':
            self.fc_out = nn.Linear(curr_input_dim, 2 * state_dim)
        else:
            self.fc_out = nn.Linear(curr_input_dim, state_dim)

    def forward(self, latent_state, state, actions):
        # we do the action-normalisation (the the env bounds) here
        actions = utl.squash_action(actions, self.args)

        ha = self.action_encoder(actions)
        hs = self.state_encoder(state)
        if self.drop_input:
            ha = self.dropout(ha)
            hs = self.dropout(hs)
        h = torch.cat((latent_state, hs, ha), dim=-1)

        for i in range(len(self.fc_layers)):
            h = F.relu(self.fc_layers[i](h))

        return self.fc_out(h)


class RewardDecoder_dme(nn.Module):
    """
    Simplified GMVAE reward decoder that only conditions on z.
    Removes y argument and pzy() branch from the original RewardDecoder_mixture_ext.
    """
    def __init__(self,
                 args,
                 layers,
                 class_dim,
                 latent_dim,
                 action_dim,
                 action_embed_dim,
                 state_dim,
                 state_embed_dim,
                 num_states,
                 multi_head=False,
                 pred_type='deterministic',
                 input_prev_state=True,
                 input_action=True,
                 dropout_rate=0.0
                 ):
        super(RewardDecoder_dme, self).__init__()

        self.args = args
        self.pred_type = pred_type
        self.multi_head = multi_head
        self.input_prev_state = input_prev_state
        self.input_action = input_action
        self.dropout = nn.Dropout(p=dropout_rate)
        self.drop_input = dropout_rate > 0.0

        if self.multi_head:
            # one output head per state to predict rewards
            curr_input_dim = latent_dim
            self.fc_layers = nn.ModuleList([])
            for i in range(len(layers)):
                self.fc_layers.append(nn.Linear(curr_input_dim, layers[i]))
                curr_input_dim = layers[i]
            self.fc_out = nn.Linear(curr_input_dim, num_states)
        else:
            # get state as input and predict reward prob
            self.state_encoder = utl.FeatureExtractor(state_dim, state_embed_dim, F.relu)
            if self.input_action:
                self.action_encoder = utl.FeatureExtractor(action_dim, action_embed_dim, F.relu)
            else:
                self.action_encoder = None
            
            # Simplified: only z + state/action features (no dispersion structure)
            curr_input_dim = latent_dim + state_embed_dim
            if input_prev_state:
                curr_input_dim += state_embed_dim
            if input_action:
                curr_input_dim += action_embed_dim
                
            self.fc_layers = nn.ModuleList([])
            for i in range(len(layers)):
                self.fc_layers.append(nn.Linear(curr_input_dim, layers[i]))
                curr_input_dim = layers[i]

            if pred_type == 'gaussian':
                self.fc_out = nn.Linear(curr_input_dim, 2)
            else:
                self.fc_out = nn.Linear(curr_input_dim, 1)

    def forward(self, latent_state, next_state, prev_state=None, actions=None):
        """
        Simplified forward pass - no y argument, no pzy() branch.
        Only conditions on z (latent_state).
        """
        # we do the action-normalisation (the the env bounds) here
        if actions is not None:
            actions = utl.squash_action(actions, self.args)

        if self.multi_head:
            h = latent_state.clone()
        else:
            # Direct concatenation without dispersion structure
            hns = self.state_encoder(next_state)
            if self.drop_input:
                hns = self.dropout(hns)

            h = torch.cat((latent_state, hns), dim=-1)
            if self.input_action:
                ha = self.action_encoder(actions)
                if self.drop_input:
                    ha = self.dropout(ha)
                h = torch.cat((h, ha), dim=-1)
            if self.input_prev_state:
                hps = self.state_encoder(prev_state)
                if self.drop_input:
                    hps = self.dropout(hps)
                h = torch.cat((h, hps), dim=-1)

        for i in range(len(self.fc_layers)):
            h = F.relu(self.fc_layers[i](h))

        return self.fc_out(h)


class TaskDecoder_dme(nn.Module):
    def __init__(self,
                 layers,
                 class_dim,
                 latent_dim,
                 pred_type,
                 task_dim,
                 num_tasks,
                 ):
        super(TaskDecoder_dme, self).__init__()

        # "task_description" or "task id"
        self.pred_type = pred_type

        curr_input_dim = latent_dim
        self.fc_layers = nn.ModuleList([])
        for i in range(len(layers)):
            self.fc_layers.append(nn.Linear(curr_input_dim, layers[i]))
            curr_input_dim = layers[i]

        output_dim = task_dim if pred_type == 'task_description' else num_tasks
        self.fc_out = nn.Linear(curr_input_dim, output_dim)

    def forward(self, latent_state):
        h = latent_state

        for i in range(len(self.fc_layers)):
            h = F.relu(self.fc_layers[i](h))

        return self.fc_out(h) 