From 98ab4b39abdd91b8421458984fef892c2d26666c Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Wed, 24 Jun 2026 10:49:46 +0200 Subject: [PATCH] Allow subclasses of FigureClass to be passed to plot_raw/plot_epochs * Allow subclasses of `MNEBrowseFigure` to be passed to plot_raw/plot_epochs, as well as the corresponding `plot(...)` methods of the raw and epochs classes. * Add an example showing onionskinning of MEG traces --- doc/changes/dev/13979.newfeature.rst | 1 + doc/changes/names.inc | 1 + examples/visualization/onionskin.py | 149 +++++++++++++++++++++++++++ mne/epochs.py | 2 + mne/io/base.py | 2 + mne/utils/docs.py | 8 ++ mne/viz/_mpl_figure.py | 6 +- mne/viz/epochs.py | 5 + mne/viz/raw.py | 5 + 9 files changed, 178 insertions(+), 1 deletion(-) create mode 100644 doc/changes/dev/13979.newfeature.rst create mode 100644 examples/visualization/onionskin.py diff --git a/doc/changes/dev/13979.newfeature.rst b/doc/changes/dev/13979.newfeature.rst new file mode 100644 index 00000000000..15b01a9cbb0 --- /dev/null +++ b/doc/changes/dev/13979.newfeature.rst @@ -0,0 +1 @@ +Allow subclasses of `MNEBrowseFigure` to be passed to plot_raw/plot_epochs, as well as the corresponding `plot(...)` methods of the raw and epochs classes, by :newcontrib:`Frankie Robertson` diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 6a08df3c88d..cc68731f900 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -110,6 +110,7 @@ .. _Felix Raimundo: https://github.com/gamazeps .. _Florian Hofer: https://github.com/hofaflo .. _Florin Pop: https://github.com/florin-pop +.. _Frankie Robertson: https://github.com/frankier .. _Frederik Weber: https://github.com/Frederik-D-Weber .. _Fu-Te Wong: https://github.com/zuxfoucault .. _Gennadiy Belonosov: https://github.com/Genuster diff --git a/examples/visualization/onionskin.py b/examples/visualization/onionskin.py new file mode 100644 index 00000000000..05e8988a357 --- /dev/null +++ b/examples/visualization/onionskin.py @@ -0,0 +1,149 @@ +""" +.. _ex-arrowmap: + +============================================================== +Advanced plotting customization by subclassing MNEBrowseFigure +============================================================== + +This example shows how plot_epochs(...) and plot_raw(...) can be customized by +subclassing MNEBrowseFigure and using the `figure_class` argument. +It plots one EEG trace overlaid ("onion-skinned") on top of another. + +This example is "bad code" in a few ways: + * Since the interface for MNEBrowseFigure is not public, it is liable to + break between minor and even patch versions of MNE without warning + * Some functionality is reimplemented from MNEBrowseFigure in a more or + less copy-paste style + * The code is backend-specific, in particular it is limited to the + matplotlib backend, and will not work with the qt browser + * Since there is no way to pass another EEG directly to the MNEBrowseFigure, + it is passed through a global variable + +Nevertheless, the example shows that the "escape hatch" of using a subclass is +available when other customization possibilities offered by MNE are not +sufficient. +""" + +from mne.datasets import eegbci +from mne.io import read_raw_edf +from mne.viz import set_browser_backend +from mne.viz._mpl_figure import MNEBrowseFigure as MNEBrowseFigureOrig + +set_browser_backend("matplotlib") + + +onionskin_eeg = None + + +def _set_onionskin_eeg(eeg): + global onionskin_eeg + onionskin_eeg = eeg + + +class OnionskinMNEBrowseFigure(MNEBrowseFigureOrig): + """ + Subclass of MNEBrowseFigure adding in onion-skin functionality, + i.e. plotting one EEG trace overlaid on top of another. + """ + + def __init__(self, *args, **kwargs): + import numpy as np + + super().__init__(*args, **kwargs) + onionskin_kwargs = { + **self.mne.trace_kwargs, + } + self.mne.onionskins = self.mne.ax_main.plot( + np.full((1, self.mne.n_channels), np.nan), **onionskin_kwargs + ) + + def _update_data(self): + import numpy as np + + from mne.io.base import BaseRaw + + super()._update_data() + if not onionskin_eeg: + self.mne.onionskin_data = None + return + start, stop = self._get_start_stop() + if isinstance(onionskin_eeg, BaseRaw): + if stop is None: + data = onionskin_eeg[:, start:] + else: + data = onionskin_eeg[:, start:stop] + data = data[0] + else: + ix_start = np.searchsorted( + self.mne.boundary_times, self.mne.t_start - self.mne.sampling_period + ) + ix_stop = ix_start + self.mne.n_epochs + item = slice(ix_start, ix_stop) + print(type(onionskin_eeg)) + data = np.concatenate( + onionskin_eeg.get_data(item=item, copy=False), axis=-1 + ) + data = self._process_data(data, start, stop, picks=self.mne.picks) + self.mne.onionskin_data = data + + def _draw_traces(self): + import numpy as np + from matplotlib.colors import to_rgba_array + from matplotlib.patches import Rectangle + + super()._draw_traces() + if self.mne.onionskin_data is None: + return + picks = self.mne.picks + offset_ixs = ( + picks + if self.mne.butterfly and self.mne.ch_selections is None + else slice(None) + ) + offsets = self.mne.trace_offsets[offset_ixs] + + ch_colors = to_rgba_array(self.mne.ch_colors) + ch_colors[:, 3] *= 0.5 + + decim = np.ones_like(picks) + data_picks_mask = np.isin(picks, self.mne.picks_data) + decim[data_picks_mask] = self.mne.decim + # decim can vary by channel type, so compute different `times` vectors + decim_times = { + decim_value: self.mne.times[::decim_value] + self.mne.first_time + for decim_value in set(decim) + } + + time_range = (self.mne.times + self.mne.first_time)[[0, -1]] + ylim = self.mne.ax_main.get_ylim() + for ii, line in enumerate(self.mne.onionskins): + this_offset = offsets[ii] + this_times = decim_times[decim[ii]] + this_data = ( + this_offset - self.mne.onionskin_data[ii] * self.mne.scale_factor + ) + this_data = this_data[..., :: decim[ii]] + clip = 0.2 if self.mne.butterfly else 0.5 + bottom = max(this_offset - clip, ylim[1]) + height = min(2 * clip, ylim[0] - bottom) + rect = Rectangle( + xy=np.array([time_range[0], bottom]), + width=time_range[1] - time_range[0], + height=height, + transform=self.mne.ax_main.transData, + ) + line.set_clip_path(rect) + line.set_xdata(this_times) + line.set_ydata(this_data) + color = ch_colors[ii] + line.set_color(color) + line.set_zorder(self.mne.zorder["data"] - 1) + + +subjects = [1] +runs = [1, 2] +raw_fnames = eegbci.load_data(subjects, runs) +_set_onionskin_eeg(read_raw_edf(raw_fnames[0], preload=True)) +read_raw_edf(raw_fnames[1], preload=True).plot( + block=True, figure_class=OnionskinMNEBrowseFigure +) diff --git a/mne/epochs.py b/mne/epochs.py index 442359f3d65..dbda7b0d281 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1327,6 +1327,7 @@ def plot( overview_mode=None, splash=True, annotation_colors=None, + figure_class=None, ): return plot_epochs( self, @@ -1354,6 +1355,7 @@ def plot( overview_mode=overview_mode, splash=splash, annotation_colors=annotation_colors, + figure_class=figure_class, ) @copy_function_doc_to_method_doc(plot_topo_image_epochs) diff --git a/mne/io/base.py b/mne/io/base.py index 597723d72b8..f142763dcf9 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -2001,6 +2001,7 @@ def plot( overview_mode=None, splash=True, verbose=None, + figure_class=None, ): return plot_raw( self, @@ -2042,6 +2043,7 @@ def plot( overview_mode=overview_mode, splash=splash, verbose=verbose, + figure_class=figure_class, ) @property diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 564bb4abb0d..eeb81830b3d 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1714,6 +1714,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): and if absent, falls back to ``'estimated'``. """ +docdict["figure_class"] = """ +figure_class : class + The backend specific `MNEBrowseFigure` class to use. This is typically used + to pass a subclass in order to customize the plot. This parameter requires + cooperation from the backend, and is currently only supported by the + ``matplotlib`` backend. +""" + docdict["fig_background"] = """ fig_background : None | array A background image for the figure. This must be a valid input to diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 283398a1469..b0504ed3f06 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -2618,7 +2618,11 @@ def _init_browser(**kwargs): """Instantiate a new MNE browse-style figure.""" from mne.io import BaseRaw - fig = _figure(toolbar=False, FigureClass=MNEBrowseFigure, layout=None, **kwargs) + figure_class = kwargs.pop("figure_class", None) + if figure_class is None: + figure_class = MNEBrowseFigure + + fig = _figure(toolbar=False, FigureClass=figure_class, layout=None, **kwargs) # splash is ignored (maybe we could do it for mpl if we get_backend() and # check if it's Qt... but seems overkill) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 0154f357588..97b34aa0703 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -765,6 +765,7 @@ def plot_epochs( overview_mode=None, splash=True, annotation_colors=None, + figure_class=None, ): """Visualize epochs. @@ -875,6 +876,9 @@ def plot_epochs( will trigger a warning. If ``None`` (default), automatic colors are used. .. versionadded:: 1.12.1 + %(figure_class)s + + .. versionadded:: 1.13 Returns ------- @@ -1089,6 +1093,7 @@ def plot_epochs( theme=theme, overview_mode=overview_mode, splash=splash, + figure_class=figure_class, ) fig = _get_browser(show=show, block=block, **params) diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 5e6febc550c..ca2a0d14de2 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -75,6 +75,7 @@ def plot_raw( overview_mode=None, splash=True, verbose=None, + figure_class=None, ): """Plot raw data. @@ -227,6 +228,9 @@ def plot_raw( .. versionadded:: 1.6 %(verbose)s + %(figure_class)s + + .. versionadded:: 1.13 Returns ------- @@ -435,6 +439,7 @@ def plot_raw( theme=theme, overview_mode=overview_mode, splash=splash, + figure_class=figure_class, ) fig = _get_browser(show=show, block=block, **params)