from bisect import bisect_right, bisect_left
import numpy as np
from pynwb.misc import Units
[docs]def get_spike_times(units: Units, index, in_interval):
"""Use bisect methods to efficiently retrieve spikes from a given unit in a given interval
Parameters
----------
units: pynwb.misc.Units
index: int
in_interval: start and stop times
Returns
-------
"""
st = units["spike_times"]
unit_start = 0 if index == 0 else st.data[index - 1]
unit_stop = st.data[index]
start_time, stop_time = in_interval
ind_start = bisect_left(st.target, start_time, unit_start, unit_stop)
ind_stop = bisect_right(st.target, stop_time, ind_start, unit_stop)
return np.asarray(st.target[ind_start:ind_stop])
[docs]def get_min_spike_time(units: Units):
"""Efficiently retrieve the first spike time across all units
Parameters
----------
units: pynwb.misc.Units
Returns
-------
"""
st = units["spike_times"]
inds = [0] + list(st.data[:-1])
first_spikes = [st.target.data[i] for i in inds]
return np.min(first_spikes)
[docs]def get_max_spike_time(units: Units):
"""Efficiently retrieve the last spike time across all units
Parameters
----------
units: pynwb.misc.Units
Returns
-------
"""
st = units["spike_times"]
inds = [x - 1 for x in st.data[:]]
last_spikes = [st.target.data[i] for i in inds]
return np.max(last_spikes)
[docs]def align_by_times(units: Units, index, starts, stops):
"""
Args:
units: pynwb.misc.Units
index: int
starts: array-like
stops: array-like
Returns:
np.array
"""
st = units["spike_times"]
unit_spike_data = st[index]
istarts = np.searchsorted(unit_spike_data, starts)
istops = np.searchsorted(unit_spike_data, stops)
for start, istart, istop in zip(starts, istarts, istops):
yield unit_spike_data[istart:istop] - start
[docs]def align_by_trials(
units: Units,
index,
start_label="start_time",
stop_label=None,
start=-0.5,
end=1.0,
):
"""
Args:
units
start_label: str
default: 'start_time'
stop_label: str
default: None (just align to start_time)
start: float
Start time for calculation before or after (negative or positive) the reference point (aligned to).
end: float
End time for calculation before or after (negative or positive) the reference point (aligned to).
Returns:
np.array(shape=(n_trials, n_time, ...))
"""
trials = units.get_ancestor("NWBFile").trials
return align_by_time_intervals(
units, index, trials, start_label, stop_label, start, end
)
[docs]def align_by_time_intervals(
units: Units,
index,
intervals,
start_label="start_time",
stop_label="stop_time",
start=0.0,
end=0.0,
rows_select=(),
progress_bar=None,
):
"""
Args:
units: time-aware neurodata_type
index: int
intervals: pynwb.epoch.TimeIntervals
start_label: str
default: 'start_time'
stop_label: str
default: 'stop_time'
start: float
Start time for calculation before or after (negative or positive) the reference point (aligned to).
end: float
End time for calculation before or after (negative or positive) the reference point (aligned to).
rows_select: array_like, optional
sub-selects specific rows
progress_bar: FloatProgress, optional
Proved a progress bar object to have this method automatically update the progress bar
Returns:
np.array(shape=(n_trials, n_time, ...))
"""
if stop_label is None:
stop_label = start_label
starts = np.squeeze(np.array(intervals[start_label][:])[rows_select] + start)
stops = np.squeeze(np.array(intervals[stop_label][:])[rows_select] + end)
if progress_bar is not None:
progress_bar.value = 0
progress_bar.description = "reading spike data"
out = []
for i, x in enumerate(align_by_times(units, index, starts, stops)):
out.append(x + start)
if progress_bar is not None:
progress_bar.value = i / len(units)
return out
[docs]def get_unobserved_intervals(units, time_window, units_select=()):
if "obs_intervals" not in units:
return []
# add observation intervals
unobserved_intervals_list = []
for i_unit in units_select:
intervals = units["obs_intervals"][i_unit] # TODO: use bisect here
intervals = np.array(intervals, dtype="object")
these_obs_intervals = intervals[
(intervals[:, 1] > time_window[0]) & (intervals[:, 0] < time_window[1])
]
unobs_intervals = np.c_[these_obs_intervals[:-1, 1], these_obs_intervals[1:, 0]]
if len(these_obs_intervals):
# handle unobserved interval on lower bound of window
if these_obs_intervals[0, 0] > time_window[0]:
unobs_intervals = np.vstack(
([time_window[0], these_obs_intervals[0, 0]], unobs_intervals)
)
# handle unobserved interval on lower bound of window
if these_obs_intervals[-1, 1] < time_window[1]:
unobs_intervals = np.vstack(
(unobs_intervals, [these_obs_intervals[-1, 1], time_window[1]])
)
else:
unobs_intervals = [time_window]
unobserved_intervals_list.append(unobs_intervals)
return unobserved_intervals_list