Note
Go to the end to download the full example code
Adaptive Cursor Control Mapping
This example contains a 2D cursor-to-target task which processes input signals from a data acquisition device (e.g. from EMG hardware) and adaptively learns a linear mapping from input magnitude to cursor position via the recursive least squares (RLS) algorithm.
Once the cursor interface is shown, press the “Enter” key to begin. The target will move to some location on the screen and the subject should attempt to move the cursor toward the target. As input data is collected, the recursive least squares algorithm updates the weights of a linear mapping from input amplitude to cursor position. Once this training procedure is finished, the target changes color and the subject can attempt to hit the targets with the mapping now fixed.
import numpy
import random
from scipy.signal import butter
from axopy import pipeline
from axopy.features import mean_absolute_value
from axopy.experiment import Experiment
from axopy import util
from axopy.task import Task, Oscilloscope
from axopy.daq import NoiseGenerator
from axopy.timing import Counter
from axopy.gui.canvas import Canvas, Circle, Cross
class RLSMapping(pipeline.Block):
"""Linear mapping of EMG amplitude to position updated by RLS.
Parameters
----------
m : int
Number of vectors in the mapping.
k : int
Dimensionality of the mapping vectors.
lam : float
Forgetting factor.
"""
def __init__(self, m, k, lam, delta=0.001):
super(RLSMapping, self).__init__()
self.m = m
self.k = k
self.lam = lam
self.delta = delta
self._init()
@classmethod
def from_weights(cls, weights):
"""Construct an RLSMapping static weights."""
obj = cls(1, 1, 1)
obj.weights = weights
return obj
def _init(self):
self.w = numpy.zeros((self.k, self.m))
self.P = numpy.eye(self.m) / self.delta
def process(self, data):
"""Just applies the current weights to the input."""
self.y = data
self.xhat = self.y.dot(self.w.T)
return self.xhat
def update(self, x):
"""Update the weights with the teaching signal."""
z = self.P.dot(self.y.T)
g = z / (self.lam + self.y.dot(z))
e = x - self.xhat
self.w = self.w + numpy.outer(e, g)
self.P = (self.P - numpy.outer(g, z)) / self.lam
class CursorFollowing(Task):
# TODO split this into two tasks (a "training" task and a "practice" task).
# This would involve storing the RLS weights and loading them for the
# practice task. Probably a good idea to write a simple cursor interface
# class to share common code between the two tasks.
target_dist = 0.8
def __init__(self, pipeline):
super(CursorFollowing, self).__init__()
self.pipeline = pipeline
def prepare_design(self, design):
d = self.target_dist
target_positions = [(d, 0), (0, d), (-d, 0), (0, -d), (0, 0)]
for training in [True, False]:
block = design.add_block()
for x, y in target_positions:
block.add_trial(attrs={
'target_x': x,
'target_y': y,
'training': training
})
block.shuffle()
def prepare_graphics(self, container):
self.canvas = Canvas()
self.cursor = Circle(0.05, color='#aa1212')
self.target = Circle(0.1, color='#32b124')
self.canvas.add_item(self.target)
self.canvas.add_item(self.cursor)
self.canvas.add_item(Cross())
container.set_widget(self.canvas)
def prepare_daq(self, daqstream):
self.daqstream = daqstream
self.daqstream.start()
self.timer = Counter(50)
self.timer.timeout.connect(self.finish_trial)
def run_trial(self, trial):
if not trial.attrs['training']:
self.target.color = '#3224b1'
self._reset()
self.target.pos = trial.attrs['target_x'], trial.attrs['target_y']
self.target.show()
self.pipeline.clear()
self.connect(self.daqstream.updated, self.update)
def update(self, data):
xhat = self.pipeline.process(data)
self.cursor.pos = xhat
target_pos = numpy.array([self.trial.attrs['target_x'],
self.trial.attrs['target_y']])
if self.trial.attrs['training']:
self.pipeline.named_blocks['RLSMapping'].update(target_pos)
if self.cursor.collides_with(self.target):
self.finish_trial()
self.timer.increment()
def finish_trial(self):
self.disconnect(self.daqstream.updated, self.update)
self._reset()
self.next_trial()
def _reset(self):
self.cursor.pos = 0, 0
self.timer.reset()
self.target.hide()
def finish(self):
self.daqstream.stop()
def key_press(self, key):
if key == util.key_escape:
self.finish()
else:
super().key_press(key)
if __name__ == '__main__':
dev = NoiseGenerator(rate=2000, num_channels=4, read_size=200)
b, a = butter(4, (10/2000./2., 450/2000./2.), 'bandpass')
preproc_pipeline = pipeline.Pipeline([
pipeline.Windower(400),
pipeline.Centerer(),
pipeline.Filter(b, a=a, overlap=200),
])
main_pipeline = pipeline.Pipeline([
preproc_pipeline,
pipeline.Callable(mean_absolute_value),
RLSMapping(4, 2, 0.99)
])
Experiment(daq=dev, subject='test').run(
Oscilloscope(preproc_pipeline),
CursorFollowing(main_pipeline)
)
Total running time of the script: ( 0 minutes 0.000 seconds)