#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 15 11:19:33 2024

@author: XXXX
"""

import torch
import torch.nn as nn
import script_gru

# Create a general gating network class
class Gate(nn.Module):
    def __init__(self, dim, modules, t=1, p_only=False, weight_init=1.0):
        super().__init__()
        self.w_gate = nn.Linear(dim, modules, bias=False)
        self.t = torch.tensor(t, dtype=torch.float)        
        self.softmax = nn.Softmax(dim=1)
        self.p_only = p_only
        init_linear(self.w_gate, weight_init)
    
    def get_logits(self, x):
        l = self.w_gate(x)
        return l - l.logsumexp(dim=-1, keepdim=True)
    
    def get_probabilities(self, l):
        return self.softmax(l / self.t)
    
    def forward(self, x):
        l = self.get_logits(x)
        p = self.get_probabilities(l)
        return p if self.p_only else (l, p)
    
# Add a categorical sampler that allows for gradients
class CatSample(Gate):
    def __init__(self, dim, modules, t=1, p_only=False, weight_init=1.0, hard=True):
        super().__init__(dim, modules, t, p_only, weight_init)
        self.hard = hard
    
    def get_probabilities(self, l):
        return torch.nn.functional.gumbel_softmax(l, tau=self.t, hard=self.hard)
      
# Model 0: mixture of experts where activation is latent state of HMM
class HMMMoE(nn.Module):
    def __init__(self, action_dim, action_hidden, action_layers, action_out, context_layers, n_modules,
                 use_gru=False, rank=4, true_modules=False, true_gating=False, flat_gating=False,    
                 n_tasks=30, n_contexts=3, n_steps=25, n_operations=6, 
                 weight_init=1, sigma_init=0.0):
        super().__init__()
        
        # Copy task parameters
        self.n_tasks = n_tasks
        self.n_contexts = n_contexts
        self.n_steps = n_steps
        self.n_operations = n_operations
        self.output_dim = action_out
                
        # Copy network parameters
        self.n_modules = n_modules
        self.true_modules = true_modules
        self.true_gating = true_gating
        self.flat_gating = flat_gating
        self.hard = False
       
        # Get device
        self.set_device(move_vars=False)
                       
        # Create independent recurrent networks
        self.action_rnn = [lambda x, y, i=i: x + torch.roll(y, i, dims=-1) for i in range(self.n_modules)] \
            if self.true_modules else nn.ModuleList(
                    [script_gru.LowRankRNNCell(action_dim, action_hidden, rank, nonlinearity='tanh')
                     if rank > 0 else script_gru.scriptVanillaCell(action_dim, action_hidden, nonlinearity='tanh', weight_init=weight_init)
                     for _ in range(n_modules)])
                
        # Output is read out from hidden state
        self.action_output = torch.nn.Identity() if self.true_modules \
            else nn.Linear(action_hidden, action_out, bias=False)            
        init_linear(self.action_output, weight_init)
               
        # Create output standard deviation (square root of variance)
        self.log_sigma = nn.Parameter(torch.tensor(sigma_init, device=self.device))
                
        # Create context RNN
        context_net = script_gru.scriptGRUCell if use_gru else script_gru.scriptVanillaCell
        self.context_rnn = context_net(
            action_dim + n_modules, action_hidden, 
            nonlinearity='tanh', weight_init=weight_init)
            
        # The output of the context RNN is a one-hot gating signal across modules
        self.context_output = CatSample(action_hidden, n_modules, t=1, p_only=True, weight_init=weight_init, hard=self.hard)
        
        # Set learnable initial state for context RNN
        self.context_h0 = nn.Parameter(torch.zeros(action_hidden, dtype=torch.float, device=self.device))
        
        # Set learnable initial state for action RNN
        self.action_h0 = torch.zeros(action_out, dtype=torch.float, device=self.device) if self.true_modules \
            else nn.Parameter(torch.zeros(action_hidden, dtype=torch.float, device=self.device))
        
        # Set learnable initial context output (i.e. initial module)
        self.context_o0 = CatSample(1, n_modules, t=1, p_only=True, hard=False)
        
    def set_device(self, device=None, move_vars=True):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') \
            if device is None else device     
        # Additionally: send non-parameter default matrices to device
        if move_vars:
            self.action_h0 = self.action_h0.to(self.device)
            self.context_h0 = self.context_h0.to(self.device)

    def t(self, tensor):
        # This is just a hacky way to avoid writing dtype=torch.float, device=self.device all the time
        return tensor.float().to(self.device)  
    
    def nll(self, data, context, target, N=500):        
        # Collect context and action signals throughout forward pass
        context_o, context_h, action_o, action_h = [], [], [], []
        activations, ancestors, weights, likelihoods = [], [], [], []
        
        # Expand data and target across particles
        data = data.unsqueeze(1).expand([-1, N, -1, -1]) # batch x ptc x time x data
        target = target.unsqueeze(1).expand([-1, N, -1, -1]) # batch x ptc x time x data
        
        # Initialise all states before first timestep
        c_h = self.context_h0.expand([data.shape[0], N, -1]) # batch x ptc x hidden
        a_h = self.action_h0.expand([data.shape[0], N, -1]) # batch x ptc x hidden
        c_o = self.context_o0(self.t(torch.ones([data.shape[0], N, 1]))) # batch x ptc x module
        
        # Run through timesteps
        for t in range(self.n_steps):         
            
            # For any nn.Module operation, I can only have one batch dimension
            # That means I'll need to flatten the first dimensions batch x ptc into batch * ptc
            d_t, c_o, c_h, a_h = [
                torch.flatten(d, 0, 1) for d in [data[:,:,t,:], c_o, c_h, a_h]]

            # 1. Sample updated state from transition
            
            if self.true_gating:
                # Use ground truth context as gating signal
                c_o = torch.flatten(context[:,t,:].unsqueeze(1).expand([-1,N,-1]), 0, 1)                          
            elif self.flat_gating:
                # Use simple transition matrix to update gating
                trans_mat = (torch.ones((self.n_modules,self.n_modules))/self.n_modules).to(self.device)
                c_o = torch.nn.functional.gumbel_softmax(torch.matmul(c_o, trans_mat), tau=1, hard=self.hard)                
            else:
                # Step the context RNN
                c_i = torch.concatenate([c_o, d_t], -1)
                c_h = self.context_rnn(c_i, c_h) # batch * ptc x hidden
                        
                # Sample new module activations as output from context RNN
                c_o = self.context_output(c_h) # batch * ptc x modules

            # 2. Calculate state emission            
            
            # Run one step of action RNN
            a_h_n = [rnn(d_t, a_h) for rnn in self.action_rnn] # modules x [batch * ptc x hidden]
            
            # Gate action RNN: run all, then multiply by active module per batch
            a_h = torch.matmul(c_o[:,None,:], torch.stack(a_h_n, dim=1)).squeeze(1) # batch * ptc x hidden   

            # Calculate action output given the hidden state of activated module
            a_o = self.action_output(a_h) # batch * ptc x data

            # Now I'm done with nn.Modules, so I can unflatten batch * ptc to batch x ptc
            d_t, c_o, c_h, a_h, a_o = [
                torch.unflatten(d, 0, [-1, N]) for d in [d_t, c_o, c_h, a_h, a_o]]
            
            # Log the gating output *before* resampling
            context_o.append(c_o)
            
            # 3. Update particle weights from likelihood

            # Get particle weights from target likelihood
            # This should actually be a multivariate gaussian, not product of gaussians
            w = self.multivariate_log_likelihood(a_o, target[:,:,t,:]) # batch x ptc

            # Calculate normalisation of particle weights
            W = w.logsumexp(dim=1, keepdim=True)
            
            # Normalise particle weights across particles
            w = w - W # batch x ptc
            
            # 4. Resample particles
            
            # Finally, resample particles from updated likelihood
            p = self.resample_systematic(w) # batch x ptc x ptc
            
            # Only resample particles for batches with feedback provided at this step
            p_keep = self.t(torch.eye(p.shape[-1])).expand([data.shape[0], -1, -1])
            do_feedback = torch.sum(context[:,t,:],-1)[:,None,None]
            p = p * do_feedback + p_keep * (1-do_feedback)
            
            # Update module and RNN hidden states according to sampled particles
            c_h = torch.matmul(p, c_h)
            a_h = torch.matmul(p, a_h)
            c_o = torch.matmul(p, c_o)
            
            # Append the current outputs and hidden states to lists
            context_h.append(c_h)
            action_o.append(a_o)
            action_h.append(a_h)
            activations.append(c_o)
            ancestors.append(torch.argmax(p, dim=-1))
            weights.append(w)
            likelihoods.append(W.squeeze() - torch.log(self.t(torch.tensor(N)))) # technically -log(N)

        # Concatenate the outputs and hidden states through time
        context_o, context_h, action_o, action_h, \
            activations, ancestors, weights, likelihoods = [
                torch.stack(y, axis=(-2 if y[0].ndim > 2 else -1)) for y in 
                [context_o, context_h, action_o, action_h,
                 activations, ancestors, weights, likelihoods]]

        # Return results
        return {'action_out': action_o, 'action_hidden': action_h, 
                'context_out': context_o, 'context_hidden': context_h,
                'activation': activations, 'ancestor': ancestors,
                'weight': weights, 'likelihood': likelihoods}
    
    def trace_back(self, output, N=1, sample=False):
        # Start from the most likely final particle, and trace back its ancestors
        
        # Get the output weights of the resampled particles *before* resampling
        w = select_along_dim(output['weight'][:,:,-1], output['ancestor'][:,:,-1], dim=1)
        
        # Then select based on that. First determine which particles you want to trace back.
        # Either sample from the final likelihoods, or deterministically take the top ones
        if sample:
            p_dist = torch.distributions.categorical.Categorical(w)
            p_id = p_dist.sample((N,)).t()
        else:
            p_id = torch.argsort(w, descending=True)[:,:N]
            
        # Then iterate back, while rebuilding the output matrix across the selected particles
        context_o, context_h, action_o, action_h = [], [], [], []
        activations, ancestors, weights, likelihoods = [], [], [], []
        
        # Iterate through reversed time
        for t in reversed(range(self.n_steps)):
            # The resampled variables are collected from *currently selected* particles
            c_o = select_along_dim(output['context_out'][:,:,t], p_id, dim=1) 
            c_h = select_along_dim(output['context_hidden'][:,:,t], p_id, dim=1) 
            a_h = select_along_dim(output['action_hidden'][:,:,t], p_id, dim=1)
            act = select_along_dim(output['activation'][:,:,t], p_id, dim=1) 
            # Then update the currently selected particles to their ancestors
            p_id = select_along_dim(output['ancestor'][:,:,t], p_id, dim=1) 
            # And get non-resampled variables from *ancestor* indices                
            a_o = select_along_dim(output['action_out'][:,:,t], p_id, dim=1)
            w = select_along_dim(output['weight'][:,:,t], p_id, dim=1)
            # W needs to be renormalised
            w = w + output['likelihood'][:,t].unsqueeze(1)
            W = w.logsumexp(dim=1, keepdim=True)
            w = w - W
            # Append the currently selected particle's properties
            ancestors.append(p_id)
            context_o.append(c_o)
            context_h.append(c_h)
            action_o.append(a_o)
            action_h.append(a_h)
            activations.append(act)
            weights.append(w)
            likelihoods.append(W.squeeze()) # technically -log(N)
        
        # Concatenate the outputs and hidden states through time, but with time reversed
        context_o, context_h, action_o, action_h, \
            activations, ancestors, weights, likelihoods = [
                torch.stack(y[::-1], axis=(-2 if y[0].ndim > 2 else -1)) for y in 
                [context_o, context_h, action_o, action_h,
                 activations, ancestors, weights, likelihoods]]

        # Return results
        return {'action_out': action_o, 'action_hidden': action_h, 
                'context_out': context_o, 'context_hidden': context_h,
                'activation': activations, 'ancestor': ancestors,
                'weight': weights, 'likelihood': likelihoods}
    
    def set_hard_sampling(self, hard=True):
        # This is useful when testing a model that was trained with soft sampling
        self.hard = hard
        self.context_output.hard = hard
    
    def multivariate_log_likelihood(self, data, targets):
        # Equation for log likelihood, given n datapoints of K dims, covariance S
        # -nK/2 log(2pi) - n/2 log(det(S)) -1/2 Sum_j^n (x_j - mu)^T S^-1 (x_j - mu)
        # Here assume S = eye * sigma**2, for K=data_dim and n=1
        # Then you get -K/2 log(2pi) - K*log(sigma) - 1/(2*sigma**2) * (x_j-mu)^T(x_y-mu)
        # 2D multivariate log likelihood over data and targets
        diff = data - targets # shape ... x K
        K = diff.shape[-1]
        # Clip sigma so likelihood can't be decreased by going essentially 0
        log_sigma = torch.clip(self.log_sigma, -5)
        # Calculate likelihood
        return -0.5*K*torch.log(2*self.t(torch.tensor(torch.pi))) - K*log_sigma \
            - 0.5/torch.exp(2*log_sigma) * (diff[...,None,:] @ diff[...,None]).squeeze()    
    
    def resample_systematic(self, weights):
        # Weights are normalised log probabilities, shape batch x ptc
        # Get number of particle from weight dimension
        N = weights.shape[-1]
        # Calculate cumulative weights in log space for stability
        log_cum_weights = torch.logcumsumexp(weights, -1)
        cum_weights = torch.exp(log_cum_weights)        
        # Systematic resampling starts with random offset between 0 and 1/N
        o = (torch.rand((weights.shape[0], 1), device=self.device).detach()) / N
        # Then set sampling points at regular intervals from offset
        s = o + torch.arange(N, device=self.device).expand([weights.shape[0], -1]) / N
        # Create bin edges from cumulative weights by adding a zero in front
        edges = torch.cat([torch.zeros_like(o), cum_weights], -1)
        # Find for each particle whether the sampling point is above their left edge
        above_left = -edges[:,None,:-1] + s[:,:,None]
        # And ofr each particle whether the sampling point is below their right edge
        below_right = edges[:,None,1:] - s[:,:,None]
        # Whether a sample falls in a bin is a product of hard thresholds
        return (above_left >= 0) * (below_right > 0)

# Model 20: HMMMoE 
class HMMMoE_RL(HMMMoE):
    def __init__(self, action_dim, action_hidden, action_layers, action_out, context_layers, n_modules,
                 use_gru=False, rank=4, true_modules=False, true_gating=False,     
                 n_tasks=30, n_contexts=3, n_steps=25, n_operations=6, task=0,
                 weight_init=0.1, sigma_init=0.05):
        # Inherit nn.Module init without the HMMoE init        
        nn.Module.__init__(self)
        
        # Copy task parameters
        self.n_tasks = n_tasks
        self.n_contexts = n_contexts
        self.n_steps = n_steps
        self.n_operations = n_operations
        self.task = task # Obsolete now; should be set to 1
        self.output_dim = action_out
                
        # Copy network parameters
        self.n_modules = n_modules
        self.true_modules = true_modules
        self.true_gating = true_gating
        self.hard = False
              
        # Get device
        self.set_device(move_vars=False)
                       
        # Create independent recurrent networks
        self.action_rnn = nn.ModuleList(
                    [script_gru.LowRankRNNCell(action_dim, action_hidden, rank, nonlinearity='tanh')
                     if rank > 0 else script_gru.scriptVanillaCell(action_dim, action_hidden, weight_init=weight_init)
                     for _ in range(n_modules)])
                
        # Output is read out from hidden state
        self.action_output = nn.ModuleList(
            [nn.Linear(action_hidden, action_out, bias=False)
             for _ in range(n_modules)])        
        for m in self.action_output:
            init_linear(m, weight_init)        
        
        # Create output standard deviation (square root of variance)
        self.log_sigma = nn.Parameter(torch.tensor(sigma_init, device=self.device))
            
        # Create context RNN
        context_net = script_gru.scriptGRUCell if use_gru else script_gru.scriptVanillaCell
        self.context_rnn = context_net(action_dim + n_modules, action_hidden, weight_init=weight_init)
               
        # The output of the context RNN is a one-hot gating signal across modules
        self.context_output = CatSample(action_hidden, n_modules, t=1, p_only=False, weight_init=weight_init, hard=self.hard)
        
        # Set learnable initial state for context RNN
        self.context_h0 = nn.Parameter(torch.zeros(action_hidden, dtype=torch.float, device=self.device))
        
        # Set learnable initial state for action RNN
        self.action_h0 = torch.zeros(action_out, dtype=torch.float, device=self.device) if self.true_modules \
            else nn.Parameter(torch.zeros(action_hidden, dtype=torch.float, device=self.device))
        
        # Set learnable initial context output (i.e. initial module)
        self.context_o0 = CatSample(1, n_modules, t=1, p_only=False, hard=False)
    
    def nll(self, env, context, target, N=500):        
        # Collect context and action signals throughout forward pass
        context_o, context_h, action_o, action_h = [], [], [], []
        action_pred, state_pred, action_inf, state_inf = [], [], [], []
        activations, ancestors, weights, likelihoods = [], [], [], []

        # Expand tasks and target across particles
        if self.task == 2:
            target = target[:, None, :, :, :].expand([-1, N, -1, -1, -1]) # batch x ptc x time x steps x 2
        else:
            target = target[:, None, :, :].expand([-1, N, -1, -1]) # batch x ptc x time x 2
        
        # Set initial state across particles
        s_t = self.t(env.task_init[None, :]).expand([context.shape[0], N, -1]) # batch x ptc x 2
        
        # Initialise all states before first timestep
        c_h = self.context_h0.expand([context.shape[0], N, -1]) # batch x ptc x hidden
        a_h = self.action_h0.expand([context.shape[0], N, -1]) # batch x ptc x hidden
        c_o = self.context_o0(self.t(torch.ones([context.shape[0], N, 1])))[1] # batch x ptc x module
                
        # Run through timesteps
        for t in range(self.n_steps):         
            
            # For any nn.Module operation, I can only have one batch dimension
            # That means I'll need to flatten the first dimensions batch x ptc into batch * ptc
            d_t, c_o_prev, c_h, a_h = [
                torch.flatten(d, 0, 1) for d in [torch.zeros_like(s_t), c_o, c_h, a_h]]
            
            # 1. Sample updated state from transition
            
            if self.true_gating:
                # Use ground truth context as gating signal
                c_o = torch.flatten(context[:,t,:].unsqueeze(1).expand([-1,N,-1]), 0, 1)
            else:
                # Step the context RNN
                c_i = torch.concatenate([c_o_prev, d_t], -1)
                c_h = self.context_rnn(c_i, c_h) # batch * ptc x hidden
                        
                # Sample new module activations as output from context RNN
                c_o = self.context_output(c_h)[1] # batch * ptc x modules
                
            # Reset the hidden state when module changes (i.e. near-zero dot product)   
            dot = torch.sum(c_o_prev * c_o, dim=-1, keepdim=True)
            # The initial hidden state must be detached for stability
            a_h = dot * a_h + (1 - dot) * self.action_h0.expand([context.shape[0] * N, -1])
                        
            # 2. Calculate state emission            
            
            # Run one step of action RNN
            a_h_n = [rnn(d_t, a_h) for rnn in self.action_rnn] # modules x [batch * ptc x hidden]
            
            # Get potential output for each RNN
            a_o_n = [out(h) for out, h in zip(self.action_output, a_h_n)] # modules x [batch * ptc x data]
            
            # Gate action RNN: run all, then multiply by active module per batch
            a_h = torch.matmul(c_o[:,None,:], torch.stack(a_h_n, dim=1)).squeeze(1) # batch * ptc x hidden   

            # Do the same for output
            a_o = torch.matmul(c_o[:,None,:], torch.stack(a_o_n, dim=1)).squeeze(1) # batch * ptc x data   

            # Now I'm done with nn.Modules, so I can unflatten batch * ptc to batch x ptc
            c_o, c_h, a_h, a_o = [
                torch.unflatten(d, 0, [-1, N]) for d in [c_o, c_h, a_h, a_o]]

            # Generate the action from module output
            a_t = self.get_action(a_o)
            
            # Transition the state based on the action
            s_t = env.transition(s_t, a_t)       
            
            # Flatten across steps for Bezier curve trajectories
            if self.task == 2:
                s_t = s_t.reshape(*s_t.shape[:-2], -1)
            
            # Log the gating output, action, and new state *before* resampling
            context_o.append(c_o)
            action_pred.append(a_t)
            state_pred.append(s_t)            
            
            # 3. Update particle weights from likelihood
        
            # Get particle weights from target likelihood
            w = self.multivariate_log_likelihood(s_t, target[:,:,t,:])

            # Calculate normalisation of particle weights
            W = w.logsumexp(dim=1, keepdim=True)
            
            # Normalise particle weights across particles
            w = w - W # batch x ptc
            
            # 4. Resample particles

            # Finally, resample particles from updated likelihood
            p = self.resample_systematic(w) # batch x ptc x ptc
            
            # Only resample particles for batches with feedback provided at this step
            p_keep = self.t(torch.eye(p.shape[-1])).expand([context.shape[0], -1, -1])
            do_feedback = torch.sum(context[:,t,:],-1)[:,None,None]
            p = p * do_feedback + p_keep * (1-do_feedback)
            
            # Update module and RNN hidden states according to sampled particles
            c_h = torch.matmul(p, c_h)
            a_h = torch.matmul(p, a_h)
            c_o = torch.matmul(p, c_o)
            a_t = torch.matmul(p, a_t)
            s_t = torch.matmul(p, s_t)
            
            # Append the current outputs and hidden states to lists
            context_h.append(c_h)
            action_o.append(a_o)
            action_inf.append(a_t)
            state_inf.append(s_t)                        
            action_h.append(a_h)
            activations.append(c_o)
            ancestors.append(torch.argmax(p, dim=-1))
            weights.append(w)
            likelihoods.append(W.squeeze()) # technically -log(N)
            
            # At the very end, in case of bezier task: set state to final position
            if self.task == 2:
                s_t = s_t[...,-2:]

        # Concatenate the outputs and hidden states through time
        context_o, context_h, action_o, action_h, \
            action_pred, state_pred, action_inf, state_inf, \
                activations, ancestors, weights, likelihoods = [
                    torch.stack(y, axis=(-2 if y[0].ndim > 2 else -1)) for y in
                    [context_o, context_h, action_o, action_h,
                     action_pred, state_pred, action_inf, state_inf,
                     activations, ancestors, weights, likelihoods]]

        # Return results
        return {'action_out': action_o, 'action_hidden': action_h, 
                'context_out': context_o, 'context_hidden': context_h,
                'pred_action': action_pred, 'pred_state': state_pred,
                'inf_action': action_inf, 'inf_state': state_inf,
                'activation': activations, 'ancestor': ancestors,
                'weight': weights, 'likelihood': likelihoods}    
    
    def nll_guided(self, env, context, target, N=500):        
        # Collect context and action signals throughout forward pass
        context_o, context_h, action_o, action_h = [], [], [], []
        action_pred, state_pred, action_inf, state_inf = [], [], [], []
        activations, ancestors, weights, likelihoods = [], [], [], []

        # Expand target across particles and modules in dim 1 and 2, indepent of what comes after
        target = target.unsqueeze(1).unsqueeze(2) # batch x 1 x 1 x time (x steps) x 2
        target = target.expand([-1, N, self.n_modules] + [-1] * (target.dim()-3)) # batch x ptc x modules x time (x steps) x 2
        
        # Set initial state across particles
        s_t = self.t(env.task_init[None, :]).expand([context.shape[0], N, -1]) # batch x ptc x 2
        
        # Initialise all states before first timestep
        c_h = self.context_h0.expand([context.shape[0], N, -1]) # batch x ptc x hidden
        a_h = self.action_h0.expand([context.shape[0], N, -1]) # batch x ptc x hidden
        c_o = self.context_o0(self.t(torch.ones([context.shape[0], N, 1])))[1] # batch x ptc x module
                
        # Run through timesteps
        for t in range(self.n_steps):         
            # Select target at current timestep: dimension 4, independent of what comes after
            t_t = target[[slice(None)]*3 + [t] + [slice(None)]*(target.dim()-4)] # batch x ptc x modules (x steps) x 2
            
            # For any nn.Module operation, I can only have one batch dimension
            # That means I'll need to flatten the first dimensions batch x ptc into batch * ptc
            d_t, s_t, t_t, c_o_prev, c_h, a_h = [
                torch.flatten(d, 0, 1) for d in [torch.zeros_like(s_t), s_t, t_t, c_o, c_h, a_h]]
            
            # 1. Retrieve log probabilities of the next hidden states
            
            if self.true_gating:
                # Use ground truth context as gating signal
                c_o = torch.flatten(context[:,t,:].unsqueeze(1).expand([-1,N,-1]), 0, 1)
                # Turn that into log probabilities; clip to avoid nans
                c_p = torch.log(torch.clip(c_o, 1e-15))
            else:
                # Step the context RNN
                c_i = torch.concatenate([c_o_prev, d_t], -1)
                c_h = self.context_rnn(c_i, c_h) # batch * ptc x hidden                        
                # Get log probabilities across modules
                c_p = self.context_output(c_h)[0] # batch * ptc x modules
               
                                                        
            # 2. Calculate state emission to get likelihood         
                        
            # Reset hidden state for any module switches
            a_h_n = c_o_prev[:,:,None] * a_h[:, None, :].expand([-1, self.n_modules, -1]) \
                + (1 - c_o_prev[:,:,None]) * self.action_h0[None, None, :].expand([context.shape[0] * N, self.n_modules, -1])        
            
            # Run one step of action RNN
            a_h_n = torch.stack([rnn(d_t, a_h_n[:,z,:]) for z, rnn in enumerate(self.action_rnn)], dim=1) # batch * ptc x modules x hidden
            
            # Get potential output for each RNN
            a_o_n = torch.stack([self.get_action(out(a_h_n[:,z,:])) 
                                 for z, out in enumerate(self.action_output)], dim=1) # batch * ptc x modules x data
            
            # Transition the state based on the output
            s_t_n = env.transition(s_t[:, None, :].expand([-1, self.n_modules, -1]), a_o_n)            
            
            # Get likelihood
            a_p = self.multivariate_log_likelihood(s_t_n, t_t) # batch * ptc x modules
            
            # Reset likelihood for batches without feedback at this step
            a_p = a_p * torch.flatten(torch.sum(context[:,t,:],-1).unsqueeze(1).expand([-1,N]), 0, 1)[:,None]
                        
            # 3. Calculate proposal: likelihood times transition
            
            # Proposal distribution q(z_t | z_t-1, y_t) = p(z_t | z_t-1) *  p(y_t | z_t)
            q = c_p + a_p # batch * ptc x modules
            
            # Get normalisation factor for q across modules - this is particle weight
            w = q.logsumexp(dim=-1, keepdim=True) # batch * ptc x 1
            
            # Normalise q so I can sample from it
            q = q - w
            
            # 4. Sample modules z_t from proposal distribution q
            
            # Sample a module for each particle
            c_o = torch.nn.functional.gumbel_softmax(q, hard=self.hard) # batch * ptc x modules
                        
            # Use sampled gating to get particle hidden state
            a_h = torch.matmul(c_o[:,None,:], a_h_n).squeeze(1) # batch * ptc x hidden   

            # Do the same for output
            a_o = torch.matmul(c_o[:,None,:], a_o_n).squeeze(1) # batch * ptc x data   
            
            # Generate the action from module output: soft clipping between -1 and 1
            a_t = a_o # batch * ptc x data
            
            # Transition the state based on the action
            s_t = env.transition(s_t, a_t) # batch * ptc x data
            
            # Now I'm done with nn.Modules, so I can unflatten batch * ptc to batch x ptc
            c_o, c_h, a_h, a_o, a_t, s_t, w = [
                torch.unflatten(d, 0, [-1, N]) 
                for d in [c_o, c_h, a_h, a_o, a_t, s_t, w.squeeze()]]
            
            # Log the gating output, action, and new state *before* resampling
            context_o.append(c_o)
            action_pred.append(a_t)
            state_pred.append(s_t)                        
            
            # 5. Resample particles from particle weights

            # Calculate normalisation of particle weights. This is marginal likelihood!
            W = w.logsumexp(dim=1, keepdim=True)
            
            # Normalise particle weights across particles
            w = w - W # batch x ptc
            
            # Finally, resample particles from updated likelihood
            p = self.resample_systematic(w) # batch x ptc x ptc
            
            # Only resample particles for batches with feedback provided at this step
            p_keep = self.t(torch.eye(p.shape[-1])).expand([context.shape[0], -1, -1])
            do_feedback = torch.sum(context[:,t,:],-1)[:,None,None]
            p = p * do_feedback + p_keep * (1-do_feedback)
            
            # Update module and RNN hidden states according to sampled particles
            c_h = torch.matmul(p, c_h)
            a_h = torch.matmul(p, a_h)
            c_o = torch.matmul(p, c_o)
            a_t = torch.matmul(p, a_t)
            s_t = torch.matmul(p, s_t)
            
            # Append the current outputs and hidden states to lists
            context_h.append(c_h)
            action_o.append(a_o)
            action_inf.append(a_t)
            state_inf.append(s_t)                        
            action_h.append(a_h)
            activations.append(c_o)
            ancestors.append(torch.argmax(p, dim=-1))
            weights.append(w)
            likelihoods.append(W.squeeze()) # technically -log(N)       

        # Concatenate the outputs and hidden states through time
        context_o, context_h, action_o, action_h, \
            action_pred, state_pred, action_inf, state_inf, \
                activations, ancestors, weights, likelihoods = [
                    torch.stack(y, axis=(-2 if y[0].ndim > 2 else -1)) for y in
                    [context_o, context_h, action_o, action_h,
                     action_pred, state_pred, action_inf, state_inf,
                     activations, ancestors, weights, likelihoods]]

        # Return results
        return {'action_out': action_o, 'action_hidden': action_h, 
                'context_out': context_o, 'context_hidden': context_h,
                'pred_action': action_pred, 'pred_state': state_pred,
                'inf_action': action_inf, 'inf_state': state_inf,
                'activation': activations, 'ancestor': ancestors,
                'weight': weights, 'likelihood': likelihoods}    
        
    def get_action(self, a_o):
        # Generate the action from module output
        return a_o # For continuous actions
    
    def trace_back(self, output, N=1, sample=False):
        # Start from the most likely final particle, and trace back its ancestors
        
        # Get the output weights of the resampled particles *before* resampling
        w = select_along_dim(output['weight'][:,:,-1], output['ancestor'][:,:,-1], dim=1)
        
        # Then select based on that. First determine which particles you want to trace back.
        # Either sample from the final likelihoods, or deterministically take the top ones
        if sample:
            p_dist = torch.distributions.categorical.Categorical(w)
            p_id = p_dist.sample((N,)).t()
        else:
            p_id = torch.argsort(w, descending=True)[:,:N] 
            
        # Then iterate back, while rebuilding the output matrix across the selected particles
        context_o, context_h, action_o, action_h = [], [], [], []
        action_pred, state_pred, action_inf, state_inf = [], [], [], []        
        activations, ancestors, weights, likelihoods = [], [], [], []
        
        # Iterate through reversed time
        for t in reversed(range(self.n_steps)):
            # The resampled variables are collected from *currently selected* particles
            c_o = select_along_dim(output['context_out'][:,:,t], p_id, dim=1) 
            c_h = select_along_dim(output['context_hidden'][:,:,t], p_id, dim=1) 
            a_h = select_along_dim(output['action_hidden'][:,:,t], p_id, dim=1)
            act = select_along_dim(output['activation'][:,:,t], p_id, dim=1) 
            a_inf = select_along_dim(output['inf_action'][:,:,t], p_id, dim=1)
            s_inf = select_along_dim(output['inf_state'][:,:,t], p_id, dim=1)             
            # Then update the currently selected particles to their ancestors
            p_id = select_along_dim(output['ancestor'][:,:,t], p_id, dim=1) 
            # And get non-resampled variables from *ancestor* indices                
            a_o = select_along_dim(output['action_out'][:,:,t], p_id, dim=1)
            w = select_along_dim(output['weight'][:,:,t], p_id, dim=1)
            a_pred = select_along_dim(output['pred_action'][:,:,t], p_id, dim=1)
            s_pred = select_along_dim(output['pred_state'][:,:,t], p_id, dim=1)            
            # W needs to be renormalised
            w = w + output['likelihood'][:,t].unsqueeze(1)
            W = w.logsumexp(dim=1, keepdim=True)
            w = w - W
            # Append the currently selected particle's properties
            ancestors.append(p_id)
            context_o.append(c_o)
            context_h.append(c_h)
            action_o.append(a_o)
            action_h.append(a_h)
            action_pred.append(a_pred)
            state_pred.append(s_pred)
            action_inf.append(a_inf)
            state_inf.append(s_inf)
            activations.append(act)
            weights.append(w)
            likelihoods.append(W.squeeze()) # technically -log(N)
        
        # Concatenate the outputs and hidden states through time, but with time reversed
        context_o, context_h, action_o, action_h, \
            action_pred, state_pred, action_inf, state_inf, \
                activations, ancestors, weights, likelihoods = [
                    torch.stack(y[::-1], axis=(-2 if y[0].ndim > 2 else -1)) for y in
                    [context_o, context_h, action_o, action_h,
                     action_pred, state_pred, action_inf, state_inf,
                     activations, ancestors, weights, likelihoods]]

        # Return results
        return {'action_out': action_o, 'action_hidden': action_h, 
                'context_out': context_o, 'context_hidden': context_h,
                'pred_action': action_pred, 'pred_state': state_pred,
                'inf_action': action_inf, 'inf_state': state_inf,
                'activation': activations, 'ancestor': ancestors,
                'weight': weights, 'likelihood': likelihoods}    
   
# Model 21: control model that is just an RNN
class RNNControl(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, 
                 n_tasks=30, n_contexts=3, n_steps=25, n_operations=6,
                 task_id=True, weight_init=1):
        super().__init__()
        
        # Copy task parameters
        self.n_tasks = n_tasks
        self.n_contexts = n_contexts
        self.n_steps = n_steps
        self.n_operations = n_operations
                
        # Copy network parameters
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.task_id = task_id
       
        # Get device
        self.set_device()
        
        # Create network that runs
        self.action_rnn = script_gru.scriptGRUCell(
            self.in_dim, self.hidden_dim, 
            nonlinearity='tanh', weight_init=weight_init)                
        
        # Output is read out from hidden state
        self.action_out =  nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim),
                               nn.ReLU(),
                               nn.Linear(self.hidden_dim, self.out_dim))

        # Set learnable initial state for RNN
        self.action_h0 = nn.Parameter(torch.zeros(self.hidden_dim, dtype=torch.float, device=self.device))
        
    def set_device(self, device=None):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') \
            if device is None else device     

    # This is just a hacky way to avoid writing dtype=torch.float, device=self.device all the time
    def t(self, tensor):
        return tensor.float().to(self.device)  

    def forward(self, data, task):        
        # Collect output and hidden state throughout forward pass
        action_o, action_h, action_i = [], [], []
                
        # And initialise the representation of the action RNN on the first iteration
        a_h = self.action_h0.expand([data.shape[0], -1])

        # Set task id input
        task = task if self.task_id else torch.zeros_like(task)
                                    
        # Run through timesteps
        for t in range(self.n_steps):          
                        
            # Prepare action RNN input from data (skip fixation for sensorimotor data)
            a_i = torch.cat([data[:,t,:], task], axis=-1)
                       
            # Run one step of action RNN
            a_h = self.action_rnn(a_i, a_h)

            # Calculate action output
            a_o = self.action_out(a_h) 
            
            # Append the current outputs and hidden states to lists
            action_o.append(a_o)
            action_h.append(a_h)
            action_i.append(a_i)
    
        # Concatenate the outputs and hidden states through time
        action_o, action_h, action_i = [
                torch.stack(y, axis=1) for y in 
                [action_o, action_h, action_i]]
        # Return results
        return {'action_out': action_o, 'action_hidden': action_h, 'action_in': action_i}           
    
    
def select_along_dim(tensor, indices, dim=1):
    # Select tensor entries for a subset of leading dimension indices,
    # while keeping the full trailing dimensions
    # E.g. for tensor of shape a x b x c x d x e, and indices of a x i,
    # return a tensor a x i x c x d x e
    
    # Expand indices to match tensor number trailing dimensions
    while indices.dim() < tensor.dim():
        indices = indices.unsqueeze(-1)

    # Make all new dimensions of indices match the tensor dimensions
    expand_shape = list(tensor.shape)
    expand_shape[dim] = indices.shape[dim]
    indices = indices.expand(expand_shape)

    # Use gather to select values along the specified dimension
    return torch.gather(tensor, dim=dim, index=indices)

def init_linear(module, weight_init=1.0, bias_init=0.0):
    # Initialise a linear layer module with scaled initial weights and zero bias
    torch.nn.init.xavier_uniform_(module.weight, gain=weight_init)
    if module.bias is not None:
        torch.nn.init.constant_(module.bias, bias_init)