motornet.nets.callbacks#

This module implements tensorflow.keras.callbacks.Callback subclasses. For more information from the tensorflow package on what callbacks are and what they can achieve, feel free to refer to the tensorflow documentation on callbacks at: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback.

class motornet_tf.nets.callbacks.BatchLogger#

Bases: Callback

In tensorflow, the default callbacks log performance metrics at the end of each epoch. This callback logs performance metrics at the end of each batch instead, providing more frequent logging of training performance.

on_batch_end(batch, logs=None)#

Called at the end of each batch. This saves the performance metrics contained in logs, as well as the model weights.

on_train_begin(logs=None)#

Called at the beginning of training. This should only be called on train mode. Logs the initial model weights.

Parameters:

logsDictionary, currently no data is passed to this argument for this method.

class motornet_tf.nets.callbacks.BatchwiseLearningRateScheduler(scheduler, verbose=0)#

Bases: LearningRateScheduler

The parent class adjusts the learning rate at the start of each epoch. This subclass instead applies the learning rate adjustement routine at the end of each batch, allowing for more frequent learning rate adjustements to occur.

Parameters:
  • scheduler – A (potentially custom-made) function that takes a batch index (integer, indexed from 0) and current learning rate (float) as inputs and returns a new learning rate as output (float).

  • verboseInteger, either 0 (quiet) or 1 (update messages toggled on).

on_batch_end(batch, logs=None)#

Called at the end of each batch. The learning rate adjustement routine implemented in the on_epoch_begin method for the parent class is now implemented here instead.

Parameters:
  • batchInteger, index of batch.

  • logsDictionary, currently no data is passed to this argument for this method.

on_epoch_begin(epoch, logs=None)#

Called at the start of each epoch. This method is empty to overwrite the learning rate adjustement routine at the beginning of each epoch.

class motornet_tf.nets.callbacks.TensorflowFix#

Bases: Callback

This callback implements a fix for saving some tensorflow objects. See this github issue for more details. https://github.com/tensorflow/tensorflow/issues/42872

Note

As of tensorflow nightly 2.6, this issue has been fixed. This callback is kept for backward compatibility.

on_train_batch_end(batch, logs=None)#

Called at the end a batch during training. This should only be called on train mode.

Parameters:
  • batchInteger, index of batch.

  • logsDictionary, currently no data is passed to this argument for this method.

on_train_begin(logs=None)#

Called at the beginning of training. This should only be called on train mode.

Parameters:

logsDictionary, currently no data is passed to this argument for this method.

class motornet_tf.nets.callbacks.TrainingPlotter(task, plot_freq=20, plot_n_t=100, plot_loss=True, plot_trials=3)#

Bases: Callback

This callback plots the loss history and/or some test trials every plot_freq batches, to help monitor the model’s behaviour during the training session.

If plot_loss is toggled on, the losses will be displayed, including the total loss and contributing losses.

If test trials are toggled on (plot_trials > 0), cartesian position, muscle activation, muscle velocity and network unit activity will be displayed, each in their own subplot.

Parameters:
  • taskmotornet.tasks.Task class or subclass, corresponding to the Task object that the Network controller is given.

  • plot_freqInteger, indicating the number of batches after which plotting will occur.

  • plot_n_tInteger, indicating the number of timesteps used for plotting test trials. This argument is ignored if plot_trials is set to 0.

  • plot_lossBool., whether or not to plot loss values.

  • plot_trialsInteger, how many trials to simulate and plot each time.

on_batch_end(batch, logs=None)#

Plot losses, as well as test trials if this option is toggled on at initialization.

Parameters:
  • batchInteger, index of batch.

  • logsDictionary, currently no data is passed to this argument for this method.

on_train_begin(logs=None)#

Called at the beginning of training. This should only be called on train mode. Keeps track of which batch was last visited (if there was a previous training session). This allows plotting training information of previous training sessions even on subsequent training sessions.

This routine is implemented here instead of in the on_training_end() method to ensure logging occurs even if the user interrupts an ongoing training session.

Parameters:

logsDictionary, currently no data is passed to this argument for this method.