import functools
from bisect import bisect
from abc import abstractmethod
import scipy
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets, fixed, Layout
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.colors import DEFAULT_PLOTLY_COLORS
from pynwb import TimeSeries
from pynwb.epoch import TimeIntervals
from .controllers import (
StartAndDurationController,
GroupAndSortController,
RangeController,
ProgressBar
)
from .utils.plotly import multi_trace
from .utils.timeseries import (
get_timeseries_tt,
get_timeseries_maxt,
get_timeseries_mint,
timeseries_time_to_ind,
get_timeseries_in_units,
)
from .utils.widgets import interactive_output, set_plotly_callbacks
from .controllers.misc import make_trial_event_controller
color_wheel = plt.rcParams["axes.prop_cycle"].by_key()["color"]
[docs]def show_ts_fields(node):
info = []
for key in ("description", "unit", "resolution", "conversion"):
info.append(
widgets.Text(value=repr(getattr(node, key)), description=key, disabled=True)
)
return widgets.VBox(info)
[docs]def show_timeseries_mpl(
time_series: TimeSeries,
time_window=None,
ax=None,
zero_start=False,
xlabel=None,
ylabel=None,
title=None,
figsize=None,
**kwargs
):
"""
Parameters
----------
time_series: TimeSeries
time_window: [int int]
ax: plt.Axes
zero_start: bool
xlabel: str
ylabel: str
title: str
figsize: tuple, optional
kwargs
Returns
-------
"""
if time_window is not None:
istart = timeseries_time_to_ind(time_series, time_window[0])
istop = timeseries_time_to_ind(time_series, time_window[1])
else:
istart = 0
istop = None
return show_indexed_timeseries_mpl(
time_series,
istart=istart,
istop=istop,
ax=ax,
zero_start=zero_start,
xlabel=xlabel,
ylabel=ylabel,
title=title,
figsize=figsize,
**kwargs,
)
[docs]def show_indexed_timeseries_mpl(
node: TimeSeries,
istart=0,
istop=None,
ax=None,
zero_start=False,
xlabel="time (s)",
ylabel=None,
title=None,
figsize=None,
neurodata_vis_spec=None,
**kwargs
):
if ylabel is None and node.unit:
ylabel = node.unit
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
tt = get_timeseries_tt(node, istart=istart, istop=istop)
if zero_start:
tt = tt - tt[0]
data, unit = get_timeseries_in_units(node, istart=istart, istop=istop)
ax.plot(tt, data, **kwargs)
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if title is not None:
ax.set_title(title)
ax.autoscale(enable=True, axis="x", tight=True)
return ax
[docs]def show_indexed_timeseries_plotly(
timeseries: TimeSeries,
istart: int = 0,
istop: int = None,
time_window: list = None,
trace_range: list = None,
offsets=None,
fig: go.FigureWidget = None,
col=None,
row=None,
zero_start=False,
scatter_kwargs: dict = None,
figure_kwargs: dict = None,
):
if istart != 0 or istop is not None:
if time_window is not None:
raise ValueError("input either time window or istart/stop but not both")
if not (
0 <= istart < timeseries.data.shape[0]
and (istop is None or 0 < istop <= timeseries.data.shape[0])
):
raise ValueError("enter correct istart/stop values")
t_istart = istart
t_istop = istop
elif time_window is not None:
t_istart = timeseries_time_to_ind(timeseries, time_window[0])
t_istop = timeseries_time_to_ind(timeseries, time_window[1])
else:
t_istart = istart
t_istop = istop
tt = get_timeseries_tt(timeseries, istart=t_istart, istop=t_istop)
data, unit = get_timeseries_in_units(timeseries, istart=t_istart, istop=t_istop)
if len(data.shape) == 1:
data = data[:, np.newaxis]
if trace_range is not None:
if not (
0 <= trace_range[0] < data.shape[1] and 0 < trace_range[1] <= data.shape[1]
):
raise ValueError("enter correct trace range")
trace_istart = trace_range[0]
trace_istop = trace_range[1]
else:
trace_istart = 0
trace_istop = data.shape[1]
if offsets is None:
offsets = np.zeros(trace_istop - trace_istart)
if zero_start:
tt = tt - tt[0]
scatter_kwargs = dict() if scatter_kwargs is None else scatter_kwargs
if fig is None:
fig = go.FigureWidget(make_subplots(rows=1, cols=1))
row = 1 if row is None else row
col = 1 if col is None else col
for i, trace_id in enumerate(range(trace_istart, trace_istop)):
fig.add_trace(
go.Scattergl(
x=tt, y=data[:, trace_id] + offsets[i], mode="lines", **scatter_kwargs
),
row=row,
col=col,
)
input_figure_kwargs = dict(
xaxis=dict(title_text="time (s)", range=[tt[0], tt[-1]]),
yaxis=dict(title_text=unit if unit is not None else None),
title=timeseries.name,
)
if figure_kwargs is None:
figure_kwargs = dict()
input_figure_kwargs.update(figure_kwargs)
fig.update_xaxes(input_figure_kwargs.pop("xaxis"), row=row, col=col)
fig.update_yaxes(input_figure_kwargs.pop("yaxis"), row=row, col=col)
fig.update_layout(**input_figure_kwargs)
return fig
[docs]def plot_traces(
timeseries: TimeSeries,
time_window=None,
trace_window=None,
title: str = None,
ylabel: str = "traces",
**kwargs
):
"""
Parameters
----------
timeseries: TimeSeries
time_window: [float, float], optional
Start time and end time in seconds.
trace_window: [int int], optional
Index range of traces to view
title: str, optional
ylabel: str, optional
Returns
-------
"""
if time_window is None:
t_ind_start = 0
t_ind_stop = None
else:
t_ind_start = timeseries_time_to_ind(timeseries, time_window[0])
t_ind_stop = timeseries_time_to_ind(timeseries, time_window[1])
if trace_window is None:
trace_window = [0, timeseries.data.shape[1]]
tt = get_timeseries_tt(timeseries, t_ind_start, t_ind_stop)
if timeseries.data.shape[1] == len(tt): # fix of orientation is incorrect
mini_data = timeseries.data[
trace_window[0] : trace_window[1], t_ind_start:t_ind_stop
].T
else:
mini_data = timeseries.data[
t_ind_start:t_ind_stop, trace_window[0] : trace_window[1]
]
gap = np.median(np.nanstd(mini_data, axis=0)) * 20
offsets = np.arange(trace_window[1] - trace_window[0]) * gap
fig, ax = plt.subplots()
ax.figure.set_size_inches(12, 6)
ax.plot(tt, mini_data + offsets, **kwargs)
ax.set_xlabel("time (s)")
if np.isfinite(gap):
ax.set_ylim(-gap, offsets[-1] + gap)
ax.set_xlim(tt[0], tt[-1])
ax.set_yticks(offsets)
ax.set_yticklabels(np.arange(trace_window[0], trace_window[1]))
if title is not None:
ax.set_title(title)
if ylabel is not None:
ax.set_ylabel(ylabel)
return fig
[docs]def show_timeseries(node, **kwargs):
if len(node.data.shape) == 1:
return SingleTracePlotlyWidget(node, **kwargs)
elif len(node.data.shape) == 2:
return BaseGroupedTraceWidget(node, **kwargs)
else:
raise ValueError(
"Visualization for TimeSeries that has data with shape {} not implemented".format(
node.data.shape
)
)
def _prep_timeseries(time_series: TimeSeries, time_window=None, order=None):
"""Pull dataset region from entire dataset. Return tt and offests used for plotting
Parameters
----------
time_series: TimeSeries
time_window
order
Returns
-------
"""
if time_window is None:
t_ind_start = 0
t_ind_stop = None
else:
t_ind_start = timeseries_time_to_ind(time_series, time_window[0])
t_ind_stop = timeseries_time_to_ind(time_series, time_window[1])
tt = get_timeseries_tt(time_series, t_ind_start, t_ind_stop)
unique_sorted_order, inverse_sort = np.unique(order, return_inverse=True)
if len(time_series.data.shape) > 1:
mini_data = time_series.data[t_ind_start:t_ind_stop, unique_sorted_order][
:, inverse_sort
]
if np.all(np.isnan(mini_data)):
return None, tt, None
gap = np.median(np.nanstd(mini_data, axis=0)) * 20
offsets = np.arange(len(order)) * gap
mini_data = mini_data + offsets
else:
mini_data = time_series.data[t_ind_start:t_ind_stop]
offsets = [0]
return mini_data, tt, offsets
[docs]def plot_grouped_traces(
time_series: TimeSeries,
time_window=None,
order=None,
ax=None,
figsize=(8, 7),
group_inds=None,
labels=None,
colors=color_wheel,
show_legend=True,
dynamic_table_region_name=None,
window=None,
**kwargs
):
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
if order is None:
if len(time_series.data.shape) > 1:
order = np.arange(time_series.data.shape[1])
else:
order = [0]
if group_inds is not None:
row_ids = getattr(time_series, dynamic_table_region_name).data[:]
channel_inds = [np.argmax(row_ids == x) for x in order]
elif window is not None:
order = order[window[0] : window[1]]
channel_inds = order
else:
channel_inds = order
if len(channel_inds):
mini_data, tt, offsets = _prep_timeseries(time_series, time_window, channel_inds)
else:
mini_data = None
tt = time_window
if mini_data is None:
ax.plot(tt, np.ones_like(tt) * np.nan, color="k")
return
if group_inds is not None:
ugroup_inds = np.unique(group_inds)
handles = []
for i, ui in enumerate(ugroup_inds):
color = colors[ugroup_inds[i] % len(colors)]
lines_handle = ax.plot(tt, mini_data[:, group_inds == ui], color=color)
handles.append(lines_handle[0])
if show_legend:
ax.legend(
handles=handles[::-1],
labels=list(labels[ugroup_inds][::-1]),
loc="upper left",
bbox_to_anchor=(1.01, 1),
)
else:
ax.plot(tt, mini_data, color="k")
ax.set_xlim((tt[0], tt[-1]))
ax.set_xlabel("time (s)")
if len(offsets) > 1:
ax.set_ylim(
offsets[0] - (offsets[1] - offsets[0]) / 2,
offsets[-1] + (offsets[-1] - offsets[-2]) / 2,
)
if len(order) <= 30:
ax.set_yticks(offsets)
ax.set_yticklabels(order)
else:
ax.set_yticks([])
[docs]def plot_grouped_traces_plotly(
time_series: TimeSeries,
time_window,
order,
group_inds=None,
labels=None,
colors=color_wheel,
fig=None,
**kwargs
):
mini_data, tt, offsets = _prep_timeseries(time_series, time_window, order)
if fig is None:
fig = go.FigureWidget()
if group_inds is not None:
ugroup_inds = np.unique(group_inds)
for igroup, ui in enumerate(ugroup_inds[::-1]):
color = colors[ugroup_inds[::-1][igroup] % len(colors)]
group_data = mini_data[:, group_inds == ui].T
multi_trace(tt, group_data, color, labels[ui], fig=fig)
else:
multi_trace(tt, mini_data.T, "black", fig=fig)
fig.update_layout(title=time_series.name, xaxis_title="time (s)")
fig.update_yaxes(tickvals=list(offsets), ticktext=[str(i) for i in order])
return fig
# self.layout = widgets.Layout(width="100%")
[docs]class AlignMultiTraceTimeSeriesByTrialsAbstract(widgets.VBox):
def __init__(
self,
time_series: TimeSeries,
trials: TimeIntervals = None,
trace_index=0,
trace_controller=None,
trace_controller_kwargs=None,
sem=True,
):
self.time_series = time_series
self.time_series_data = time_series.data[()]
self.time_series_timestamps = None
if time_series.rate is None:
self.time_series_timestamps = time_series.timestamps[()]
super().__init__()
if trials is None:
self.trials = self.get_trials()
if self.trials is None:
self.children = [widgets.HTML("No trials present")]
return
else:
self.trials = trials
if trace_controller is None:
ntraces = self.time_series.data.shape[1]
input_trace_controller_kwargs = dict(
options=[x for x in range(ntraces)],
value=trace_index,
description="trace",
layout=Layout(width="200px"),
)
if trace_controller_kwargs is not None:
input_trace_controller_kwargs.update(trace_controller_kwargs)
self.trace_controller = widgets.Dropdown(**input_trace_controller_kwargs)
else:
self.trace_controller = trace_controller
self.trial_event_controller = make_trial_event_controller(
self.trials, layout=Layout(width="200px")
)
self.before_ft = widgets.FloatText(
0.5, min=0, description="before (s)", layout=Layout(width="200px")
)
self.after_ft = widgets.FloatText(
2.0, min=0, description="after (s)", layout=Layout(width="200px")
)
self.gas = self.make_group_and_sort(window=False, control_order=False)
self.align_to_zero_cb = widgets.Checkbox(description="align to zero")
self.controls = dict(
index=self.trace_controller,
after=self.after_ft,
before=self.before_ft,
start_label=self.trial_event_controller,
gas=self.gas,
align_to_zero=self.align_to_zero_cb,
)
vbox_cols = [
[self.gas, self.align_to_zero_cb],
[
self.trace_controller,
self.trial_event_controller,
self.before_ft,
self.after_ft,
],
]
if sem:
self.sem_cb = widgets.Checkbox(description="show SEM")
self.controls.update(sem=self.sem_cb)
vbox_cols[0].append(self.sem_cb)
out_fig = set_plotly_callbacks(self.update, self.controls)
self.children = [widgets.HBox([widgets.VBox(i) for i in vbox_cols]), out_fig]
[docs] def get_trials(self):
return self.time_series.get_ancestor("NWBFile").trials
[docs] def make_group_and_sort(
self, window=None, control_order=False, control_limit=False
):
return GroupAndSortController(
self.trials,
window=window,
control_order=control_order,
control_limit=control_limit,
)
[docs] def plot_group(self,group_inds, data_trialized, time_trialized, fig, order):
for group in np.unique(group_inds):
line_color = color_wheel[group%len(color_wheel)]
pb = ProgressBar(np.where(group_inds==group)[0],
desc=f"plotting {group} data", leave=False)
group_data = []
group_ts = []
for i,trim_trial_no in enumerate(pb):
trial_idx = order[trim_trial_no]
group_data.append(data_trialized[trial_idx])
group_ts.append(time_trialized[trial_idx])
fig = multi_trace(group_ts,group_data,fig=fig,color=line_color,label=str(group),insert_nans=True)
tt_flat = np.concatenate(time_trialized)
fig.update_layout(xaxis_title='time (s)',
yaxis_title=self.time_series.name,
xaxis_range=(np.min(tt_flat), np.max(tt_flat)))
return fig
[docs]class AlignMultiTraceTimeSeriesByTrialsConstant(
AlignMultiTraceTimeSeriesByTrialsAbstract
):
def __init__(
self,
time_series: TimeSeries,
trials: TimeIntervals = None,
trace_index=0,
trace_controller=None,
trace_controller_kwargs=None,
):
self.time_series = time_series
super().__init__(
time_series=time_series,
trials=trials,
trace_index=trace_index,
trace_controller=trace_controller,
trace_controller_kwargs=trace_controller_kwargs,
sem=True,
)
[docs] @functools.lru_cache()
def align_data(self, start_label, before, after, index=None):
starts = np.array(self.trials[start_label][:]) - before
out_data_aligned = []
out_ts_aligned = []
for start in starts:
idx_start = int((start - self.time_series.starting_time)*self.time_series.rate)
idx_stop = int(idx_start + (before+after)*self.time_series.rate)
out_ts_aligned.append(np.linspace(-before,after,num=idx_stop-idx_start))
if len(self.time_series_data.shape) > 1 and index is not None:
out_data_aligned.append(self.time_series_data[idx_start:idx_stop, index])
else:
out_data_aligned.append(self.time_series_data[idx_start:idx_stop])
return out_data_aligned, out_ts_aligned
[docs] def update(
self,
index: int,
start_label: str = "start_time",
before: float = 0.0,
after: float = 1.0,
order=None,
group_inds=None,
labels=None,
align_to_zero=False,
sem=False,
fig:go.FigureWidget = None,
):
data, time_ts_aligned = self.align_data(start_label, before, after, index)
if group_inds is None:
group_inds = np.zeros(len(self.trials), dtype=int)
if align_to_zero:
for trial_no in order:
data_zero_id = bisect(time_ts_aligned[trial_no], 0)
data[trial_no] -= data[trial_no][data_zero_id]
fig = go.FigureWidget() if fig is None else fig
fig.data = []
fig.layout = {}
if sem:
group_stats = []
for group in np.unique(group_inds):
this_mean = np.nanmean(data[group_inds == group, :], axis=0)
err = scipy.stats.sem(
data[group_inds == group, :], axis=0, nan_policy="omit"
)
group_stats.append(
dict(
mean=this_mean,
lower=this_mean - 2 * err,
upper=this_mean + 2 * err,
group=group,
)
)
for stats in group_stats:
plot_kwargs = dict()
color = color_wheel[stats["group"]]
if labels is not None:
plot_kwargs.update(text=labels[stats["group"]])
fig.add_scattergl(x=time_ts_aligned[0], y=stats["lower"], line_color=color)
fig.add_scattergl(x=time_ts_aligned[0], y=stats["upper"], line_color=color, fill='tonexty', opacity=0.2)
fig.add_scattergl(x=time_ts_aligned[0], y=stats["mean"], line_color=color, **plot_kwargs)
else:
fig = self.plot_group(group_inds,data,time_ts_aligned,fig,order)
return fig
[docs]class AlignMultiTraceTimeSeriesByTrialsVariable(
AlignMultiTraceTimeSeriesByTrialsAbstract
):
def __init__(
self,
time_series: TimeSeries,
trials: TimeIntervals = None,
trace_index=0,
trace_controller=None,
trace_controller_kwargs=None,
):
self.time_series = time_series
super().__init__(
time_series=time_series,
trials=trials,
trace_index=trace_index,
trace_controller=trace_controller,
trace_controller_kwargs=trace_controller_kwargs,
sem=False,
)
[docs] @functools.lru_cache()
def align_data(self, start_label, before, after, index=None):
starts = np.array(self.trials[start_label][:]) - before
out_data_aligned = []
out_ts_aligned = []
for start in starts:
idx_start = bisect(self.time_series_timestamps, start-before)
idx_stop = bisect(self.time_series_timestamps, start + after, lo=idx_start)
out_ts_aligned.append(self.time_series_timestamps[idx_start:idx_stop]-
self.time_series_timestamps[idx_start]-before)
if len(self.time_series_data.shape) > 1 and index is not None:
out_data_aligned.append(self.time_series_data[idx_start:idx_stop, index])
else:
out_data_aligned.append(self.time_series_data[idx_start:idx_stop])
return out_data_aligned, out_ts_aligned
[docs] def update(
self,
index: int,
start_label: str = "start_time",
before: float = 0.0,
after: float = 1.0,
order=None,
group_inds=None,
labels=None,
align_to_zero=False,
fig:go.FigureWidget = None,
):
data,time_ts_aligned = self.align_data(start_label,before,after,index)
if group_inds is None:
group_inds = np.zeros(len(self.trials), dtype=int)
if align_to_zero:
for trial_no in order:
data_zero_id = bisect(time_ts_aligned[trial_no], 0)
data[trial_no] -= data[trial_no][data_zero_id]
fig = fig if fig is not None else go.FigureWidget()
fig.data = []
fig.layout = {}
return self.plot_group(group_inds,data,time_ts_aligned,fig,order)