Source code for nwbwidgets.misc
from functools import partial
import scipy
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
import plotly.graph_objects as go
from ipywidgets import widgets, fixed, FloatProgress, Layout
from pynwb.epoch import TimeIntervals
from pynwb.misc import AnnotationSeries, Units, DecompositionSeries
from .analysis.spikes import compute_smoothed_firing_rate
from .controllers import (
make_trial_event_controller,
GroupAndSortController,
StartAndDurationController,
ProgressBar,
)
from .utils.dynamictable import infer_categorical_columns, extract_data_from_intervals
from .utils.mpl import create_big_ax
from .utils.plotly import event_group
from .utils.units import (
get_spike_times,
get_max_spike_time,
get_min_spike_time,
align_by_time_intervals,
get_unobserved_intervals,
)
from .utils.widgets import interactive_output, clean_axes
color_wheel = plt.rcParams["axes.prop_cycle"].by_key()["color"]
[docs]def show_annotations(annotations: AnnotationSeries, **kwargs):
fig, ax = plt.subplots()
ax.eventplot(annotations.timestamps, **kwargs)
ax.set_xlabel("time (s)")
return fig
[docs]def show_session_raster(
units: Units,
time_window=None,
units_window=None,
show_obs_intervals=True,
order=None,
group_inds=None,
labels=None,
show_legend=True,
progress_bar=None,
):
"""
Parameters
----------
units: pynwb.misc.Units
time_window: [int, int]
units_window: [int, int]
show_obs_intervals: bool
order: array-like, optional
group_inds: array-like, optional
labels: array-like, optional
show_legend: bool
default = True
Does not show legend if color_by is None or 'id'.
progress_bar: FloatProgress, optional
Returns
-------
matplotlib.pyplot.Figure
"""
if time_window is None:
time_window = [get_min_spike_time(units), get_max_spike_time(units)]
if units_window is None:
units_window = [0, len(units)]
if order is None:
order = np.arange(len(units), dtype="int")
if progress_bar:
this_iter = ProgressBar(order, desc="reading spike data", leave=False)
progress_bar = this_iter.container
else:
this_iter = order
data = []
for unit in this_iter:
data.append(get_spike_times(units, unit, time_window))
if show_obs_intervals:
unobserved_intervals_list = get_unobserved_intervals(units, time_window, order)
else:
unobserved_intervals_list = None
ax = plot_grouped_events(
data,
time_window,
group_inds=group_inds,
labels=labels,
show_legend=show_legend,
offset=units_window[0],
unobserved_intervals_list=unobserved_intervals_list,
progress_bar=progress_bar,
)
ax.set_ylabel("unit #")
if len(data) <= 30:
unit_id_display = np.array(units.id.data[:])[[x for x in this_iter]]
ax.set_yticklabels(unit_id_display)
else:
ax.axes.yaxis.set_visible(False)
return ax
[docs]class RasterWidget(widgets.HBox):
def __init__(
self,
units: Units,
foreign_time_window_controller: StartAndDurationController = None,
foreign_group_and_sort_controller: GroupAndSortController = None,
group_by=None,
):
super().__init__()
self.units = units
if foreign_time_window_controller is None:
self.tmin = get_min_spike_time(units)
self.tmax = get_max_spike_time(units)
self.time_window_controller = StartAndDurationController(
tmin=self.tmin, tmax=self.tmax
)
else:
self.time_window_controller = foreign_time_window_controller
if foreign_group_and_sort_controller:
self.gas = foreign_group_and_sort_controller
else:
self.gas = self.make_group_and_sort(group_by=group_by, control_order=False)
self.progress_bar = widgets.HBox()
self.controls = dict(time_window=self.time_window_controller, gas=self.gas)
plot_func = partial(
show_session_raster, units=self.units, progress_bar=self.progress_bar
)
out_fig = interactive_output(plot_func, self.controls)
if foreign_time_window_controller:
right_panel = widgets.VBox(
children=[
self.progress_bar,
out_fig,
],
layout=Layout(width="100%"),
)
else:
right_panel = widgets.VBox(
children=[
self.time_window_controller,
self.progress_bar,
out_fig,
],
layout=Layout(width="100%"),
)
if foreign_group_and_sort_controller:
self.children = [right_panel]
else:
self.children = [self.gas, right_panel]
self.layout = Layout(width="100%")
[docs] def make_group_and_sort(self, group_by=None, control_order=True):
return GroupAndSortController(
self.units, group_by=group_by, control_order=control_order
)
[docs]def show_decomposition_series(node, **kwargs):
# Use Rendering... as a placeholder
ntabs = 2
children = [widgets.HTML("Rendering...") for _ in range(ntabs)]
def on_selected_index(change):
# Click on Traces Tab
if change.new == 1 and isinstance(change.owner.children[1], widgets.HTML):
widget_box = show_decomposition_traces(node)
children[1] = widget_box
change.owner.children = children
field_lay = widgets.Layout(
max_height="40px", max_width="500px", min_height="30px", min_width="130px"
)
vbox = []
for key, val in node.fields.items():
lbl_key = widgets.Label(key + ":", layout=field_lay)
lbl_val = widgets.Label(str(val), layout=field_lay)
vbox.append(widgets.HBox(children=[lbl_key, lbl_val]))
children[0] = widgets.VBox(vbox)
tab_nest = widgets.Tab()
tab_nest.children = children
tab_nest.set_title(0, "Fields")
tab_nest.set_title(1, "Traces")
tab_nest.observe(on_selected_index, names="selected_index")
return tab_nest
[docs]def show_decomposition_traces(node: DecompositionSeries):
# Produce figure
def control_plot(x0, x1, ch0, ch1):
fig, ax = plt.subplots(nrows=nBands, ncols=1, sharex=True, figsize=(14, 7))
for bd in range(nBands):
data = node.data[x0:x1, ch0: ch1 + 1, bd]
xx = np.arange(x0, x1)
mu_array = np.mean(data, 0)
sd_array = np.std(data, 0)
offset = np.mean(sd_array) * 5
yticks = [i * offset for i in range(ch1 + 1 - ch0)]
for i in range(ch1 + 1 - ch0):
ax[bd].plot(xx, data[:, i] - mu_array[i] + yticks[i])
ax[bd].set_ylabel("Ch #", fontsize=20)
ax[bd].set_yticks(yticks)
ax[bd].set_yticklabels([str(i) for i in range(ch0, ch1 + 1)])
ax[bd].tick_params(axis="both", which="major", labelsize=16)
ax[bd].set_xlabel("Time [ms]", fontsize=20)
return fig
nSamples = node.data.shape[0]
nChannels = node.data.shape[1]
nBands = node.data.shape[2]
fs = node.rate
# Controls
field_lay = widgets.Layout(
max_height="40px", max_width="100px", min_height="30px", min_width="70px"
)
x0 = widgets.BoundedIntText(
value=0, min=0, max=int(1000 * nSamples / fs - 100), layout=field_lay
)
x1 = widgets.BoundedIntText(
value=nSamples, min=100, max=int(1000 * nSamples / fs), layout=field_lay
)
ch0 = widgets.BoundedIntText(
value=0, min=0, max=int(nChannels - 1), layout=field_lay
)
ch1 = widgets.BoundedIntText(
value=10, min=0, max=int(nChannels - 1), layout=field_lay
)
controls = {"x0": x0, "x1": x1, "ch0": ch0, "ch1": ch1}
out_fig = widgets.interactive_output(control_plot, controls)
# Assemble layout box
lbl_x = widgets.Label("Time [ms]:", layout=field_lay)
lbl_ch = widgets.Label("Ch #:", layout=field_lay)
lbl_blank = widgets.Label(" ", layout=field_lay)
hbox0 = widgets.HBox(children=[lbl_x, x0, x1, lbl_blank, lbl_ch, ch0, ch1])
vbox = widgets.VBox(children=[hbox0, out_fig])
return vbox
[docs]class PSTHWidget(widgets.VBox):
def __init__(
self,
input_data: Units,
trials: TimeIntervals = None,
unit_index=0,
unit_controller=None,
ntt=1000,
):
self.units = input_data
super().__init__()
if trials is None:
self.trials = self.get_trials()
if self.trials is None:
self.children = [widgets.HTML("No trials present")]
return
else:
self.trials = trials
if unit_controller is None:
self.unit_ids = self.units.id.data[:]
n_units = len(self.unit_ids)
self.unit_controller = widgets.Dropdown(
options=[(str(self.unit_ids[x]), x) for x in range(n_units)],
value=unit_index,
description="unit",
layout=Layout(width="200px"),
)
else:
self.unit_controller = unit_controller
self.trial_event_controller = make_trial_event_controller(
self.trials, layout=Layout(width="200px"), multiple=True
)
self.start_ft = widgets.FloatText(
-0.5, step=0.1, description="start (s)", layout=Layout(width="200px"),
description_tooltip='Start time for calculation before or after (negative or positive) the reference point (aligned to)'
)
self.end_ft = widgets.FloatText(
1.0, step=0.1, description="end (s)", layout=Layout(width="200px"),
description_tooltip='End time for calculation before or after (negative or positive) the reference point (aligned to).'
)
self.psth_type_radio = widgets.RadioButtons(
options=["histogram", "gaussian"], layout=Layout(width="100px")
)
self.bins_ft = widgets.IntText(
30, min=0, description="# bins", layout=Layout(width="150px")
)
self.gaussian_sd_ft = widgets.FloatText(
0.05,
min=0.001,
description="sd (s)",
layout=Layout(width="150px"),
active=False,
step=0.01,
)
self.gas = self.make_group_and_sort(window=False, control_order=False)
self.controls = dict(
ntt=fixed(ntt),
index=self.unit_controller,
end=self.end_ft,
start=self.start_ft,
start_labels=self.trial_event_controller,
gas=self.gas,
plot_type=self.psth_type_radio,
sigma_in_secs=self.gaussian_sd_ft,
nbins=self.bins_ft
# progress_bar=fixed(progress_bar)
)
out_fig = interactive_output(self.update, self.controls)
self.children = [
widgets.HBox(
[
widgets.VBox(
[
self.gas,
widgets.HBox(
[
self.psth_type_radio,
widgets.VBox([self.gaussian_sd_ft, self.bins_ft]),
]
),
]
),
widgets.VBox(
[
self.unit_controller,
self.trial_event_controller,
self.start_ft,
self.end_ft,
]
),
]
),
out_fig,
]
[docs] def make_group_and_sort(self, window=None, control_order=False):
return GroupAndSortController(
self.trials, window=window, control_order=control_order
)
[docs] def update(
self,
index: int,
start_labels: tuple = ("start_time",),
start: float = 0.0,
end: float = 1.0,
order=None,
group_inds=None,
labels=None,
sigma_in_secs=0.05,
ntt: int = 1000,
progress_bar=None,
figsize=(12, 7),
nbins=30,
plot_type="histogram",
align_line_color=(0.7, 0.7, 0.7),
):
"""
Parameters
----------
index: int
Index of unit
start_label: str, optional
Trial column name to align on
start: float
Start time for calculation before or after (negative or positive) the reference point (aligned to).
end: float
End time for calculation before or after (negative or positive) the reference point (aligned to).
order
group_inds
labels
sigma_in_secs: float, optional
standard deviation of gaussian kernel
ntt:
Number of time points to use for smooth curve
progress_bar:
figsize: tuple, optional
Returns
-------
matplotlib.Figure
"""
fig, axs = plt.subplots(2, len(start_labels), figsize=figsize,
sharex=True)
clean_axes(axs.ravel())
ax1_ylims = []
for i_s, start_label in enumerate(start_labels):
if len(start_labels) > 1:
ax0 = axs[0, i_s]
ax1 = axs[1, i_s]
else:
ax0 = axs[0]
ax1 = axs[1]
data = align_by_time_intervals(
self.units,
index,
self.trials,
start_label,
start_label,
start,
end,
order,
progress_bar=progress_bar,
)
if i_s == len(start_labels) - 1:
show_legend = True
else:
show_legend = False
show_psth_raster(
data,
start,
end,
group_inds,
labels,
show_legend=show_legend,
ax=ax0,
progress_bar=progress_bar,
fontsize=12,
)
ax0.set_title(f"{start_label}")
ax0.set_xticks([])
ax0.set_xlabel("")
if i_s > 0:
ax0.set_ylabel("")
# Raster always show the same number of trials. We can avoid showing tick labels too
ax0.set_yticklabels([])
if plot_type == "gaussian":
self.bins_ft.layout.visibility = "hidden"
self.bins_ft.layout.height = "0px"
self.gaussian_sd_ft.layout.visibility = None
self.gaussian_sd_ft.layout.height = None
# expanded data so that gaussian smoother uses larger window than is viewed
expanded_data = align_by_time_intervals(
units=self.units,
index=index,
intervals=self.trials,
start_label=start_label,
stop_label=start_label,
start=start - sigma_in_secs * 4,
end=end + sigma_in_secs * 4,
rows_select=order,
progress_bar=progress_bar,
)
show_psth_smoothed(
data=expanded_data,
ax=ax1,
start=start - sigma_in_secs * 4,
end=end + sigma_in_secs * 4,
group_inds=group_inds,
sigma_in_secs=sigma_in_secs,
ntt=ntt,
)
elif plot_type == "histogram":
self.gaussian_sd_ft.layout.visibility = "hidden"
self.gaussian_sd_ft.layout.height = "0px"
self.bins_ft.layout.visibility = None
self.bins_ft.layout.height = None
show_histogram(data, ax1, start, end, group_inds, nbins=nbins)
else:
raise ValueError(
"unsupported plot type {}".format(self.psth_type_radio.value)
)
ax1.set_xlim([start, end])
if i_s == 0:
ax1.set_ylabel("firing rate (Hz)", fontsize=12)
ax1.set_xlabel("time (s)", fontsize=12)
ax1.axvline(color=align_line_color)
ax1_ylims.append(ax1.get_ylim())
if len(start_labels) > 1:
# Adjust bottom axes y axis
min_y = np.min(np.array(ax1_ylims)[:, 0])
max_y = np.max(np.array(ax1_ylims)[:, 1])
for i_b, ax_btm in enumerate(axs[1, :]):
ax_btm.set_ylim(min_y, max_y)
if i_b > 0:
ax_btm.set_ylabel("")
# After adjusting ylims we can avoid showing tick labels
ax_btm.set_yticklabels([])
fig.suptitle(f"Unit {self.unit_ids[index]}", fontsize=15)
fig.subplots_adjust(wspace=0.3)
return fig
[docs]def show_histogram(
data, ax: plt.Axes, start: float, end: float, group_inds=None, nbins: int = 30
):
if not len(data):
return
if group_inds is None:
height, x = np.histogram(np.hstack(data), bins=nbins, range=(start, end))
width = np.diff(x[:2])
height = height / len(data) / width
ax.bar(x[:-1], height, edgecolor=(0.3, 0.3, 0.3), width=width, align="edge")
else:
data = np.asarray(data, dtype="object")
# group_inds = np.asarray(group_inds)
for group in np.unique(group_inds):
this_data = np.hstack(data[group_inds == group])
height, x = np.histogram(this_data, bins=nbins, range=(start, end))
width = np.diff(x[:2])
height = height / np.sum(group_inds == group) / width
ax.bar(
x[:-1],
height,
color=color_wheel[group % len(color_wheel)],
edgecolor=(0.3, 0.3, 0.3),
width=width,
align="edge",
alpha=0.6,
)
[docs]def show_psth_smoothed(
data,
ax,
start: float,
end: float,
group_inds=None,
sigma_in_secs: float = 0.05,
ntt: int = 1000,
):
if not len(data): # TODO: when does this occur?
return
all_data = np.hstack(data)
if not len(all_data): # no spikes
return
tt = np.linspace(start, end, ntt)
smoothed = np.array(
[compute_smoothed_firing_rate(x, tt, sigma_in_secs) for x in data]
)
if group_inds is None:
group_inds = np.zeros((len(smoothed)), dtype=int)
group_stats = []
for group in np.unique(group_inds):
this_mean = np.mean(smoothed[group_inds == group], axis=0)
err = scipy.stats.sem(smoothed[group_inds == group], axis=0)
group_stats.append(
dict(
mean=this_mean,
lower=this_mean - 2 * err,
upper=this_mean + 2 * err,
group=group,
)
)
for stats in group_stats:
color = color_wheel[stats["group"] % len(color_wheel)]
ax.plot(tt, stats["mean"], color=color)
ax.fill_between(tt, stats["lower"], stats["upper"], alpha=0.2, color=color)
[docs]def plot_grouped_events(
data,
window,
group_inds=None,
colors=color_wheel,
ax=None,
labels=None,
show_legend=True,
offset=0,
unobserved_intervals_list=None,
progress_bar=None,
figsize=(8, 6),
fontsize=12,
):
"""
Parameters
----------
data: array-like
window: array-like [float, float]
Time in seconds
group_inds: array-like dtype=int, optional
colors: array-like, optional
ax: plt.Axes, optional
labels: array-like dtype=str, optional
show_legend: bool, optional
offset: number, optional
unobserved_intervals_list: array-like, optional
progress_bar: FloatProgress, optional
figsize: tuple, optional
fontsize: int, optional
Returns
-------
"""
data = np.asarray(data, dtype="object")
legend_kwargs = dict()
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
if hasattr(fig, "canvas"):
fig.canvas.header_visible = False
else:
legend_kwargs.update(bbox_to_anchor=(1.01, 1))
if group_inds is not None:
ugroup_inds = np.unique(group_inds)
handles = []
if progress_bar is not None:
this_iter = ProgressBar(
enumerate(ugroup_inds),
desc="plotting spikes",
leave=False,
total=len(ugroup_inds),
)
progress_bar = this_iter.container
else:
this_iter = enumerate(ugroup_inds)
for i, ui in this_iter:
color = colors[ugroup_inds[i] % len(colors)]
lineoffsets = np.where(group_inds == ui)[0] + offset
event_collection = ax.eventplot(
data[group_inds == ui],
orientation="horizontal",
lineoffsets=lineoffsets,
color=color,
)
handles.append(event_collection[0])
if show_legend:
ax.legend(
handles=handles[::-1],
labels=list(labels[ugroup_inds][::-1]),
loc="upper left",
bbox_to_anchor=(1.01, 1),
**legend_kwargs,
)
else:
ax.eventplot(
data,
orientation="horizontal",
color="k",
lineoffsets=np.arange(len(data)) + offset,
)
if unobserved_intervals_list is not None:
plot_unobserved_intervals(unobserved_intervals_list, ax, offset=offset)
ax.set_xlim(window)
ax.set_xlabel("time (s)", fontsize=fontsize)
ax.set_ylim(np.array([-0.5, len(data) - 0.5]) + offset)
if len(data) <= 30:
ax.set_yticks(range(offset, len(data) + offset))
ax.set_yticklabels(range(offset, len(data) + offset))
return ax
[docs]def plot_unobserved_intervals(
unobserved_intervals_list, ax, offset=0, color=(0.85, 0.85, 0.85)
):
for irow, unobs_intervals in enumerate(unobserved_intervals_list):
rects = [
Rectangle(
(i_interval[0], irow - 0.5 + offset), i_interval[1] - i_interval[0], 1
)
for i_interval in unobs_intervals
]
pc = PatchCollection(rects, color=color)
ax.add_collection(pc)
[docs]def show_psth_raster(
data,
start=-0.5,
end=2.0,
group_inds=None,
labels=None,
ax=None,
show_legend=True,
align_line_color=(0.7, 0.7, 0.7),
progress_bar: FloatProgress = None,
fontsize=12,
) -> plt.Axes:
"""
Parameters
----------
data: array-like
start: float
Start time for calculation before or after (negative or positive) the reference point (aligned to).
end: float
End time for calculation before or after (negative or positive) the reference point (aligned to).
group_inds: array-like, optional
labels: array-like, optional
ax: plt.Axes, optional
show_legend: bool, optional
align_line_color: array-like, optional
[R, G, B] (0-1)
Default = [0.7, 0.7, 0.7]
progress_bar: FloatProgress, optional
fontsize: int, optional
Returns
-------
plt.Axes
"""
if not len(data):
return ax
ax = plot_grouped_events(
data,
[start, end],
group_inds,
color_wheel,
ax,
labels,
show_legend=show_legend,
progress_bar=progress_bar,
fontsize=fontsize
)
ax.set_ylabel("trials", fontsize=fontsize)
ax.axvline(color=align_line_color)
return ax
[docs]def raster_grid(
units: Units,
time_intervals: TimeIntervals,
index,
start,
end,
rows_label=None,
cols_label=None,
trials_select=None,
align_by="start_time",
) -> plt.Figure:
"""
Parameters
----------
units: pynwb.misc.Units
time_intervals: pynwb.epoch.TimeIntervals
index: int
start: float
Start time for calculation before or after (negative or positive) the reference point (aligned to).
end: float
End time for calculation before or after (negative or positive) the reference point (aligned to).
rows_label: str, optional
cols_label: str, optional
trials_select: np.array(dtype=bool), optional
align_by: str, optional
Returns
-------
plt.Figure
"""
if time_intervals is None:
raise ValueError("trials must exist (trials cannot be None)")
if trials_select is None:
trials_select = np.ones((len(time_intervals),)).astype("bool")
if rows_label is not None:
row_vals = np.array(time_intervals[rows_label][:])
urow_vals = np.unique(row_vals[trials_select])
if urow_vals.dtype == np.float64:
urow_vals = urow_vals[~np.isnan(urow_vals)]
else:
urow_vals = [None]
nrows = len(urow_vals)
if cols_label is not None:
col_vals = np.array(time_intervals[cols_label][:])
ucol_vals = np.unique(col_vals[trials_select])
if ucol_vals.dtype == np.float64:
ucol_vals = ucol_vals[~np.isnan(ucol_vals)]
else:
ucol_vals = [None]
ncols = len(ucol_vals)
fig, axs = plt.subplots(
nrows, ncols, sharex=True, sharey=True, squeeze=False, figsize=(10, 10)
)
big_ax = create_big_ax(fig)
for i, row in enumerate(urow_vals):
for j, col in enumerate(ucol_vals):
ax = axs[i, j]
ax_trials_select = trials_select.copy()
if row is not None:
ax_trials_select &= row_vals == row
if col is not None:
ax_trials_select &= col_vals == col
ax_trials_select = np.where(ax_trials_select)[0]
if len(ax_trials_select):
data = align_by_time_intervals(
units,
index,
time_intervals,
align_by,
align_by,
start,
end,
ax_trials_select,
)
show_psth_raster(data, start, end, ax=ax)
ax.set_xlabel("")
ax.set_ylabel("")
if ax.get_subplotspec().is_first_col():
ax.set_ylabel(row)
if ax.get_subplotspec().is_last_row():
ax.set_xlabel(col)
big_ax.set_xlabel(cols_label, labelpad=50)
big_ax.set_ylabel(rows_label, labelpad=60)
return fig
[docs]def plot_grouped_events_plotly(
data,
window=None,
group_inds=None,
colors=color_wheel,
labels=None,
show_legend=True,
unobserved_intervals_list=None,
progress_bar=None,
fig=None,
**kwargs,
):
data = np.array(data, dtype=object)
if fig is None:
fig = go.FigureWidget()
if group_inds is not None:
ugroup_inds = np.unique(group_inds)
offset = 0
for i in np.arange(len(ugroup_inds)):
ui = ugroup_inds[i]
color = colors[ugroup_inds[i] % len(colors)]
this_data = data[group_inds == ui]
event_group(
this_data,
offset=offset,
label=labels[ui],
color=color,
fig=fig,
**kwargs,
)
offset += len(this_data)
else:
event_group(data, fig=fig, **kwargs)
fig.update_layout(xaxis_title="time (s)")
return fig
[docs]class RasterWidgetPlotly(widgets.HBox):
def __init__(
self,
units: Units,
foreign_time_window_controller: StartAndDurationController = None,
foreign_group_and_sort_controller: GroupAndSortController = None,
group_by=None,
fig: go.FigureWidget = None,
):
super().__init__()
self.units = units
if foreign_time_window_controller is None:
self.tmin = get_min_spike_time(units)
self.tmax = get_max_spike_time(units)
self.time_window_controller = StartAndDurationController(
tmin=self.tmin, tmax=self.tmax
)
else:
self.time_window_controller = foreign_time_window_controller
if foreign_group_and_sort_controller:
self.gas = foreign_group_and_sort_controller
else:
self.gas = GroupAndSortController(
dynamic_table=self.units, group_by=group_by
)
self.show_legend_cb = widgets.Checkbox(value=True, description="show legend")
if fig is None:
self.fig = go.FigureWidget()
self.fig.update_layout(margin=dict(l=20, r=20, t=30, b=20))
else:
self.fig = fig
show_session_raster_plotly(
self.units, self.fig, self.time_window_controller.value, **self.gas.value
)
# set children
if foreign_time_window_controller:
right_panel = widgets.VBox(
children=[
self.progress_bar,
self.fig,
],
layout=Layout(width="100%"),
)
else:
right_panel = widgets.VBox(
children=[self.time_window_controller, self.fig, self.show_legend_cb],
layout=Layout(width="100%"),
)
if foreign_group_and_sort_controller:
self.children = [right_panel]
else:
self.children = [self.gas, right_panel]
self.layout = Layout(width="100%")
self.time_window_controller.observe(self.update_fig, "value")
self.gas.observe(self.update_fig, "value")
self.show_legend_cb.observe(self.toggle_legend, "value")
[docs] def toggle_legend(self, change):
self.fig.update_layout(showlegend=self.show_legend_cb.value)
[docs] def update_fig(self, change):
time_window = self.time_window_controller.value
gas_kwargs = self.gas.value
with self.fig.batch_update():
self.fig.data = None
show_session_raster_plotly(self.units, self.fig, time_window, **gas_kwargs)
[docs]def show_session_raster_plotly(
units: Units, fig, time_window=None, order=None, progress_bar=None, **kwargs
):
"""
Parameters
----------
units: pynwb.misc.Units
time_window: [int, int]
show_obs_intervals: bool
order: array-like, optional
group_inds: array-like, optional
labels: array-like, optional
show_legend: bool
default = True
Does not show legend if color_by is None or 'id'.
progress_bar: FloatProgress, optional
Returns
-------
go.FigureWidget
"""
if time_window is None:
time_window = [get_min_spike_time(units), get_max_spike_time(units)]
if order is None:
order = np.arange(len(units), dtype="int")
if progress_bar:
this_iter = ProgressBar(order, desc="reading spike data", leave=False)
progress_bar = this_iter.container
else:
this_iter = order
data = []
for unit in this_iter:
data.append(get_spike_times(units, unit, time_window))
# if show_obs_intervals:
# unobserved_intervals_list = get_unobserved_intervals(units, time_window, order)
# else:
# unobserved_intervals_list = None
fig.update_yaxes(tickvals=[], ticktext=[])
if len(order) <= 100:
kwargs.update(marker="line-ns", line_width=2)
else:
kwargs.update(line_width=1)
fig = plot_grouped_events_plotly(data=data, fig=fig, **kwargs)
if len(order) <= 40:
fig.update_yaxes(
tickvals=np.arange(len(order)), ticktext=[str(i) for i in order]
)
fig.update_layout(
title="units",
xaxis_title="time (s)",
legend=dict(x=1.0, y=0, traceorder="reversed"),
xaxis=dict(range=time_window),
yaxis=dict(range=[-0.5, len(order) + 0.5]),
)
return fig
[docs]class UnitsAndTrialsControllerWidget(widgets.VBox):
InnerWidget = None
def __init__(
self,
units: Units,
trials: TimeIntervals = None,
unit_index=0,
**kwargs
):
"""
Creates a UnitsAndTrials controller that controls InnerWidget.
Parameters
----------
units: pynwb.misc.Units object
trials: pynwb.epoch.TimeIntervals object
unit_index: int
"""
super().__init__()
self.units = units
self.kwargs = kwargs
# Check if there is trials table and create controller
if trials is None:
self.trials = self.get_trials()
if self.trials is None:
self.children = [widgets.HTML("No trials present")]
return
else:
self.trials = trials
# Create variables choice dropdowns
groups = self.get_groups(self.trials)
self.rows_controller = widgets.Dropdown(
options=[None] + list(groups),
description="rows",
value=None
)
self.rows_controller.observe(self.rows_callback, names='value')
self.cols_controller = widgets.Dropdown(
options=[None] + list(groups),
description="cols",
disabled=True,
)
# Unit controller
unit_ids = self.units.id.data[:]
n_units = len(unit_ids)
self.unit_controller = widgets.Dropdown(
options=[(str(unit_ids[x]), x) for x in range(n_units)],
value=unit_index,
description="unit",
)
# Trial event controller (align by)
self.trial_event_controller = make_trial_event_controller(self.trials)
# Start / End controllers
self.start_ft = widgets.FloatText(
-0.5, step=0.1, description="start (s)", layout=Layout(width="200px"),
description_tooltip='Start time for calculation before or after (negative or positive) the reference point (aligned to)'
)
self.end_ft = widgets.FloatText(
1.0, step=0.1, description="end (s)", layout=Layout(width="200px"),
description_tooltip='End time for calculation before or after (negative or positive) the reference point (aligned to).'
)
self.fixed = dict(
units=self.units,
time_intervals=self.trials,
)
self.controls = {
"index": self.unit_controller,
"start": self.start_ft,
"end": self.end_ft,
"align_by": self.trial_event_controller,
"rows_label": self.rows_controller,
"cols_label": self.cols_controller,
}
self.children = [
self.unit_controller,
self.rows_controller,
self.cols_controller,
self.trial_event_controller,
self.start_ft,
self.end_ft
]
[docs] def rows_callback(self, change):
"""
Gets triggered when self.rows_controller changes. Updates other dropdown options.
"""
if change['new'] is None:
self.cols_controller.disabled = True
self.cols_controller.value = None
else:
self.cols_controller.disabled = False
[docs]class RasterGridWidget(widgets.VBox):
def __init__(
self,
units: Units,
trials: TimeIntervals = None,
unit_index=0,
units_trials_controller=None,
):
super().__init__()
# Create Units and Trials controller
if not units_trials_controller:
units_trials_controller = UnitsAndTrialsControllerWidget(
units=units,
trials=trials,
unit_index=unit_index
)
self.children = [units_trials_controller]
self.fig = interactive_output(
f=raster_grid,
controls=units_trials_controller.controls,
fixed=units_trials_controller.fixed
)
self.children += tuple([self.fig])
[docs]class TuningCurveWidget(widgets.VBox):
def __init__(
self,
units: Units,
trials: TimeIntervals = None,
unit_index=0,
units_trials_controller=None,
):
super().__init__()
self.children = []
# Create Units and Trials controller
if not units_trials_controller:
units_trials_controller = UnitsAndTrialsControllerWidget(
units=units,
trials=trials,
unit_index=unit_index
)
self.children = [units_trials_controller]
self.fig = interactive_output(
f=draw_tuning_curve,
controls=units_trials_controller.controls,
fixed=units_trials_controller.fixed
)
self.children += tuple([self.fig])
[docs]class TuningCurveExtendedWidget(widgets.VBox):
def __init__(
self,
units: Units,
trials: TimeIntervals = None,
unit_index=0
):
super().__init__()
# Controller
self.units_trials_controller = UnitsAndTrialsControllerWidget(
units=units,
trials=trials,
unit_index=unit_index
)
# Tuning curve widget
self.tuning_curve = TuningCurveWidget(
units=units,
trials=trials,
unit_index=unit_index,
units_trials_controller=self.units_trials_controller,
)
# Raster grid widget
self.raster_grid = RasterGridWidget(
units=units,
trials=trials,
unit_index=unit_index,
units_trials_controller=self.units_trials_controller,
)
self.children = [
self.units_trials_controller,
self.tuning_curve,
self.raster_grid
]
[docs]def draw_tuning_curve(
units: Units,
time_intervals: TimeIntervals,
index,
start,
end,
rows_label=None,
cols_label=None,
align_by="start_time",
) -> plt.Figure:
if rows_label is None:
return widgets.HTML("Select at least one variable")
# 1D histogram
if cols_label is None:
return draw_tuning_curve_1d(
units,
time_intervals,
index,
start,
end,
rows_label,
align_by
)
return draw_tuning_curve_2d(
units,
time_intervals,
index,
start,
end,
rows_label,
cols_label,
align_by
)
[docs]def draw_tuning_curve_1d(
units: Units,
time_intervals: TimeIntervals,
index,
start,
end,
rows_label=None,
align_by="start_time",
) -> plt.Figure:
rows_data, var1_classes = extract_data_from_intervals(time_intervals[rows_label])
avg_rates = []
for v1 in var1_classes:
indexes = [i for i, d in enumerate(rows_data) if d == v1]
data = align_by_time_intervals(
units=units,
index=index,
intervals=time_intervals,
start_label=align_by,
stop_label=align_by,
start=start,
end=end,
rows_select=indexes
)
n_trials = len(data)
n_spikes = len(np.hstack(data))
duration = end - start
avg_rates.append(n_spikes / (n_trials * duration))
x = np.arange(len(var1_classes)) # the label locations
si = sort_mixed_type_list(var1_classes)
avg_rates_sorted = [avg_rates[i] for i in si]
fig, ax = plt.subplots(figsize=(14, 7))
width = 0.95 # the width of the bars
rects1 = ax.bar(x, avg_rates_sorted, width)
lines1 = ax.plot(x, avg_rates_sorted, '-o', color='k', lw=2)
# Labels
ax.set_ylabel('Avg rate')
ax.set_xlabel(rows_label)
ax.set_xticks(x)
ax.set_xticklabels([var1_classes[i] for i in si], rotation=45)
fig.tight_layout()
[docs]def draw_tuning_curve_2d(
units: Units,
time_intervals: TimeIntervals,
index,
start,
end,
rows_label=None,
cols_label=None,
align_by="start_time",
) -> plt.Figure:
rows_data, var1_classes = extract_data_from_intervals(time_intervals[rows_label])
cols_data, var2_classes = extract_data_from_intervals(time_intervals[cols_label])
avg_rates = np.zeros((len(var1_classes), len(var2_classes)))
for i, v1 in enumerate(var1_classes):
for j, v2 in enumerate(var2_classes):
indexes1 = [ii for ii, d in enumerate(rows_data) if d == v1]
indexes2 = [ii for ii, d in enumerate(cols_data) if d == v2]
intersect = list(set(indexes1) & set(indexes2))
if len(intersect) > 0:
data = align_by_time_intervals(
units=units,
index=index,
intervals=time_intervals,
start_label=align_by,
stop_label=align_by,
start=start,
end=end,
rows_select=intersect
)
n_trials = len(data)
n_spikes = len(np.hstack(data))
duration = end - start
avg_rates[i, j] = n_spikes / (n_trials * duration)
fig, ax = plt.subplots(figsize=(14, 7))
pos = ax.imshow(avg_rates.T, origin='lower', cmap='Greys')
cbar = fig.colorbar(pos, ax=ax)
cbar.set_label('spikes / second')
# Labels
ax.set_xticks(np.arange(len(var1_classes)))
ax.set_yticks(np.arange(len(var2_classes)))
ax.set_xlabel(rows_label)
ax.set_ylabel(cols_label)
ax.set_xticklabels(var1_classes, rotation=45)
ax.set_yticklabels(var2_classes, rotation=45)
return fig
[docs]def sort_mixed_type_list(x):
"""Returns the indexes for a sorted list of mixed types"""
x_num = list()
x_num_i = list()
x_oth = list()
x_oth_i = list()
for i, xx in enumerate(x):
try:
x_num.append(float(xx))
x_num_i.append(i)
except:
x_oth.append(str(xx))
x_oth_i.append(i)
x_num_si = np.argsort(x_num)
x_oth_si = np.argsort(x_oth)
return [x_num_i[ii] for ii in x_num_si] + [x_oth_i[ii] for ii in x_oth_si]