Writing Custom Keras Layers

If the existing Keras layers don’t meet your requirements you can create a custom layer. For simple, stateless custom operations, you are probably better off using layer_lambda() layers. But for any custom operation that has trainable weights, you should implement your own layer.

The example below illustrates the skeleton of a Keras custom layer. The mnist_antirectifier example includes another demonstration of creating a custom layer.

KerasLayer R6 Class

To create a custom Keras layer, you create an R6 class derived from KerasLayer. There are three methods to implement (only one of which, call(), is required for all types of layer):

build(input_shape): This is where you will define your weights. Note that if your layer doesn’t define trainable weights then you need not implemented this method.

call(x): This is where the layer’s logic lives. Unless you want your layer to support masking, you only have to care about the first argument passed to call: the input tensor.

compute_output_shape(input_shape): In case your layer modifies the shape of its input, you should specify here the shape transformation logic. This allows Keras to do automatic shape inference. If you don’t modify the shape of the input then you need not implement this method.

It accepts object as its first parameter (the object will either be a Keras sequential model or another Keras layer). The object parameter enables the layer to be composed with other layers using the magrittr pipe (%>%) operator.

It converts it’s output_dim to integer using the as.integer() function. This is done as convenience to the user because Keras variables are strongly typed (you can’t pass a float if an integer is expected). This enables users of the function to write output_dim = 32 rather than output_dim = 32L.

Some additional parameters not used by the layer (name and trainable) are in the function signature. Custom layer functions can include any of the core layer function arguments (input_shape,
batch_input_shape, batch_size, dtype, name, trainable, and weights) and they will be automatically forwarded to the Layer base class.