from typing import Dict, List, Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import layer_init, layer_init_lstm, layer_init_gru

from regym.rl_algorithms.utils import _extract_from_rnn_states, extract_subtree, copy_hdict
from regym.rl_algorithms.networks.bodies import LSTMBody

import wandb

def _register_nan_checks(model):
    def check_grad(module, grad_in, grad_out):
        #wandb.log({f"{type(module).__name__}_gradients": wandb.Histogram(grad_in)})
        if any([torch.any(torch.isnan(gi.data)) for gi in grad_in if gi is not None]):
            print(type(module).__name__)
            import ipdb; ipdb.set_trace()

    model.apply(lambda module: module.register_backward_hook(check_grad))


class BasicDCEMHeads(nn.Module):
    def __init__(
        self,
        memory, 
        input_dim=256, 
        nbr_heads=1, 
        simplified=False,
        ):
        super(BasicDNCHeads,self).__init__()

        self.memory = memory
        self.mem_dim = self.memory.mem_dim
        self.nbr_heads = nbr_heads
        self.input_dim = input_dim
        self.simplified = simplified 

        self.generate_query_net()

    def generate_query_net(self) :
        # Generates:
        # kr: read keys
        self.query_dim = self.nbr_heads*self.memory.mem_dim
        
        self.query_net = layer_init(
            nn.Linear(
                self.input_dim, 
                self.query_dim
            ),
            w_scale=1e-3,
            init_type='ortho',
        )
    
    def write(self, memory_state, ctrl_inputs):
        raise NotImplementedError

    def read(self, memory_state, ctrl_inputs):
        raise NotImplementedError

    def forward(self, query_input):
        # WARNING: it is imperative to make a copy 
        # of the frame_state, otherwise any changes 
        # will be repercuted onto the current frame_state
        x = query_inputs
        query_output = self.query_net(x)
        #ctrl_output = ctrl_output.view((-1, self.nbr_heads, self.head_gate_dim))

        odict = self._generate_addressing(query_output)

        return odict
           
    def _generate_addressing(self, query_output) :
        odict = {}
        
        start = 0
        end = self.nbr_heads*self.mem_dim
        #odict['kr'] = ctrl_output[:,start:end].reshape(-1, self.nbr_heads, self.mem_dim)
        odict['kr'] = torch.tanh(query_output[:,start:end]).reshape(-1, self.nbr_heads, self.mem_dim)
        
        return odict

    
class ReadWriteHeads(BasicDNCHeads):
    def __init__(
        self, 
        memory, 
        nbr_heads=1, 
        input_dim=256,
        simplified=False,
        ):
        super(ReadWriteHeads,self).__init__(
            memory=memory,
            input_dim=input_dim,
            nbr_heads=nbr_heads,
            simplified=simplified,
        )
    
    def _update_usage_vector(
        self,
        prev_usage_vector,
        free_gates,
        prev_read_weights,
        prev_write_weights,
        ):
        batch_size = prev_usage_vector.shape[0]
        # ensure minimum usage for stability:
        prev_usage_vector = 5e-3+(1-5e-3)*prev_usage_vector
        
        # write_weights = write_weights.detach()  # detach from the computation graph
        # (batch_size x nbr_read/write_heads x mem_nbr_slots)
        psi = torch.prod(1 - free_gates.reshape(batch_size, -1, 1) * prev_read_weights, dim=1)
        # (batch_size x nbr_mem_slots)
        #wandblog({f"psi": wandb.Histogram(psi.cpu().detach())})
        
        # if we only had one write head:
        # usage = prev_usage_vector + pev_write_weights -prev_usage_vector*prev_write_weights
        # with multiple write head:
        ## the more we write, the more usage increases:
        ## because these values are weights in [0,1],
        ## multiplying them together reduces the usage,
        ## unless we multiple together the opposite probabilities on each slots,
        ## thus reducing the overal opposite probabilities, and increasing 
        ## the probability of the event of using a given memory slot.
        ## Thus, we take againt the opposite probabilty of those successive events:
        reg_prev_write_weights = (1-torch.prod(1-prev_write_weights, dim=1))
        # (batch_size x mem_nbr_slots)
        usage = prev_usage_vector + (1 - prev_usage_vector) * reg_prev_write_weights
        usage = usage * psi
        return usage

    def forward(
        self,
        query_input,
        ):
        odict = super(ReadWriteHeads, self).forward(
            query_input=query_inputs,
        )
        return odict

    def write(
        self, 
        memory_state, 
        odict, 
        prev_usage_vector,
        prev_read_weights,
        prev_write_weights,
        ):
        batch_size = prev_usage_vector.shape[0]
        updated_usage_vector = self._update_usage_vector(
            prev_usage_vector=prev_usage_vector,
            free_gates=odict['f'],
            prev_read_weights=prev_read_weights,
            prev_write_weights=prev_write_weights,
        )
        # (batch_size x mem_nbr_slots)
        #wandb.log({f"usage": wandb.Histogram(updated_usage_vector.cpu().detach())})
        
        # Adapted from:
        # https://github.com/ixaxaar/pytorch-dnc/blob/33e35326db74c7ccd45360d6668682e60b407d1f/dnc/memory.py#L84
        ## Compute free list:
        sorted_usage, phi = torch.topk(
            updated_usage_vector,
            k=self.memory.mem_nbr_slots,
            dim=1,
            largest=False,
        )

        ## Compute 1-index-delayed cum. product of sorted usages:
        delayed_sorted_usage = torch.cat([
            torch.ones(*sorted_usage.shape[:-1], 1).to(phi.device),
            sorted_usage,],
            dim=-1,
        )
        delayed_prod_sorted_usage = torch.cumprod(
            delayed_sorted_usage,
            dim=-1,
        )[...,:-1] # j-th slot only gets the cumprod till (j-1)-th slot.
        
        sorted_allocation_weights = (1-sorted_usage)*delayed_prod_sorted_usage
        #(batch_size x mem_nbr_slots)
        
        # Unsort allocation weights 
        # by reversing sorting (== by sorting the sorted indices):
        _, unsorted_indices = torch.topk(
            phi,
            k=self.memory.mem_nbr_slots,
            dim=1,
            largest=False,
        )
        # and then re-order the sorted allocation weights:
        allocation_weights = torch.gather(
            sorted_allocation_weights,
            dim=1,
            index=unsorted_indices.long(),
        ).reshape(batch_size, 1, self.memory.mem_nbr_slots)
        # (batch_size x 1 x mem_nbr_slots)
        #wandb.log({f"allocation": wandb.Histogram(allocation_weights.cpu().detach())})

        # Content Addressing :
        wc = self.memory.content_addressing(memory_state, odict['kw'], odict['betaw'])
        #wandblog({f"write_content": wandb.Histogram(wc.cpu().detach())})

        # Interpolation between content and allocation:
        write_weights = odict['gw']*(odict['ga']*allocation_weights+(1-odict['ga'])*wc)
        #(batch_size x 1 x nbr_mem_slots  )
        new_memory_state = self.memory.write(
            memory_state=memory_state,
            w=write_weights,
            erase=odict['erase'],
            add=odict['write'],
        )
        
        odict['usage_vector'] = updated_usage_vector
        odict['write_weights'] = write_weights
        odict['allocation_weights'] = allocation_weights

        return new_memory_state, updated_usage_vector, write_weights 
    
    def simplified_write(
        self,
        memory_state:torch.Tensor,
        odict:Dict[str,torch.Tensor],
        discount_factor:float,
        timestep:int,
        prev_ret_write_weights:torch.Tensor,
        prev_write_weights:torch.Tensor,
        vector_to_write:Optional[torch.Tensor]=None,
        ):
        batch_size = memory_state.shape[0]

        # Write weights:
        bfilter = (timestep < self.memory.mem_nbr_slots).long()
        ts_write_weights = torch.zeros(batch_size, 1, self.memory.mem_nbr_slots).to(
            timestep.device
        ).index_fill_(
            dim=-1,
            index=(bfilter*timestep).long().reshape(batch_size),
            value=1.0,
        )

        _, least_used_index = odict['usage_vector'].min(dim=-1, keepdim=True)
        # (batch_size, 1)
        lu_write_weights = torch.zeros(batch_size, 1, self.memory.mem_nbr_slots).to(
            timestep.device
        ).index_fill_(
            dim=-1,
            index=least_used_index.long().reshape(batch_size),
            value=1.0,
        )
        
        write_weights = bfilter*ts_write_weights+(1-bfilter)*lu_write_weights

        # Retroactive Adressing:
        ## Interpolation between prev_write_weights and prev_retroactive_weights:
        ret_write_weights = discount_factor*prev_ret_write_weights+(1-discount_factor)*prev_write_weights
        #(batch_size x 1 x nbr_mem_slots  )
        
        if vector_to_write is None:
            vector_to_write = odict['write']

        new_memory_state = self.memory.simplified_write(
            memory_state=memory_state,
            write_weights=write_weights,
            ret_write_weights=ret_write_weights,
            vector_to_write=vector_to_write,
        )
        
        return new_memory_state, write_weights, ret_write_weights

    def read(
        self,
        memory_state:torch.Tensor,
        odict:Dict[str,torch.Tensor],
        prev_usage_vector:torch.Tensor,
        k:int,
        ):
        """
        memory_state: batch_size x NS x mem_dim
        where NS = min(nbr_memory_slots, max(timestep, k))
        """
        # Content Addressing :
        read_weights = self.memory.content_addressing(
            memory_state, 
            odict['kr'], 
            odict['betar'],
        )
        #( batch_size, nbrHeads, NS)
        odict['read_weights'] = read_weights
        
        topk_similarities , topk_positions = torch.topk(
            read_weights,
            k=k,
            dim=-1,
            largest=True,
        )
        #TODO: maybe sorting is important?
        #( batch_size, nbrHeads, k)
        import ipdb; ipdb.set_trace()

        topk_memories = self.memory.read(
            memory_state=memory_state, 
            positions=topk_positions
        )
        #( batch_size, nbrHeads, k, mem_dim)

        topk_weighted_memories = topk_similarities.softmax(dim=-1).unsqueeze(-1)*topk_memories

        odict['read_vectors'] = read_vectors
        
        updated_usage_vector = prev_usage_vector + read_weights.sum(dim=1)
        #( batch_size, nbr_mem_slots)
        odict['usage_vector'] = updated_usage_vector 

        return read_vectors, updated_usage_vector 


class DNCController(LSTMBody):
    def __init__(
        self, 
        input_dim=32, 
        hidden_units=[512], 
        output_dim=32, 
        mem_nbr_slots=128, 
        mem_dim= 32, 
        nbr_read_heads=1, 
        nbr_write_heads=1,
        extra_inputs_infos: Optional[Dict]={},
        ):
        """
        :param extra_inputs_infos: Dictionnary containing the shape of the lstm-relevant extra inputs.
        """

        #LSTMinput_size = (input_dim+output_dim)+mem_dim*nbr_read_heads
        LSTMinput_size = input_dim
        # output_dim was added in the context of few-shot learning 
        # where the previous desired output is fed as input alongside the new input.
        # mem_dim*nbr_read_heads are implicit parts that must be taken into account:
        # they are out-of-concern here, though:
        # the NTM module is itself adding them to the input...
        
        super(DNCController, self).__init__(
            state_dim=LSTMinput_size,
            hidden_units=hidden_units,
            gate=None,
            extra_inputs_infos=extra_inputs_infos,
        )

        self.input_dim = input_dim
        self.hidden_units = hidden_units
        self.output_dim = output_dim
        self.mem_nbr_slots = mem_nbr_slots
        self.mem_dim = mem_dim
        self.nbr_read_heads = nbr_read_heads
        self.nbr_write_heads = nbr_write_heads

        self.build_controller()

    def build_controller(self):
        controller_lstm_output_dim = self.hidden_units[-1]
        # Output Function:
        self.linear_output = layer_init(
            nn.Linear(
                controller_lstm_output_dim,
                self.output_dim,
            ),
            w_scale=1e-3,
            init_type='ortho',
        )
        
        # External Outputs :
        self.output_fn = []
        # input = (r0_{t}, ..., rN_{t})
        self.EXTinput_size = self.mem_dim * self.nbr_read_heads
        self.output_fn.append( 
            layer_init(
                nn.Linear(
                    self.EXTinput_size, 
                    self.output_dim
                ),
                w_scale=1e-3,
            )
        )
        
        self.output_fn = nn.Sequential(*self.output_fn)

    def forward_external_output_fn(self, vt_output, slots_read) :
        batch_size = slots_read.shape[0]
        rslots_read = slots_read.reshape(batch_size, -1)
        output_fn_output = vt_output + self.output_fn(rslots_read)
        
        return output_fn_output
    
    def forward_controller(self, inputs):
        '''
        :param inputs: input to LSTM cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).
        hidden_states: list of hidden_state(s) one for each self.layers.
        cell_states: list of hidden_state(s) one for each self.layers.
        '''
        # WARNING: it is imperative to make a copy 
        # of the frame_state, otherwise any changes 
        # will be repercuted onto the current frame_state
        x, frame_states = inputs[0], copy_hdict(inputs[1])
        
        recurrent_neurons = extract_subtree(
            in_dict=frame_states,
            node_id='lstm',
        )

        extra_inputs = extract_subtree(
            in_dict=frame_states,
            node_id='extra_inputs',
        )

        extra_inputs = [v[0].to(x.dtype).to(x.device) for v in extra_inputs.values()]
        if len(extra_inputs): x = torch.cat([x]+extra_inputs, dim=-1)
        augmented_x = x 

        if next(self.layers[0].parameters()).is_cuda and not(x.is_cuda):    x = x.cuda() 
        hidden_states, cell_states = recurrent_neurons['hidden'], recurrent_neurons['cell']

        next_hstates, next_cstates = [], []
        outputs = []
        for idx, (layer, hx, cx) in enumerate(zip(self.layers, hidden_states, cell_states) ):
            batch_size = x.size(0)
            if hx.size(0) == 1: # then we have just resetted the values, we need to expand those:
                hx = torch.cat([hx]*batch_size, dim=0)
                cx = torch.cat([cx]*batch_size, dim=0)
            elif hx.size(0) != batch_size:
                raise NotImplementedError("Sizes of the hidden states and the inputs do not coincide.")

            if next(layer.parameters()).is_cuda and \
                (hx is not None or not(hx.is_cuda)) and \
                (cx is  not None or not(cx.is_cuda)):
                if hx is not None:  hx = hx.cuda()
                if cx is not None:  cx = cx.cuda() 

            """
            nhx, ncx = layer(x, (hx, cx))
            next_hstates.append(nhx)
            next_cstates.append(ncx)
            """
            # VDN:
            if len(x.shape)==3:
                raise NotImplementedError
                shapex = x.shape
                shapehx = hx.shape
                shapecx = cx.shape 
                x = x.reshape(-1, shapex[-1])
                hx = hx.reshape(-1, shapehx[-1])
                cx = cx.reshape(-1, shapecx[-1])
                nhx, ncx = layer(x, (hx, cx))
                nhx = nhx.reshape(*shapehx[:2], -1)
                ncx = ncx.reshape(*shapecx[:2], -1)
            else:
                nhx, ncx = layer(x, (hx, cx))

            outputs.append([nhx, ncx])
            next_hstates.append(outputs[-1][0])
            next_cstates.append(outputs[-1][1])
            
            # Consider not applying activation functions on last layer's output?
            if self.gate is not None:
                x = self.gate(outputs[-1][0])
            else:
                x = outputs[-1][0]
        
        vt = self.linear_output(x.reshape(batch_size,-1))

        frame_states.update({'lstm':
            {'hidden': next_hstates, 
            'cell': next_cstates}
        })

        return augmented_x, vt, x, frame_states
    
    def get_reset_states(self, cuda=False, repeat=1):
        hidden_states, cell_states = [], []
        for layer in self.layers:
            h = torch.zeros(repeat, layer.hidden_size)
            if cuda:
                h = h.cuda()
            hidden_states.append(h)
            cell_states.append(h)
        return {'lstm':{'hidden': hidden_states, 'cell': cell_states}}

    def get_feature_shape(self):
        return self.output_dim


def asp(t, K=8):
    batch_size = t.shape[0]
    row_size = t.shape[1]
    col_size = t.shape[2]
    t_v, t_i = t.topk(k=K, dim=-1, largest=True, sorted=False)
    """
    st = torch.zeros_like(t)
    for bidx in range(batch_size):
        for ridx in range(row_size):
            for k in range(K):
                st[bidx, ridx, t_i[bidx, ridx, k]] = t[bidx, ridx, t_i[bidx, ridx, k]]
    st = st.to_sparse()
    """
    st = torch.zeros_like(t).scatter_(index=t_i, dim=-1, src=t_v).to_sparse()
    return st


class DNCMemory(nn.Module) :
    def __init__(
        self, 
        mem_nbr_slots, 
        mem_dim,
        sparse_K=0,
        ):
        
        super(DNCMemory,self).__init__()

        self.mem_nbr_slots = mem_nbr_slots
        self.mem_dim = mem_dim
        self.sparse_K = sparse_K

        self.initialize_memory()

    def initialize_memory(self) :
        # Constant 
        ## Null:
        self.init_mem = torch.zeros((1, self.mem_nbr_slots,self.mem_dim))
        ## Small:
        #self.init_mem = 1e-6*torch.ones((1, self.mem_nbr_slots,self.mem_dim))
        
    def get_reset_states(self, cuda=False, repeat=1):
        memory = []
        h = self.init_mem.clone().repeat(repeat, 1 , 1)
        if self.sparse_K!=0:    h = h.to_sparse()
        if cuda:
            h = h.cuda()
        memory.append(h)
        return {'memory': memory}

    def content_addressing(
        self,
        memory,
        k,
        beta
        ):
        batch_size = k.shape[0]
        nbrHeads = k.size()[1]
        eps = 1e-10
        
        #memory_bhSMidx = torch.cat([memory.unsqueeze(1)]*nbrHeads, dim=1).to(k.device)
        memory_bhSMidx = memory.unsqueeze(1).repeat(1,nbrHeads,1,1).to(k.device)
        # (batch_size, nbrHeads, nbr_mem_slot, mem_dim)
        #kmat = torch.cat([k.unsqueeze(2)]*self.mem_nbr_slots, dim=2)
        kmat = k.unsqueeze(2)
        # (batch_size, nbrHeards, 1, nbr_mem_slot)
        cossim = F.cosine_similarity( kmat, memory_bhSMidx, dim=-1)
        #(batch_size x nbrHeads nbr_mem_slot )
        w = F.softmax( beta * cossim, dim=-1)
        #(batch_size x nbrHeads nbr_mem_slot )
        # beta : (batch_size x nbrHeads x 1)
        return w 

    def write(
        self, 
        memory_state, 
        w, 
        erase, 
        add,
        ):
        # erase/add: (batch_size, nbrHeads, mem_dim)
        # w: (batch_size, nbrHeads, nbr_mem_slot)
        # memory_state: (batch_size, nbr_mem_slot, mem_dim)
        batch_size = w.shape[0]
        nmemory = memory_state

        nh = erase.shape[1]
        e = torch.matmul(w.unsqueeze(-1), erase.unsqueeze(2))
        a = torch.matmul(w.unsqueeze(-1), add.unsqueeze(2))
        for hidx in range(nh):
            nmemory = nmemory*(1-e[:,hidx])+a[:,hidx]
        
        return nmemory

    def simplified_write(
        self,
        memory_state,
        write_weights,
        ret_write_weights,
        vector_to_write,
        ):
        # w: (batch_size, nbrHeads, nbr_mem_slot)
        # memory_state: (batch_size, nbr_mem_slot, 2*mem_dim)
        batch_size = write_weights.shape[0]
        nmemory = memory_state

        nh = write_weights.shape[1]
        zero = torch.zeros_like(vector_to_write)
        z_write = torch.cat([vector_to_write, zero], dim=-1)
        z_ret = torch.cat([zero, vector_to_write], dim=-1)

        ret = torch.matmul(ret_write_weights.unsqueeze(-1), z_ret.unsqueeze(2))
        add = torch.matmul(write_weights.unsqueeze(-1), z_write.unsqueeze(2))
        for hidx in range(nh):
            nmemory = nmemory+ret[:,hidx]+add[:,hidx]
        return nmemory
        
    def read(self, memory_state, w):
        reading = torch.matmul(w, memory_state)
        #(batch_size x nbrHeads x mem_dim)
        return reading
        

class DCEMBody(nn.Module) :
    def __init__(
        self,
        input_dim=32, 
        hidden_units=512, 
        output_dim=32, 
        mem_nbr_slots=128, 
        mem_dim= 32, 
        nbr_read_heads=1,
        K=4,
        clip=20,
        sparse_K=0,
        extra_inputs_infos: Optional[Dict]={},
        kwargs:Optional[Dict[str,Any]]={},
        ):
        """
        :param simplified: Boolean, if True, then this module implements the simplified version 
            of the DNC proposed in Wayne et al., 2018 (https://arxiv.org/pdf/1803.10760.pdf),
            and re-used in Hill et al., 2020 (https://arxiv.org/pdf/2009.01719.pdf).
        """
        super(DNCBody,self).__init__()

        self.input_dim = input_dim
        self.hidden_units = hidden_units
        self.hidden_dim = hidden_units[-1]
        self.output_dim = output_dim
        self.extra_inputs_infos = extra_inputs_infos
        self.kwargs = kwargs

        self.mem_nbr_slots = mem_nbr_slots
        self.mem_dim = mem_dim
        self.sparse_K = sparse_K

        self.nbr_read_heads = nbr_read_heads
        self.K = K 

        self.clip = clip 

        self.build_memory()
        self.build_controller()
        self.build_heads()
        
        #_register_nan_checks(self)

    def build_memory(self) :
        self.memory = DCEMMemory(
            mem_nbr_slots=self.mem_nbr_slots,
            mem_dim=self.mem_dim,
            sparse_K=self.sparse_K,
        )
        
    def build_controller(self) :
        self.controller = DCEMController( 
            # taking into account the key and value modalities embeddings:
            #TODO: parameterise the dimension of the embeddings separatly?
            input_dim=self.input_dim+self.mem_dim*2, 
            hidden_units=self.hidden_units, 
            output_dim=self.output_dim, 
            mem_nbr_slots=self.mem_nbr_slots, 
            mem_dim=self.mem_dim, 
            nbr_read_heads=self.nbr_read_heads, 
            extra_inputs_infos=self.extra_inputs_infos,
        )

    def build_heads(self) :
        self.readWriteHeads = ReadWriteHeads(
            memory=self.memory,
            nbr_heads=self.nbr_read_heads, 
            input_dim=self.hidden_dim,
        )

    def _reset_weights(self, cuda=False, repeat=1, nbr_heads=1):
        # Constant:
        prev_w = torch.zeros((repeat, nbr_heads, self.mem_nbr_slots))
        # Constant with diversity:
        """
        prev_w = []
        for hidx in range(nbr_heads):
            offset = nbr_heads
            hw = torch.zeros(repeat, 1, self.mem_nbr_slots)
            hw[...,hidx+offset] = 1.0
            prev_w.append(hw)
        prev_w = torch.cat(prev_w, dim=1)
        """
        # Learnable:
        # prev_w = self.prev_w.repeat(repeat, 1, 1) 
        if cuda:
            prev_w = prev_w.cuda()
        return [prev_w]
            
    def get_reset_states(self, cuda=False, repeat=1):
        ## As an encapsulating module, it is its responsability
        # to call get_reset_states on the encapsulated elements:
        hdict = {'dcem_body':{}}

        """
        prev_read_vec = []
        h = torch.zeros(repeat, self.nbr_read_heads*self.mem_dim)
        if cuda:
            h = h.cuda()
        prev_read_vec.append(h)
        hdict['dcem_body']['prev_read_vec'] = prev_read_vec
        """

        prev_usage_vector = []
        h = torch.zeros(repeat, self.mem_nbr_slots)
        if cuda:    h = h.cuda()
        prev_usage_vector.append(h)
        hdict['dcem_body']['prev_usage_vector'] = prev_usage_vector
    
        """
        prev_write_weights = []
        h = torch.zeros(repeat, self.nbr_write_heads, self.mem_nbr_slots)
        if cuda:    h = h.cuda()
        prev_write_weights.append(h)
        hdict['dcem_body']['prev_write_weights'] = prev_write_weights
        """

        prev_timestep = []
        h = (-1)*torch.ones(repeat, 1, 1)
        if cuda:    h = h.cuda()
        prev_timestep.append(h)
        hdict['dcem_body']['prev_timestep'] = prev_timestep
            
        """
        prev_ret_write_weights = []
        h = torch.zeros(repeat, self.nbr_write_heads, self.mem_nbr_slots)
        if cuda:    h = h.cuda()
        prev_ret_write_weights.append(h)
        hdict['dcem_body']['prev_ret_write_weights'] = prev_ret_write_weights
        """

        hdict['dcem_controller'] = self.controller.get_reset_states(repeat=repeat, cuda=cuda)
        hdict['dcem_memory'] = self.memory.get_reset_states(repeat=repeat, cuda=cuda)
        return {'dcem':hdict}
    
    def _extract_prev_hidden_state(self, state_dict:Dict[str,Any]):
        
        # Extract 'hidden''s list:
        hiddens = extract_subtrees(
            in_dict=state_dict, 
            node_id="hidden",
        )
        # List[List[Tensor]]
        
        vdn = self.kwargs.get('vdn', False)
        vdn_nbr_players = self.kwargs.get('vdn_nbr_players', 2)
        
        nbr_rnn_modules = len(hiddens[0])
        batch_size = hiddens[0][0].shape[0]

        mult = 0
        if vdn and batch_size!=1: 
            batch_size = batch_size // vdn_nbr_players
            mult = self.player_idx
        
        hiddens = torch.stack([
            torch.cat([
                hiddens[0][part_id][mult*batch_size+actor_id].reshape(-1) 
                for part_id in range(nbr_rnn_modules)
                ],
                dim=0,
            )
            for actor_id in range(batch_size)
            ],
            dim=0,
        )
        # batch_size x nbr_parts*hidden_dims
        
        return hiddens 

 

    def forward(self, inputs):
        # 'input' : batch_dim x seq_len x self.input_dim
        # 'prev_desired_output' : batch_dim x seq_len x self.output_dim
        # 'prev_read_vec' : batch_dim x seq_len x self.nbr_read_head * self.mem_dim
        #x['prev_read_vec'] = self.read_outputs[-1]
        # Taking into account the previously read vector as a state:
        x, frame_states = inputs[0], copy_hdict(inputs[1])
        batch_size = x.shape[0]

        dcem_state_dict = extract_subtree(
            in_dict=frame_states,
            node_id='dcem',
        )

        key_modality_embedding = dcem_state_dict['key_modality_embedding'][0]
        key_modality_embedding = key_modality_embedding.to(x.dtype).to(x.device)
        value_modality_embedding = dcem_state_dict['value_modality_embedding'][0]
        value_modality_embedding = value_modality_embedding.to(x.dtype).to(x.device)
        
        # Reading from memory:
        ## Compute query q(v_t,l_t,h_t-1):
        prev_hidden_state = self._extract_prev_hidden_state(dcem_state_dict)
        prev_hidden_state = prev_hidden_state.to(x.dtype).to(x.device)
        
        query_input = torch.cat([key_modality_embedding, value_modality_embedding, prev_hidden_state], dim=-1)
        
        odict = self.readWriteHeads(query_input=query_input)
        
        ## Memory Read :
        prev_usage_vector = dnc_state_dict['dcem_body']['prev_usage_vector'][0].to(x.device)
        # batch_dim x nbr_read_heads * mem_dim :
        memory_state = dnc_state_dict['dcem_memory']['memory'][0].to(x.device) 
        if memory_state.is_sparse:
            memory_state = memory_state.to_dense()

        read_vec, new_usage_vector = self.readWriteHeads.read(
            memory_state=memory_state,
            odict=odict,
            prev_usage_vector=prev_usage_vector,
            k=self.K,
        )
           
        ##
        ##

        # Writing to memory:
        timestep = 1+dnc_state_dict['dcem_body']['prev_timestep'][0].to(vt.device)
        prev_ret_write_weights = dnc_state_dict['dcem_body']['prev_ret_write_weights'][0].to(vt.device)
        #(batch_size x nbrHeads x nbr_mem_slot )
        
        written_memory_state, new_ret_write_weights =self.readWriteHeads.write(
            memory_state=memory_state,
            key_to_write=key_modality_embedding,
            value_to_write=value_modality_embedding,
            read_weights=read_weights,
            discount_factor=self.discount_factor,
            timestep=timestep,
            prev_ret_write_weights=prev_ret_write_weights,
        )

        # updateing frame state:
        dnc_state_dict['dcem_body']['prev_timestep'] = [timestep]
        dnc_state_dict['dcem_body']['prev_ret_write_weights'] = [new_ret_write_weights]
        
 
        ##
        ##

        x = torch.cat([x, key_modality_embedding, value_modality_embedding], dim=-1)
        
        # Controller Outputs :
        # output : batch_dim x hidden_dim
        # state : ( h, c) 
        controller_inputs = [x, dnc_state_dict['dnc_controller']]
        augmented_x, vt, nx, dnc_state_dict['dnc_controller'] = self.controller.forward_controller(controller_inputs)
        
        #wandb.log({f"vt": wandb.Histogram(vt.cpu().detach())})
        #wandblog({f"nx": wandb.Histogram(nx.cpu().detach())})
        
        # clip the controller output
        nx = torch.clamp(nx, -self.clip, self.clip)
        
          
                    # updating frame state:
        dnc_state_dict['dnc_body']['prev_usage_vector'] = [new_usage_vector]
        dnc_state_dict['dnc_body']['prev_write_weights'] = [new_write_weights]
             
        # External Output Function :
        ext_output = self.controller.forward_external_output_fn( 
            vt_output=vt,
            slots_read=read_vec,
        )

        if self.sparse_K!=0:
            written_memory_state = asp(written_memory_state, K=self.sparse_K)
        dnc_state_dict['dnc_memory']['memory'] = [written_memory_state]
        
        frame_states.update({'dnc':dnc_state_dict})
        
        return ext_output, frame_states 

    def get_feature_shape(self):
        return self.output_dim
