Source code for nwbwidgets.behavior
from abc import abstractmethod
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from plotly import graph_objects as go
from pynwb import TimeSeries
from pynwb.behavior import SpatialSeries, BehavioralEvents
from nwbwidgets import base
from .utils.timeseries import (
get_timeseries_tt,
get_timeseries_in_units,
timeseries_time_to_ind,
)
from .timeseries import (
AlignMultiTraceTimeSeriesByTrialsConstant,
AlignMultiTraceTimeSeriesByTrialsVariable,
AbstractTraceWidget,
SingleTracePlotlyWidget,
SeparateTracesPlotlyWidget,
)
from .base import lazy_tabs
from .controllers import StartAndDurationController
[docs]def show_behavioral_events(beh_events: BehavioralEvents, neurodata_vis_spec: dict):
return base.dict2accordion(
beh_events.time_series, neurodata_vis_spec, ls="", marker="|"
)
[docs]def show_spatial_series(node: SpatialSeries, **kwargs):
data, unit = get_timeseries_in_units(node)
tt = get_timeseries_tt(node)
if len(data.shape) == 1:
fig, ax = plt.subplots()
ax.plot(tt, data, **kwargs)
ax.set_xlabel("t (sec)")
if unit:
ax.set_xlabel("x ({})".format(unit))
else:
ax.set_xlabel("x")
ax.set_ylabel("x")
elif data.shape[1] == 2:
fig, ax = plt.subplots()
ax.plot(data[:, 0], data[:, 1], **kwargs)
if unit:
ax.set_xlabel("x ({})".format(unit))
ax.set_ylabel("y ({})".format(unit))
else:
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.axis("equal")
elif data.shape[1] == 3:
import ipyvolume.pylab as p3
fig = p3.figure()
p3.scatter(data[:, 0], data[:, 1], data[:, 2], **kwargs)
p3.xlim(np.min(data[:, 0]), np.max(data[:, 0]))
p3.ylim(np.min(data[:, 1]), np.max(data[:, 1]))
p3.zlim(np.min(data[:, 2]), np.max(data[:, 2]))
else:
raise NotImplementedError
return fig
[docs]def route_spatial_series(spatial_series, **kwargs):
if len(spatial_series.data.shape) == 1:
dict_ = {
"over time": SingleTracePlotlyWidget,
"trial aligned": trial_align_spatial_series,
}
elif spatial_series.data.shape[1] == 2:
dict_ = {
"over time": SeparateTracesPlotlyWidget,
"trace": SpatialSeriesTraceWidget2D,
"trial aligned": trial_align_spatial_series,
}
elif spatial_series.data.shape[1] == 3:
dict_ = {
"over time": SeparateTracesPlotlyWidget,
"trace": SpatialSeriesTraceWidget3D,
"trial aligned": trial_align_spatial_series,
}
else:
return widgets.HTML("Unsupported number of dimensions.")
return lazy_tabs(dict_, spatial_series)
[docs]class SpatialSeriesTraceWidget(AbstractTraceWidget):
def __init__(
self,
timeseries: TimeSeries,
foreign_time_window_controller: StartAndDurationController = None,
**kwargs,
):
super().__init__(
timeseries=timeseries,
foreign_time_window_controller=foreign_time_window_controller,
**kwargs,
)
[docs] def set_out_fig(self):
timeseries = self.controls["timeseries"].value
time_window = self.controls["time_window"].value
istart = timeseries_time_to_ind(timeseries, time_window[0])
istop = timeseries_time_to_ind(timeseries, time_window[1])
data, units = get_timeseries_in_units(timeseries, istart, istop)
tt = get_timeseries_tt(timeseries, istart, istop)
self.plot_data(data, units, tt)
def on_change(change):
time_window = self.controls["time_window"].value
istart = timeseries_time_to_ind(timeseries, time_window[0])
istop = timeseries_time_to_ind(timeseries, time_window[1])
data, units = get_timeseries_in_units(timeseries, istart, istop)
tt = get_timeseries_tt(timeseries, istart, istop)
self.update_plot(data, tt)
self.controls["time_window"].observe(on_change)
[docs]class SpatialSeriesTraceWidget2D(SpatialSeriesTraceWidget):
[docs] def plot_data(self, data, units, tt):
if units is None:
units = "no units"
self.out_fig = go.FigureWidget(
data=go.Scattergl(
x=list(data[:, 0]),
y=list(data[:, 1]),
mode="markers+lines",
marker_color=tt,
marker_colorscale="Viridis",
marker_colorbar=dict(thickness=20, title="time (s)"),
marker_size=5,
)
)
self.out_fig.update_layout(
title=self.timeseries.name,
xaxis_title=f"x ({units})",
yaxis_title=f"y ({units})",
)
[docs] def update_plot(self, data, tt):
self.out_fig.data[0].x = list(data[:, 0])
self.out_fig.data[0].y = list(data[:, 1])
self.out_fig.update_traces(marker_color=list(tt))
[docs]class SpatialSeriesTraceWidget3D(SpatialSeriesTraceWidget):
[docs] def plot_data(self, data, units, tt):
if units is None:
units = "no units"
self.out_fig = go.FigureWidget(
data=go.Scatter3d(
x=data[:, 0],
y=data[:, 1],
z=data[:, 2],
mode="markers+lines",
marker_color=tt,
marker_colorscale="Viridis",
marker_colorbar=dict(thickness=20, title="time (s)"),
marker_size=5,
)
)
self.out_fig.update_layout(
scene=dict(
xaxis_title=f"x ({units})",
yaxis_title=f"y ({units})",
zaxis_title=f"x ({units})",
)
)
[docs] def update_plot(self, data, tt):
self.out_fig.data[0].x = list(data[:, 0])
self.out_fig.data[0].y = list(data[:, 1])
self.out_fig.data[0].z = list(data[:, 2])
self.out_fig.update_traces(marker_color=list(tt))
[docs]def trial_align_spatial_series(spatial_series, trials=None):
options = [("x", 0), ("y", 1), ("z", 2)][: spatial_series.data.shape[1]]
if spatial_series.rate is None:
return AlignMultiTraceTimeSeriesByTrialsVariable(
time_series=spatial_series,
trials=trials,
trace_controller_kwargs=dict(options=options),
)
else:
return AlignMultiTraceTimeSeriesByTrialsConstant(
time_series=spatial_series,
trials=trials,
trace_controller_kwargs=dict(options=options),
)