"""
GMVAE Encoder implementation following Dilokthanakul et al 2017
- qφ(w | h) - style (Gaussian)
- qφ(z | h) - skill (Gaussian)  
- qφ(y | w, z) - analytic ∝ pθ(z|w,y) p(y)
"""

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from utils import helpers as utl
from utils.helpers import get_device


class RNNEncoder_dme(nn.Module):
    def __init__(self,
                 args,
                 # network size
                 layers_before_gru=(),
                 hidden_size=64,
                 layers_after_gru=(),
                 class_dim=10,  # K - number of mixture components
                 latent_dim=32,  # D - latent dimension for both w and z
                 # actions, states, rewards
                 action_dim=2,
                 action_embed_dim=10,
                 state_dim=2,
                 state_embed_dim=10,
                 reward_size=1,
                 reward_embed_size=5,
                 ):
        super(RNNEncoder_dme, self).__init__()

        self.args = args
        self.class_dim = class_dim  # K
        self.latent_dim = latent_dim  # D
        self.hidden_size = hidden_size
        self.reparameterise = self._sample_gaussian

        # embed action, state, reward
        self.state_encoder = utl.FeatureExtractor(state_dim, state_embed_dim, F.relu)
        self.action_encoder = utl.FeatureExtractor(action_dim, action_embed_dim, F.relu)
        self.reward_encoder = utl.FeatureExtractor(reward_size, reward_embed_size, F.relu)

        # fully connected layers before the recurrent cell
        curr_input_dim = action_embed_dim + state_embed_dim + reward_embed_size
        self.fc_before_gru = nn.ModuleList([])
        for i in range(len(layers_before_gru)):
            self.fc_before_gru.append(nn.Linear(curr_input_dim, layers_before_gru[i]))
            curr_input_dim = layers_before_gru[i]

        # recurrent unit
        self.gru = nn.GRU(input_size=curr_input_dim,
                          hidden_size=hidden_size,
                          num_layers=1,
                          )

        for name, param in self.gru.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.orthogonal_(param)

        # fully connected layers after the recurrent cell
        curr_input_dim = hidden_size
        self.fc_after_gru = nn.ModuleList([])
        for i in range(len(layers_after_gru)):
            self.fc_after_gru.append(nn.Linear(curr_input_dim, layers_after_gru[i]))
            curr_input_dim = layers_after_gru[i]

        # === GMVAE HEADS ===
        # qφ(w | h) - style head
        self.w_head = nn.Sequential(
            nn.Linear(curr_input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 2 * latent_dim)  # mu_w and logvar_w
        )

        # qφ(z | h) - skill head  
        self.z_head = nn.Sequential(
            nn.Linear(curr_input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 2 * latent_dim)  # mu_z and logvar_z
        )

        # Prior network: w → (mu_y, logvar_y) for each mixture component
        self.prior_net_w2yz = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 2 * class_dim * latent_dim)  # K * D for mu_y, K * D for logvar_y
        )

    def _sample_gaussian(self, mu, logvar, num=None):
        if num is None:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            raise NotImplementedError

    def reset_hidden(self, hidden_state, done):
        """ Reset the hidden state where the BAMDP was done (i.e., we get a new task) """
        if hidden_state.dim() != done.dim():
            if done.dim() == 2:
                done = done.unsqueeze(0)
            elif done.dim() == 1:
                done = done.unsqueeze(0).unsqueeze(2)
        hidden_state = hidden_state * (1 - done)
        return hidden_state

    def compute_analytic_y(self, w, z):
        """
        Compute analytic y probabilities following Dilokthanakul et al 2017:
        qφ(y | w, z) ∝ pθ(z|w,y) p(y)
        """
        batch_size = w.shape[0]
        
        # Get prior parameters for each mixture component
        prior_params = self.prior_net_w2yz(w)  # (B, 2*K*D)
        prior_params = prior_params.view(batch_size, 2, self.class_dim, self.latent_dim)  # (B, 2, K, D)
        
        mu_y = prior_params[:, 0]  # (B, K, D)
        logvar_y = prior_params[:, 1]  # (B, K, D)
        
        # Compute log-likelihood of z under each component
        # ll_k = -0.5 * (((z - mu_y)**2)/exp(logvar_y) + logvar_y).sum(-1)
        z_expanded = z.unsqueeze(1).expand(-1, self.class_dim, -1)  # (B, K, D)
        
        # Compute log-likelihood for each component
        diff_sq = (z_expanded - mu_y).pow(2)  # (B, K, D)
        var_y = torch.exp(logvar_y)  # (B, K, D)
        
        ll_k = -0.5 * ((diff_sq / var_y) + logvar_y + np.log(2 * np.pi)).sum(-1)  # (B, K)
        
        # Add uniform prior log p(y) = log(1/K)
        ll_k = ll_k + np.log(1.0 / self.class_dim)
        
        # Compute log probabilities and soft assignment
        log_qy = ll_k - torch.logsumexp(ll_k, dim=1, keepdim=True)  # (B, K)
        y_probs = torch.exp(log_qy)  # soft assignment (B, K)
        
        # Hard assignment for discrete y
        y_hard = F.one_hot(y_probs.argmax(dim=1), self.class_dim).float()  # (B, K)
        
        # Return both for compatibility
        return y_probs, y_hard, log_qy, mu_y, logvar_y

    def prior_mixture(self, batch_size, sample=True):
        """Create prior samples for the start of trajectories"""
        # Start with zero hidden state
        hidden_state = torch.zeros((1, batch_size, self.hidden_size), 
                                 requires_grad=True).to(get_device())

        h = hidden_state
        # Forward through fully connected layers after GRU
        for i in range(len(self.fc_after_gru)):
            h = F.relu(self.fc_after_gru[i](h))

        h_reshaped = h.view(-1, h.size(2))  # (B, hidden_dim)

        # Sample w and z from standard normal priors
        w_mu = torch.zeros(batch_size, self.latent_dim).to(get_device())
        w_logvar = torch.zeros(batch_size, self.latent_dim).to(get_device())
        w = self._sample_gaussian(w_mu, w_logvar)

        z_mu = torch.zeros(batch_size, self.latent_dim).to(get_device())
        z_logvar = torch.zeros(batch_size, self.latent_dim).to(get_device())
        z = self._sample_gaussian(z_mu, z_logvar)

        # Compute analytic y
        y_probs, y_hard, log_qy, mu_y_prior, logvar_y_prior = self.compute_analytic_y(w, z)

        # Return only what's actually used by evaluation code:
        # prior_y, prior_z, prior_mu, prior_var, prior_logits, prior_prob, prior_hidden_state, w
        return y_probs, z, z_mu, torch.exp(z_logvar), log_qy, y_probs, hidden_state, w

    def forward(self, actions, states, rewards, hidden_state, return_prior, 
                sample=True, detach_every=None, y_intercept=None, return_w_params=False):
        """
        Forward pass of the GMVAE encoder.
        
        Args:
            return_w_params: If True, returns w_mu, w_logvar as additional outputs (for VAE internal use)
        
        Returns:
            If return_w_params=False (default, for metalearners): 
                latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w (7 outputs)
            If return_w_params=True (for VAE internal use):
                latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w, w_mu, w_logvar (9 outputs)
        """
        # Normalize actions
        actions = utl.squash_action(actions, self.args)

        # Reshape inputs
        actions = actions.reshape((-1, *actions.shape[-2:]))
        states = states.reshape((-1, *states.shape[-2:]))
        rewards = rewards.reshape((-1, *rewards.shape[-2:]))
        if hidden_state is not None:
            hidden_state = hidden_state.reshape((-1, *hidden_state.shape[-2:]))

        # Handle prior if requested
        if return_prior:
            prior_y, prior_z, prior_mu, prior_var, prior_logits, prior_prob, prior_hidden_state, prior_w = self.prior_mixture(actions.shape[1])
            hidden_state = prior_hidden_state.clone()

        # Extract features for states, actions, rewards
        ha = self.action_encoder(actions)
        hs = self.state_encoder(states)
        hr = self.reward_encoder(rewards)
        h = torch.cat((ha, hs, hr), dim=2)

        # Forward through layers before GRU
        for i in range(len(self.fc_before_gru)):
            h = F.relu(self.fc_before_gru[i](h))

        # GRU forward pass
        if detach_every is None:
            output, _ = self.gru(h, hidden_state)
        else:
            output = []
            for i in range(int(np.ceil(h.shape[0] / detach_every))):
                curr_input = h[i*detach_every:i*detach_every+detach_every]
                curr_output, hidden_state = self.gru(curr_input, hidden_state)
                output.append(curr_output)
                hidden_state = hidden_state.detach()
            output = torch.cat(output, dim=0)

        gru_h = output.clone()

        # Forward through layers after GRU
        for i in range(len(self.fc_after_gru)):
            gru_h = F.relu(self.fc_after_gru[i](gru_h))

        # === GMVAE computation ===
        seq_len, batch_size, _ = gru_h.shape
        gru_h_flat = gru_h.view(-1, gru_h.size(2))  # (seq_len * batch_size, hidden_dim)

        # Compute w and z
        w_params = self.w_head(gru_h_flat)  # (seq_len * batch_size, 2*latent_dim)
        w_mu, w_logvar = w_params.chunk(2, dim=-1)
        w = self._sample_gaussian(w_mu, w_logvar) if sample else w_mu

        z_params = self.z_head(gru_h_flat)  # (seq_len * batch_size, 2*latent_dim)
        z_mu, z_logvar = z_params.chunk(2, dim=-1)
        z = self._sample_gaussian(z_mu, z_logvar) if sample else z_mu

        # Compute analytic y
        y_probs, y_hard, log_qy, mu_y, logvar_y = self.compute_analytic_y(w, z)

        # Reshape back to sequence format
        def reshape_back(tensor, target_shape):
            return tensor.view(seq_len, batch_size, -1)

        w = reshape_back(w, (seq_len, batch_size, self.latent_dim))
        w_mu = reshape_back(w_mu, (seq_len, batch_size, self.latent_dim))
        w_logvar = reshape_back(w_logvar, (seq_len, batch_size, self.latent_dim))
        z = reshape_back(z, (seq_len, batch_size, self.latent_dim))
        z_mu = reshape_back(z_mu, (seq_len, batch_size, self.latent_dim))
        z_logvar = reshape_back(z_logvar, (seq_len, batch_size, self.latent_dim))
        y_probs = reshape_back(y_probs, (seq_len, batch_size, self.class_dim))
        log_qy = reshape_back(log_qy, (seq_len, batch_size, self.class_dim))

        # Add prior if requested
        if return_prior:
            output = torch.cat((prior_hidden_state, output))
            
            # For prior, we need w_mu, w_logvar = 0 (standard normal prior)
            prior_w_mu = torch.zeros_like(prior_w)
            prior_w_logvar = torch.zeros_like(prior_w)
            
            w = torch.cat((prior_w.unsqueeze(0), w))
            w_mu = torch.cat((prior_w_mu.unsqueeze(0), w_mu))
            w_logvar = torch.cat((prior_w_logvar.unsqueeze(0), w_logvar))
            z = torch.cat((prior_z.unsqueeze(0), z))
            z_mu = torch.cat((prior_mu.unsqueeze(0), z_mu))
            z_logvar = torch.cat((torch.log(prior_var + 1e-20).unsqueeze(0), z_logvar))
            y_probs = torch.cat((prior_y.unsqueeze(0), y_probs))
            log_qy = torch.cat((prior_logits.unsqueeze(0), log_qy))

        # Handle single timestep case
        if z_mu.shape[0] == 1:
            w = w[0]
            w_mu, w_logvar = w_mu[0], w_logvar[0]
            z, z_mu, z_logvar = z[0], z_mu[0], z_logvar[0]
            y_probs, log_qy = y_probs[0], log_qy[0]

        # Return based on what's requested
        if return_w_params:
            # For VAE internal use: return w_mu, w_logvar as additional outputs
            return z, z_mu, z_logvar, output, y_probs, y_probs, w, w_mu, w_logvar
        else:
            # For metalearners: return only the standard 7 outputs
            return z, z_mu, z_logvar, output, y_probs, y_probs, w

    def sample_virtual_task(self, batch_size):
        """
        Sample virtual task variables for imagination following the generative model:
        p(w) = N(0,I), p(z|w,y) = Σ_k y_k N(μ_k(w), σ_k²(w)), qφ(y|w,z) ∝ pθ(z|w,y) p(y)
        
        Location: models/encoder_dme.py
        Function name: sample_virtual_task
        Return shapes:
        - w: (batch_size, latent_dim) - style sample from N(0,I)  
        - y_soft: (batch_size, vae_mixture_num) - analytic mixture probabilities
        - z: (batch_size, latent_dim) - skill sample from p(z|w,y)
        - mu_z: (batch_size, latent_dim) - mean of p(z|w,y)
        - logvar_z: (batch_size, latent_dim) - logvar of p(z|w,y)
        """
        device = get_device()
        
        # 1. Sample w ~ N(0,I)
        w = torch.randn(batch_size, self.latent_dim).to(device)
        
        # 2. Compute mu_y, logvar_y for each mixture component using prior network
        prior_params = self.prior_net_w2yz(w)  # (B, 2*K*D)
        prior_params = prior_params.view(batch_size, 2, self.class_dim, self.latent_dim)  # (B, 2, K, D)
        mu_y = prior_params[:, 0]  # (B, K, D)
        logvar_y = prior_params[:, 1]  # (B, K, D)
        
        # 3. Sample initial z from mixture to compute analytic y
        # Use uniform mixture for initial sampling
        component_probs = torch.ones(batch_size, self.class_dim).to(device) / self.class_dim
        component_idx = torch.multinomial(component_probs, 1).squeeze(-1)  # (B,)
        
        # Sample z from selected components
        selected_mu = mu_y[torch.arange(batch_size), component_idx]  # (B, D)
        selected_logvar = logvar_y[torch.arange(batch_size), component_idx]  # (B, D)
        z_init = self._sample_gaussian(selected_mu, selected_logvar)
        
        # 4. Compute analytic y using this z
        y_soft, _, _, _, _ = self.compute_analytic_y(w, z_init)
        
        # 5. Resample z using the computed analytic y weights
        # Weighted mean and variance
        mu_z = (y_soft.unsqueeze(-1) * mu_y).sum(1)  # (B, D)
        var_z = (y_soft.unsqueeze(-1) * torch.exp(logvar_y)).sum(1)  # (B, D)
        logvar_z = torch.log(var_z + 1e-8)
        
        # Final z sample from weighted mixture
        z = self._sample_gaussian(mu_z, logvar_z)
        
        return w, y_soft, z, mu_z, logvar_z 