import numpy as np
from scipy.signal import stft
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.colors import DEFAULT_PLOTLY_COLORS
from ipywidgets import widgets, ValueWidget
from pynwb import TimeSeries
from pynwb.base import DynamicTable
from pynwb.ecephys import SpikeEventSeries, ElectricalSeries
from .base import fig2widget, lazy_tabs, render_dataframe
from .timeseries import BaseGroupedTraceWidget
from .brains import HumanElectrodesPlotlyWidget
[docs]def show_spectrogram(nwbobj: TimeSeries, channel=0, **kwargs):
fig, ax = plt.subplots()
f, t, Zxx = stft(nwbobj.data[:, channel], nwbobj.rate, nperseg=2 * 17)
ax.imshow(
np.log(np.abs(Zxx)),
aspect="auto",
extent=[0, max(t), 0, max(f)],
origin="lower",
)
ax.set_ylim(0, max(f))
ax.set_xlabel("time (s)")
ax.set_ylabel("frequency (Hz)")
fig.show()
[docs]def show_electrodes(electrodes_table):
in_dict = dict(table=render_dataframe)
if np.isnan(electrodes_table.x[0]): # position is not defined
in_dict.update(electrode_groups=ElectrodeGroupsWidget)
else:
subject = electrodes_table.get_ancestor("NWBFile").subject
if subject is not None:
species = subject.species
if species in ("mouse", "Mus musculus"):
in_dict.update(CCF=show_ccf)
elif species in ("human", "Homo sapiens"):
in_dict.update(render=HumanElectrodesPlotlyWidget)
return lazy_tabs(in_dict, electrodes_table)
[docs]def show_ccf(electrodes_table=None, **kwargs):
from ccfwidget import CCFWidget
input_kwargs = {}
if electrodes_table is not None:
df = electrodes_table.to_dataframe()
markers = [
idf[["x", "y", "z"]].to_numpy() for _, idf in df.groupby("group_name")
]
input_kwargs.update(markers=markers)
input_kwargs.update(kwargs)
return CCFWidget(**input_kwargs)
[docs]def show_spike_event_series(ses: SpikeEventSeries, **kwargs):
def control_plot(spk_ind):
fig, ax = plt.subplots(figsize=(9, 5))
data = ses.data[spk_ind]
if nChannels > 1:
for ch in range(nChannels):
ax.plot(data[:, ch], color="#d9d9d9")
else:
ax.plot(data[:], color="#d9d9d9")
ax.plot(np.mean(data, axis=1), color="k")
ax.set_xlabel("Time")
ax.set_ylabel("Amplitude")
fig.show()
return fig2widget(fig)
if len(ses.data.shape) == 3:
nChannels = ses.data.shape[2]
else:
nChannels = ses.data.shape[1]
nSpikes = ses.data.shape[0]
# Controls
field_lay = widgets.Layout(
max_height="40px", max_width="100px", min_height="30px", min_width="70px"
)
spk_ind = widgets.BoundedIntText(value=0, min=0, max=nSpikes - 1, layout=field_lay)
controls = {"spk_ind": spk_ind}
out_fig = widgets.interactive_output(control_plot, controls)
# Assemble layout box
lbl_spk = widgets.Label("Spike ID:", layout=field_lay)
lbl_nspks0 = widgets.Label("N° spikes:", layout=field_lay)
lbl_nspks1 = widgets.Label(str(nSpikes), layout=field_lay)
lbl_nch0 = widgets.Label("N° channels:", layout=field_lay)
lbl_nch1 = widgets.Label(str(nChannels), layout=field_lay)
hbox0 = widgets.HBox(children=[lbl_spk, spk_ind])
vbox0 = widgets.VBox(
children=[
widgets.HBox(children=[lbl_nspks0, lbl_nspks1]),
widgets.HBox(children=[lbl_nch0, lbl_nch1]),
hbox0,
]
)
hbox1 = widgets.HBox(children=[vbox0, out_fig])
return hbox1