diff --git a/pyproject.toml b/pyproject.toml index 02ff6bf..2e924ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "trame-components", "trame-tauri>=0.6.2", "Pillow", + "trame-colormaps>=1.0.0", ] [project.optional-dependencies] diff --git a/src/e3sm_quickview/app.py b/src/e3sm_quickview/app.py index 75be15f..21c5c1d 100644 --- a/src/e3sm_quickview/app.py +++ b/src/e3sm_quickview/app.py @@ -18,6 +18,7 @@ from e3sm_quickview.components import css, dialogs, doc, drawers, file_browser, toolbars from e3sm_quickview.pipeline import EAMVisSource from e3sm_quickview.utils import cli, compute, perf +from e3sm_quickview.utils.colors import get_type_color from e3sm_quickview.view_manager import ViewManager v3.enable_lab() @@ -52,7 +53,7 @@ def __init__(self, server=None): "variables_selected": [], # Control 'Load Variables' button availability "variables_loaded": False, - # Dynamic type-color mapping (populated when data loads) + # Dimension type → Vuetify color mapping via utils/colors.py "variable_types": [], # Dimension arrays (will be populated dynamically) "midpoints": [], @@ -332,29 +333,31 @@ def download_state(self): views_to_export = state_content["views"] = [] for view_type, var_names in active_variables.items(): for var_name in var_names: - config = self.view_manager.get_view(var_name, view_type).config + view = self.view_manager.get_view(var_name, view_type) + config = view.config + cmap = view.colormap views_to_export.append( { "type": view_type, "name": var_name, "config": { - # lut - "preset": config.preset, - "invert": config.invert, - "color_blind": config.color_blind, - "use_log_scale": config.use_log_scale, - "discrete_log": config.discrete_log, - "n_discrete_colors": config.n_discrete_colors, - # layout + # view layout "order": config.order, "size": config.size, "offset": config.offset, "break_row": config.break_row, - # color range - "override_range": config.override_range, - "color_range": config.color_range, - "color_value_min": config.color_value_min, - "color_value_max": config.color_value_max, + }, + "colormap": { + "preset": cmap.preset, + "invert": cmap.invert, + "color_blind": cmap.color_blind, + "use_log_scale": cmap.use_log_scale, + "discrete_log": cmap.discrete_log, + "n_discrete_colors": cmap.n_discrete_colors, + "override_range": cmap.override_range, + "color_range": cmap.color_range, + "color_value_min": cmap.color_value_min, + "color_value_max": cmap.color_value_max, }, } ) @@ -405,14 +408,29 @@ async def _import_state(self, state_content): self.state.animation_track = data_sel["animation_track"] # Update view states + _COLORMAP_KEYS = { + "preset", "invert", "color_blind", "use_log_scale", + "discrete_log", "n_discrete_colors", "override_range", + "color_range", "color_value_min", "color_value_max", + } for view_state in state_content["views"]: view_type = view_state["type"] var_name = view_state["name"] - cfg = view_state["config"] - if "color_range" in cfg and isinstance(cfg["color_range"], list): - cfg["color_range"] = tuple(cfg["color_range"]) - config = self.view_manager.get_view(var_name, view_type).config - config.update(**cfg) + view = self.view_manager.get_view(var_name, view_type) + + cfg = dict(view_state["config"]) + # Backward compat: old state files store colormap fields in "config" + cmap_cfg = dict(view_state.get("colormap", {})) + if not cmap_cfg: + cmap_cfg = {k: cfg.pop(k) for k in list(cfg) if k in _COLORMAP_KEYS} + + # Layout config + view.config.update(**cfg) + # Colormap config + if "color_range" in cmap_cfg and isinstance(cmap_cfg["color_range"], list): + cmap_cfg["color_range"] = tuple(cmap_cfg["color_range"]) + if cmap_cfg: + view.colormap.update(**cmap_cfg) # Update layout self.state.aspect_ratio = state_content["layout"]["aspect-ratio"] @@ -470,9 +488,7 @@ async def data_loading_open(self, simulation, connectivity): ), ] - # Build dynamic type-color mapping - from e3sm_quickview.utils.colors import get_type_color - + # Dimension type → Vuetify color mapping via utils/colors.py dim_types = sorted( set( ", ".join(var.dimensions) @@ -618,7 +634,7 @@ def _on_slicing_change(self, var, ind_var, **_): self.source.UpdatePipeline() with perf.timed("tick.color_range"): - self.view_manager.update_color_range() + self.view_manager.update_color_range() # colormaps module with perf.timed("tick.render"): self.view_manager.render() @@ -656,7 +672,7 @@ def _on_downstream_change( self.source.UpdatePipeline() with perf.timed("downstream_change.color_range"): - self.view_manager.update_color_range() + self.view_manager.update_color_range() # colormaps module with perf.timed("downstream_change.render"): self.view_manager.render() diff --git a/src/e3sm_quickview/components/view.py b/src/e3sm_quickview/components/view.py index 689d457..804407f 100644 --- a/src/e3sm_quickview/components/view.py +++ b/src/e3sm_quickview/components/view.py @@ -107,212 +107,3 @@ def create_size_menu(name, config): ) -def create_bottom_bar(config, update_color_preset): - with config.provide_as("config"): - with html.Div( - classes="bg-blue-grey-darken-2 d-flex align-center", - style="height:1rem;position:relative;top:0;user-select:none;cursor:context-menu;", - ): - with v3.VMenu( - v_model="config.menu", - activator="parent", - location=( - "active_layout !== 'auto_layout' || config.size == 12 ? 'top' : 'end'", - ), - close_on_content_click=False, - ): - with v3.VCard(style="max-width: 360px;min-width: 360px;"): - with v3.VCardItem(classes="py-0 px-2"): - with html.Div(classes="d-flex align-center"): - v3.VIconBtn( - v_tooltip_bottom=( - "config.color_blind ? 'Toggle for all color presets' : 'Toggle for colorblind safe color presets'", - ), - icon=( - "config.color_blind ? 'mdi-shield-check-outline' : 'mdi-palette'", - ), - click="config.color_blind = !config.color_blind", - size="small", - text=( - "config.color_blind ? 'Colorblind Safe' : 'All Colors'", - ), - variant="text", - ) - v3.VIconBtn( - v_tooltip_bottom=( - "config.invert ? 'Toggle to normal preset' : 'Toggle to inverted preset'", - ), - icon=( - "config.invert ? 'mdi-invert-colors' : 'mdi-invert-colors-off'", - ), - click="config.invert = !config.invert", - size="small", - text=( - "config.invert ? 'Inverted Preset' : 'Normal Preset'", - ), - variant="text", - ) - v3.VIconBtn( - v_tooltip_bottom=( - "config.use_log_scale === 'linear' ? 'Toggle to log scale' : config.use_log_scale === 'log' ? 'Toggle to symlog scale' : 'Toggle to linear scale'", - ), - icon=( - "config.use_log_scale === 'log' ? 'mdi-math-log' : config.use_log_scale === 'symlog' ? 'mdi-sine-wave mdi-rotate-330' : 'mdi-stairs'", - ), - click="config.use_log_scale = config.use_log_scale === 'linear' ? 'log' : config.use_log_scale === 'log' ? 'symlog' : 'linear'", - size="small", - text=( - "config.use_log_scale === 'log' ? 'Log' : config.use_log_scale === 'symlog' ? 'SymLog' : 'Linear'", - ), - variant="text", - ) - v3.VIconBtn( - v_tooltip_bottom=( - "config.override_range ? 'Toggle to use data range' : 'Toggle to use custom range'", - ), - icon=( - "config.override_range ? 'mdi-arrow-expand-horizontal' : 'mdi-pencil'", - ), - click="config.override_range = !config.override_range", - size="small", - text=( - "config.override_range ? 'Custom Range' : 'Data Range'", - ), - variant="text", - ) - v3.VIconBtn( - v_tooltip_bottom=( - "config.discrete_log ? 'Switch to continuous colormap' : 'Switch to discrete colormap'", - ), - icon=( - "config.discrete_log ? 'mdi-view-sequential' : 'mdi-gradient-horizontal'", - ), - click="config.discrete_log = !config.discrete_log", - size="small", - text=( - "config.discrete_log ? 'Discrete' : 'Continuous'", - ), - variant="text", - ) - - v3.VTextField( - v_model="config.search", - clearable=True, - placeholder=("config.preset",), - click_clear="config.search = null", - single_line=True, - variant="solo", - density="compact", - flat=True, - hide_details="auto", - # style="min-width: 150px;", - classes="d-inline", - reverse=True, - ) - v3.VIconBtn( - icon="mdi-close", - size="small", - text="Close", - click="config.menu=false", - ) - - with v3.VCardItem( - v_show="config.discrete_log", - classes="py-0 mb-2", - ): - v3.VNumberInput( - v_model="config.n_discrete_colors", - hide_details=True, - density="compact", - variant="outlined", - flat=True, - label=( - "config.use_log_scale === 'linear' ? 'Colors per tick interval' : 'Colors per order of magnitude'", - ), - classes="mt-2", - step=[1], - min=[1], - max=[20], - ) - with v3.VCardItem( - v_show="config.override_range", classes="py-0 mb-2" - ): - v3.VTextField( - v_model="config.color_value_min", - hide_details=True, - density="compact", - variant="outlined", - flat=True, - label="Min", - classes="mt-2", - error=("!config.color_value_min_valid",), - ) - v3.VTextField( - v_model="config.color_value_max", - hide_details=True, - density="compact", - variant="outlined", - flat=True, - label="Max", - classes="mt-2", - error=("!config.color_value_max_valid",), - ) - v3.VDivider() - with v3.VList(density="compact", max_height="40vh"): - with v3.VListItem( - v_for="entry in (config.invert ? luts_inverted : luts_normal)", - v_show="(config.search?.length ? entry.name.toLowerCase().includes(config.search.toLowerCase()) : 1) && (!config.color_blind || entry.safe)", - key="entry.name", - subtitle=("entry.name",), - click=( - update_color_preset, - "[entry.name, config.invert, config.use_log_scale, config.discrete_log, config.n_discrete_colors, config.n_colors]", - ), - active=("config.preset === entry.name",), - ): - html.Img( - src=("entry.url",), - style="width:100%;min-width:20rem;height:1rem;", - classes="rounded", - ) - html.Div( - "{{ utils.quickview.formatRange(config.color_range?.[0], config.use_log_scale, config.color_range?.[0], config.color_range?.[1]) }}", - classes="text-caption px-2 text-no-wrap", - ) - with html.Div( - classes="rounded w-100", - style="height:70%;position:relative;", - ): - html.Img( - src=("config.lut_img",), - style="width:100%;height:2rem;", - draggable=False, - ) - with html.Div( - style="position:absolute;top:0;left:0;right:0;bottom:0;pointer-events:none;", - ): - with html.Div( - v_for="(tick, i) in config.color_ticks", - key="i", - style=( - "`position:absolute;left:${tick.position}%;top:0;height:100%;transform:translateX(-50%);display:flex;flex-direction:column;align-items:center;`", - ), - ): - html.Div( - style=( - "`width:1.5px;height:30%;background:${tick.color};`", - ), - ) - html.Span( - "{{ tick.label }}", - style=( - "`font-size:0.5rem;line-height:1;white-space:nowrap;color:${tick.color};`", - ), - ) - html.Div( - style=("`width:1.5px;flex:1;background:${tick.color};`",), - ) - html.Div( - "{{ utils.quickview.formatRange(config.color_range?.[1], config.use_log_scale, config.color_range?.[0], config.color_range?.[1]) }}", - classes="text-caption px-2 text-no-wrap", - ) diff --git a/src/e3sm_quickview/utils/color.py b/src/e3sm_quickview/utils/color.py deleted file mode 100644 index 43874d9..0000000 --- a/src/e3sm_quickview/utils/color.py +++ /dev/null @@ -1,125 +0,0 @@ -import base64 - -from paraview import servermanager, simple -from vtkmodules.vtkCommonCore import vtkUnsignedCharArray -from vtkmodules.vtkCommonDataModel import vtkImageData -from vtkmodules.vtkIOImage import vtkPNGWriter - - -def get_color_preset_names(): - presets = servermanager.vtkSMTransferFunctionPresets.GetInstance() - return [ - presets.GetPresetName(index) for index in range(presets.GetNumberOfPresets()) - ] - - -def generate_colormaps(): - color_maps = {} - samples = 255 - rgb = [0, 0, 0] - names = get_color_preset_names() - lut = simple.GetColorTransferFunction("to_generate_image") - vtk_lut = lut.GetClientSideObject() - colorArray = vtkUnsignedCharArray() - colorArray.SetNumberOfComponents(3) - colorArray.SetNumberOfTuples(samples) - imgData = vtkImageData() - imgData.SetDimensions(samples, 1, 1) - imgData.GetPointData().SetScalars(colorArray) - writer = vtkPNGWriter() - writer.WriteToMemoryOn() - writer.SetInputData(imgData) - writer.SetCompressionLevel(1) - - for name in names: - if name.endswith(")"): - skip_number = name[-2] - if skip_number in "0123456789": - continue - - imgs = [] - for inverted in range(2): - lut.ApplyPreset(name, True) - if inverted: - lut.InvertTransferFunction() - - v_min = lut.RGBPoints[0] - v_max = lut.RGBPoints[-4] - step = (v_max - v_min) / (samples - 1) - - for i in range(samples): - value = v_min + step * float(i) - vtk_lut.GetColor(value, rgb) - r = int(round(rgb[0] * 255)) - g = int(round(rgb[1] * 255)) - b = int(round(rgb[2] * 255)) - colorArray.SetTuple3(i, r, g, b) - - writer.Write() - img_bytes = writer.GetResult() - - base64_img = base64.standard_b64encode(img_bytes).decode("utf-8") - imgs.append(f"data:image/png;base64,{base64_img}") - - color_maps[name] = imgs - - return {k: {"normal": v[0], "inverted": v[1]} for k, v in color_maps.items()} - - -COLORBAR_CACHE = generate_colormaps() - - -def get_cached_colorbar_image(colormap_name, inverted=False): - """ - Get a cached colorbar image for a given colormap. - - Parameters: - ----------- - colormap_name : str - Name of the colormap (e.g., "Cool to Warm", "Rainbow Desaturated") - inverted : bool - Whether to get the inverted version - - Returns: - -------- - str - Base64-encoded PNG image as a data URI, or empty string if not found - """ - if colormap_name in COLORBAR_CACHE: - variant = "inverted" if inverted else "normal" - return COLORBAR_CACHE[colormap_name].get(variant, "") - - return "" - - -def lut_to_img(lut_proxy): - samples = 255 - rgb = [0, 0, 0] - vtk_lut = lut_proxy.GetClientSideObject() - colorArray = vtkUnsignedCharArray() - colorArray.SetNumberOfComponents(3) - colorArray.SetNumberOfTuples(samples) - imgData = vtkImageData() - imgData.SetDimensions(samples, 1, 1) - imgData.GetPointData().SetScalars(colorArray) - writer = vtkPNGWriter() - writer.WriteToMemoryOn() - writer.SetInputData(imgData) - writer.SetCompressionLevel(1) - - v_min = lut_proxy.RGBPoints[0] - v_max = lut_proxy.RGBPoints[-4] - step = (v_max - v_min) / (samples - 1) - - for i in range(samples): - value = v_min + step * float(i) - vtk_lut.GetColor(value, rgb) - r = int(round(rgb[0] * 255)) - g = int(round(rgb[1] * 255)) - b = int(round(rgb[2] * 255)) - colorArray.SetTuple3(i, r, g, b) - - writer.Write() - base64_img = base64.standard_b64encode(writer.GetResult()).decode("utf-8") - - return f"data:image/png;base64,{base64_img}" diff --git a/src/e3sm_quickview/utils/math.py b/src/e3sm_quickview/utils/math.py index 79364af..b4e80f9 100644 --- a/src/e3sm_quickview/utils/math.py +++ b/src/e3sm_quickview/utils/math.py @@ -9,38 +9,6 @@ from typing import List, Tuple, Optional -def calculate_linthresh(data): - """Calculate the linear threshold for symlog scaling. - - Excludes true zeros (values within ±tiny of the data dtype), - then returns min(abs(valid)). - - Operates on the original array without copies. - - Args: - data: numpy array of data values - - Returns: - linthresh value (float), floored at dtype tiny to avoid zero - """ - threshold = np.finfo(data.dtype).tiny - - # Find min |x| > threshold without allocating a copy. - # Using where= runs as a tight vectorized C loop, roughly 2-3 orders - # of magnitude faster than a Python for loop. - min_pos = np.nanmin(data, where=data > threshold, initial=np.inf) - # For negatives: max(data) where data < -threshold gives closest to zero - max_neg = np.nanmax(data, where=data < -threshold, initial=-np.inf) - min_abs = min(min_pos, -max_neg) - - if min_abs == np.inf: - linthresh = 1.0 - else: - linthresh = max(float(min_abs), float(np.finfo(data.dtype).tiny)) - - return linthresh - - def calculate_weighted_average( data_array: np.ndarray, weights: Optional[np.ndarray] = None ) -> float: @@ -165,236 +133,3 @@ def normalize_range( normalized = (value - old_min) / (old_max - old_min) return new_min + normalized * (new_max - new_min) - - -def get_nice_ticks(vmin, vmax, n, scale="linear", linthresh=None): - """Compute nicely spaced tick values for a given range and scale. - - Args: - vmin: Minimum data value - vmax: Maximum data value - n: Desired number of ticks - scale: One of 'linear', 'log', or 'symlog' - - Returns: - Sorted array of unique, snapped tick values. - """ - - def snap(val): - if np.isclose(val, 0, atol=1e-12): - return 0.0 - sign = np.sign(val) - val_abs = abs(val) - mag = 10 ** np.floor(np.log10(val_abs)) - residual = val_abs / mag - nice_steps = np.array([1.0, 2.0, 5.0, 10.0]) - best_step = nice_steps[np.abs(nice_steps - residual).argmin()] - return sign * best_step * mag - - if scale == "linear": - raw_ticks = np.linspace(vmin, vmax, n) - elif scale == "log": - # Use integer powers of 10 that fall strictly inside [vmin, vmax] - log_floor = linthresh if linthresh is not None else 1e-15 - safe_vmin = max(vmin, log_floor) - safe_vmax = max(vmax, log_floor) - start_exp = int(np.floor(np.log10(safe_vmin))) - stop_exp = int(np.ceil(np.log10(safe_vmax))) - powers = [ - 10.0**e - for e in range(start_exp, stop_exp + 1) - if safe_vmin <= 10.0**e <= safe_vmax - ] - # Fall back to log-spaced ticks when no powers of 10 are interior - if len(powers) < 2: - raw_ticks = np.geomspace(safe_vmin, safe_vmax, n) - else: - raw_ticks = np.array(powers) - elif scale == "symlog": - if linthresh is None: - linthresh = 1.0 - # Use powers of 10 as tick values, matching the LUT breakpoints - ticks_set = set() - if vmin < 0: - lo = max(linthresh, 1e-30) - for e in range( - int(np.floor(np.log10(lo))), int(np.floor(np.log10(abs(vmin)))) + 1 - ): - val = -(10.0**e) - if vmin <= val < 0: - ticks_set.add(val) - if vmax > 0: - lo = max(linthresh, 1e-30) - for e in range( - int(np.floor(np.log10(lo))), int(np.floor(np.log10(vmax))) + 1 - ): - val = 10.0**e - if 0 < val <= vmax: - ticks_set.add(val) - if vmin <= 0 <= vmax: - ticks_set.add(0.0) - raw_ticks = np.array(sorted(ticks_set)) - # Skip snap — powers of 10 are already nice - return raw_ticks - else: - raw_ticks = np.linspace(vmin, vmax, n) - - nice_ticks = np.array([snap(t) for t in raw_ticks]) - - # Force 0 for non-log scales if it's within range - if vmin <= 0 <= vmax and scale != "log": - idx = np.abs(nice_ticks).argmin() - nice_ticks[idx] = 0.0 - - return np.unique(np.sort(nice_ticks)) - - -def format_tick(val): - """Format a tick value as a concise human-readable string. - - Returns a string suitable for display on a colorbar. Powers of 10 are - shown as '10^N', very large/small values use scientific notation, and - intermediate values use fixed-point. - """ - if val == 0: - return "0" - - val_abs = abs(val) - log10 = np.log10(val_abs) - - if np.isclose(log10, np.round(log10), atol=1e-12): - exponent = int(np.round(log10)) - sign = "-" if val < 0 else "" - if exponent == 0: - return f"{sign}1" - if exponent == 1: - return f"{sign}10" - return f"{sign}10^{exponent}" - - if val_abs >= 1000 or val_abs <= 0.01: - return f"{val:.1e}" - return f"{int(val) if val == int(val) else val:.1f}" - - -def tick_contrast_color(r, g, b): - """Return '#fff' or '#000' for best contrast against the given RGB color. - - Uses the W3C relative luminance formula to decide. RGB values are - expected in [0, 1] range. - """ - luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b - return "#000" if luminance > 0.45 else "#fff" - - -def compute_color_ticks( - vmin, vmax, scale="linear", n=5, min_gap=7, edge_margin=3, linthresh=None -): - """Compute tick marks for a colorbar. - - Tick positions are computed in the space matching the scale: - - linear: position = (val - vmin) / (vmax - vmin) * 100 - - symlog: position = (symlog(val) - symlog(vmin)) / (symlog(vmax) - symlog(vmin)) * 100 - - The colorbar image is always the linear preset, so symlog ticks - appear at different positions than linear ticks for the same values. - - Args: - vmin: Minimum color range value - vmax: Maximum color range value - scale: One of 'linear', 'log', or 'symlog' - n: Desired number of ticks - min_gap: Minimum gap between ticks in percentage points - edge_margin: Minimum distance from edges (0% and 100%) in percentage points - - Returns: - List of dicts with 'position' (0-100 percentage) and 'label' keys. - """ - if vmin >= vmax: - return [] - - raw_n = n if scale == "linear" else n * 2 - ticks = get_nice_ticks(vmin, vmax, raw_n, scale, linthresh=linthresh) - data_range = vmax - vmin - - # Build mapping functions for non-linear tick positions - _symlog_fn = None - _log_min = _log_max = _log_range = None - - if scale == "symlog": - if linthresh is None: - linthresh = 1.0 - - def _symlog_fn(v): - v = np.asarray(v, dtype=float) - return np.sign(v) * np.log10(1.0 + np.abs(v) / linthresh) - - s_min = float(_symlog_fn(vmin)) - s_max = float(_symlog_fn(vmax)) - s_range = s_max - s_min - - elif scale == "log": - log_floor = linthresh if linthresh is not None else 1e-30 - safe_vmin = max(vmin, log_floor) - safe_vmax = max(vmax, log_floor) - _log_min = np.log10(safe_vmin) - _log_max = np.log10(safe_vmax) - _log_range = _log_max - _log_min - - # Build candidate list with position in the appropriate space - candidates = [] - has_zero = False - for t in ticks: - val = float(t) - if scale == "symlog" and s_range != 0: - pos = (float(_symlog_fn(val)) - s_min) / s_range * 100 - elif scale == "log" and _log_range and _log_range != 0 and val > 0: - pos = (np.log10(val) - _log_min) / _log_range * 100 - else: - pos = (val - vmin) / data_range * 100 - if edge_margin <= pos <= (100 - edge_margin): - is_zero = val == 0 - if is_zero: - has_zero = True - candidates.append( - { - "position": round(pos, 2), - "label": format_tick(val), - "priority": is_zero, - } - ) - - # Always include 0 when it falls within the range (for any scale) - if not has_zero and scale != "log": - if scale == "symlog" and s_range != 0: - zero_pos = (float(_symlog_fn(0.0)) - s_min) / s_range * 100 - else: - zero_pos = (0.0 - vmin) / data_range * 100 - if 0 <= zero_pos <= 100: - tick = {"position": round(zero_pos, 2), "label": "0", "priority": True} - # Insert in sorted order - inserted = False - for i, c in enumerate(candidates): - if tick["position"] <= c["position"]: - candidates.insert(i, tick) - inserted = True - break - if not inserted: - candidates.append(tick) - - # Filter out ticks that are too close together, but never remove priority ticks - result = [] - for tick in candidates: - is_priority = tick.get("priority", False) - if is_priority: - if result and (tick["position"] - result[-1]["position"]) < min_gap: - if not result[-1].get("priority", False): - result.pop() - result.append(tick) - elif not result or (tick["position"] - result[-1]["position"]) >= min_gap: - # Also check distance to next priority tick (look-ahead) - result.append(tick) - - # Clean up internal flags before returning - for tick in result: - tick.pop("priority", None) - return result diff --git a/src/e3sm_quickview/view_manager.py b/src/e3sm_quickview/view_manager.py index c5c5956..e6df3a3 100644 --- a/src/e3sm_quickview/view_manager.py +++ b/src/e3sm_quickview/view_manager.py @@ -2,11 +2,8 @@ import math import time -import numpy as np - # Rendering Factory import vtkmodules.vtkRenderingOpenGL2 # noqa: F401 -from paraview import simple from paraview.modules.vtkPVVTKExtensionsInteractionStyle import ( vtkPVInteractorStyle, vtkPVTrackballZoom, @@ -26,16 +23,10 @@ vtkRenderWindowInteractor, ) +from trame.dataclasses.colormaps import ColormapConfig +from trame.widgets.colormaps import HorizontalScalarBar from e3sm_quickview.components import view as tview -from e3sm_quickview.presets import COLOR_BLIND_SAFE from e3sm_quickview.utils import perf -from e3sm_quickview.utils.color import COLORBAR_CACHE, lut_to_img -from e3sm_quickview.utils.math import ( - calculate_linthresh, - compute_color_ticks, - format_tick, - tick_contrast_color, -) def auto_size_to_col(size): @@ -66,35 +57,16 @@ def auto_size_to_col(size): } -def lut_name(element): - return element.get("name").lower() - - class ViewConfiguration(dataclass.StateDataModel): + # --- View identity --- variable: str = dataclass.Sync(str) - preset: str = dataclass.Sync(str, "BuGnYl") - invert: bool = dataclass.Sync(bool, False) - color_blind: bool = dataclass.Sync(bool, False) - use_log_scale: str = dataclass.Sync(str, "linear") - discrete_log: bool = dataclass.Sync(bool, False) - n_discrete_colors: int = dataclass.Sync(int, 4) - color_value_min: str = dataclass.Sync(str, "0") - color_value_max: str = dataclass.Sync(str, "1") - color_value_min_valid: bool = dataclass.Sync(bool, True) - color_value_max_valid: bool = dataclass.Sync(bool, True) - color_range: list[float] = dataclass.Sync(tuple[float, float], (0, 1)) - override_range: bool = dataclass.Sync(bool, False) + + # --- Layout --- order: int = dataclass.Sync(int, 0) size: int = dataclass.Sync(int, 6) offset: int = dataclass.Sync(int, 0) break_row: bool = dataclass.Sync(bool, False) - menu: bool = dataclass.Sync(bool, False) swap_group: list[str] = dataclass.Sync(list[str], list) - search: str | None = dataclass.Sync(str) - n_colors: int = dataclass.Sync(int, 255) - lut_img: str = dataclass.Sync(str) - color_ticks: list = dataclass.Sync(list, list) - effective_color_range: list[float] = dataclass.Sync(tuple[float, float], (0, 1)) class VariableView(TrameComponent): @@ -123,39 +95,17 @@ def __init__(self, server, source, variable_name, variable_type, camera): self.actor = vtkActor(mapper=self.mapper) self.renderer.AddActor(self.actor) - # Lookup table color management - self.lut = simple.GetColorTransferFunction(variable_name) - self.lut.NanOpacity = 0.0 - - # Color mapping - self.mapper.SetScalarVisibility(1) - self.mapper.SetScalarModeToUseCellFieldData() - self.mapper.SelectColorArray(variable_name) - self.mapper.SetLookupTable(self.lut.GetClientSideObject()) - # Add annotation to the view (continents, gridlines) self.renderer.AddActor(source.continent.actor) self.renderer.AddActor(source.grid_lines.actor) - # Reactive behavior - self.config.watch( - ["color_value_min", "color_value_max"], - self.color_range_str_to_float, - ) - self.config.watch( - ["override_range", "color_range"], self.update_color_range, eager=True - ) - self.config.watch( - [ - "preset", - "invert", - "use_log_scale", - "discrete_log", - "n_discrete_colors", - ], - self.update_color_preset, - eager=True, - ) + # colormaps module: creates LUT, wires mapper, manages presets/range/ticks + self.colormap = ColormapConfig( + server, + mapper=self.mapper, + data_array_fn=lambda: self.data_array, + ).set_data_array(variable_name, lambda: self.data_array, "cell") + self.colormap.watch(["mapper_change"], lambda **_: self.render()) # GUI self._build_ui() @@ -187,705 +137,12 @@ def render(self): if self.ctx.view: self.ctx.view.update() - def update_color_preset( - self, - name, - invert, - log_scale, - discrete_log=False, - n_discrete_colors=4, - n_colors=255, - ): - self.config.preset = name - - # ApplyPreset resets range to [0,1], so always apply the linear - # preset first, rescale to the current range, then apply transforms - self._apply_linear_to_lut(invert) - self.lut.RescaleTransferFunction(*self.config.color_range) - - # Capture the linear colorbar image (always the same regardless of scale) - ctf = self.lut.GetClientSideObject() - self.config.effective_color_range = ctf.GetRange() - self.config.lut_img = lut_to_img(self.lut) - - # Save a reference to the linear LUT range for tick contrast sampling - linear_rgb_points = list(self.lut.RGBPoints) - - # Compute linthresh (smallest positive non-zero value) from data - # for log and symlog scales. - linthresh = None - if log_scale in ("log", "symlog"): - from vtkmodules.util.numpy_support import vtk_to_numpy - - arr = self.data_array - if arr is not None: - linthresh = calculate_linthresh(vtk_to_numpy(arr)) - else: - linthresh = 1.0 - - n_sub = max(1, min(20, int(n_discrete_colors))) - if log_scale == "linear" and discrete_log: - display_rgb_points = self._apply_discrete_linear_to_lut( - linear_rgb_points, n_sub - ) - if display_rgb_points is not None: - linear_rgb_points = display_rgb_points - elif log_scale == "log": - if discrete_log: - display_rgb_points = self._apply_discrete_log_to_lut( - linthresh, linear_rgb_points, n_sub - ) - if display_rgb_points is not None: - linear_rgb_points = display_rgb_points - else: - self._apply_log_to_lut(linthresh) - elif log_scale == "symlog": - if discrete_log: - display_rgb_points = self._apply_discrete_symlog_to_lut( - linthresh, linear_rgb_points, n_sub - ) - if display_rgb_points is not None: - linear_rgb_points = display_rgb_points - else: - self._apply_symlog_to_lut(linthresh, linear_rgb_points) - - self._compute_ticks(linthresh=linthresh, linear_rgb_points=linear_rgb_points) - - # For symlog (or any discrete mode), rebuild the client-side CTF as - # the VERY LAST step so nothing (proxy sync, _compute_ticks, lut_to_img) - # can overwrite it. - if log_scale == "symlog" or (discrete_log and log_scale in ("log", "linear")): - from vtkmodules.vtkRenderingCore import vtkColorTransferFunction - - pts = list(self.lut.RGBPoints) - ctf = vtkColorTransferFunction() - for i in range(0, len(pts), 4): - ctf.AddRGBPoint(pts[i], pts[i + 1], pts[i + 2], pts[i + 3]) - self._symlog_ctf = ctf # prevent GC - else: - self.lut.UpdateVTKObjects() - ctf = self.lut.GetClientSideObject() - - self.mapper.SetLookupTable(ctf) - self.mapper.Modified() - - self.render() - - def _apply_linear_to_lut(self, invert=False): - """Apply preset with linear scale.""" - self.lut.UseLogScale = 0 - self.lut.ApplyPreset(self.config.preset, True) - if invert: - self.lut.InvertTransferFunction() - - def _apply_discrete_linear_to_lut(self, linear_rgb_points, n_sub=1): - """Build a discrete (stepped) linear LUT. - - The data range is divided into N_INTERVALS equal-percentage intervals. - Each interval is then split into *n_sub* equal sub-bands, each with a - flat color sampled from the continuous linear LUT at the sub-band - midpoint. The boundary values are stored so ``_compute_ticks`` can - place tick marks at the exact same positions. - """ - N_INTERVALS = 4 - ctf = self.lut.GetClientSideObject() - x_min, x_max = ctf.GetRange() - data_range = x_max - x_min - if data_range == 0: - return - - # Evenly spaced boundaries (percentages of data range) - boundaries = [ - x_min + data_range * i / N_INTERVALS for i in range(N_INTERVALS + 1) - ] - # Store boundary values and their display positions (%) for tick alignment - self._discrete_tick_data = [ - {"val": boundaries[i], "pos": i / N_INTERVALS * 100} - for i in range(1, N_INTERVALS) - ] - - if len(boundaries) < 2: - return - - # Build a temporary linear CTF from the saved linear RGB points - from vtkmodules.vtkRenderingCore import vtkColorTransferFunction - - linear_ctf = vtkColorTransferFunction() - for i in range(0, len(linear_rgb_points), 4): - linear_ctf.AddRGBPoint( - linear_rgb_points[i], - linear_rgb_points[i + 1], - linear_rgb_points[i + 2], - linear_rgb_points[i + 3], - ) - - rgb = [0.0, 0.0, 0.0] - eps = data_range * 1e-9 - display_rgb_points = [] - render_rgb_points = [] - band_idx = 0 - total_bands = (len(boundaries) - 1) * n_sub - for i in range(len(boundaries) - 1): - lo = boundaries[i] - hi = boundaries[i + 1] - for j in range(n_sub): - # Sub-band edges in linear space - sub_lo = lo + (hi - lo) * j / n_sub - sub_hi = lo + (hi - lo) * (j + 1) / n_sub - sub_mid = (sub_lo + sub_hi) / 2.0 - linear_ctf.GetColor(sub_mid, rgb) - r, g, b = float(rgb[0]), float(rgb[1]), float(rgb[2]) - - is_first = band_idx == 0 - is_last = band_idx == total_bands - 1 - - if is_first: - display_rgb_points.extend([sub_lo, r, g, b]) - render_rgb_points.extend([sub_lo, r, g, b]) - else: - display_rgb_points.extend([sub_lo + eps, r, g, b]) - render_rgb_points.extend([sub_lo + eps, r, g, b]) - - if is_last: - display_rgb_points.extend([sub_hi, r, g, b]) - render_rgb_points.extend([sub_hi, r, g, b]) - else: - display_rgb_points.extend([sub_hi - eps, r, g, b]) - render_rgb_points.extend([sub_hi - eps, r, g, b]) - - band_idx += 1 - - # Generate the discrete banded colorbar image - self.lut.RGBPoints = display_rgb_points - self.config.lut_img = lut_to_img(self.lut) - - # Store rendering points on proxy - self.lut.UseLogScale = 0 - self.lut.RGBPoints = render_rgb_points - - return display_rgb_points - - def _apply_log_to_lut(self, linthresh): - """Transform the already-prepared LUT to log scale. - - Uses linthresh (smallest positive non-zero data value) as the floor - when the range includes zero or negative values. - The colorbar image is captured before this call, so it stays linear. - """ - ctf = self.lut.GetClientSideObject() - x_min, x_max = ctf.GetRange() - if x_max <= 0: - return - if x_min <= 0: - x_min = linthresh - self.lut.RescaleTransferFunction(x_min, x_max) - self.lut.MapControlPointsToLogSpace() - self.lut.UseLogScale = 1 - - def _apply_discrete_log_to_lut(self, linthresh, linear_rgb_points, n_sub=1): - """Build a discrete (stepped) log-scale LUT. - - Decade boundaries are powers of 10 from linthresh to x_max. - Each decade is split into *n_sub* equal sub-bands in log space, - each with a flat color sampled from the continuous linear LUT. - """ - ctf = self.lut.GetClientSideObject() - x_min, x_max = ctf.GetRange() - if x_max <= 0: - return - # Clamp floor - x_min = max(x_min, linthresh) - data_range = x_max - x_min - if data_range == 0: - return - - log_min = np.log10(x_min) - log_max = np.log10(x_max) - log_range = log_max - log_min - if log_range == 0: - return - - # Build decade boundaries - boundaries = [x_min] - e_lo = int(np.ceil(np.log10(x_min))) - e_hi = int(np.floor(np.log10(x_max))) - for e in range(e_lo, e_hi + 1): - val = 10.0**e - if x_min < val < x_max: - boundaries.append(val) - boundaries.append(x_max) - - if len(boundaries) < 2: - return - - # Store boundary values and their display positions (%) for tick alignment - log_min = np.log10(x_min) - log_max = np.log10(x_max) - log_range_val = log_max - log_min - self._discrete_tick_data = [] - for bv in boundaries[1:-1]: - pct = (np.log10(bv) - log_min) / log_range_val * 100 if log_range_val else 0 - self._discrete_tick_data.append({"val": bv, "pos": float(pct)}) - - # Build a continuous log CTF so discrete bands sample colours that - # match the continuous log rendering. - from vtkmodules.vtkRenderingCore import vtkColorTransferFunction - - linear_ctf = vtkColorTransferFunction() - for i in range(0, len(linear_rgb_points), 4): - linear_ctf.AddRGBPoint( - linear_rgb_points[i], - linear_rgb_points[i + 1], - linear_rgb_points[i + 2], - linear_rgb_points[i + 3], - ) - - n_samples = 256 - log_vals = np.linspace(log_min, log_max, n_samples) - log_ctf = vtkColorTransferFunction() - rgb_tmp = [0.0, 0.0, 0.0] - for lg in log_vals: - v = 10.0**lg - v = max(x_min, min(x_max, v)) - t = (lg - log_min) / log_range - x_lookup = x_min + t * data_range - linear_ctf.GetColor(x_lookup, rgb_tmp) - log_ctf.AddRGBPoint(v, rgb_tmp[0], rgb_tmp[1], rgb_tmp[2]) - - rgb = [0.0, 0.0, 0.0] - eps_data = data_range * 1e-9 - eps_lin = data_range * 1e-9 - display_rgb_points = [] - render_rgb_points = [] - band_idx = 0 - total_bands = (len(boundaries) - 1) * n_sub - for i in range(len(boundaries) - 1): - log_lo_decade = np.log10(boundaries[i]) - log_hi_decade = np.log10(boundaries[i + 1]) - for j in range(n_sub): - # Sub-band edges in log space - log_lo = log_lo_decade + (log_hi_decade - log_lo_decade) * j / n_sub - log_hi = ( - log_lo_decade + (log_hi_decade - log_lo_decade) * (j + 1) / n_sub - ) - log_mid = (log_lo + log_hi) / 2.0 - # Sample color from continuous log CTF at sub-band midpoint - v_mid = 10.0**log_mid - v_mid = max(x_min, min(x_max, v_mid)) - log_ctf.GetColor(v_mid, rgb) - r, g, b = float(rgb[0]), float(rgb[1]), float(rgb[2]) - - # Data-space boundaries for rendering - v_lo = 10.0**log_lo - v_hi = 10.0**log_hi - v_lo = max(x_min, min(x_max, v_lo)) - v_hi = max(x_min, min(x_max, v_hi)) - - # Linear positions for display image - t_lo_pos = (log_lo - log_min) / log_range - t_hi_pos = (log_hi - log_min) / log_range - d_lo = x_min + t_lo_pos * data_range - d_hi = x_min + t_hi_pos * data_range - - is_first = band_idx == 0 - is_last = band_idx == total_bands - 1 - - if is_first: - display_rgb_points.extend([d_lo, r, g, b]) - render_rgb_points.extend([float(v_lo), r, g, b]) - else: - display_rgb_points.extend([d_lo + eps_lin, r, g, b]) - render_rgb_points.extend([float(v_lo) + eps_data, r, g, b]) - - if is_last: - display_rgb_points.extend([d_hi, r, g, b]) - render_rgb_points.extend([float(v_hi), r, g, b]) - else: - display_rgb_points.extend([d_hi - eps_lin, r, g, b]) - render_rgb_points.extend([float(v_hi) - eps_data, r, g, b]) - - band_idx += 1 - - # Generate the discrete banded colorbar image - self.lut.RGBPoints = display_rgb_points - self.config.lut_img = lut_to_img(self.lut) - - # Store rendering points on proxy - self.lut.UseLogScale = 0 - self.lut.RGBPoints = render_rgb_points - - return display_rgb_points - - def _apply_symlog_to_lut(self, linthresh, linear_rgb_points=None): - """Build a symlog LUT with decade control points. - - Control points are placed at powers of 10 (and ±linthresh, 0 for - mixed-sign data). The RGB color for each control point is sampled - from the linear colorbar at the position where that value falls in - symlog space: t = (symlog(v) - symlog(min)) / (symlog(max) - symlog(min)). - """ - ctf = self.lut.GetClientSideObject() - x_min, x_max = ctf.GetRange() - data_range = x_max - x_min - if data_range == 0: - return - - def symlog(v): - v = np.asarray(v, dtype=float) - return np.sign(v) * np.log10(1.0 + np.abs(v) / linthresh) - - # Build control points: N uniform samples in symlog space, plus - # mandatory breakpoints at ±linthresh and 0 for exact transitions. - n_samples = 256 - s_min_val = float(symlog(x_min)) - s_max_val = float(symlog(x_max)) - s_range_bp = s_max_val - s_min_val - if s_range_bp == 0: - return - - # Uniform in symlog space → invert to data space - # Inverse of symlog: v = sign(s) * linthresh * (10^|s| - 1) - s_vals = np.linspace(s_min_val, s_max_val, n_samples) - breakpoints = [] - for s in s_vals: - v = float(np.sign(s) * linthresh * (10.0 ** abs(s) - 1.0)) - v = max(x_min, min(x_max, v)) - breakpoints.append(v) - - # Symlog range for normalization - s_min = float(symlog(x_min)) - s_max = float(symlog(x_max)) - s_range = s_max - s_min - if s_range == 0: - return - - # Build a standalone linear CTF for safe color sampling - from vtkmodules.vtkRenderingCore import vtkColorTransferFunction - - linear_ctf = vtkColorTransferFunction() - if linear_rgb_points: - src = linear_rgb_points - else: - src = list(self.lut.RGBPoints) - for i in range(0, len(src), 4): - linear_ctf.AddRGBPoint(src[i], src[i + 1], src[i + 2], src[i + 3]) - - # Sample RGB from the linear CTF at symlog-normalized positions - rgb = [0.0, 0.0, 0.0] - new_rgb_points = [] - display_rgb_points = [] - for v in breakpoints: - t = (float(symlog(v)) - s_min) / s_range - x_lookup = x_min + t * data_range - linear_ctf.GetColor(x_lookup, rgb) - r, g, b = float(rgb[0]), float(rgb[1]), float(rgb[2]) - new_rgb_points.extend([float(v), r, g, b]) - # Display points: uniform linear positions with symlog colors - display_rgb_points.extend([x_lookup, r, g, b]) - - # Regenerate colorbar image from display points so it matches the 3D - self.lut.UseLogScale = 0 - self.lut.RGBPoints = display_rgb_points - self.config.lut_img = lut_to_img(self.lut) - - # Store rendering points on proxy — the actual CTF used by the - # mapper is a standalone vtkColorTransferFunction built in - # update_color_preset to avoid proxy client-side object issues. - self.lut.RGBPoints = new_rgb_points - - def _apply_discrete_symlog_to_lut(self, linthresh, linear_rgb_points, n_sub=1): - """Build a discrete (stepped) symlog LUT. - - Each decade interval is split into *n_sub* equal sub-bands in symlog - space, each with a flat color sampled from the continuous LUT at the - sub-band midpoint. Twin control points with a tiny offset create hard - steps at the sub-band boundaries. The display image is also replaced - with a banded colorbar. - """ - ctf = self.lut.GetClientSideObject() - x_min, x_max = ctf.GetRange() - data_range = x_max - x_min - if data_range == 0: - return - - def symlog(v): - v = np.asarray(v, dtype=float) - return np.sign(v) * np.log10(1.0 + np.abs(v) / linthresh) - - # Build decade boundaries (same logic as symlog ticks) - boundaries = set() - if x_min < 0: - lo = max(linthresh, 1e-30) - for e in range( - int(np.floor(np.log10(lo))), - int(np.floor(np.log10(abs(x_min)))) + 1, - ): - val = -(10.0**e) - if x_min <= val < 0: - boundaries.add(val) - if x_max > 0: - lo = max(linthresh, 1e-30) - for e in range( - int(np.floor(np.log10(lo))), - int(np.floor(np.log10(x_max))) + 1, - ): - val = 10.0**e - if 0 < val <= x_max: - boundaries.add(val) - if x_min < 0 and x_max > 0: - boundaries.update((-linthresh, 0.0, linthresh)) - elif x_min < 0 and x_max <= 0: - if -linthresh >= x_min: - boundaries.add(-linthresh) - elif x_min >= 0 and x_max > 0: - if linthresh <= x_max: - boundaries.add(linthresh) - if x_min <= 0 <= x_max: - boundaries.add(0.0) - boundaries.add(x_min) - boundaries.add(x_max) - # Filter to only values within [x_min, x_max] - boundaries = sorted(b for b in boundaries if x_min <= b <= x_max) - - if len(boundaries) < 2: - return - - # Symlog range for normalization - s_min = float(symlog(x_min)) - s_max = float(symlog(x_max)) - s_range = s_max - s_min - if s_range == 0: - return - - # Store boundary values and their display positions (%) for tick alignment. - # All boundaries are used for discrete bands, but when x_min < 0 we - # thin the displayed ticks: always show 0, then only every other - # decade moving outward from 0 in each direction. - all_tick_data = [] - for bv in boundaries[1:-1]: - s_val = float(symlog(bv)) - pct = (s_val - s_min) / s_range * 100 - all_tick_data.append({"val": bv, "pos": float(pct)}) - - if x_min < 0: - # Exclude linthresh / -linthresh from tick labels - lt = float(linthresh) - filtered = [t for t in all_tick_data if abs(abs(t["val"]) - lt) > 1e-12] - # Separate into negative, zero, and positive - neg = [t for t in filtered if t["val"] < 0] - zero = [t for t in filtered if t["val"] == 0] - pos = [t for t in filtered if t["val"] > 0] - # Keep every other decade tick moving outward from 0 - neg_outward = list(reversed(neg)) - thinned_neg = [neg_outward[i] for i in range(0, len(neg_outward), 2)] - thinned_pos = [pos[i] for i in range(0, len(pos), 2)] - self._discrete_tick_data = sorted( - thinned_neg + zero + thinned_pos, key=lambda t: t["val"] - ) - else: - self._discrete_tick_data = all_tick_data - - # Build a continuous symlog CTF (same as _apply_symlog_to_lut) so - # discrete bands sample colours that match the continuous rendering. - from vtkmodules.vtkRenderingCore import vtkColorTransferFunction - - linear_ctf = vtkColorTransferFunction() - for i in range(0, len(linear_rgb_points), 4): - linear_ctf.AddRGBPoint( - linear_rgb_points[i], - linear_rgb_points[i + 1], - linear_rgb_points[i + 2], - linear_rgb_points[i + 3], - ) - - n_samples = 256 - s_vals = np.linspace(s_min, s_max, n_samples) - symlog_ctf = vtkColorTransferFunction() - rgb_tmp = [0.0, 0.0, 0.0] - for s in s_vals: - v = float(np.sign(s) * linthresh * (10.0 ** abs(s) - 1.0)) - v = max(x_min, min(x_max, v)) - t = (s - s_min) / s_range - x_lookup = x_min + t * data_range - linear_ctf.GetColor(x_lookup, rgb_tmp) - symlog_ctf.AddRGBPoint(v, rgb_tmp[0], rgb_tmp[1], rgb_tmp[2]) - - # For each decade interval, split into n_sub equal sub-bands in - # symlog space. Each sub-band gets a flat color sampled from the - # continuous symlog LUT at the sub-band midpoint. - rgb = [0.0, 0.0, 0.0] - eps_data = (x_max - x_min) * 1e-9 - eps_lin = data_range * 1e-9 - display_rgb_points = [] - render_rgb_points = [] - band_idx = 0 - total_bands = (len(boundaries) - 1) * n_sub - for i in range(len(boundaries) - 1): - s_lo_decade = float(symlog(boundaries[i])) - s_hi_decade = float(symlog(boundaries[i + 1])) - for j in range(n_sub): - # Sub-band edges in symlog space - s_lo = s_lo_decade + (s_hi_decade - s_lo_decade) * j / n_sub - s_hi = s_lo_decade + (s_hi_decade - s_lo_decade) * (j + 1) / n_sub - s_mid = (s_lo + s_hi) / 2.0 - - # Invert symlog to get data-space values - v_mid = float(np.sign(s_mid) * linthresh * (10.0 ** abs(s_mid) - 1.0)) - v_mid = max(x_min, min(x_max, v_mid)) - symlog_ctf.GetColor(v_mid, rgb) - r, g, b = float(rgb[0]), float(rgb[1]), float(rgb[2]) - - # Invert symlog to get data-space boundaries for rendering - v_lo = float(np.sign(s_lo) * linthresh * (10.0 ** abs(s_lo) - 1.0)) - v_hi = float(np.sign(s_hi) * linthresh * (10.0 ** abs(s_hi) - 1.0)) - v_lo = max(x_min, min(x_max, v_lo)) - v_hi = max(x_min, min(x_max, v_hi)) - - # Linear positions for display image - t_lo_pos = (s_lo - s_min) / s_range - t_hi_pos = (s_hi - s_min) / s_range - d_lo = x_min + t_lo_pos * data_range - d_hi = x_min + t_hi_pos * data_range - - is_first = band_idx == 0 - is_last = band_idx == total_bands - 1 - - if is_first: - display_rgb_points.extend([d_lo, r, g, b]) - render_rgb_points.extend([float(v_lo), r, g, b]) - else: - display_rgb_points.extend([d_lo + eps_lin, r, g, b]) - render_rgb_points.extend([float(v_lo) + eps_data, r, g, b]) - - if is_last: - display_rgb_points.extend([d_hi, r, g, b]) - render_rgb_points.extend([float(v_hi), r, g, b]) - else: - display_rgb_points.extend([d_hi - eps_lin, r, g, b]) - render_rgb_points.extend([float(v_hi) - eps_data, r, g, b]) - - band_idx += 1 - - # Generate the discrete banded colorbar image - self.lut.RGBPoints = display_rgb_points - self.config.lut_img = lut_to_img(self.lut) - - # Store rendering points on proxy — the actual CTF used by the - # mapper is a standalone vtkColorTransferFunction built in - # update_color_preset to avoid proxy client-side object issues. - self.lut.RGBPoints = render_rgb_points - - return display_rgb_points - - def color_range_str_to_float(self, color_value_min, color_value_max): - try: - min_value = float(color_value_min) - self.config.color_value_min_valid = not math.isnan(min_value) - except ValueError: - self.config.color_value_min_valid = False - - try: - max_value = float(color_value_max) - self.config.color_value_max_valid = not math.isnan(max_value) - except ValueError: - self.config.color_value_max_valid = False - - if self.config.color_value_min_valid and self.config.color_value_max_valid: - self.config.color_range = (min_value, max_value) - @property def data_array(self): self.source.data_reader.vtk_geometry.Update() ds = self.source.data_reader.vtk_geometry.GetOutput() return ds.GetCellData().GetArray(self.variable_name) - def update_color_range(self, *_): - if self.config.override_range: - skip_update = False - if math.isnan(self.config.color_range[0]): - skip_update = True - self.config.color_value_min_valid = False - - if math.isnan(self.config.color_range[1]): - skip_update = True - self.config.color_value_max_valid = False - - if skip_update: - return - - self.lut.RescaleTransferFunction(*self.config.color_range) - else: - data_array = self.data_array - if data_array: - data_range = data_array.GetRange() - self.config.color_range = data_range - self.config.color_value_min = str(data_range[0]) - self.config.color_value_max = str(data_range[1]) - self.config.color_value_min_valid = True - self.config.color_value_max_valid = True - self.lut.RescaleTransferFunction(*data_range) - - self.update_color_preset( - self.config.preset, - self.config.invert, - self.config.use_log_scale, - self.config.discrete_log, - self.config.n_discrete_colors, - ) - - def _compute_ticks(self, linthresh=None, linear_rgb_points=None): - vmin, vmax = self.config.color_range - - # For discrete mode, use pre-computed boundary positions. - # For continuous linear, use the same evenly spaced percentage ticks. - if self.config.discrete_log and hasattr(self, "_discrete_tick_data"): - ticks = [ - {"position": round(td["pos"], 2), "label": format_tick(td["val"])} - for td in self._discrete_tick_data - ] - elif self.config.use_log_scale == "linear": - N_INTERVALS = 4 - data_range = vmax - vmin - ticks = [] - if data_range > 0: - for i in range(1, N_INTERVALS): - val = vmin + data_range * i / N_INTERVALS - pos = i / N_INTERVALS * 100 - ticks.append({"position": round(pos, 2), "label": format_tick(val)}) - else: - ticks = compute_color_ticks( - vmin, vmax, scale=self.config.use_log_scale, n=5, linthresh=linthresh - ) - # Sample colors from the *linear* LUT so tick contrast matches the - # displayed colorbar image, not the log/symlog-remapped rendering LUT. - rgb_points = ( - linear_rgb_points if linear_rgb_points else list(self.lut.RGBPoints) - ) - if len(rgb_points) < 4: - self.config.color_ticks = [] - return - img_min = rgb_points[0] - img_max = rgb_points[-4] - img_range = img_max - img_min - if img_range == 0: - self.config.color_ticks = [] - return - # Build a temporary linear CTF to sample colors from - from vtkmodules.vtkRenderingCore import vtkColorTransferFunction - - linear_ctf = vtkColorTransferFunction() - for i in range(0, len(rgb_points), 4): - linear_ctf.AddRGBPoint( - rgb_points[i], rgb_points[i + 1], rgb_points[i + 2], rgb_points[i + 3] - ) - rgb = [0.0, 0.0, 0.0] - for tick in ticks: - t = tick["position"] / 100.0 - value = img_min + t * img_range - linear_ctf.GetColor(value, rgb) - tick["color"] = tick_contrast_color(rgb[0], rgb[1], rgb[2]) - self.config.color_ticks = ticks - def _build_ui(self): with DivLayout( self.server, template_name=self.name, connect_parent=False, classes="h-100" @@ -975,7 +232,8 @@ def _build_ui(self): size=(self.update_size, "[$event]"), ) - tview.create_bottom_bar(self.config, self.update_color_preset) + with self.colormap.provide_as(self.name): + HorizontalScalarBar(self.name, popup_location="top") class ViewManager(TrameComponent): @@ -1031,19 +289,6 @@ def __init__(self, server, source): rca.initialize(self.server) - self.state.luts_normal = [ - {"name": k, "url": v["normal"], "safe": k in COLOR_BLIND_SAFE} - for k, v in COLORBAR_CACHE.items() - ] - self.state.luts_inverted = [ - {"name": k, "url": v["inverted"], "safe": k in COLOR_BLIND_SAFE} - for k, v in COLORBAR_CACHE.items() - ] - - # Sort lists - self.state.luts_normal.sort(key=lut_name) - self.state.luts_inverted.sort(key=lut_name) - def _on_render_start(self, *_): if perf.is_enabled(): self._render_t0 = time.perf_counter() @@ -1159,9 +404,9 @@ def render(self): self.pending_render = False def update_color_range(self): + """Update color range on all views via colormaps module.""" for view in list(self._var2view.values()): - view.update_color_range() - self.render() + view.colormap.update_color_range() # colormaps module def get_view(self, variable_name, variable_type): view = self._var2view.get(variable_name) @@ -1282,7 +527,7 @@ def build_auto_layout(self, variables=None): # Create UI based on variables self.state.swap_groups = {} - # Build a lookup from type name to color from state.variable_types + # Vuetify color per dimension type (e.g. midpoint, interface) via utils/colors.py type_to_color = {vt["name"]: vt["color"] for vt in self.state.variable_types} with DivLayout(self.server, template_name="auto_layout") as self.ui: self.ui.root.classes = "all-variables" @@ -1295,7 +540,7 @@ def build_auto_layout(self, variables=None): if total_size == 0: continue - # Look up color from variable_types to match chip colors + # Border color matches dimension type chips via utils/colors.py border_color = type_to_color.get(", ".join(var_type), "primary") with v3.VAlert( border="start",