Source code for omnipose.plot

from .utils import rescale, torch_norm
from .color import sinebow

import matplotlib as mpl
mpl.rcParams['svg.fonttype'] = 'none'  # keep text as real text in the SVG
mpl.rcParams['text.usetex'] = False      # Avoid LaTeX (which converts text to paths)

from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
# import matplotlib.pyplot as plt

import types

import numpy as np
from matplotlib.backend_bases import GraphicsContextBase, RendererBase

from mpl_toolkits.axes_grid1 import ImageGrid

from skimage import img_as_ubyte


[docs] def setup(): # Import necessary libraries import matplotlib as mpl import matplotlib.pyplot as plt import ipywidgets as widgets from IPython.display import display, HTML # Custom CSS to center plots display(HTML(""" <style> .jp-OutputArea-output img { display: block; margin: 0 auto; } </style> """)) # Inject into the global namespace of the notebook ipython = get_ipython() # Get the IPython instance ipython.user_global_ns['mpl'] = mpl ipython.user_global_ns['plt'] = plt ipython.user_global_ns['widgets'] = widgets ipython.user_global_ns['display'] = display # Set matplotlib inline for Jupyter notebooks ipython.run_line_magic('matplotlib', 'inline') # Define rc_params rc_params = { 'figure.dpi': 300, 'figure.figsize': (2, 2), 'image.cmap': 'gray', 'image.interpolation': 'nearest', 'figure.frameon': False, 'axes.grid': False, 'axes.facecolor': 'none', # Transparent axes 'figure.facecolor': 'none', # Transparent figure background 'savefig.facecolor': 'none', # Transparent save background 'text.color': 'gray', # Gray text for flexibility 'axes.labelcolor': 'gray', 'xtick.color': 'gray', 'ytick.color': 'gray', 'axes.edgecolor': 'gray', # Legend defaults – place legend outside axes on the right, no frame 'legend.loc': 'center left', 'legend.frameon': False, 'legend.framealpha': 0, 'legend.borderaxespad': 0.0, 'lines.scale_dashes': False } # Update rcParams mpl.rcParams.update(rc_params) # Monkey-patch Axes.legend to default to outside-right placement with no frame from matplotlib.axes import Axes as _Axes _orig_legend = _Axes.legend def _legend(self, *args, **kwargs): kwargs.setdefault('loc', 'center left') kwargs.setdefault('bbox_to_anchor', (1.02, 0.5)) kwargs.setdefault('frameon', False) kwargs.setdefault('framealpha', 0) kwargs.setdefault('borderaxespad', 0.0) return _orig_legend(self, *args, **kwargs) _Axes.legend = _legend
[docs] def figure(nrow=None, ncol=None, aspect=1, **kwargs): figsize = kwargs.get('figsize', 2) if not isinstance(figsize, (list, tuple, np.ndarray)) and figsize is not None: figsize = (figsize*aspect, figsize) kwargs['figsize'] = figsize fig = Figure(**kwargs) # fig = plt.figure(**kwargs) if nrow is not None and ncol is not None: axs = [] for i in range(nrow * ncol): ax = fig.add_subplot(nrow, ncol, i + 1) axs.append(ax) return fig, axs else: ax = fig.add_subplot(111) return fig, ax
class GC(GraphicsContextBase): def __init__(self): super().__init__() self._capstyle = 'round'
[docs] def custom_new_gc(self): return GC()
def plot_edges(shape,affinity_graph,neighbors,coords, figsize=1,fig=None,ax=None, extent=None, slc=None, pic=None, edgecol=[.75]*3+[.5],linewidth=0.15,step_inds=None, cmap='inferno',origin='lower',bounds=None): # import core here because that can take a while to load from .core import affinity_to_edges from .utils import get_neigh_inds from matplotlib.collections import LineCollection print('adjust this to make edges appear even on edges or when target is 0') nstep,npix = affinity_graph.shape coords = tuple(coords) indexes, neigh_inds, ind_matrix = get_neigh_inds(tuple(neighbors),coords,shape) if step_inds is None: step_inds = np.arange(nstep) px_inds = np.arange(npix) edge_list = affinity_to_edges(affinity_graph.astype(bool), neigh_inds, step_inds, px_inds) aff_coords = np.array(coords).T segments = np.stack([[aff_coords[:,::-1][e]+0.5 for e in edge] for edge in edge_list]) # segments = np.stack([[aff_coords[e]+0.5 for e in edge] for edge in edge_list]) RendererBase.new_gc = types.MethodType(custom_new_gc, RendererBase) newfig = fig is None and ax is None if newfig: if type(figsize) is not (list or tuple): figsize = (figsize,figsize) # fig, ax = plt.subplots(figsize=figsize) fig, ax = figure(figsize=figsize) # ax.invert_yaxis() if extent is None: extent = np.array([0,shape[1],0,shape[0]]) nopic = pic is None if nopic: summed_affinity = np.zeros(shape,dtype=int) summed_affinity[coords] = np.sum(affinity_graph,axis=0) # print(np.unique(summed_affinity)) # c = sinebow(8) # colors = np.array(list(c.values())) # affinity_cmap = mpl.colors.ListedColormap(colors) # colors = mpl.colormaps.get_cmap(cmap).reversed()(np.linspace(-1,1,8)) colors = mpl.colormaps.get_cmap(cmap).reversed()(np.linspace(0,1,9)) # colors = mpl.colormaps.get_cmap(cmap)(np.linspace(0,1,8)) colors = np.vstack((np.array([0]*4),colors)) affinity_cmap = mpl.colors.ListedColormap(colors) pic = affinity_cmap(summed_affinity) # # Generate random values between 0.5 and 1 # random_values = np.random.uniform(.75, 1, size=(len(segments),)) # # Multiply base_color by random values # colors = edgecol * random_values[:, np.newaxis] colors = edgecol ax.imshow(pic[slc] if slc is not None else pic, extent=extent,origin=origin) line_segments = LineCollection(segments, color=colors,linewidths=linewidth) ax.add_collection(line_segments) if newfig: # plt.axis('off') ax.set_axis_off() ax.invert_yaxis() # plt.show() canvas = FigureCanvas(fig) canvas.draw() if nopic: return summed_affinity, affinity_cmap else: return None,None # if bounds is None: # line_segments = LineCollection(segments, color=colors,linewidths=linewidth) # # if bounds is not None: # # clip_rect = Rectangle((bounds[0], bounds[1]), bounds[2], bounds[3]) # # clip_rect.set_transform(ax.transData) # # line_segments.set_clip_path(clip_rect) # else: # # Create a bounding box that defines the extent # bbox = Bbox.from_extents(bounds[0], bounds[1], bounds[0]+bounds[2], bounds[1]+bounds[3]) # # Create a path for each line segment and clip it to the bounding box # clipped_segments = [Path(seg).clip_to_bbox(bbox).to_polygons() for seg in segments] # # Create a line collection with the clipped segments # line_segments = LineCollection(clipped_segments) import numpy as np import matplotlib as mpl import types from matplotlib.collections import LineCollection from matplotlib.backend_bases import RendererBase from matplotlib.figure import Figure from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
[docs] def plot_edges( shape, affinity_graph, neighbors, coords, figsize=1, fig=None, ax=None, extent=None, slc=None, pic=None, edgecol=[.75]*3 + [.5], linewidth=0.15, step_inds=None, cmap='inferno', origin='lower', bounds=None, ): """ Render an affinity graph as line segments laid over an optional image. Boundary pixels (including linear index 0) are handled explicitly, so every valid edge appears—even when its target lies on the image border. """ # ——————————————————————————————————————————— imports that take time kept local from .utils import get_neigh_inds from .core import affinity_to_edges # retained in case callers expect it nstep, npix = affinity_graph.shape coords = tuple(coords) # build lookup tables for neighbours indexes, neigh_inds, ind_matrix = get_neigh_inds(tuple(neighbors), coords, shape) # default to all steps if none supplied if step_inds is None: step_inds = np.arange(nstep) px_inds = np.arange(npix) # ------------------------------------------------------------------------- # Build edge list manually so edges touching the border are never lost # ------------------------------------------------------------------------- aff_coords = np.array(coords).T # (2, N) -> (y, x) segments = [] for s in step_inds: mask = affinity_graph[s].astype(bool) # where an edge exists if not mask.any(): continue src_idx = px_inds[mask] dst_idx = neigh_inds[s, mask] valid = dst_idx >= 0 # drop out-of-bounds neighbours src_idx = src_idx[valid] dst_idx = dst_idx[valid] for a, b in zip(src_idx, dst_idx): # flip Y/X order for imshow coords and shift to pixel-centres (+0.5) segments.append(aff_coords[:, ::-1][[a, b]] + 0.5) if not segments: raise ValueError("No edges found to plot; check affinity_graph and neighbours.") segments = np.stack(segments) # ------------------------------------------------------------------------- # Figure / axes handling # ------------------------------------------------------------------------- RendererBase.new_gc = types.MethodType(custom_new_gc, RendererBase) newfig = fig is None and ax is None if newfig: if not isinstance(figsize, (list, tuple)): figsize = (figsize, figsize) fig = Figure(figsize=figsize) ax = fig.add_subplot(111) if extent is None: extent = np.array([0, shape[1], 0, shape[0]]) # ------------------------------------------------------------------------- # Background image (affinity heat-map) – create if not supplied # ------------------------------------------------------------------------- nopic = pic is None if nopic: summed_affinity = np.zeros(shape, dtype=int) summed_affinity[coords] = np.sum(affinity_graph, axis=0) # build a visually pleasing reversed colormap colors = mpl.colormaps.get_cmap(cmap).reversed()(np.linspace(0, 1, 9)) colors = np.vstack((np.array([0]*4), colors)) # prepend transparent/black affinity_cmap = mpl.colors.ListedColormap(colors) pic = affinity_cmap(summed_affinity) ax.imshow(pic[slc] if slc is not None else pic, extent=extent, origin=origin) # ------------------------------------------------------------------------- # Draw edges # ------------------------------------------------------------------------- line_segments = LineCollection(segments, color=edgecol, linewidths=linewidth) ax.add_collection(line_segments) if newfig: ax.set_axis_off() ax.invert_yaxis() canvas = FigureCanvas(fig) canvas.draw() # ------------------------------------------------------------------------- # Return values mirror original signature # ------------------------------------------------------------------------- if nopic: return summed_affinity, affinity_cmap return None, None
# @njit # def colorize(im,colors=None,color_weights=None,offset=0): # N = len(im) # if colors is None: # angle = np.arange(0,1,1/N)*2*np.pi+offset # angles = np.stack((angle,angle+2*np.pi/3,angle+4*np.pi/3),axis=-1) # colors = (np.cos(angles)+1)/2 # if color_weights is not None: # colors *= color_weights # rgb = np.zeros((im.shape[1], im.shape[2], 3)) # for i in range(N): # for j in range(3): # rgb[..., j] += im[i] * colors[i, j] # rgb /= N # return rgb # @njit
[docs] def colorize(im, colors=None, color_weights=None, offset=0, channel_axis=-1): N = len(im) if colors is None: angle = np.arange(0, 1, 1/N) * 2 * np.pi + offset angles = np.stack((angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3), axis=-1) colors = (np.cos(angles) + 1) / 2 if color_weights is not None: colors *= np.expand_dims(color_weights,-1) rgb_shape = im.shape[1:] + (colors.shape[1],) if channel_axis == 0: rgb_shape = rgb_shape[::-1] rgb = np.zeros(rgb_shape) # Use broadcasting to multiply im and colors and sum along the 0th dimension rgb = (np.expand_dims(im, axis=-1) * colors.reshape(colors.shape[0], 1, 1, colors.shape[1])).mean(axis=0) return rgb
# def colorize_GPU(im, colors=None, color_weights=None, offset=0, channel_axis=-1): # import torch # N = im.shape[0] # device = im.device # if colors is None: # angle = torch.linspace(0, 1, N, device=device) * 2 * np.pi + offset # angles = torch.stack((angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3), dim=-1) # colors = (torch.cos(angles) + 1) / 2 # if color_weights is not None: # colors *= color_weights.unsqueeze(-1) # # colors /= color_weights.sum() # im = im.unsqueeze(-1) # Add an extra dimension to `im` # # Perform the multiplication and mean computation using `einsum` - way faster than using view # rgb = torch.einsum('ijkl,il->jkl', im.float(), colors.float()) / N # return rgb # def colorize_GPU(im, colors=None, color_weights=None, offset=0, intervals=None): # import torch # import string # N = im.shape[0] # Number of channels # device = im.device # if colors is None: # angle = torch.linspace(0, 1, N, device=device) * 2 * np.pi + offset # angles = torch.stack((angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3), dim=-1) # colors = (torch.cos(angles) + 1) / 2 # Generate RGB colors # if color_weights is not None: # colors *= color_weights.unsqueeze(-1) # Apply color weights to colors # # Determine the number of spatial dimensions # num_spatial_dims = im.ndim - 1 # Exclude the channel dimension # # Generate index letters for einsum (excluding 'c' and 'l') # idx_letters = ''.join(letter for letter in string.ascii_lowercase if letter not in {'c', 'l'}) # spatial_indices = idx_letters[:num_spatial_dims] # # Build the einsum equation dynamically # im_indices = 'c' + spatial_indices # colors_indices = 'c l' # 'l' corresponds to the RGB channels # output_indices = spatial_indices + 'l' # einsum_eq = f'{im_indices},{colors_indices}->{output_indices}' # # print('einsum_eq:',einsum_eq) # # Perform the weighted sum across channels to produce the RGB image # rgb = torch.einsum(einsum_eq, im.float(), colors.float()) / N # return rgb
[docs] def colorize_GPU(im, colors=None, color_weights=None, intervals=None, offset=0): import torch import numpy as np import string from opt_einsum import contract C = im.shape[0] # Number of input channels device = im.device # Determine the number of spatial dimensions num_spatial_dims = im.ndim - 1 # Exclude the channel dimension # Generate index letters for einsum (excluding 'c' and 'l') idx_letters = ''.join(letter for letter in string.ascii_lowercase if letter not in {'c', 'l'}) spatial_indices = idx_letters[:num_spatial_dims] # Build einsum indices dynamically im_indices = 'c' + spatial_indices aggregator_indices = 'cN' # 'N' corresponds to the interval/bin dimension colors_indices = 'cl' # Colors indexed by channel and RGB output_indices = 'N' + spatial_indices + 'l' einsum_eq = f'{im_indices},{aggregator_indices},{colors_indices}->{output_indices}' # print('einsum_eq:', einsum_eq) if intervals is None: intervals = [C] N = len(intervals) # Number of intervals aggregator = torch.zeros(C, N, device=device) start = 0 for i, length in enumerate(intervals): aggregator[start:start + length, i] = 1 / length start += length # Generate colors if not provided if colors is None: angle = torch.linspace(0, 1, C, device=device) * 2 * np.pi + offset angles = torch.stack((angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3), dim=-1) colors = (torch.cos(angles) + 1) / 2 # Generate RGB colors for intervals or channels # Apply color weights if provided if color_weights is not None: colors *= color_weights.unsqueeze(-1) # Perform einsum operation # rgb = torch.einsum(einsum_eq, im.float(), aggregator.float(), colors.float()) rgb = contract(einsum_eq, im.float(), aggregator.float(), colors.float()) # big difference on CPU # Squeeze the interval dimension if no intervals are used if N==1: rgb = rgb.squeeze(0) # Shape: [spatial_dims..., 3] return rgb
def colorize_dask(im_dask, colors=None, color_weights=None, intervals=None, offset=0): import dask.array as da import numpy as np # Get the channel count and spatial shape. C = im_dask.shape[0] spatial_shape = im_dask.shape[1:] spatial_size = np.prod(spatial_shape) # If intervals is not provided, treat the entire channel set as one interval. if intervals is None: intervals = [C] N = len(intervals) # Build aggregator matrix of shape (C, N) aggregator = np.zeros((C, N), dtype=np.float32) start = 0 for i, size in enumerate(intervals): aggregator[start:start + size, i] = 1.0 / size start += size # Create default colors if not provided; shape will be (C, 3) if colors is None: angle = np.linspace(0, 1, C, endpoint=False) * 2 * np.pi + offset angles = np.stack([angle, angle + 2*np.pi/3, angle + 4*np.pi/3], axis=-1) colors = (np.cos(angles) + 1.0) / 2.0 if color_weights is not None: colors *= color_weights[:, None] # Combine aggregator and colors: shape (C, N, 3) aggregator_colors = aggregator[..., None] * colors[:, None, :] # Reshape aggregator_colors to (C, N*3) agg_col_reshaped = aggregator_colors.reshape(C, N * 3) # Reshape the dask array to (C, Z*Y*X) # im_flat = im_dask.reshape(C, -1,limit='') im_flat = im_dask.reshape(C, -1) # Contract over the channel dimension using dot: # Compute a dot product: (N*3, Z*Y*X) = (C, N*3).T dot (C, Z*Y*X) out_flat = da.dot(agg_col_reshaped.T, im_flat) # Reshape the output to (N, 3, Z, Y, X) out_reshaped = out_flat.reshape(N, 3, *spatial_shape) # Move the channel axis to the end: (N, Z, Y, X, 3) out_final = da.moveaxis(out_reshaped, 1, -1) # For a single interval, squeeze out the interval dimension if N == 1: out_final = out_final.squeeze(axis=0) return out_final
[docs] def colorize_dask_fast(im_dask, colors=None, color_weights=None, intervals=None, offset=0): import dask.array as da import numpy as np C = im_dask.shape[0] spatial_shape = im_dask.shape[1:] spatial_size = np.prod(spatial_shape) if intervals is None: intervals = [C] N = len(intervals) # Precompute aggregator matrix (C, N) aggregator = np.zeros((C, N), dtype=np.float32) start = 0 for i, size in enumerate(intervals): aggregator[start:start + size, i] = 1.0 / size start += size # Compute colors (C, 3) if colors is None: angle = np.linspace(0, 1, C, endpoint=False) * 2 * np.pi + offset angles = np.stack([angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3], axis=-1) colors = (np.cos(angles) + 1.0) / 2.0 if color_weights is not None: colors *= color_weights[:, None] # Precompute final weighting matrix: (C, N, 3) weights = aggregator[..., None] * colors[:, None, :] # shape (C, N, 3) # Collapse color dimensions early: (C, N*3) weights_flat = weights.reshape(C, N * 3) # Flatten input to shape (C, ZYX) im_flat = im_dask.reshape(C, -1) # Matrix multiplication: (N*3, ZYX) out_flat = da.dot(weights_flat.T, im_flat) # Reshape to (N, 3, Z, Y, X) out = out_flat.reshape(N, 3, *spatial_shape) # Move color channel to last axis: (N, Z, Y, X, 3) out = da.moveaxis(out, 1, -1) # If single interval, squeeze it out if N == 1: out = out.squeeze(axis=0) return out
# def colorize_dask_2(im_dask, colors=None, color_weights=None, intervals=None, offset=0): slow # import dask.array as da # import numpy as np # # Determine the number of channels # C = im_dask.shape[0] # if intervals is None: # intervals = [C] # N = len(intervals) # # Build the aggregator matrix (C, N) # aggregator = np.zeros((C, N), dtype=np.float32) # start = 0 # for i, size in enumerate(intervals): # aggregator[start:start+size, i] = 1.0 / size # start += size # # Generate default colors if not provided (shape: (C, 3)) # if colors is None: # angle = np.linspace(0, 1, C, endpoint=False) * 2 * np.pi + offset # angles = np.stack([angle, angle + 2*np.pi/3, angle + 4*np.pi/3], axis=-1) # colors = (np.cos(angles) + 1.0) / 2.0 # if color_weights is not None: # colors *= color_weights[:, None] # # Compute aggregator_colors: shape (C, N, 3) # aggregator_colors = aggregator[..., None] * colors[:, None, :] # # Use tensordot to contract the channel axis # # im_dask has shape (C, ...); aggregator_colors has shape (C, N, 3) # # tensordot over axis 0 will yield an output of shape (..., N, 3) # out = da.tensordot(im_dask, aggregator_colors, axes=([0], [0])) # # Rearrange axes: move the N-axis (currently second-to-last) to the front, # # so the output shape becomes (N, ..., 3) # out = da.moveaxis(out, -2, 0) # # If only one interval is used, squeeze out the extra dimension # if N == 1: # out = out.squeeze(axis=0) # return out def colorize_dask(im_dask, colors=None, color_weights=None, intervals=None, offset=0): # slower import dask.array as da import numpy as np from opt_einsum import contract # Determine the number of channels C = im_dask.shape[0] if intervals is None: intervals = [C] N = len(intervals) # Build the aggregator matrix of shape (C, N) aggregator = np.zeros((C, N), dtype=np.float32) start = 0 for i, size in enumerate(intervals): aggregator[start:start+size, i] = 1.0 / size start += size # Generate default colors if none provided; result shape is (C, 3) if colors is None: angle = np.linspace(0, 1, C, endpoint=False) * 2 * np.pi + offset angles = np.stack([angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3], axis=-1) colors = (np.cos(angles) + 1.0) / 2.0 # Apply any provided color weights if color_weights is not None: colors *= color_weights[:, None] # Combine aggregator and colors to form an array of shape (C, N, 3) aggregator_colors = aggregator[..., None] * colors[:, None, :] # Use dask.array.einsum to perform the colorization # out = da.einsum('c..., cnr -> n...r', im_dask, aggregator_colors) out = contract('c..., cnr -> n...r', im_dask, aggregator_colors) # If only one interval is used, squeeze out the extra axis if N == 1: out = out.squeeze(axis=0) return out
[docs] def colorize_dask(im_dask, colors=None, color_weights=None, intervals=None, offset=0): import numpy as np from opt_einsum import contract import dask # Number of channels C = im_dask.shape[0] spatial_shape = im_dask.shape[1:] spatial_size = np.prod(spatial_shape) # Interval setup if intervals is None: intervals = [C] N = len(intervals) # Build aggregator: shape (C, N) aggregator = np.zeros((C, N), dtype=np.float32) start = 0 for i, size in enumerate(intervals): aggregator[start : start + size, i] = 1.0 / size start += size # Default color generation: shape (C, 3) if colors is None: angle = np.linspace(0, 1, C, endpoint=False) * 2 * np.pi + offset angles = np.stack([angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3], axis=-1) colors = (np.cos(angles) + 1.0) / 2.0 # Apply any color weights if color_weights is not None: colors *= color_weights[:, None] # Combine aggregator and colors: shape (C, N, 3) # Keep this as a NumPy array to avoid a big Dask overhead agg_colors = aggregator[..., None] * colors[:, None, :] # Flatten input from (C, Z, Y, X) -> (C, Z*Y*X) with dask.config.set(**{'array.slicing.split_large_chunks': False}): im_flat = im_dask.reshape(C, spatial_size) # Perform a single einsum contraction: # cX * cNr -> NXr # c -> channel axis, X -> flattened spatial axis, N-> interval groups, r -> RGB out_flat = contract('cX,cNr->NXr', im_flat, agg_colors) # Reshape from (N, X, 3) -> (N, Z, Y, X, 3) out = out_flat.reshape(N, *spatial_shape, 3) # For a single interval, remove the interval dimension if N == 1: out = out[0] # shape (Z, Y, X, 3) return out
[docs] def colorize_dask_matmul(im_dask, colors=None, color_weights=None, intervals=None, offset=0): """ A faster version of colorize_dask that uses a single matrix multiply instead of explicit loops or opt_einsum for the core contraction step. """ import numpy as np import dask # Number of channels C = im_dask.shape[0] spatial_shape = im_dask.shape[1:] spatial_size = np.prod(spatial_shape) # Interval setup if intervals is None: intervals = [C] N = len(intervals) # Build aggregator: shape (C, N) aggregator = np.zeros((C, N), dtype=np.float32) start = 0 for i, size in enumerate(intervals): aggregator[start : start + size, i] = 1.0 / size start += size # Default color generation: shape (C, 3) if colors is None: angle = np.linspace(0, 1, C, endpoint=False) * 2 * np.pi + offset angles = np.stack([angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3], axis=-1) colors = (np.cos(angles) + 1.0) / 2.0 # shape (C, 3) # Apply any color weights if color_weights is not None: colors = colors * color_weights[:, None] # aggregator (C, N), colors (C, 3) # aggregator[..., None] * colors => shape (C, N, 3) # then collapse to shape (C, N*3) so that a single matrix multiply can be used combined = (aggregator[..., None] * colors[:, None, :]).reshape(C, N * 3).astype(np.float32) # Flatten the input image: (C, Z, Y, X) -> (C, Z*Y*X) # Casting to float32 can help ensure everything matches for the matrix multiply with dask.config.set(**{'array.slicing.split_large_chunks': False}): im_flat = im_dask.reshape(C, spatial_size, merge_chunks=True).astype(np.float32) # im_flat = da.reshape(im_dask, (C, spatial_size), merge_chunks=True).astype(np.float32) # Matrix multiplication: # im_flat^T is shape (X, C) # combined is shape (C, N*3) # => result is shape (X, N*3) out_mat = im_flat.T @ combined # Reshape: # out_mat: (X, N*3) # => (X, N, 3) => transpose to (N, X, 3) out_flat = out_mat.reshape(spatial_size, N, 3).transpose(1, 0, 2) # Finally shape it to (N, Z, Y, X, 3) out = out_flat.reshape(N, *spatial_shape, 3) # If only one interval, remove that dimension to keep the same behavior if N == 1: out = out[0] # shape (Z, Y, X, 3) return out
[docs] def apply_ncolor(masks,offset=0,cmap=None,max_depth=20,expand=True, maxv=1, greedy=False): import ncolor from cmap import Colormap cmap = Colormap(cmap) if isinstance(cmap, str) else cmap m,n = ncolor.label(masks, max_depth=max_depth, return_n=True, conn=2, expand=expand, greedy=greedy) if cmap is None: c = sinebow(n,offset=offset) colors = np.array(list(c.values())) cmap = mpl.colors.ListedColormap(colors) return cmap(m) else: return cmap(rescale(m)/maxv)
[docs] def set_outline(ax, outline_color=None, outline_width=0): """ - Always hide axis ticks (ax.axis("off")). - If outline_color is not None and outline_width > 0, show spines with that color/width. - Otherwise, hide spines (no border). """ # Always turn off ticks: # ax.axis("off") ax.set_xticks([]) ax.set_yticks([]) ax.patch.set_alpha(0) # Decide whether to draw spines: if outline_color is not None and outline_width > 0: for spine in ax.spines.values(): spine.set_edgecolor(outline_color) spine.set_linewidth(outline_width) else: # Hide spines entirely for s in ax.spines.values(): s.set_visible(False)
[docs] def imshow(imgs, figsize=2, ax=None, hold=False, titles=None, title_size=8, spacing=0.05, textcolor=[0.5]*3, dpi=300, text_scale=1, outline_color=None, # e.g. [0.5]*3 outline_width=0.5, # e.g. 0.5 show=False, **kwargs): """ Display one or more images. Optionally add an outline (colored border) around each image if outline_color is not None and outline_width > 0. Otherwise, axes ticks etc. remain off, as before. """ # ------------------------------------------------------------- # If imgs is a list, we display multiple images side by side # ------------------------------------------------------------- if isinstance(imgs, list): if titles is None: titles = [None] * len(imgs) if title_size is None: title_size = figsize / len(imgs) * text_scale # Create figure + subplots for multiple images fig, axes = figure( nrow=1, ncol=len(imgs), figsize=(figsize * len(imgs), figsize), frameon=False, facecolor=[0, 0, 0, 0] ) for this_ax, img, ttl in zip(axes, imgs, titles): this_ax.imshow(img, **kwargs) set_outline(this_ax, outline_color, outline_width) this_ax.set_facecolor([0, 0, 0, 0]) if ttl is not None: this_ax.set_title(ttl, fontsize=title_size, color=textcolor) # ------------------------------------------------------------- # Otherwise, just one image # ------------------------------------------------------------- else: if not isinstance(figsize, (list, tuple)): figsize = (figsize, figsize) if title_size is None: title_size = figsize[0] * text_scale if ax is None: subplot_args = { 'frameon': False, 'figsize': figsize, 'facecolor': [0, 0, 0, 0], 'dpi': dpi } fig, ax = figure(**subplot_args) else: hold = True fig = ax.get_figure() ax.imshow(imgs, **kwargs) set_outline(ax, outline_color, outline_width) ax.set_facecolor([0, 0, 0, 0]) if titles is not None: ax.set_title(titles, fontsize=title_size, color=textcolor) if not hold: display(fig) else: return fig
# def get_cmap(masks): # lut = ncolor.get_lut(masks) # make sure int64 # c = sinebow(lut.max()) # colors = [c[l] for l in lut] # cmap = mpl.colors.ListedColormap(colors) # return cmap # @njit() # def rgb_flow(dP,transparency=False,mask=None,norm=False): # """ dP is 2 x Y x X => 'optic' flow representation # Parameters # ------------- # dP: NDarray, float # Flow field component stack [B,dy,dx] # transparency: bool, default False # magnitude of flow controls opacity, not lightness (clear background) # mask: 2D array # Multiplies each RGB component to suppress noise # """ # mag = np.sqrt(np.sum(dP**2,axis=1)) # if norm: # mag = np.clip(utils.normalize99(mag), 0, 1.).astype(np.float32) # angles = np.arctan2(dP[:,1], dP[:,0])+np.pi # a = 2 # r = ((np.cos(angles)+1)/a) # g = ((np.cos(angles+2*np.pi/3)+1)/a) # b = ((np.cos(angles+4*np.pi/3)+1)/a) # if transparency: # im = np.stack((r,g,b,mag),axis=-1) # else: # im = np.stack((r*mag,g*mag,b*mag),axis=-1) # if mask is not None and transparency and dP.shape[0]<3: # im[...,-1] *= mask # im = (np.clip(im, 0, 1) * 255).astype(np.uint8) # return im # from numba import jit # @jit(nopython=True) # @njit() # def rgb_flow(dP, transparency=True, mask=None, norm=True): # mag = np.sqrt(np.sum(dP**2,axis=1)).reshape(1, -1) # vecs = dP[:,0] + dP[:,1]*1j # roots = np.exp(1j * np.pi * (2 * np.arange(3) / 3 +1)) # rgb = (np.real(roots * vecs.reshape(-1, 1) / np.max(mag)).T + 1 ) / 2 # if norm: # # mag = np.clip(utils.normalize99(mag), 0, 1.).astype(np.float32) # mag -= np.min(mag) # mag /= np.max(mag) # shape = dP.shape # newshape = (shape[0], shape[3], shape[2], 3+transparency) # # newshape = (shape[0], shape[2], shape[3], 3+transparency) # if transparency: # im = np.concatenate((rgb, mag), axis=0) # else: # im = rgb * mag # im = (np.clip(im.T.reshape(newshape), 0, 1) * 255).astype(np.uint8) # # im = np.swapaxes(im,1,2) # return im # @njit() # def rgb_flow(dP, transparency=True, mask=None, norm=True): # mag = np.sqrt(np.sum(dP**2,axis=1)) # vecs = dP[:,0] + dP[:,1]*1j # roots = np.exp(1j * np.pi * (2 * np.arange(3) / 3 +1)).reshape((1, 1, 1, -1)) # rgb = (np.real(vecs[...,None]*roots / np.max(mag)) + 1 ) / 2 # if norm: # mag -= np.min(mag) # mag /= np.max(mag) # shape = dP.shape # newshape = (shape[0], shape[2], shape[3], 3+transparency) # print(rgb.shape,newshape, mag.shape, vecs.shape) # if transparency: # im = np.empty(newshape) # im[..., :3] = rgb # im[..., 3] = mag # else: # im = rgb * mag # im = (np.clip(im, 0, 1) * 255).astype(np.uint8) # return im
[docs] def rgb_flow(dP, transparency=True, mask=None, norm=True, device=None): """Meant for stacks of dP, unsqueeze if using on a single plane.""" import torch if device is None: device = torch.device('cpu') if isinstance(dP,torch.Tensor): device = dP.device else: dP = torch.from_numpy(dP).to(device) mag = torch_norm(dP,dim=1) vecs = dP[:,0] + dP[:,1]*1j roots = torch.exp(1j * np.pi * (2 * torch.arange(3, device=device) / 3 +1)) rgb = (torch.real(vecs.unsqueeze(-1)*roots.view(1, 1, 1, -1) / torch.max(mag)) + 1 ) / 2 # f = 1.5 # rgb /= f # rgb += (1-1/f)/2 if norm: mag -= torch.min(mag) mag /= torch.max(mag) if transparency: im = torch.cat((rgb, mag[..., None]), dim=-1) else: im = rgb * mag[..., None] im = (torch.clamp(im, 0, 1) * 255).type(torch.uint8) return im
[docs] def create_colormap(image, labels): """ Create a colormap based on the average color of each label in the image. Parameters ---------- image: ndarray An RGB image. labels: ndarray A 2D array of labels corresponding to the image. Returns ------- colormap: ndarray A colormap where each row is the RGB color for the corresponding label. """ # Ensure the image is in the range 0-255 image = img_as_ubyte(image) # Initialize an array to hold the RGB color for each label colormap = np.zeros((labels.max() + 1, 3), dtype=np.uint8) # Calculate the average color for each label for label in np.unique(labels): mask = labels == label colormap[label] = image[mask].mean(axis=0) return colormap
[docs] def color_from_RGB(im,rgb,m,bd=None, mode='inner',connectivity=2): from skimage import color if bd is None: from skimage.segmentation import find_boundaries bd = find_boundaries(m,mode=mode,connectivity=connectivity) alpha = (m>0)*.5 alpha[bd] = 1 alpha = np.stack([alpha]*3,axis=-1) m = ncolor.format_labels(m) cmap = create_colormap(rgb,m) clrs = rescale(cmap[1:]) overlay = color.label2rgb(m,im,clrs, bg_label=0, alpha=alpha # saturation=1, # kind='overlay', # alpha=1 ) return overlay
[docs] def split_list(lst, N): return [lst[i:i + N] for i in range(0, len(lst), N)]
[docs] def image_grid(images, column_titles=None, row_titles=None, plot_labels=None, xticks=[], yticks=[], outline=False, outline_color=[0.5]*3, outline_width=.5, padding=0.05, interset_padding=0.1, fontsize=8, fontcolor=[0.5]*3, facecolor=None, figsize=6, dpi=300, order='ij', reverse_row=False, stack_direction='horizontal', lpad=0.05, lpos='top_middle', return_axes=False, fig=None, offset=[0, 0], supcolor=None, right_justify_rows=False, # New flag for right justification **kwargs): if supcolor is None: supcolor = fontcolor label_positions = { 'top_middle': {'coords': (0.5, 1 - lpad), 'va': 'top', 'ha': 'center'}, 'bottom_left': {'coords': (lpad, lpad), 'va': 'bottom', 'ha': 'left'}, 'bottom_middle': {'coords': (0.5, lpad), 'va': 'bottom', 'ha': 'center'}, 'top_left': {'coords': (lpad, 1 - lpad), 'va': 'top', 'ha': 'left'}, 'above_middle': {'coords': (.5, 1 +lpad), 'va': 'bottom', 'ha': 'center'}, } # Check if 'images' is a list of lists of lists, meaning multiple image sets if isinstance(images[0][0], list): multiple_sets = True else: multiple_sets = False images = [images] # Treat single set as a list of one plot_labels = [plot_labels] if plot_labels is not None else None n_sets = len(images) ij = order == 'ij' # if (not ij and column_titles is not None) or (ij and row_titles is not None): # row_titles, column_titles = column_titles, row_titles # ── swap the title lists when using column-major order ───────────────── if not ij: column_titles, row_titles = row_titles, column_titles # Initialize lists to hold positions and sizes all_left = [] all_bottom = [] all_width = [] all_height = [] # Initialize offset for stacking total_offset_x = 0 total_offset_y = 0 for set_idx, image_set in enumerate(images): # ───────────────────── grid dimensions ─────────────────────────── if ij: nrows = len(image_set) ncols = max(len(row) for row in image_set) else: ncols = len(image_set) nrows = max(len(col) for col in image_set) # ───────────────────── constant-size axis setup ────────────────── p = padding # gap between axes base = 1.0 # fixed width (ij) or height (!ij) positions = [] if ij: # constant widths → variable heights cur_bottom = total_offset_y for r, row in enumerate(image_set): rep = next((im for im in row if im is not None), None) ratio = (rep.shape[0] / rep.shape[1]) if rep is not None else 1.0 h = ratio * base row_offset = ((ncols - len(row)) * (base + p)) if right_justify_rows else 0 for c, _ in enumerate(row): left = total_offset_x + row_offset + c * (base + p) bottom = cur_bottom positions.append((left, bottom, base, h)) cur_bottom += h + p set_span_x = (base + p) * ncols - p set_span_y = cur_bottom - total_offset_y - p else: # constant heights → variable widths cur_left = total_offset_x for c, col in enumerate(image_set): rep = next((im for im in col if im is not None), None) aspect = (rep.shape[1] / rep.shape[0]) if rep is not None else 1.0 w = aspect * base for r, _ in enumerate(col): left = cur_left bottom = total_offset_y + r * (base + p) positions.append((left, bottom, w, base)) cur_left += w + p set_span_x = cur_left - total_offset_x - p set_span_y = (base + p) * nrows - p # ───────────────────── collect positions ───────────────────────── lefts, bottoms, widths, heights = zip(*positions) all_left.extend(lefts); all_bottom.extend(bottoms) all_width.extend(widths); all_height.extend(heights) # ───────────────────── inter-set stacking ──────────────────────── if multiple_sets and set_idx < n_sets - 1: if stack_direction == 'horizontal': total_offset_x += set_span_x + interset_padding elif stack_direction == 'vertical': total_offset_y += set_span_y + interset_padding # Normalize positions lefts = np.array(all_left) bottoms = np.array(all_bottom) widths = np.array(all_width) heights = np.array(all_height) max_w = max(lefts + widths) max_h = max(bottoms + heights) lefts /= max_w widths /= max_w # Adjust bottoms for top-down layout bottoms = (max_h - bottoms - heights) / max_h heights /= max_h # Use the existing figure if provided; otherwise, create a new one if fig is None: # if not isinstance(figsize, (list, tuple)): figsize=(figsize, figsize * max_h / max_w) if ij else (figsize * max_w / max_h, figsize) fig = Figure(figsize=figsize, frameon=False if facecolor is None else True, facecolor=[0] * 4 if facecolor is None else facecolor, dpi=dpi) # Apply offsets to the left and bottom positions lefts += offset[0] bottoms += offset[1] # Add the subplots axes = [] for idx, (left, bottom, width, height) in enumerate(zip(lefts, bottoms, widths, heights)): ax = fig.add_axes([left, bottom, width, height]) axes.append(ax) # Add images to the subplots idx = 0 for set_idx, image_set in enumerate(images): for row_idx, row in enumerate(image_set): for col_idx, img in enumerate(row): ax = axes[idx] idx += 1 ax.set_xticks(xticks) ax.set_yticks(yticks) ax.patch.set_alpha(0) if img is not None: ax.imshow(img, **kwargs) # Add plot labels if plot_labels is not None: try: label = plot_labels[set_idx][row_idx][col_idx] except IndexError: label = None if label is not None: coords = label_positions[lpos]['coords'] va = label_positions[lpos]['va'] ha = label_positions[lpos]['ha'] text = ax.text(coords[0], coords[1], label, fontsize=fontsize, color=fontcolor, va=va, ha=ha, transform=ax.transAxes) if img is None: text.set_color([.5] * 4) # ── column titles ────────────────────────────────────────── if column_titles is not None: want_title = ( (ij and row_idx == 0 and col_idx < len(column_titles)) or (not ij and col_idx == 0 and row_idx < len(column_titles)) ) if want_title and (stack_direction != 'vertical' or set_idx == 0): title_idx = col_idx if ij else row_idx ax.text(0.5, 1 + p, column_titles[title_idx], rotation=0, fontsize=fontsize, color=supcolor, va='bottom', ha='center', transform=ax.transAxes) # ── row titles ───────────────────────────────────────────── if row_titles is not None: want_title = ( (ij and col_idx == 0 and row_idx < len(row_titles)) or (not ij and row_idx == 0 and col_idx < len(row_titles)) ) if want_title and (stack_direction != 'horizontal' or set_idx == 0): title_idx = row_idx if ij else col_idx ax.text(-p, 0.5, row_titles[title_idx], rotation=0, fontsize=fontsize, color=supcolor, va='center', ha='right', transform=ax.transAxes) # Add outline if needed if outline: for s in ax.spines.values(): s.set_color(outline_color) s.set_linewidth(outline_width) else: for s in ax.spines.values(): s.set_visible(False) if return_axes: pos = [lefts, bottoms, widths, heights] return fig, axes, pos else: return fig
[docs] def color_grid(colors, **kwargs): # Convert colors to a numpy array colors = np.array(colors) # If colors is a 1D array (single color), reshape it to a 2D array if colors.ndim == 1: colors = colors.reshape(1, -1) # Ensure colors have 3 components (RGB) if colors.shape[-1] == 4: # If colors have 4 components (RGBA), remove the alpha component colors = colors[:, :3] # Create a list of 1x1 images images = [[np.full((1, 1, 3), color, dtype=np.float32)] for color in colors] # Display the image grid return image_grid(images, **kwargs)
# from https://stackoverflow.com/a/63530703/13326811
[docs] def colored_line_segments(xs,ys,zs=None,color='k',mid_colors=False): from scipy.interpolate import interp1d from matplotlib.colors import colorConverter if isinstance(color,str): color = colorConverter.to_rgba(color)[:-1] color = np.array([color for i in range(len(xs))]) segs = [] seg_colors = [] lastColor = [color[0][0],color[0][1],color[0][2]] start = [xs[0],ys[0]] end = [xs[0],ys[0]] if not zs is None: start.append(zs[0]) end.append(zs[0]) else: zs = [zs]*len(xs) for x,y,z,c in zip(xs,ys,zs,color): if mid_colors: seg_colors.append([(chan+lastChan)*.5 for chan,lastChan in zip(c,lastColor)]) else: seg_colors.append(c) lastColor = c[:-1] if not z is None: start = [end[0],end[1],end[2]] end = [x,y,z] else: start = [end[0],end[1]] end = [x,y] segs.append([start,end]) colors = [(*color,1) for color in seg_colors] return segs, colors
[docs] def segmented_resample(xs,ys,zs=None,color='k',n_resample=100,mid_colors=False): from scipy.interpolate import interp1d from matplotlib.colors import colorConverter n_points = len(xs) if isinstance(color,str): color = colorConverter.to_rgba(color)[:-1] color = np.array([color for i in range(n_points)]) n_segs = (n_points-1)*(n_resample-1) xsInterp = np.linspace(0,1,n_resample) segs = [] seg_colors = [] hiResXs = [xs[0]] hiResYs = [ys[0]] if not zs is None: hiResZs = [zs[0]] RGB = color.swapaxes(0,1) for i in range(n_points-1): fit_xHiRes = interp1d([0,1],xs[i:i+2]) fit_yHiRes = interp1d([0,1],ys[i:i+2]) xHiRes = fit_xHiRes(xsInterp) yHiRes = fit_yHiRes(xsInterp) hiResXs = hiResXs+list(xHiRes[1:]) hiResYs = hiResYs+list(yHiRes[1:]) R_HiRes = interp1d([0,1],RGB[0][i:i+2])(xsInterp) G_HiRes = interp1d([0,1],RGB[1][i:i+2])(xsInterp) B_HiRes = interp1d([0,1],RGB[2][i:i+2])(xsInterp) lastColor = [R_HiRes[0],G_HiRes[0],B_HiRes[0]] start = [xHiRes[0],yHiRes[0]] end = [xHiRes[0],yHiRes[0]] if not zs is None: fit_zHiRes = interp1d([0,1],zs[i:i+2]) zHiRes = fit_zHiRes(xsInterp) hiResZs = hiResZs+list(zHiRes[1:]) start.append(zHiRes[0]) end.append(zHiRes[0]) else: zHiRes = [zs]*len(xHiRes) if mid_colors: seg_colors.append([R_HiRes[0],G_HiRes[0],B_HiRes[0]]) for x,y,z,r,g,b in zip(xHiRes[1:],yHiRes[1:],zHiRes[1:],R_HiRes[1:],G_HiRes[1:],B_HiRes[1:]): if mid_colors: seg_colors.append([(chan+lastChan)*.5 for chan,lastChan in zip((r,g,b),lastColor)]) else: seg_colors.append([r,g,b]) lastColor = [r,g,b] if not z is None: start = [end[0],end[1],end[2]] end = [x,y,z] else: start = [end[0],end[1]] end = [x,y] segs.append([start,end]) colors = [(*color,1) for color in seg_colors] data = [hiResXs,hiResYs] if not zs is None: data = [hiResXs,hiResYs,hiResZs] return segs, colors, data
[docs] def faded_segment_resample(xs,ys,zs=None,color='k',fade_len=20,n_resample=100,direction='Head'): segs, colors, hiResData = segmented_resample(xs,ys,zs,color,n_resample) n_segs = len(segs) if fade_len>len(segs): fade_len=n_segs if direction=='Head': #Head fade alphas = np.concatenate((np.zeros(n_segs-fade_len),np.linspace(0,1,fade_len))) else: #Tail fade alphas = np.concatenate((np.linspace(1,0,fade_len),np.zeros(n_segs-fade_len))) colors = [(*color[:-1],alpha) for color,alpha in zip(colors,alphas)] return segs, colors, hiResData
# https://stackoverflow.com/a/27537018/13326811 def _get_perp_line(current_seg, out_of_page, linewidth): perp = np.cross(current_seg, out_of_page)[0:2] perp_unit = _get_unit_vector(perp) current_seg_perp_line = perp_unit*linewidth return current_seg_perp_line def _get_unit_vector(vector): vector_size = (vector[0]**2 + vector[1]**2)**0.5 vector_unit = vector / vector_size return vector_unit[0:2]
[docs] def colored_line(x, y, ax, z=None, line_width=1, MAP='jet'): # use pcolormesh to make interpolated rectangles num_pts = len(x) [xs, ys, zs] = [ np.zeros((num_pts,2)), np.zeros((num_pts,2)), np.zeros((num_pts,2)) ] dist = 0 out_of_page = [0, 0, 1] for i in range(num_pts): # set the colors and the x,y locations of the source line xs[i][0] = x[i] ys[i][0] = y[i] if i > 0: x_delta = x[i] - x[i-1] y_delta = y[i] - y[i-1] seg_length = (x_delta**2 + y_delta**2)**0.5 dist += seg_length zs[i] = [dist, dist] # define the offset perpendicular points if i == num_pts - 1: current_seg = [x[i]-x[i-1], y[i]-y[i-1], 0] else: current_seg = [x[i+1]-x[i], y[i+1]-y[i], 0] current_seg_perp = _get_perp_line( current_seg, out_of_page, line_width) if i == 0 or i == num_pts - 1: xs[i][1] = xs[i][0] + current_seg_perp[0] ys[i][1] = ys[i][0] + current_seg_perp[1] continue current_pt = [x[i], y[i]] current_seg_unit = _get_unit_vector(current_seg) previous_seg = [x[i]-x[i-1], y[i]-y[i-1], 0] previous_seg_perp = _get_perp_line( previous_seg, out_of_page, line_width) previous_seg_unit = _get_unit_vector(previous_seg) # current_pt + previous_seg_perp + scalar * previous_seg_unit = # current_pt + current_seg_perp - scalar * current_seg_unit = scalar = ( (current_seg_perp - previous_seg_perp) / (previous_seg_unit + current_seg_unit) ) new_pt = current_pt + previous_seg_perp + scalar[0] * previous_seg_unit xs[i][1] = new_pt[0] ys[i][1] = new_pt[1] # fig, ax = plt.subplots() # cm = cm.get_cmap(MAP) cm = mpl.colormaps[MAP] ax.pcolormesh(xs, ys, zs, shading='gouraud', cmap=cm)
[docs] def color_swatches(colors, figsize=0.5, dpi=150, fontsize=5, fontcolor='w', padding=0.05, titles=None, ncol=None): if ncol is None: ncol = len(colors) # Convert colors to a numpy array colors = np.array(colors) # If colors is a 1D array (single color), reshape it to a 2D array if colors.ndim == 1: colors = colors.reshape(1, -1) # Create a list of swatches swatches = [np.full((1, 1, 3), color, dtype=np.float32) for color in colors] # Display the swatches # return imshow(swatches, figsize=figsize, dpi=dpi, titles=titles) return image_grid(split_list(swatches,ncol), plot_labels=split_list(titles,ncol) if titles is not None else None, padding=0.05, fontsize=fontsize, fontcolor=fontcolor, facecolor=[0]*4, figsize=figsize*ncol, dpi=dpi)
[docs] def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): from matplotlib.colors import LinearSegmentedColormap cmap=mpl.colormaps[cmap] if isinstance(cmap, str) else cmap new_cmap = LinearSegmentedColormap.from_list( 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval), cmap(np.linspace(minval, maxval, n))) return new_cmap
from operator import sub
[docs] def get_aspect(ax): # Total figure size figW, figH = ax.get_figure().get_size_inches() # Axis size on figure _, _, w, h = ax.get_position().bounds # Ratio of display units disp_ratio = (figH * h) / (figW * w) # Ratio of data units # Negative over negative because of the order of subtraction data_ratio = sub(*ax.get_ylim()) / sub(*ax.get_xlim()) return disp_ratio / data_ratio
from .utils import kernel_setup, get_neighbors from .core import boundary_to_masks, masks_to_affinity, get_contour import matplotlib.patches as mpatches import matplotlib.path as mpath from skimage.segmentation import find_boundaries from scipy.interpolate import splprep, splev from matplotlib.collections import PatchCollection
[docs] def vector_contours(fig,ax,mask, crop=None, smooth_factor=5, color = 'r', linewidth=1, y_offset=0, x_offset=0, pad=2, mode='constant', zorder=1, ): msk = np.pad(mask,pad,mode='edge') # msk = np.pad(mask,pad,mode=mode) if crop is not None: # Crop the mask to the specified region msk = msk[crop] msk = np.pad(msk,1,mode='constant', constant_values=0) # set up dimensions dim = msk.ndim shape = msk.shape steps,inds,idx,fact,sign = kernel_setup(dim) # remove spur points - this method is way easier than running core._despur() on the priginal affinity graph bd = find_boundaries(msk,mode='inner',connectivity=2) msk, bounds, _ = boundary_to_masks(bd,binary_mask=msk>0,connectivity=1,min_size=0) # generate affinity graph coords = np.nonzero(msk) neighbors = get_neighbors(tuple(coords),steps,dim,shape) # shape (d,3**d,npix) affinity_graph = masks_to_affinity(msk, coords, steps, inds, idx, fact, sign, dim, neighbors) # find contours contour_map, contour_list, unique_L = get_contour(msk, affinity_graph, coords, neighbors, cardinal_only=True) # List to hold patches patches = [] for contour in contour_list: if len(contour) > 1: pts = np.stack([c[contour] for c in coords]).T[:, ::-1] # YX to XY pts+= np.array([x_offset,y_offset]) # Apply offsets tck, u = splprep(pts.T, u=None, s=len(pts)/smooth_factor, per=1) u_new = np.linspace(u.min(), u.max(), len(pts)) x_new, y_new = splev(u_new, tck, der=0) # Define the points of the polygon # points = np.column_stack([y_new-pad+y_offset, x_new-pad+x_offset]) # points = np.column_stack([ x_new-pad+x_offset,y_new-pad+y_offset]) # points = np.column_stack([ x_new-2*pad+x_offset,y_new-2*pad+y_offset]) # points = np.column_stack([x_new-pad,y_new-pad]) if isinstance(pad,tuple): # If pad is a tuple, apply it to x and y separately points = np.column_stack([x_new-(pad[0][0]+1), y_new-(pad[1][0]+1)]) else: points = np.column_stack([x_new-(pad+1),y_new-(pad+1)]) # Create a Path from the points path = mpath.Path(points, closed=True) # Create a PathPatch from the Path patch = mpatches.PathPatch(path, fill=None, edgecolor=color, # linewidth= fig.dpi/72, linewidth=linewidth, zorder=zorder, capstyle='round') # ax.add_patch(patch) # Add patch to list patches.append(patch) # Create a PatchCollection from the list of patches # Add the PatchCollection to the axis/axes if isinstance(ax,list): for a in ax: patch_collection = PatchCollection(patches, match_original=True, snap=False) a.add_collection(patch_collection) else: patch_collection = PatchCollection(patches, match_original=True, snap=False) ax.add_collection(patch_collection)