End-to-end differentiable memory through attention mechanisms.

The Neural Turing Machine (NTM) (Graves, Wayne, and Danihelka 2014) is a memory-augmented neural network architecture introduced by Alex Graves and colleagues from DeepMind in 2014. Although its architectural components are loosely inspired by Alan Turing’s Turing Machine, Graves later mentioned during a lecture that the metaphor has become somewhat strained and perhaps “Neural Von Neumann Machine” would have been a more fitting name as the key idea of the architecture is to decouple memory from computation (Graves 2018). Similar to a CPU interacting with the RAM, a controller network learns to perform read/write operations on a large external memory matrix. This separation aims to alleviate a shortcoming of commonly used recurrent networks, namely that their computational complexity grows exponentially as the hidden state’s capacity is increased. The second important contribution of the paper was its novel use of soft attention mechanisms to index discrete memory locations. Ubiquitous today, attention can be used to make all memory interactions differentiable. This makes the NTM trainable end-to-end using the backpropagation algorithm. The result is an expressive architecture that is capable of learning simple algorithms from input-output-samples. Or framed differently, a differentiable computer, that programs itself based on inputs from its environment. The original paper showed that the NTM outperforms Long Short-Term Memory Networks on a range of memory tasks.

Since the NTM was published different variations have been proposed and new memory-augmented architectures have been developed. Furthermore, over the years these architectures have found their way into the artificial brains of agents in various reinforcement learning settings. In spite of the bells and whistles proposed by newer architectures, I think the Neural Turing Machine remains a great introduction to neural memory mechanisms. The paper is very accessible and introduces important aspects and primitives to think about, such as associative recall and how attention can be used for operating on discrete data-structures in a differentiable manner.

My main goal for this post is to explain the NTM architecture using a bottom-up approach. I will begin each section by going through the theory and the mathematical formulations. I tried to include visual examples to help the reader build an intuitive understanding of the different memory operations. Each theory section is followed by a code interlude, in which we will implement what we have learned using the Pytorch library. By the end of this post, we will have created a working implementation of the Neural Turing Machine. I will conclude with addressing some of the architecture’s shortcomings and present some tricks that can help to stabilise the training process.

The implementation was inspired by existing open source repositories (Ha 2017; Zana 2017). I adopted the modular style, the `HeadBase`

class for shared functionality between heads, the learned initialisation strategy, as well as the hidden state unpacking scheme. I tried to simplify the way a module’s parameters and/or state can be initialised and reset. By introducing appropriate methods for this purpose, it becomes easier to experiment with different initialisations, which, as we will see later, turns out to be essential to make the training process numerically stable.

Figure 1 shows the main components of the Neural Turing Machine. It is comprised of a controller network \(C\) (usually a recurrent model such as an LSTM, but feed-forward is also possible), a memory matrix \(M\) and a set of read and write heads, \(R\) and \(W\). At each time step \(t\), the controller receives an input \(\boldsymbol{x}_t\) from the environment. Using the current input and its previous state \(\boldsymbol{h}_{t-1}\), it then produces parameters or *instructions* for the heads. By following these instructions, a head’s current focus is shifted to specific rows or *slots* of the memory matrix and a read or write operation is executed. Finally, The read-out vectors are routed back to the controller, where they are incorporated into the new state \(\boldsymbol{h}_{t}\) and the final output is generated.

In the following sections we will explore in detail how the read and write operations are defined and how each head produces its attention vector.

Based on the instructions it receives from the controller, a head selects one or more memory locations for reading or writing by attending to the slots in memory to varying degrees^{1}. A head can thereby focus sharply on the memory at a single location or weakly at many locations.

How exactly this attention mechanism works will be discussed in detail in the Section on memory addressing. Before that, we will take a look at the mathematical operations that underlie the reading and writing processes.

Assume that at a given time step \(t\), we have a read head \(R\) and a memory \(M_t\). In our example, the memory has 4 distinct memory slots, each of which stores a 5-dimensional binary vector^{2}. Based on the current input from the controller the read head produces a normalised attention vector \(\boldsymbol{w}^\text{r}_t\) with values between 0 and 1 and a total sum equal to 1. The head determines thereby how much focus should be given to each of the memory’s locations.

Using these weightings, we can define a reading procedure that reads a vector \(\boldsymbol{r}_t\) from memory:

\[ \begin{equation} \boldsymbol{r}_{t} = \sum_{i} \boldsymbol{w}^\text{r}_{t}(i) M_{t}(i) \end{equation} \]

This operation can be visualized in two steps. First, each element of the focus vector \(\boldsymbol{w}^\text{r}_t\) is multiplied with the corresponding row in memory. Second, the rows are added together to produce the *read vector* \(\boldsymbol{r}_t\):

The resulting read vector is simply a weighted sum of the contents of the locations.

Now let’s see how information can be written to the memory.

Inspired by Long Short Term Memory networks (with forget gates), the writing operation is decomposed into two steps: erase and add. At any given time step \(t\), we have a write head \(W\) with attention weightings \(\boldsymbol{w}^\text{w}_t\) and a memory matrix from the previous time step \(M_{t-1}\). Before we can delete information, we need to determine to which degree each of the stored elements should be erased. For this purpose, we introduce the *erase vector* \(\boldsymbol{e}_t\) with the same dimensions as a single memory location and values in the range \([0, 1]\). Each element corresponds to a cell in a given location and indicates how many per cent of the cell’s value should be erased. For example, if an element in \(\boldsymbol{e}_t\) has the value \(1\), the corresponding cell value in the memory location will be completely removed. By multiplying the transpose of the weightings with the erase vector we obtain a matrix with the same dimensions as our memory (in our example \(4 \times 5\)):

Each row of the resulting matrix represents a different version of the original erase vector, scaled by the head’s attention value at that location^{3}. We can interpret this matrix as an *erase filter*. Its contents describe how many per cent of a given memory cell should be removed. By subtracting it from a second matrix \(E\) of ones we turn it into a *remain filter* with the opposite effect.

To conclude the erase step, we simply calculate the element-wise product between our filter and the memory of the previous time step:

Mathematically, we can express the above steps with the following equation, where \(\tilde{M}_{t}\) is the erased memory matrix:

\[\tilde{M}_{t} = M_{t-1} \circ \left[E-{\boldsymbol{w}_t^\text{w}}^{\top} \boldsymbol{e}_t\right]\]

The next step in the writing process is to generate the information with which the old memory should be updated. To achieve this, we introduce a real-valued *add vector* \(\boldsymbol{a}_t\) and again multiply it with the transpose of the attention vector. This operation produces an *update matrix*:

Again, each row of the matrix represents a different version of the original add vector, scaled by the head’s attention value at that location. Finally, the writing operation is completed by adding the update matrix to the erased memory:

The full writing process can be expressed with the following set of equations:

\[ \begin{align} \tilde{M}_{t} &= M_{t-1} \circ \left[E-{\boldsymbol{w}_t^\text{w}}^{\top} \boldsymbol{e}_t\right]\\ M_t &= \tilde{M}_{t} + {\boldsymbol{w}_t^\text{w}}^{\top} \boldsymbol{a}_t \end{align} \] or more compactly:

\[ M_t = M_{t-1} \circ \left[E-{\boldsymbol{w}_t^\text{w}}^{\top} \boldsymbol{e}_t\right]+ {\boldsymbol{w}_t^\text{w}}^{\top} \boldsymbol{a}_t \]

We begin our Neural Turing Machine implementation with the memory module and the reading and writing mechanisms.

We will make use of the following imports:

```
import torch
from torch import nn
import torch.nn.functional as F
from collections import namedtuple
from torch.nn.utils import clip_grad_norm_
```

The memory inherits from `nn.Module`

and is defined by the number of rows, the number of columns and the actual data. The memory contents are represented as a 3-dimensional tensor with shape `(batch_size, num_rows, num_columns)`

. In addition, we define an `init_state`

function that can be used to initialise or reset the memory bank. It will be called when the Neural Turing Machine module is instantiated or reset.

```
class Memory(nn.Module):
def __init__(self, num_rows, num_cols):
super(Memory, self).__init__()
self.num_rows = num_rows
self.num_cols = num_cols
self.data = None
def init_state(self, batch_size, device):
self.data = torch.zeros(batch_size, self.num_rows, self.num_cols).to(device)
```

Let’s now move to the head implementations. Note that both read and write heads inherit from the `HeadBase`

class, which in turn inherits its traits from `nn.Module`

. For now, it is only important to know that the `HeadBase`

class contains functionality that is shared between both read and write heads e.g. addressing mechanisms (covered in the next section).

In order to implement the read operation a read head needs access to two parameters: the weightings `w`

and the memory contents. While `w`

is passed as an argument from outside, the memory bank is accessed via the `memory`

attribute. `w`

is defined as a tensor of shape `(batch_size, num_rows)`

.

```
class ReadHead(HeadBase):
def __init__(self, memory):
super(ReadHead, self).__init__(memory)
def read(self, w):
return torch.matmul(w.unsqueeze(1), self.memory.data).squeeze(1)
```

The `write`

method takes weightings `w`

, erase vectors `e`

and add vectors `a`

as input and updates the memory contents. Both `e`

and `a`

are defined as tensors of size `(batch_size, num_columns)`

.

```
class WriteHead(HeadBase):
def __init__(self, memory):
super(WriteHead, self).__init__(memory)
def erase(self, w, e):
return self.memory.data * (1 - w.unsqueeze(2) * e.unsqueeze(1))
def write(self, w, e, a):
memory_erased = self.erase(w, e)
self.memory.data = memory_erased + (w.unsqueeze(2) * a.unsqueeze(1))
```

At this point, we have seen the mathematical definitions and one possible implementation of the reading and writing operations. One essential question remains to be answered though. How can a head select discrete memory locations in a differentiable manner? The solution to this problem is arguably the key innovation of the Neural Turing Machine architecture. It should be clear from previous sections, that the authors re-framed the discrete indexing problem as one of attention. Under this view, we are now looking for a function that takes parameters (instructions) generated by the controller as input and produces a categorical distribution over all memory locations. The Neural Turing Machine implements this function by combining two complementary addressing mechanisms: *content-based* and *location-based* addressing.

Content-based or *associative* addressing means, that memory locations are selected, whose contents are the most similar to a *key vector* produced by the controller. Location-based addressing allows a head to shift its current focus to adjacent memory slots within a specified range.

Figure 2 shows the 4-step pipeline implemented by each head to map controller outputs to attention vectors. Each sub-step yields intermediate weightings \(\boldsymbol{w}\).

The first step is content-based addressing. The result is an attention vector \(\boldsymbol{w}_t^C\) termed the *content-weightings*. The focus is strongest for locations, whose contents are the most similar to the current input, or the *key*, as measured by the cosine similarity \(D\) between the two vectors:

\[D[\mathbf{u}, \mathbf{v}]=\frac{\mathbf{u} \cdot \mathbf{v}}{\|\mathbf{u}\| \cdot\|\mathbf{v}\|}\]

Similarity in the context of the cosine distance means that the cosine of the angle between two given vectors is small. This variable can take on the value 1 (if the vectors point in the same direction), -1 (if the vectors point in perfectly opposite directions) or any value in between (any other angle). For example, let’s say we have the following key \(\boldsymbol{k}\) and memory matrix \(M\). Comparing the key to each entry in the memory yields a similarity vector.

To attenuate or amplify the focus, the similarity vector is multiplied by a real-valued, positive scalar \(\beta\), which we call the *key strength*.

Finally, to convert the scaled similarity vector into attention weightings we simply apply the softmax function.

We can visualize how the similarity values behave as the key-strength \(\beta\) is increased:

We can observe that for small values of \(\beta\) the focus vector will be diffuse and for large values, the vector will be sharply focused on the row that is most similar to the key.

The main advantage of content-based addressing is how simple it is to retrieve information from memory. The controller merely needs to produce an approximation to the exact, stored item. There are situations, however, where the specific location is more important than the memory contents themselves. This brings us to location-based addressing. But first, let’s take a look at how content-based addressing can be realized in Pytorch.

Since memory addressing is shared by read as well as write heads, the addressing mechanisms will live inside the `HeadBase`

class. The `focus_head`

method will contain the entire memory addressing pipeline (see Figure 2) to produce the weightings for the current time step. Up until this point, however, we have only seen content-based addressing, which is implemented by the `_content_weight`

method. Its parameters are a batch of keys `k`

of size `(batch_size, num_columns)`

and key strengths of size `(batch_size, 1)`

.

```
class HeadBase(nn.Module):
def __init__(self, memory):
super(HeadBase, self).__init__()
self.memory = memory
self.init_params()
def focus_head(self, k, beta):
w_c = self._content_weight(k, beta)
return w_c
def _content_weight(self, k, beta):
k = k.unsqueeze(1).expand_as(self.memory.data)
similarity_scores = F.cosine_similarity(k, self.memory.data, dim=2)
w = F.softmax(beta * similarity_scores, dim=1)
return w
def forward(self, h):
raise NotImplementedError
def init_state(self, batch_size):
self.batch_size = batch_size
def init_params(self):
pass
```

For some computational processes we do not care about the actual stored information but it is crucial to consistently access a specific memory location. For example when we compute a simple function like \(f(x, y) = xy\) we are indifferent about the actual values of \(x\) and \(y\), but the variables need to be accessed at consistent locations in memory. For this purpose, we implement a mechanism that enables the Neural Turing Machine to shift its heads to specific memory slots, that lie within a given shift-range from its current position. Sequential shifts at each time step allow the network to implement basic looping constructs.

As shown in Figure 2, the addressing mechanism is facilitated by three distinct steps: *Gated Interpolation*, *Convolutional Shift* and *Sharpen*. In the following sections, we are going to look at each of these in separation. As a quick aside before we continue, note that strictly speaking associative addressing is more general than location-based addressing. This is because location information could be stored alongside the actual contents in the memory. In the original paper, however, the authors mentioned that providing location-based addressing as a primitive operation proved essential for some forms of generalisation (Graves, Wayne, and Danihelka 2014).

Gated interpolation controls to what extent the content-based addressing mechanism should be used. The controller emits a parameter \(g\) in the range \((0,1)\), which we will call the *interpolation gate*. Using \(g\), we calculate new weightings as an interpolation between the final attention vector of the previous time step \(\boldsymbol{w}_{t-1}\) and the current content-weightings \(\boldsymbol{w}_t^C\):

\[ \begin{equation} \boldsymbol{w}_{t}^{g} = g_{t} \boldsymbol{w}_{t}^{c}+\left(1-g_{t}\right) \boldsymbol{w}_{t-1} \end{equation} \]

Note that if \(g\) is \(0\) at a given time step the content-weightings \(\boldsymbol{w}_t^C\) will be the zero vector and hence, entirely ignored. Similarly, if \(g\) is \(1\) the weightings of the previous time step \(\boldsymbol{w}_{t-1}\) won’t influence the intermediate weightings \(\boldsymbol{w}_t^g\) for further calculations. If \(g\) is a value between \(0\) and \(1\) exclusive, the two vectors will be *blended* together in proportion to \(g\).

In this step, we introduce the central mechanism that enables a head to shift its current focus and attend to adjacent memory locations.

The shift operation is implemented as a one-dimensional, circular convolution, where a kernel \(\boldsymbol{s}_t\) emitted by the controller is convolved over the head’s attention vector. \(\boldsymbol{s}_t\) represents a normalised distribution over the allowed integer shifts and you can think of it as *shift-instructions*. For example, if only shifts by one position are allowed, \(\boldsymbol{s}_t\) will be a vector comprised of three elements, each of which can be interpreted as one of the following instructions:

\[(\texttt{shift 1 forward},\quad \texttt{maintain focus},\quad \texttt{shift 1 backward})\]

Of course, we can have a shift kernel with an allowed shift range greater than one. In the general case, the weightings will have \(2n+1\) elements, where \(n\) is the highest absolute shift value. For instance, for a maximum shift range of two locations, \(\boldsymbol{s}_t\) will now be a \(5\)-dimensional vector:

\[(\texttt{shift +2},\quad \texttt{shift +1},\quad \texttt{shift 0},\quad \texttt{shift -1}\quad \texttt{shift -2})\] The shifted attention vector \(\tilde{\boldsymbol{w}}_{t}\) can be obtained using:

\[ \begin{equation} \tilde{\boldsymbol{w}}_{t}(i) = \sum_{j=0}^{N-1} \boldsymbol{w}_{t}^{g}(j) \boldsymbol{s}_{t}(i-j) \end{equation} \]

First, we unroll the current attention vector \(\boldsymbol{w}_{t}^{g}\) based on the specified maximum shift value in a circular manner. This means, that we append the last element to the beginning and the first element to the end of the sequence. For a maximum shift value \(n > 1\), the sequence is padded with the first \(n\) elements at the tail and the last \(n\) elements at the head. In this way, if a head is focussed on the last row of the memory and performs a forward shift by one position, its focus will move back to the first memory location. Analogously, a backward shift from the first slot, will shift the head’s focus to the end of the memory bank.

Finally, we convolve the shift weightings over the unrolled sequence.

Let’s illustrate this with an example. The maximum shift range is set to 1 position and we have the following attention vector \(\boldsymbol{w}_t^g\) and shift weightings \(\boldsymbol{s}_t\) (note that this filter implements a backward-shift):

Our next step is to pad the sequence in a circular manner. We append the last element to the head and the first element to the tail:

Now we convolve the shift weightings \(\boldsymbol{s}_t\) over the unrolled weightings to calculate the shifted weightings \(\tilde{\boldsymbol{w}}_t\):

Compare \(\boldsymbol{w}_{t}^{g}\) with \(\tilde{\boldsymbol{w}}_t\). We can observe, that the head’s attention has shifted one position backwards.

We have seen now how we can leverage the circular convolution operation to shift a head’s focus. There is, however, one more aspect of this implementation that we need to consider. Remember that the shift weightings are a normalised distribution over the possible shift instructions. This distribution, however, does not need to be sharply focussed on only a single instruction. How does an unfocussed shift-kernel shift the attention on the memory locations? Assume that we have a maximum shift range of one, a sharply focussed attention vector \(\boldsymbol{w}^g_t = (1, 0, 0, 0)\) (four memory slots) and slightly dispersed shift weightings \(\boldsymbol{s}_t = (0.7, 0.1, 0.2)\):

After we apply our shift operation, we obtain the following shifted focus vector \(\tilde{\boldsymbol{w}}_t\):

By inspecting the new attention vector we can see that the focus has been dispersed. This gives us a new insight into how our shift mechanism operates. We can observe that \(70\%\) of a given element’s value have been shifted one position forward, \(20\%\) one position backwards and \(10\%\) remain at the original position. Therefore, each dimension in \(\boldsymbol{s}_t\) specifies the amount of an element’s value that is subject to the instruction at that dimension. For example, the first dimension in \(s_t = (0.7, 0.1, 0.2)\) can be interpreted as “For each element in \(\boldsymbol{w}\), shift 70% of its original value one position forward”. The main takeaway here is that the shift weightings represent how a head’s focus should be distributed over memory locations after the shift operation and that unfocussed shift weightings lead to dispersions in the resulting attention vector.

To reduce the dispersions caused by the shift operation, each head receives a scalar parameter \(\gamma \geq 1\) that controls the sharpness of the weightings. The final focus vector \(\boldsymbol{w}_t\) is normalised as follows:

\[ \begin{equation} \boldsymbol{w}_{t}(i) = \frac{\tilde{\boldsymbol{w}}_{t}(i)^{\gamma_{t}}}{\sum_{j} \tilde{\boldsymbol{w}}_{t}(j)^{\gamma_{t}}} \end{equation} \]

We can visualize the effect on the memory weightings as \(\gamma\) is increased:

The larger \(\gamma\) is, the more focussed the final attention vector will be.

Let us now build out the location-based addressing mechanism. Each step we have seen above will be implemented as a separate method. All of these will then be called successively by the `focus_head`

method to produce the final weightings.

Gated interpolation takes weightings from the previous time step `prev_w`

with shape `(batch_size, num_rows)`

and the interpolation gate `g`

with shape `(batch_size, 1)`

as input and blends `prev_w`

and the content-weightings together.

```
class HeadBase(nn.Module):
...
def focus_head(self, k, beta, g):
w_c = self._content_weight(k, beta)
w_g = self._gated_interpolation(w_c, prev_w, g)
return w_g
...
def _gated_interpolation(self, w, prev_w, g):
return g*w + (1-g)*prev_w
...
```

Note that for the convolutional shift we introduce a new attribute `max_shift`

inside the constructor. This integer value indicates the maximum number of positions a head can shift at each time step. The shifted weightings are computed using the `_mod_shift`

method. The shift weightings `s`

are passed to the method as a tensor with shape `(batch_size, 2 * max_shift + 1)`

. As Pytorch requires 3-dimensional shapes for convolutional filters, we transform the tensor to be of shape `(batch_size, 1, 2 * max_shift + 1)`

. Next, we apply the filters to a batch of unrolled weightings. Because every filter is applied to every sequence, we obtain `batch_size`

outputs for each input sequence i.e. output shape: `(batch_size, batch_size, num_columns)`

. We are, however, only interested in the output sequence of the \(i\)-th filter for the \(i\)-th unrolled sequence. For this reason, we need to filter the results with `[range(self.batch_size), range(self.batch_size)]`

. The method returns a batch of shifted weightings with shape `(batch_size, num_columns)`

.

```
class HeadBase(nn.Module):
def __init__(self, memory, hidden_size, max_shift):
super(HeadBase, self).__init__()
self.memory = memory
self.hidden_size = hidden_size
self.max_shift = max_shift
def focus_head(self, k, beta, prev_w, g, s):
w_c = self._content_weight(k, beta)
w_g = self._gated_interpolation(w_c, prev_w, g)
w_s = self._mod_shift(w_g, s)
return w_g
...
def _mod_shift(self, w, s):
unrolled = torch.cat([w[:, -self.max_shift:], w, w[:, :self.max_shift]], 1)
return F.conv1d(unrolled.unsqueeze(1),
s.unsqueeze(1))[range(self.batch_size), range(self.batch_size)]
...
```

To complete the `focus_head`

method we implement `_sharpen`

. In addition to the weightings, it takes the parameter `gamma`

of shape `(batch_size, 1)`

as input and computes the final attention vector `w`

.

```
class HeadBase(nn.Module):
...
def focus_head(self, k, beta, prev_w, g, s, gamma):
w_c = self._content_weight(k, beta)
w_g = self._gated_interpolation(w_c, prev_w, g)
w_s = self._mod_shift(w_g, s)
w = self._sharpen(w_s, gamma)
return w
...
def _sharpen(self, w, gamma):
w = w.pow(gamma)
return torch.div(w, w.sum(1).view(-1, 1) + 1e-16)
...
```

Let’s do a quick recap and review the capabilities of the memory addressing mechanism. By combining the different operations, the network can operate in three complementary modes. First, the addressing-system can access contents purely based on the similarity with the produced key, without using location-based addressing. Second, content-based weightings can be chosen and shifted by a given number of locations. What this means in computational terms is that a head finds a block of data in memory and subsequently shifts its focus to access a particular element within that block. Third, a weighting from the previous time-step can be chosen and rotated without making use of the content-addressing mechanism. This enables the system to iterate over a sequence of memory locations by performing sequential shifts.

To sum up, we have described an addressing system that allows the network to learn how to produce attention vectors to access memory locations for read/write operations. All the mechanisms are differentiable, which makes the system trainable end-to-end.

In this code interlude, we are going to finish the read and write head implementations, build a simple LSTM-based controller and create the `NTMCell`

class that encapsulates all of our modules.

We need to implement two important functionalities. First, each head needs a way to extract the memory addressing parameters from the controller’s output. Second, we need to define the unique forward passes for both read and write heads. We implement the first point by defining a `hidden_state_unpacking_scheme`

for each head. This method returns a list of tuples, each of which describes the length and the activation function for a given parameter. The activation functions ensure that all parameters lie within the correct numerical range. Each head is equipped with a fully-connected layer to transform the input from the controller into an internal representation, whose length is equal to the sum of the lengths of all parameters defined in the unpacking scheme. We will implement a shared method `unpack_hidden_state`

in the `HeadBase`

class, that allows a head to split its internal vector into a tuple of addressing parameters. Let’s now look at the individual `forward`

passes.

At every time step a read head takes a batch of hidden vectors `h`

(the controller’s instructions) with shape `(batch_size, hidden_size)`

and the previous weightings `prev_w`

as input. It then transforms the hidden vectors using an internal fully-connected layer `fc`

(defined in `HeadBase`

). Next, it uses `unpack_hidden`

to extract the individual parameters and uses these to produce new weightings `w`

by calling `focus_head`

. After the head’s focus is determined, it calls the `read`

to retrieve a read vector with shape `(batch_size, num_cols)`

from memory. Finally, the current weightings and the read-out vectors are returned.

We initialise a head’s state using the `init_state`

method. For a read head, it returns the initial read and attention vectors. The attention weightings are set to focus sharply on the first memory location.

```
class ReadHead(HeadBase):
def __init__(self, memory, hidden_sz, max_shift):
super(ReadHead, self).__init__(memory, hidden_sz, max_shift)
def hidden_state_unpacking_scheme(self):
return [
# size, activation-function
(self.memory.num_cols, torch.tanh), # k
(1, F.softplus), # β
(1, torch.sigmoid), # g
(2*self.max_shift+1, lambda x: F.softmax(x, dim=1)), # s
(1, lambda x: 1 + F.softplus(x)) # γ
]
def read(self, w):
return torch.matmul(w.unsqueeze(1), self.memory.data).squeeze(1)
def forward(self, h, prev_w):
k, beta, g, s, gamma = self.unpack_hidden_state(self.fc(h))
w = self.focus_head(k, beta, prev_w, g, s, gamma)
read = self.read(w)
return read, w
def init_state(self, batch_size, device):
self.batch_size = batch_size
reads = torch.zeros(batch_size, self.memory.num_cols).to(device)
read_focus = torch.zeros(batch_size, self.memory.num_rows).to(device)
read_focus[:, 0] = 1.
return reads, read_focus
```

The write head follows the same principle but its unpacking scheme now also includes the erase vector `e`

and add vector `a`

, both of shape `(batch_size, num_cols`

). After extracting the parameters, the head calls the `write`

method to update the previous memory. Finally, the head returns the weightings used for the writing operation.

For write heads, the `init_state`

method only returns the head’s initial attention vectors. Again, each of these is set to focus sharply on the first memory location.

```
class WriteHead(HeadBase):
def __init__(self, memory, hidden_sz, max_shift):
super(WriteHead, self).__init__(memory, hidden_size, max_shift)
def hidden_state_unpacking_scheme(self):
return [
# size, activation-function
(self.memory.num_cols, torch.tanh), # k
(1, F.softplus), # β
(1, torch.sigmoid), # g
(2*self.max_shift+1, lambda x: F.softmax(x, dim=1)), # s
(1, lambda x: F.softplus(x) + 1), # γ
(self.memory.num_cols, torch.sigmoid), # e
(self.memory.num_cols, torch.tanh) # a
]
def erase(self, w, e):
return self.memory.data * (1 - w.unsqueeze(2) * e.unsqueeze(1))
def write(self, w, e, a):
memory_erased = self.erase(w, e)
self.memory.data = memory_erased + (w.unsqueeze(2) * a.unsqueeze(1))
def forward(self, h, prev_w):
k, beta, g, s, gamma, e, a = self.unpack_hidden_state(self.fc(h))
w = self.focus_head(k, beta, prev_w, g, s, gamma)
self.write(w, e, a)
return w
def init_state(self, batch_size, device):
self.batch_size = batch_size
write_focus = torch.zeros(batch_size, self.memory.num_rows).to(device)
write_focus[:, 0] = 1.
return write_focus
```

Here we show the final implementation of the `HeadBase`

class. Take a look at the `unpack_hidden_state`

method to see how the individual parameters are extracted from the hidden vector and how the appropriate activation functions are applied.

```
class HeadBase(nn.Module):
def __init__(self, memory, hidden_size, max_shift):
super(HeadBase, self).__init__()
self.memory = memory
self.hidden_size = hidden_size
self.max_shift = max_shift
self.fc = nn.Linear(hidden_size,
sum([s for s, _ in self.hidden_state_unpacking_scheme()]))
self.init_params()
def forward(self, h):
raise NotImplementedError
def hidden_state_unpacking_scheme():
raise NotImplementedError
def unpack_hidden_state(self, h):
chunk_idxs, activations = zip(*self.hidden_state_unpacking_scheme())
chunks = torch.split(h, chunk_idxs, dim=1)
return tuple(activation(chunk) for chunk, activation in zip(chunks, activations))
def focus_head(self, k, beta, prev_w, g, s, gamma):
w_c = self._content_weight(k, beta)
w_g = self._gated_interpolation(w_c, prev_w, g)
w_s = self._mod_shift(w_g, s)
w = self._sharpen(w_s, gamma)
return w
def _content_weight(self, k, beta):
k = k.unsqueeze(1).expand_as(self.memory.data)
similarity_scores = F.cosine_similarity(k, self.memory.data, dim=2)
w = F.softmax(beta * similarity_scores, dim=1)
return w
def _gated_interpolation(self, w, prev_w, g):
return g*w + (1-g)*prev_w
def _mod_shift(self, w, s):
unrolled = torch.cat([w[:, -self.max_shift:], w, w[:, :self.max_shift]], 1)
return F.conv1d(unrolled.unsqueeze(1),
s.unsqueeze(1))[range(self.batch_size), range(self.batch_size)]
def _sharpen(self, w, gamma):
w = w.pow(gamma)
return torch.div(w, w.sum(1).view(-1, 1) + 1e-16)
def init_state(self, batch_size):
raise NotImplementedError
def init_params(self):
pass
```

The controller could be either a recurrent or a fully-connected neural network. As an example, I use a simple recurrent controller based on stacked LSTM-cells. The controller takes the concatenation of the current inputs `xb`

and the previous read-vectors `prev_reads`

as input (see implementation of `NTMCell`

) and returns a hidden state `h`

of shape `(batch_size, hidden_size)`

. This hidden state is used to parametrise the heads. Moreover, it returns a list of tuples `hidden_states`

containing the hidden and cell states of each recurrent layer.

The method `init_state`

returns the intial hidden and cell states for the stacked-LSTM.

```
class LSTMCellController(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.fc = nn.Linear(input_size, hidden_size)
self.cells = nn.ModuleList(
[nn.LSTMCell(hidden_size, hidden_size) for _ in range(num_layers)]
)
self.init_params()
def forward(self, xb, prev_hidden_states):
hidden_states = []
h = self.fc(xb)
for i, cell in enumerate(self.cells):
h, c = cell(h, prev_hidden_states[i])
hidden_states += [(h, c)]
return h, hidden_states
def init_state(self, batch_size, device):
h = torch.zeros(self.num_layers, self.batch_size, self.hidden_size).to(device)
c = torch.zeros(self.num_layers, self.batch_size, self.hidden_size).to(device)
return list(zip(h, c))
def init_params(self):
pass
```

Finally, we can finish our implementation by defining the `NTMCell`

class, which will encapsulate all the modules we have created so far. Besides computing the output, its main function is to instantiate the memory, the heads and the controller. The NTM’s `state`

is represented as a named tuple and includes the read vectors returned by the read heads, the focus weights used by both the read and write heads and the hidden and cell states of the LSTM-controller. Let us now take a look at the `forward`

method. At each time step, the NTM-cell receives inputs `xb`

with shape `(batch_size, input_size)`

. If no input is specified, the NTM will receive a batch of zero vectors. The first step is to unpack the previous state. Next, we concatenate the previous read vectors `prev_reads`

with the current input `xb`

. These sequences are then passed to the controller, which returns its hidden state (the parameters for the heads) and a list of tuples storing the hidden and cell state of each layer. In the second step, we iterate over the list of heads and each head executes a read or write operation. Attention weightings for read and write heads as well as read vectors are collected in separate lists. These lists, together with the controller’s hidden and cell states, are then stored in a new state tuple. In a final step, the hidden state returned by the controller is concatenated with the current read-out vectors and passed through a fully-connected layer to create the final output.

Have a look at the `init_state`

method, to see how the NTM initialises the states of its submodules.

```
class NTMCell(nn.Module):
def __init__(self,
input_size,
hidden_size,
output_size,
memory_num_rows,
memory_num_cols,
controller_num_layers=3,
num_heads=1,
max_shift=1):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.num_heads = num_heads
self.max_shift = max_shift
self.memory_num_rows = memory_num_rows
self.memory_num_cols = memory_num_cols
self.state_container = namedtuple('state', [
'read_vectors', 'read_focus_weights', 'write_focus_weights', 'hidden_states'
])
self.batch_size = None
# module instantiations
self.memory = Memory(memory_num_rows, memory_num_cols)
self.controller_num_layers = controller_num_layers
self.controller = LSTMCellController(input_size + num_heads * memory_num_cols, hidden_size, controller_num_layers)
# create heads
self.read_heads = nn.ModuleList([
ReadHead(self.memory, hidden_size, max_shift)
for _ in range(num_heads)
])
self.write_heads = nn.ModuleList([
WriteHead(self.memory, hidden_size, max_shift)
for _ in range(num_heads)
])
self.fc = nn.Linear(hidden_size + num_heads * memory_num_cols, output_size)
self.init_params()
def forward(self, xb=None):
if xb is None:
xb = torch.zeros(self.batch_size, self.input_size).to(self.device)
# unpacking the previous state
prev_reads = self.state.read_vectors
prev_hiddens = self.state.hidden_states
prev_read_foci = self.state.read_focus_weights
prev_write_foci = self.state.write_focus_weights
xb = torch.cat([xb, *prev_reads], dim=1)
# controller output
h, hidden_states = self.controller(xb, prev_hiddens)
# read and write
reads = []
read_foci = []
write_foci = []
for i, (read_head, write_head) in enumerate(zip(self.read_heads, self.write_heads)):
read, read_focus = read_head(h, prev_read_foci[i]) # read
write_focus = write_head(h, prev_write_foci[i]) # write
reads += [read]
read_foci += [read_focus]
write_foci += [write_focus]
# pack new state
self.state = self.state_container(reads, read_foci, write_foci, hidden_states)
# output
out = torch.cat([h, *reads], dim=1)
return self.fc(out)
def init_params(self):
pass
def init_state(self, batch_size, device):
self.batch_size = batch_size
self.device = device
# Initialize the memory
self.memory.init_state(batch_size, device)
# hidden state init
hidden_states = self.controller.init_state(batch_size, device)
# init heads and collect initial read vectors, read foci and write foci
reads = []
read_foci = []
write_foci = []
for rh, wh in zip(self.read_heads, self.write_heads):
read, read_focus = rh.init_state(batch_size, device)
write_focus = wh.init_state(batch_size, device)
reads += [read]
read_foci += [read_focus]
write_foci += [write_focus]
# pack the initial state
self.state = self.state_container(reads, read_foci, write_foci, hidden_states)
```

That’s the entire architecture! Before you can use the NTM, however, you will have to initialise the network’s weights. Every module containing learnable parameters has a method `init_params`

that is called when the module is instantiated. You can use this to experiment with different parameter initialisations in a modular way.

Finally, the `NTMCell`

can be initialised and reset in the following way:

```
ntm = NTMCell(input_size,
hidden_size,
output_size,
memory_num_rows,
memory_num_cols,
controller_num_layers,
num_heads,
max_shift)
ntm = ntm.to(device)
ntm.init_state(batch_size, device) # init state of all modules
...
# reset within training loop
ntm.init_state(batch_size, device)
```

The main shortcoming of the Neural Turing Machine and arguably the one that prevents mainstream adoption is that the architecture is notoriously difficult to train. Even though the writing mechanism makes the architecture more powerful than read-only systems, learning how to write information to and read back from memory creates an extremely difficult coupled learning problem (Graves 2018). In the next section, we will look at some tricks that can help to mitigate numerical instabilities.

Another problem is that temporal information is neglected. Think about how hearing a specific song or smelling a particular smell takes you back to a related moment in the past. This association is not an isolated memory. We usually also remember a kind of temporal window around the retrieved moment (Graves 2018). Even though, our shift mechanism can move a head’s focus to adjacent memory locations, this does not automatically imply that there exists a temporal link between the memories.

The Differentiable Neural Computer (Graves et al. 2016) is an architecture that tries to resolve some of these shortcomings. Among other improvements, it replaces the shift mechanism with a temporal link matrix and extends the NTM by a differentiable free list to allocate and deallocate memory.

In this last code interlude, I will list some tricks that can help to make the training of the NTM numerically stable. Many of these were adopted from a paper by Collier et. al.(Collier and Beel 2018), as well as existing open source implementations.

In my experience, learning the initial values of the hidden and cell state, the initial read vector and the initial focus of read/write-heads often helped to stabilise the training process.

To learn the initial hidden and cell states of an LSTM controller, we add two new tensors `h_bias`

and `c_bias`

to our constructor. These tensors hold trainable parameters that are updated during training. When the NTM is instantiated or reset, `init_state`

is called. The bias tensors are extended to entire batches and moved to the appropriate device. Additionally, it can be helpful to scale-down the initial random biases using a small value `alpha`

e.g. 0.05.

```
class LSTMCellController(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
...
self.h_bias = nn.Parameter(torch.rand(self.num_layers, 1, self.hidden_size) * alpha)
self.c_bias = nn.Parameter(torch.rand(self.num_layers, 1, self.hidden_size) * alpha)
...
def init_state(self, batch_size, device):
h = self.h_bias.clone().repeat(1, batch_size, 1).to(device)
c = self.c_bias.clone().repeat(1, batch_size, 1).to(device)
return list(zip(h, c))
```

We can follow the same approach to learn the initialisation of read and focus vectors. The only difference is that we need to apply the softmax activation function to the attention vectors as these need to be normalised.

```
class ReadHead(HeadBase):
def __init__(self, memory, hidden_size, max_shift):
super(ReadHead, self).__init__(memory, hidden_sz, max_shift)
self.read_bias = nn.Parameter(torch.randn(1, self.memory.num_cols))
self.read_focus_bias = nn.Parameter(torch.randn(1, self.memory.num_rows))
...
def init_state(self, batch_size, device):
self.batch_size = batch_size
reads = self.read_bias.clone().repeat(batch_size, 1).to(device)
read_focus = self.read_focus_bias.clone().repeat(batch_size, 1).to(device)
return reads, torch.softmax(read_focus, dim=1)
```

And again for the initial attention vector for the write heads…

```
class WriteHead(HeadBase):
def __init__(self, memory, hidden_size, max_shift):
super(WriteHead, self).__init__(memory, hidden_sz, max_shift)
self.write_focus_bias = nn.Parameter(torch.rand(1, self.memory.num_rows))
...
def init_state(self, batch_size, device):
self.batch_size = batch_size
write_focus = self.write_focus_bias.clone().repeat(batch_size, 1).to(device)
return torch.softmax(write_focus, dim=1)
```

For the memory state, we deviate from our learned initialisation strategy. Collier et. al. (Collier and Beel 2018) showed that on a subset of the original algorithmic tasks the best approach was to initialise the memory contents with a small constant value e.g. \(10^{-6}\). If the memory matrix is large, constant initialisation reduces the number of trainable parameters substantially.

```
class Memory(nn.Module):
def __init__(self, num_rows, num_cols):
super(Memory, self).__init__()
self.num_rows = num_rows
self.num_cols = num_cols
self.mem_bias = torch.Tensor().new_full((num_rows, num_cols), 1e-6)
def init_state(self, batch_size, device):
self.data = self.mem_bias.clone().repeat(batch_size, 1, 1).to(device)
```

Two more tricks we can use to make the training process more stable is to clip the gradients and initialise the bias of the forget-gate to a positive value \(\geq 1\). Using larger positive values for the bias pushes the forget-gate’s sigmoid activation very close to one. This has the effect that at the beginning of the training phase the model degenerates into a standard LSTM. It will not explicitly forget anything until it has learned to forget (Gers, Schmidhuber, and Cummins 1999). Although Pytorch does not have an explicit way to set the bias of the forget-gate, we can create our own function by exploiting the fact that the bias vector of an `LSTMCell`

(or `LSTM`

) layer has the structure `[bias_ig | bias_fg | bias_gg | bias_og]`

.

```
def set_forget_gate_(tensor, val=1.0):
"""
Initialise the biases of the forget gate to `val`, and all other gates to 0.
"""
# gates are (b_hi|b_hf|b_hg|b_ho) of shape (4*hidden_size)
tensor.data.zero_()
hidden_size = tensor.shape[0] // 4
tensor.data[hidden_size:(2 * hidden_size)] = val
```

We can use this function within our LSTM’s `init_params`

method to set the bias value.

```
class LSTMCellController(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
...
self.cells = nn.ModuleList(
[nn.LSTMCell(hidden_size, hidden_size) for _ in range(num_layers)]
)
...
self.init_params()
...
def init_params(self):
...
for param in self.cells.parameters():
if param.dim() < 2:
set_forget_gate_(param, 1.0)
else:
...
```

Both reference papers (Graves, Wayne, and Danihelka 2014; Collier and Beel 2018) employed gradient clipping to prevent spikes in the gradients. In Pytorch gradient clipping can be implemented inside the training loop using the `clip_grad_norm_`

function:

```
loss.backward()
clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
optimizer.zero_grad()
```

In this post, we have learned about the Neural Turing Machine, a differentiable computer, that can learn simple programs using inputs from its environment. We covered both the mathematical definitions and implemented the architecture from scratch using the Pytorch library. I hope you leave with an intuitive understanding of the main architectural components and memory access mechanisms. As I get deeper into the literature, I will cover more architectures, both historical and contemporary. Moreover, I want to complement my usual readings in artificial intelligence and machine learning with more related research from neuroscience. Finally, I am planning to experiment with memory-enhanced agents in different reinforcement learning settings, with the goal of discovering useful memory priors.

If you have questions or find errors please reach out on Twitter or send an email to schmidinger.n AT gmail.com.

Collier, Mark, and Joeran Beel. 2018. “Implementing Neural Turing Machines.” *Lecture Notes in Computer Science*, 94–104. https://doi.org/10.1007/978-3-030-01424-7_10.

Gers, Felix A, Jürgen Schmidhuber, and Fred Cummins. 1999. “Learning to Forget: Continual Prediction with Lstm.”

Graves, Alex. 2018. “Advanced Deep Learning & Reinforcement Learning Lecture 17: Attention and Memory in Deep Learning.” *YouTube*. DeepMind. https://www.youtube.com/watch?v=Q57rzaHHO0k&list=PLqYmG7hTraZDNJre23vqCGIVpfZ_K2RZs&index=15.

Graves, Alex, Greg Wayne, and Ivo Danihelka. 2014. “Neural Turing Machines.” https://arxiv.org/abs/1410.5401.

Graves, Alex, Greg Wayne, Malcolm Reynolds, Tim Harley, Ivo Danihelka, Agnieszka Grabska-Barwińska, Sergio Gómez Colmenarejo, et al. 2016. “Hybrid Computing Using a Neural Network with Dynamic External Memory.” *Nature* 538 (7626): 471.

Ha, Junsoo. 2017. “Neural Turing Machines Pytorch Implementation.” *GitHub Repository*. GitHub. https://github.com/kuc2477/pytorch-ntm.

Zana, Guy. 2017. “PyTorch Neural Turing Machine (Ntm).” *GitHub Repository*. GitHub. https://github.com/loudinthecloud/pytorch-ntm.

The way how the heads are parameterised by the outputs of the controller is an important lesson for developing neural architectures. First, come up with a functional form for a desired mechanism and then let another network generate parameters for the given function.↩︎

Note that I chose binary vectors for my examples because intermediate values are easy to calculate in one’s mind and patterns of calculations become more apparent visually. In practice, however, the memory will usually hold real-valued vectors.↩︎

Note that to fully erase a cell, a head needs to sharply focus on a single memory location and the corresponding element of \(\boldsymbol{e}\) needs to be \(1\).↩︎

For attribution, please cite this work as

Schmidinger (2019, Dec. 25). Niklas Schmidinger: Exploring Neural Turing Machines. Retrieved from www.niklasschmidinger.com/posts/2019-12-25-neural-turing-machines/

BibTeX citation

@misc{schmidinger2019exploring, author = {Schmidinger, Niklas}, title = {Niklas Schmidinger: Exploring Neural Turing Machines}, url = {www.niklasschmidinger.com/posts/2019-12-25-neural-turing-machines/}, year = {2019} }