Source code for nwbwidgets.brains
import numpy as np
import ipywidgets as widgets
import plotly.graph_objects as go
from pynwb.base import DynamicTable
import trimesh
from .base import df_to_hover_text
[docs]def make_cylinder_mesh(
radius, height, sections=32, position=(0, 0, 0), direction=(1, 0, 0), **kwargs
):
new_normal = direction / np.linalg.norm(direction)
cosx, cosy = new_normal[:2]
sinx = np.sqrt(1 - cosx ** 2)
siny = np.sqrt(1 - cosy ** 2)
yaw = [[cosx, -sinx, 0, 0], [sinx, cosx, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
pitch = [[cosy, 0, siny, 0], [0, 1, 0, 0], [-siny, 0, cosy, 0], [0, 0, 0, 1]]
transform = np.dot(yaw, pitch)
transform[:3, 3] = position
cylinder = trimesh.primitives.Cylinder(
radius=radius, height=height, sections=sections, transform=transform
)
x, y, z = cylinder.vertices.T
i, j, k = cylinder.faces.T
return go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs)
[docs]def make_cylinders(
positions, directions, radius=1, height=1, sections=32, name="cylinders", **kwargs
):
return [
make_cylinder_mesh(
position=position,
direction=direction,
radius=radius,
height=height,
sections=sections,
showlegend=not i,
legendgroup=name,
name=name,
**kwargs
)
for i, (position, direction) in enumerate(zip(positions, directions))
]
[docs]class HumanElectrodesPlotlyWidget(widgets.VBox):
def __init__(self, electrodes: DynamicTable, **kwargs):
super().__init__()
self.electrodes = electrodes
slider_kwargs = dict(
value=1.0, min=0.0, max=1.0, style={"description_width": "initial"}
)
left_opacity_slider = widgets.FloatSlider(
description="left hemi opacity", **slider_kwargs
)
right_opacity_slider = widgets.FloatSlider(
description="right hemi opacity", **slider_kwargs
)
color_by_dropdown = widgets.Dropdown(
options=list(electrodes.colnames),
value="group_name",
description="Color By:",
disabled=False,
)
color_by_dropdown.observe(self.color_electrode_by)
left_opacity_slider.observe(self.observe_left_opacity)
right_opacity_slider.observe(self.observe_right_opacity)
self.fig = go.FigureWidget()
self.plot_human_brain()
self.show_electrodes(electrodes, color_by_dropdown.value)
sliders = widgets.HBox([left_opacity_slider, right_opacity_slider])
self.children = [self.fig, widgets.VBox([sliders, color_by_dropdown])]
[docs] @staticmethod
def find_normals(points, k=3):
normals = []
for point in points:
from skspatial.objects import Points, Plane
distance = np.linalg.norm(points - point, axis=1)
# closest_inds = np.argpartition(distance, 3)
# x0, x1, x2 = points[closest_inds[:3]]
# normal = np.cross((x1 - x0), (x2 - x0))
closest_inds = np.argpartition(distance, k)
close_points = points[closest_inds[:k]]
normal = np.asarray(Plane.best_fit(close_points).normal)
normals.append(normal)
return normals
[docs] def show_electrodes(self, electrodes: DynamicTable, color_by):
positions = np.c_[electrodes.x[:], electrodes.y[:], electrodes.z[:]]
if isinstance(electrodes[color_by][0], (bytes, str, np.bool_)):
ugroups, group_inv = np.unique(electrodes[color_by][:], return_inverse=True)
colors = group_inv
show_leg = True
show_scale = False
elif isinstance(electrodes[color_by][0], (np.ndarray, np.float)):
colors = np.ravel(electrodes[color_by][:])
ugroups, group_inv = [0], np.array([0] * len(colors))
show_leg = False
show_scale = True
else:
print("Not a valid data type")
return
c_max = np.max(colors)
c_min = np.min(colors)
for i, group in enumerate(ugroups):
sel_positions = positions[group_inv == i]
c = colors[group_inv == i]
x, y, z = sel_positions.T
if isinstance(group, bytes):
group = group.decode()
"""
if 'GRID' in group:
normals = self.find_normals(sel_positions, 5)
with self.fig.batch_update():
[self.fig.add_trace(trace) for trace in make_cylinders(
positions=sel_positions,
directions=normals,
radius=2,
height=.5,
color=c,
name=group
)]
else:
"""
self.fig.add_trace(
go.Scatter3d(
mode="markers",
x=x,
y=y,
z=z,
name=str(group),
legendgroup=str(group),
marker=dict(
color=c,
cmax=c_max,
cmin=c_min,
colorscale="Viridis",
colorbar=dict(title="Colorbar"),
showscale=show_scale,
),
text=df_to_hover_text(electrodes.to_dataframe()),
hoverinfo="text",
showlegend=show_leg,
)
),
self.fig.update_layout(
legend=dict(
x=0,
y=1,
),
)
[docs] def plot_human_brain(self, left_opacity=1.0, right_opacity=1.0):
from nilearn import datasets, surface
mesh = datasets.fetch_surf_fsaverage("fsaverage5")
def create_mesh(name, **kwargs):
vertices, triangles = surface.load_surf_mesh(mesh[name])
x, y, z = vertices.T
i, j, k = triangles.T
return go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs)
kwargs = dict(
color="lightgray",
lighting=dict(specular=1, ambient=0.9, roughness=0.9, diffuse=0.9),
hoverinfo="skip",
)
self.fig.add_trace(create_mesh("pial_left", opacity=left_opacity, **kwargs))
self.fig.add_trace(create_mesh("pial_right", opacity=right_opacity, **kwargs))
self.fig.update_layout(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
),
height=500,
margin=dict(t=20, b=0),
)
[docs] def observe_left_opacity(self, change):
if "new" in change and isinstance(change["new"], float):
self.fig.data[0].opacity = change["new"]
[docs] def observe_right_opacity(self, change):
if "new" in change and isinstance(change["new"], float):
self.fig.data[1].opacity = change["new"]
[docs] def color_electrode_by(self, change):
if "new" in change and isinstance(change["new"], str):
with self.fig.batch_update():
self.fig.data = None
self.plot_human_brain()
self.show_electrodes(self.electrodes, change["new"])