Source code for nwbwidgets.misc
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import scipy
from ipywidgets import FloatProgress, Layout, fixed, widgets
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
from pynwb.epoch import TimeIntervals
from pynwb.misc import AnnotationSeries, DecompositionSeries, Units
from .analysis.spikes import compute_smoothed_firing_rate
from .controllers import (
GroupAndSortController,
ProgressBar,
StartAndDurationController,
make_trial_event_controller,
)
from .utils.dynamictable import extract_data_from_intervals, infer_categorical_columns
from .utils.mpl import create_big_ax
from .utils.plotly import event_group
from .utils.units import (
align_by_time_intervals,
get_max_spike_time,
get_min_spike_time,
get_spike_times,
get_unobserved_intervals,
)
from .utils.widgets import clean_axes, interactive_output
color_wheel = plt.rcParams["axes.prop_cycle"].by_key()["color"]
[docs]def show_annotations(annotations: AnnotationSeries, **kwargs):
fig, ax = plt.subplots()
ax.eventplot(annotations.timestamps, **kwargs)
ax.set_xlabel("time (s)")
return fig
[docs]def show_session_raster(
units: Units,
time_window=None,
units_window=None,
show_obs_intervals=True,
order=None,
group_inds=None,
labels=None,
show_legend=True,
progress_bar=None,
):
"""
Parameters
----------
units: pynwb.misc.Units
time_window: [int, int]
units_window: [int, int]
show_obs_intervals: bool
order: array-like, optional
group_inds: array-like, optional
labels: array-like, optional
show_legend: bool
default = True
Does not show legend if color_by is None or 'id'.
progress_bar: FloatProgress, optional
Returns
-------
matplotlib.pyplot.Figure
"""
if time_window is None:
time_window = [get_min_spike_time(units), get_max_spike_time(units)]
if units_window is None:
units_window = [0, len(units)]
if order is None:
order = np.arange(len(units), dtype="int")
if progress_bar:
this_iter = ProgressBar(order, desc="reading spike data", leave=False)
progress_bar = this_iter.container
else:
this_iter = order
data = []
for unit in this_iter:
data.append(get_spike_times(units, unit, time_window))
if show_obs_intervals:
unobserved_intervals_list = get_unobserved_intervals(units, time_window, order)
else:
unobserved_intervals_list = None
ax = plot_grouped_events(
data,
time_window,
group_inds=group_inds,
labels=labels,
show_legend=show_legend,
offset=units_window[0],
unobserved_intervals_list=unobserved_intervals_list,
progress_bar=progress_bar,
)
ax.set_ylabel("unit #")
if len(data) <= 30:
unit_id_display = np.array(units.id.data[:])[[x for x in this_iter]]
ax.set_yticklabels(unit_id_display)
else:
ax.axes.yaxis.set_visible(False)
return ax
[docs]class RasterWidget(widgets.HBox):
def __init__(
self,
units: Units,
foreign_time_window_controller: StartAndDurationController = None,
foreign_group_and_sort_controller: GroupAndSortController = None,
group_by=None,
):
super().__init__()
self.units = units
if foreign_time_window_controller is None:
self.tmin = get_min_spike_time(units)
self.tmax = get_max_spike_time(units)
self.time_window_controller = StartAndDurationController(tmin=self.tmin, tmax=self.tmax)
else:
self.time_window_controller = foreign_time_window_controller
if foreign_group_and_sort_controller:
self.gas = foreign_group_and_sort_controller
else:
self.gas = self.make_group_and_sort(group_by=group_by, control_order=False)
self.progress_bar = widgets.HBox()
self.controls = dict(time_window=self.time_window_controller, gas=self.gas)
plot_func = partial(show_session_raster, units=self.units, progress_bar=self.progress_bar)
out_fig = interactive_output(plot_func, self.controls)
if foreign_time_window_controller:
right_panel = widgets.VBox(
children=[
self.progress_bar,
out_fig,
],
layout=Layout(width="100%"),
)
else:
right_panel = widgets.VBox(
children=[
self.time_window_controller,
self.progress_bar,
out_fig,
],
layout=Layout(width="100%"),
)
if foreign_group_and_sort_controller:
self.children = [right_panel]
else:
self.children = [self.gas, right_panel]
self.layout = Layout(width="100%")
[docs] def make_group_and_sort(self, group_by=None, control_order=True):
return GroupAndSortController(self.units, group_by=group_by, control_order=control_order)
[docs]def show_decomposition_series(node, **kwargs):
# Use Rendering... as a placeholder
ntabs = 2
children = [widgets.HTML("Rendering...") for _ in range(ntabs)]
def on_selected_index(change):
# Click on Traces Tab
if change.new == 1 and isinstance(change.owner.children[1], widgets.HTML):
widget_box = show_decomposition_traces(node)
children[1] = widget_box
change.owner.children = children
field_lay = widgets.Layout(max_height="40px", max_width="500px", min_height="30px", min_width="130px")
vbox = []
for key, val in node.fields.items():
lbl_key = widgets.Label(key + ":", layout=field_lay)
lbl_val = widgets.Label(str(val), layout=field_lay)
vbox.append(widgets.HBox(children=[lbl_key, lbl_val]))
children[0] = widgets.VBox(vbox)
tab_nest = widgets.Tab()
tab_nest.children = children
tab_nest.set_title(0, "Fields")
tab_nest.set_title(1, "Traces")
tab_nest.observe(on_selected_index, names="selected_index")
return tab_nest
[docs]def show_decomposition_traces(node: DecompositionSeries):
# Produce figure
def control_plot(x0, x1, ch0, ch1):
fig, ax = plt.subplots(nrows=nBands, ncols=1, sharex=True, figsize=(14, 7))
for bd in range(nBands):
data = node.data[x0:x1, ch0 : ch1 + 1, bd]
xx = np.arange(x0, x1)
mu_array = np.mean(data, 0)
sd_array = np.std(data, 0)
offset = np.mean(sd_array) * 5
yticks = [i * offset for i in range(ch1 + 1 - ch0)]
for i in range(ch1 + 1 - ch0):
ax[bd].plot(xx, data[:, i] - mu_array[i] + yticks[i])
ax[bd].set_ylabel("Ch #", fontsize=20)
ax[bd].set_yticks(yticks)
ax[bd].set_yticklabels([str(i) for i in range(ch0, ch1 + 1)])
ax[bd].tick_params(axis="both", which="major", labelsize=16)
ax[bd].set_xlabel("Time [ms]", fontsize=20)
return fig
nSamples = node.data.shape[0]
nChannels = node.data.shape[1]
nBands = node.data.shape[2]
fs = node.rate
# Controls
field_lay = widgets.Layout(max_height="40px", max_width="100px", min_height="30px", min_width="70px")
x0 = widgets.BoundedIntText(value=0, min=0, max=int(1000 * nSamples / fs - 100), layout=field_lay)
x1 = widgets.BoundedIntText(value=nSamples, min=100, max=int(1000 * nSamples / fs), layout=field_lay)
ch0 = widgets.BoundedIntText(value=0, min=0, max=int(nChannels - 1), layout=field_lay)
ch1 = widgets.BoundedIntText(value=10, min=0, max=int(nChannels - 1), layout=field_lay)
controls = {"x0": x0, "x1": x1, "ch0": ch0, "ch1": ch1}
out_fig = widgets.interactive_output(control_plot, controls)
# Assemble layout box
lbl_x = widgets.Label("Time [ms]:", layout=field_lay)
lbl_ch = widgets.Label("Ch #:", layout=field_lay)
lbl_blank = widgets.Label(" ", layout=field_lay)
hbox0 = widgets.HBox(children=[lbl_x, x0, x1, lbl_blank, lbl_ch, ch0, ch1])
vbox = widgets.VBox(children=[hbox0, out_fig])
return vbox
[docs]class PSTHWidget(widgets.VBox):
def __init__(
self,
input_data: Units,
trials: TimeIntervals = None,
unit_index=0,
unit_controller=None,
ntt=1000,
):
self.units = input_data
super().__init__()
if trials is None:
self.trials = self.get_trials()
if self.trials is None:
self.children = [widgets.HTML("No trials present")]
return
else:
self.trials = trials
if unit_controller is None:
self.unit_ids = self.units.id.data[:]
n_units = len(self.unit_ids)
self.unit_controller = widgets.Dropdown(
options=[(str(self.unit_ids[x]), x) for x in range(n_units)],
value=unit_index,
description="unit",
layout=Layout(width="200px"),
)
else:
self.unit_controller = unit_controller
self.trial_event_controller = make_trial_event_controller(
self.trials, layout=Layout(width="200px"), multiple=True
)
self.start_ft = widgets.FloatText(
-0.5,
step=0.1,
description="start (s)",
layout=Layout(width="200px"),
description_tooltip="Start time for calculation before or after (negative or positive) the reference point (aligned to)",
)
self.end_ft = widgets.FloatText(
1.0,
step=0.1,
description="end (s)",
layout=Layout(width="200px"),
description_tooltip="End time for calculation before or after (negative or positive) the reference point (aligned to).",
)
self.psth_type_radio = widgets.RadioButtons(options=["histogram", "gaussian"], layout=Layout(width="100px"))
self.bins_ft = widgets.IntText(30, min=0, description="# bins", layout=Layout(width="150px"))
self.gaussian_sd_ft = widgets.FloatText(
0.05,
min=0.001,
description="sd (s)",
layout=Layout(width="150px"),
active=False,
step=0.01,
)
self.gas = self.make_group_and_sort(window=False, control_order=False)
self.controls = dict(
ntt=fixed(ntt),
index=self.unit_controller,
end=self.end_ft,
start=self.start_ft,
start_labels=self.trial_event_controller,
gas=self.gas,
plot_type=self.psth_type_radio,
sigma_in_secs=self.gaussian_sd_ft,
nbins=self.bins_ft
# progress_bar=fixed(progress_bar)
)
out_fig = interactive_output(self.update, self.controls)
self.children = [
widgets.HBox(
[
widgets.VBox(
[
self.gas,
widgets.HBox(
[
self.psth_type_radio,
widgets.VBox([self.gaussian_sd_ft, self.bins_ft]),
]
),
]
),
widgets.VBox(
[
self.unit_controller,
self.trial_event_controller,
self.start_ft,
self.end_ft,
]
),
]
),
out_fig,
]
[docs] def make_group_and_sort(self, window=None, control_order=False):
return GroupAndSortController(self.trials, window=window, control_order=control_order)
[docs] def update(
self,
index: int,
start_labels: tuple = ("start_time",),
start: float = 0.0,
end: float = 1.0,
order=None,
group_inds=None,
labels=None,
sigma_in_secs=0.05,
ntt: int = 1000,
progress_bar=None,
figsize=(12, 7),
nbins=30,
plot_type="histogram",
align_line_color=(0.7, 0.7, 0.7),
):
"""
Parameters
----------
index: int
Index of unit
start_label: str, optional
Trial column name to align on
start: float
Start time for calculation before or after (negative or positive) the reference point (aligned to).
end: float
End time for calculation before or after (negative or positive) the reference point (aligned to).
order
group_inds
labels
sigma_in_secs: float, optional
standard deviation of gaussian kernel
ntt:
Number of time points to use for smooth curve
progress_bar:
figsize: tuple, optional
Returns
-------
matplotlib.Figure
"""
fig, axs = plt.subplots(2, len(start_labels), figsize=figsize, sharex=True)
clean_axes(axs.ravel())
ax1_ylims = []
for i_s, start_label in enumerate(start_labels):
if len(start_labels) > 1:
ax0 = axs[0, i_s]
ax1 = axs[1, i_s]
else:
ax0 = axs[0]
ax1 = axs[1]
data = align_by_time_intervals(
self.units,
index,
self.trials,
start_label,
start_label,
start,
end,
order,
progress_bar=progress_bar,
)
if i_s == len(start_labels) - 1:
show_legend = True
else:
show_legend = False
show_psth_raster(
data,
start,
end,
group_inds,
labels,
show_legend=show_legend,
ax=ax0,
progress_bar=progress_bar,
fontsize=12,
)
ax0.set_title(f"{start_label}")
ax0.set_xticks([])
ax0.set_xlabel("")
if i_s > 0:
ax0.set_ylabel("")
# Raster always show the same number of trials. We can avoid showing tick labels too
ax0.set_yticklabels([])
if plot_type == "gaussian":
self.bins_ft.layout.visibility = "hidden"
self.bins_ft.layout.height = "0px"
self.gaussian_sd_ft.layout.visibility = None
self.gaussian_sd_ft.layout.height = None
# expanded data so that gaussian smoother uses larger window than is viewed
expanded_data = align_by_time_intervals(
units=self.units,
index=index,
intervals=self.trials,
start_label=start_label,
stop_label=start_label,
start=start - sigma_in_secs * 4,
end=end + sigma_in_secs * 4,
rows_select=order,
progress_bar=progress_bar,
)
show_psth_smoothed(
data=expanded_data,
ax=ax1,
start=start - sigma_in_secs * 4,
end=end + sigma_in_secs * 4,
group_inds=group_inds,
sigma_in_secs=sigma_in_secs,
ntt=ntt,
)
elif plot_type == "histogram":
self.gaussian_sd_ft.layout.visibility = "hidden"
self.gaussian_sd_ft.layout.height = "0px"
self.bins_ft.layout.visibility = None
self.bins_ft.layout.height = None
show_histogram(data, ax1, start, end, group_inds, nbins=nbins)
else:
raise ValueError("unsupported plot type {}".format(self.psth_type_radio.value))
ax1.set_xlim([start, end])
if i_s == 0:
ax1.set_ylabel("firing rate (Hz)", fontsize=12)
ax1.set_xlabel("time (s)", fontsize=12)
ax1.axvline(color=align_line_color)
ax1_ylims.append(ax1.get_ylim())
if len(start_labels) > 1:
# Adjust bottom axes y axis
min_y = np.min(np.array(ax1_ylims)[:, 0])
max_y = np.max(np.array(ax1_ylims)[:, 1])
for i_b, ax_btm in enumerate(axs[1, :]):
ax_btm.set_ylim(min_y, max_y)
if i_b > 0:
ax_btm.set_ylabel("")
# After adjusting ylims we can avoid showing tick labels
ax_btm.set_yticklabels([])
fig.suptitle(f"Unit {self.unit_ids[index]}", fontsize=15)
fig.subplots_adjust(wspace=0.3)
return fig
[docs]def show_histogram(data, ax: plt.Axes, start: float, end: float, group_inds=None, nbins: int = 30):
if not len(data):
return
if group_inds is None:
height, x = np.histogram(np.hstack(data), bins=nbins, range=(start, end))
width = np.diff(x[:2])
height = height / len(data) / width
ax.bar(x[:-1], height, edgecolor=(0.3, 0.3, 0.3), width=width, align="edge")
else:
data = np.asarray(data, dtype="object")
# group_inds = np.asarray(group_inds)
for group in np.unique(group_inds):
this_data = np.hstack(data[group_inds == group])
height, x = np.histogram(this_data, bins=nbins, range=(start, end))
width = np.diff(x[:2])
height = height / np.sum(group_inds == group) / width
ax.bar(
x[:-1],
height,
color=color_wheel[group % len(color_wheel)],
edgecolor=(0.3, 0.3, 0.3),
width=width,
align="edge",
alpha=0.6,
)
[docs]def show_psth_smoothed(
data,
ax,
start: float,
end: float,
group_inds=None,
sigma_in_secs: float = 0.05,
ntt: int = 1000,
):
if not len(data): # TODO: when does this occur?
return
all_data = np.hstack(data)
if not len(all_data): # no spikes
return
tt = np.linspace(start, end, ntt)
smoothed = np.array([compute_smoothed_firing_rate(x, tt, sigma_in_secs) for x in data])
if group_inds is None:
group_inds = np.zeros((len(smoothed)), dtype=int)
group_stats = []
for group in np.unique(group_inds):
this_mean = np.mean(smoothed[group_inds == group], axis=0)
err = scipy.stats.sem(smoothed[group_inds == group], axis=0)
group_stats.append(
dict(
mean=this_mean,
lower=this_mean - 2 * err,
upper=this_mean + 2 * err,
group=group,
)
)
for stats in group_stats:
color = color_wheel[stats["group"] % len(color_wheel)]
ax.plot(tt, stats["mean"], color=color)
ax.fill_between(tt, stats["lower"], stats["upper"], alpha=0.2, color=color)
[docs]def plot_grouped_events(
data,
window,
group_inds=None,
colors=color_wheel,
ax=None,
labels=None,
show_legend=True,
offset=0,
unobserved_intervals_list=None,
progress_bar=None,
figsize=(8, 6),
fontsize=12,
):
"""
Parameters
----------
data: array-like
window: array-like [float, float]
Time in seconds
group_inds: array-like dtype=int, optional
colors: array-like, optional
ax: plt.Axes, optional
labels: array-like dtype=str, optional
show_legend: bool, optional
offset: number, optional
unobserved_intervals_list: array-like, optional
progress_bar: FloatProgress, optional
figsize: tuple, optional
fontsize: int, optional
Returns
-------
"""
data = np.asarray(data, dtype="object")
legend_kwargs = dict()
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
if hasattr(fig, "canvas"):
fig.canvas.header_visible = False
else:
legend_kwargs.update(bbox_to_anchor=(1.01, 1))
if group_inds is not None:
ugroup_inds = np.unique(group_inds)
handles = []
if progress_bar is not None:
this_iter = ProgressBar(
enumerate(ugroup_inds),
desc="plotting spikes",
leave=False,
total=len(ugroup_inds),
)
progress_bar = this_iter.container
else:
this_iter = enumerate(ugroup_inds)
for i, ui in this_iter:
color = colors[ugroup_inds[i] % len(colors)]
lineoffsets = np.where(group_inds == ui)[0] + offset
event_collection = ax.eventplot(
data[group_inds == ui],
orientation="horizontal",
lineoffsets=lineoffsets,
color=color,
)
handles.append(event_collection[0])
if show_legend:
ax.legend(
handles=handles[::-1],
labels=list(labels[ugroup_inds][::-1]),
loc="upper left",
bbox_to_anchor=(1.01, 1),
**legend_kwargs,
)
else:
ax.eventplot(
data,
orientation="horizontal",
color="k",
lineoffsets=np.arange(len(data)) + offset,
)
if unobserved_intervals_list is not None:
plot_unobserved_intervals(unobserved_intervals_list, ax, offset=offset)
ax.set_xlim(window)
ax.set_xlabel("time (s)", fontsize=fontsize)
ax.set_ylim(np.array([-0.5, len(data) - 0.5]) + offset)
if len(data) <= 30:
ax.set_yticks(range(offset, len(data) + offset))
ax.set_yticklabels(range(offset, len(data) + offset))
return ax
[docs]def plot_unobserved_intervals(unobserved_intervals_list, ax, offset=0, color=(0.85, 0.85, 0.85)):
for irow, unobs_intervals in enumerate(unobserved_intervals_list):
rects = [
Rectangle((i_interval[0], irow - 0.5 + offset), i_interval[1] - i_interval[0], 1)
for i_interval in unobs_intervals
]
pc = PatchCollection(rects, color=color)
ax.add_collection(pc)
[docs]def show_psth_raster(
data,
start=-0.5,
end=2.0,
group_inds=None,
labels=None,
ax=None,
show_legend=True,
align_line_color=(0.7, 0.7, 0.7),
progress_bar: FloatProgress = None,
fontsize=12,
) -> plt.Axes:
"""
Parameters
----------
data: array-like
start: float
Start time for calculation before or after (negative or positive) the reference point (aligned to).
end: float
End time for calculation before or after (negative or positive) the reference point (aligned to).
group_inds: array-like, optional
labels: array-like, optional
ax: plt.Axes, optional
show_legend: bool, optional
align_line_color: array-like, optional
[R, G, B] (0-1)
Default = [0.7, 0.7, 0.7]
progress_bar: FloatProgress, optional
fontsize: int, optional
Returns
-------
plt.Axes
"""
if not len(data):
return ax
ax = plot_grouped_events(
data,
[start, end],
group_inds,
color_wheel,
ax,
labels,
show_legend=show_legend,
progress_bar=progress_bar,
fontsize=fontsize,
)
ax.set_ylabel("trials", fontsize=fontsize)
ax.axvline(color=align_line_color)
return ax
[docs]def raster_grid(
units: Units,
time_intervals: TimeIntervals,
index,
start,
end,
rows_label=None,
cols_label=None,
trials_select=None,
align_by="start_time",
) -> plt.Figure:
"""
Parameters
----------
units: pynwb.misc.Units
time_intervals: pynwb.epoch.TimeIntervals
index: int
start: float
Start time for calculation before or after (negative or positive) the reference point (aligned to).
end: float
End time for calculation before or after (negative or positive) the reference point (aligned to).
rows_label: str, optional
cols_label: str, optional
trials_select: np.array(dtype=bool), optional
align_by: str, optional
Returns
-------
plt.Figure
"""
if time_intervals is None:
raise ValueError("trials must exist (trials cannot be None)")
if trials_select is None:
trials_select = np.ones((len(time_intervals),)).astype("bool")
if rows_label is not None:
row_vals = np.array(time_intervals[rows_label][:])
urow_vals = np.unique(row_vals[trials_select])
if urow_vals.dtype == np.float64:
urow_vals = urow_vals[~np.isnan(urow_vals)]
else:
urow_vals = [None]
nrows = len(urow_vals)
if cols_label is not None:
col_vals = np.array(time_intervals[cols_label][:])
ucol_vals = np.unique(col_vals[trials_select])
if ucol_vals.dtype == np.float64:
ucol_vals = ucol_vals[~np.isnan(ucol_vals)]
else:
ucol_vals = [None]
ncols = len(ucol_vals)
fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, squeeze=False, figsize=(10, 10))
big_ax = create_big_ax(fig)
for i, row in enumerate(urow_vals):
for j, col in enumerate(ucol_vals):
ax = axs[i, j]
ax_trials_select = trials_select.copy()
if row is not None:
ax_trials_select &= row_vals == row
if col is not None:
ax_trials_select &= col_vals == col
ax_trials_select = np.where(ax_trials_select)[0]
if len(ax_trials_select):
data = align_by_time_intervals(
units,
index,
time_intervals,
align_by,
align_by,
start,
end,
ax_trials_select,
)
show_psth_raster(data, start, end, ax=ax)
ax.set_xlabel("")
ax.set_ylabel("")
if ax.get_subplotspec().is_first_col():
ax.set_ylabel(row)
if ax.get_subplotspec().is_last_row():
ax.set_xlabel(col)
big_ax.set_xlabel(cols_label, labelpad=50)
big_ax.set_ylabel(rows_label, labelpad=60)
return fig
[docs]def plot_grouped_events_plotly(
data,
window=None,
group_inds=None,
colors=color_wheel,
labels=None,
show_legend=True,
unobserved_intervals_list=None,
progress_bar=None,
fig=None,
**kwargs,
):
data = np.array(data, dtype=object)
if fig is None:
fig = go.FigureWidget()
if group_inds is not None:
ugroup_inds = np.unique(group_inds)
offset = 0
for i in np.arange(len(ugroup_inds)):
ui = ugroup_inds[i]
color = colors[ugroup_inds[i] % len(colors)]
this_data = data[group_inds == ui]
event_group(
this_data,
offset=offset,
label=labels[ui],
color=color,
fig=fig,
**kwargs,
)
offset += len(this_data)
else:
event_group(data, fig=fig, **kwargs)
fig.update_layout(xaxis_title="time (s)")
return fig
[docs]class RasterWidgetPlotly(widgets.HBox):
def __init__(
self,
units: Units,
foreign_time_window_controller: StartAndDurationController = None,
foreign_group_and_sort_controller: GroupAndSortController = None,
group_by=None,
fig: go.FigureWidget = None,
):
super().__init__()
self.units = units
if foreign_time_window_controller is None:
self.tmin = get_min_spike_time(units)
self.tmax = get_max_spike_time(units)
self.time_window_controller = StartAndDurationController(tmin=self.tmin, tmax=self.tmax)
else:
self.time_window_controller = foreign_time_window_controller
if foreign_group_and_sort_controller:
self.gas = foreign_group_and_sort_controller
else:
self.gas = GroupAndSortController(dynamic_table=self.units, group_by=group_by)
self.show_legend_cb = widgets.Checkbox(value=True, description="show legend")
if fig is None:
self.fig = go.FigureWidget()
self.fig.update_layout(margin=dict(l=20, r=20, t=30, b=20))
else:
self.fig = fig
show_session_raster_plotly(self.units, self.fig, self.time_window_controller.value, **self.gas.value)
# set children
if foreign_time_window_controller:
right_panel = widgets.VBox(
children=[
self.progress_bar,
self.fig,
],
layout=Layout(width="100%"),
)
else:
right_panel = widgets.VBox(
children=[self.time_window_controller, self.fig, self.show_legend_cb],
layout=Layout(width="100%"),
)
if foreign_group_and_sort_controller:
self.children = [right_panel]
else:
self.children = [self.gas, right_panel]
self.layout = Layout(width="100%")
self.time_window_controller.observe(self.update_fig, "value")
self.gas.observe(self.update_fig, "value")
self.show_legend_cb.observe(self.toggle_legend, "value")
[docs] def toggle_legend(self, change):
self.fig.update_layout(showlegend=self.show_legend_cb.value)
[docs] def update_fig(self, change):
time_window = self.time_window_controller.value
gas_kwargs = self.gas.value
with self.fig.batch_update():
self.fig.data = None
show_session_raster_plotly(self.units, self.fig, time_window, **gas_kwargs)
[docs]def show_session_raster_plotly(units: Units, fig, time_window=None, order=None, progress_bar=None, **kwargs):
"""
Parameters
----------
units: pynwb.misc.Units
time_window: [int, int]
show_obs_intervals: bool
order: array-like, optional
group_inds: array-like, optional
labels: array-like, optional
show_legend: bool
default = True
Does not show legend if color_by is None or 'id'.
progress_bar: FloatProgress, optional
Returns
-------
go.FigureWidget
"""
if time_window is None:
time_window = [get_min_spike_time(units), get_max_spike_time(units)]
if order is None:
order = np.arange(len(units), dtype="int")
if progress_bar:
this_iter = ProgressBar(order, desc="reading spike data", leave=False)
progress_bar = this_iter.container
else:
this_iter = order
data = []
for unit in this_iter:
data.append(get_spike_times(units, unit, time_window))
# if show_obs_intervals:
# unobserved_intervals_list = get_unobserved_intervals(units, time_window, order)
# else:
# unobserved_intervals_list = None
fig.update_yaxes(tickvals=[], ticktext=[])
if len(order) <= 100:
kwargs.update(marker="line-ns", line_width=2)
else:
kwargs.update(line_width=1)
fig = plot_grouped_events_plotly(data=data, fig=fig, **kwargs)
if len(order) <= 40:
fig.update_yaxes(tickvals=np.arange(len(order)), ticktext=[str(i) for i in order])
fig.update_layout(
title="units",
xaxis_title="time (s)",
legend=dict(x=1.0, y=0, traceorder="reversed"),
xaxis=dict(range=time_window),
yaxis=dict(range=[-0.5, len(order) + 0.5]),
)
return fig
[docs]class UnitsAndTrialsControllerWidget(widgets.VBox):
InnerWidget = None
def __init__(self, units: Units, trials: TimeIntervals = None, unit_index=0, **kwargs):
"""
Creates a UnitsAndTrials controller that controls InnerWidget.
Parameters
----------
units: pynwb.misc.Units object
trials: pynwb.epoch.TimeIntervals object
unit_index: int
"""
super().__init__()
self.units = units
self.kwargs = kwargs
# Check if there is trials table and create controller
if trials is None:
self.trials = self.get_trials()
if self.trials is None:
self.children = [widgets.HTML("No trials present")]
return
else:
self.trials = trials
# Create variables choice dropdowns
groups = self.get_groups(self.trials)
self.rows_controller = widgets.Dropdown(options=[None] + list(groups), description="rows", value=None)
self.rows_controller.observe(self.rows_callback, names="value")
self.cols_controller = widgets.Dropdown(
options=[None] + list(groups),
description="cols",
disabled=True,
)
# Unit controller
unit_ids = self.units.id.data[:]
n_units = len(unit_ids)
self.unit_controller = widgets.Dropdown(
options=[(str(unit_ids[x]), x) for x in range(n_units)],
value=unit_index,
description="unit",
)
# Trial event controller (align by)
self.trial_event_controller = make_trial_event_controller(self.trials)
# Start / End controllers
self.start_ft = widgets.FloatText(
-0.5,
step=0.1,
description="start (s)",
layout=Layout(width="200px"),
description_tooltip="Start time for calculation before or after (negative or positive) the reference point (aligned to)",
)
self.end_ft = widgets.FloatText(
1.0,
step=0.1,
description="end (s)",
layout=Layout(width="200px"),
description_tooltip="End time for calculation before or after (negative or positive) the reference point (aligned to).",
)
self.fixed = dict(
units=self.units,
time_intervals=self.trials,
)
self.controls = {
"index": self.unit_controller,
"start": self.start_ft,
"end": self.end_ft,
"align_by": self.trial_event_controller,
"rows_label": self.rows_controller,
"cols_label": self.cols_controller,
}
self.children = [
self.unit_controller,
self.rows_controller,
self.cols_controller,
self.trial_event_controller,
self.start_ft,
self.end_ft,
]
[docs] def rows_callback(self, change):
"""
Gets triggered when self.rows_controller changes. Updates other dropdown options.
"""
if change["new"] is None:
self.cols_controller.disabled = True
self.cols_controller.value = None
else:
self.cols_controller.disabled = False
[docs]class RasterGridWidget(widgets.VBox):
def __init__(
self,
units: Units,
trials: TimeIntervals = None,
unit_index=0,
units_trials_controller=None,
):
super().__init__()
# Create Units and Trials controller
if not units_trials_controller:
units_trials_controller = UnitsAndTrialsControllerWidget(units=units, trials=trials, unit_index=unit_index)
self.children = [units_trials_controller]
self.fig = interactive_output(
f=raster_grid, controls=units_trials_controller.controls, fixed=units_trials_controller.fixed
)
self.children += tuple([self.fig])
[docs]class TuningCurveWidget(widgets.VBox):
def __init__(
self,
units: Units,
trials: TimeIntervals = None,
unit_index=0,
units_trials_controller=None,
):
super().__init__()
self.children = []
# Create Units and Trials controller
if not units_trials_controller:
units_trials_controller = UnitsAndTrialsControllerWidget(units=units, trials=trials, unit_index=unit_index)
self.children = [units_trials_controller]
self.fig = interactive_output(
f=draw_tuning_curve, controls=units_trials_controller.controls, fixed=units_trials_controller.fixed
)
self.children += tuple([self.fig])
[docs]class TuningCurveExtendedWidget(widgets.VBox):
def __init__(self, units: Units, trials: TimeIntervals = None, unit_index=0):
super().__init__()
# Controller
self.units_trials_controller = UnitsAndTrialsControllerWidget(units=units, trials=trials, unit_index=unit_index)
# Tuning curve widget
self.tuning_curve = TuningCurveWidget(
units=units,
trials=trials,
unit_index=unit_index,
units_trials_controller=self.units_trials_controller,
)
# Raster grid widget
self.raster_grid = RasterGridWidget(
units=units,
trials=trials,
unit_index=unit_index,
units_trials_controller=self.units_trials_controller,
)
self.children = [self.units_trials_controller, self.tuning_curve, self.raster_grid]
[docs]def draw_tuning_curve(
units: Units,
time_intervals: TimeIntervals,
index,
start,
end,
rows_label=None,
cols_label=None,
align_by="start_time",
) -> plt.Figure:
if rows_label is None:
return widgets.HTML("Select at least one variable")
# 1D histogram
if cols_label is None:
return draw_tuning_curve_1d(units, time_intervals, index, start, end, rows_label, align_by)
return draw_tuning_curve_2d(units, time_intervals, index, start, end, rows_label, cols_label, align_by)
[docs]def draw_tuning_curve_1d(
units: Units,
time_intervals: TimeIntervals,
index,
start,
end,
rows_label=None,
align_by="start_time",
) -> plt.Figure:
rows_data, var1_classes = extract_data_from_intervals(time_intervals[rows_label])
avg_rates = []
for v1 in var1_classes:
indexes = [i for i, d in enumerate(rows_data) if d == v1]
data = align_by_time_intervals(
units=units,
index=index,
intervals=time_intervals,
start_label=align_by,
stop_label=align_by,
start=start,
end=end,
rows_select=indexes,
)
n_trials = len(data)
n_spikes = len(np.hstack(data))
duration = end - start
avg_rates.append(n_spikes / (n_trials * duration))
x = np.arange(len(var1_classes)) # the label locations
si = sort_mixed_type_list(var1_classes)
avg_rates_sorted = [avg_rates[i] for i in si]
fig, ax = plt.subplots(figsize=(14, 7))
width = 0.95 # the width of the bars
rects1 = ax.bar(x, avg_rates_sorted, width)
lines1 = ax.plot(x, avg_rates_sorted, "-o", color="k", lw=2)
# Labels
ax.set_ylabel("Avg rate")
ax.set_xlabel(rows_label)
ax.set_xticks(x)
ax.set_xticklabels([var1_classes[i] for i in si], rotation=45)
fig.tight_layout()
[docs]def draw_tuning_curve_2d(
units: Units,
time_intervals: TimeIntervals,
index,
start,
end,
rows_label=None,
cols_label=None,
align_by="start_time",
) -> plt.Figure:
rows_data, var1_classes = extract_data_from_intervals(time_intervals[rows_label])
cols_data, var2_classes = extract_data_from_intervals(time_intervals[cols_label])
avg_rates = np.zeros((len(var1_classes), len(var2_classes)))
for i, v1 in enumerate(var1_classes):
for j, v2 in enumerate(var2_classes):
indexes1 = [ii for ii, d in enumerate(rows_data) if d == v1]
indexes2 = [ii for ii, d in enumerate(cols_data) if d == v2]
intersect = list(set(indexes1) & set(indexes2))
if len(intersect) > 0:
data = align_by_time_intervals(
units=units,
index=index,
intervals=time_intervals,
start_label=align_by,
stop_label=align_by,
start=start,
end=end,
rows_select=intersect,
)
n_trials = len(data)
n_spikes = len(np.hstack(data))
duration = end - start
avg_rates[i, j] = n_spikes / (n_trials * duration)
fig, ax = plt.subplots(figsize=(14, 7))
pos = ax.imshow(avg_rates.T, origin="lower", cmap="Greys")
cbar = fig.colorbar(pos, ax=ax)
cbar.set_label("spikes / second")
# Labels
ax.set_xticks(np.arange(len(var1_classes)))
ax.set_yticks(np.arange(len(var2_classes)))
ax.set_xlabel(rows_label)
ax.set_ylabel(cols_label)
ax.set_xticklabels(var1_classes, rotation=45)
ax.set_yticklabels(var2_classes, rotation=45)
return fig
[docs]def sort_mixed_type_list(x):
"""Returns the indexes for a sorted list of mixed types"""
x_num = list()
x_num_i = list()
x_oth = list()
x_oth_i = list()
for i, xx in enumerate(x):
try:
x_num.append(float(xx))
x_num_i.append(i)
except:
x_oth.append(str(xx))
x_oth_i.append(i)
x_num_si = np.argsort(x_num)
x_oth_si = np.argsort(x_oth)
return [x_num_i[ii] for ii in x_num_si] + [x_oth_i[ii] for ii in x_oth_si]