Omnipose in 3D

You can use the dim (dimension) argument to tell Omnipose to segment your images using a 3D model. This means that an image stack or 3D array is treated as a 3D volume given to a network trained on 3D volumes. This is very different from do_3D in Cellpose, which cleverly leveraged 2D predictions on all 2D slices of a 3D volume to construct a 3D flow field for segmentation. It turns out that the pseudo-ND Cellpose flows are an approximation to the true 3D flows of Omnipose, because the flows in each slice point to a local center of the cell, a.k.a. the cell skeleton to which the Omnipose field points. Thus, is not recommended to use Omnipose 2D slice predictions with do_3D. Instead, this notebook assumes you have trained a 3D model such as the plant_omni model.

 1# Import dependencies
 2import numpy as np
 3from cellpose_omni import models, core
 4
 5# This checks to see if you have set up your GPU properly.
 6# CPU performance is a lot slower, but not a problem if you 
 7# are only processing a few images.
 8use_GPU = core.use_gpu()
 9
10# for plotting
11import matplotlib as mpl
12import matplotlib.pyplot as plt
13mpl.rcParams['figure.dpi'] = 300
14plt.style.use('dark_background')
15%matplotlib inline

Read in data

Here I am choosing one of the scaled-down volumes of the plant Arabidopsis thaliana dataset we used in the Omnipose paper.

1from pathlib import Path
2import os
3from cellpose_omni import io
4import omnipose
5omnidir = Path(omnipose.__file__).parent.parent
6basedir = os.path.join(omnidir,'docs','test_files_3D')
7files = io.get_image_files(basedir)
8files # this displays the variable if it the last thing in the code block
['/home/kcutler/DataDrive/omnipose/docs/test_files_3D/Movie1_t00004_crop_gt.tif']
 1from cellpose_omni import io, transforms
 2from omnipose.utils import normalize99
 3
 4imgs = [io.imread(f) for f in files]
 5
 6# print some info about the images.
 7for i in imgs:
 8    print('Original image shape:',i.shape)
 9    print('data type:',i.dtype)
10    print('data range:', i.min(),i.max())
11nimg = len(imgs)
12print('number of images:',nimg)
Original image shape: (162, 207, 443)
data type: uint8
data range: 0 247
number of images: 1

Initialize model

plant_omni is the model trained on these plant cell images. (The image we loaded is from the test set, of course.)

 1from cellpose_omni import models
 2model_name = 'plant_omni'
 3
 4dim = 3
 5nclasses = 3 # flow + dist + boundary
 6nchan = 1
 7omni = 1
 8rescale = False
 9diam_mean = 0
10use_GPU = 0 # Most people do not have enough VRAM to run on GPU... 24GB not enough for this image, need nearly 48GB
11model = models.CellposeModel(gpu=use_GPU, model_type=model_name, net_avg=False, 
12                             diam_mean=diam_mean, nclasses=nclasses, dim=dim, nchan=nchan)
2025-02-01 03:02:36,537	[INFO]     cellpose_omni/models.py       __init__....()	 line 436	>>plant_omni<< model set to be used
2025-02-01 03:02:36,538	[INFO]     cellpose_omni/core.py         assi...evice()	 line  72	Using CPU.
2025-02-01 03:02:36,640	[INFO]                                   __init__....()	 line 173	u-net config: ([1, 32, 64, 128, 256], 5, 3)

Run segmentation

 1import torch
 2torch.cuda.empty_cache()
 3mask_threshold = -5 #usually this is -1
 4flow_threshold = 0.
 5diam_threshold = 12
 6net_avg = False
 7cluster = False
 8verbose = 1
 9tile = 0
10chans = None
11compute_masks = 1
12resample=False
13rescale=None
14omni=True
15flow_factor = 10 # multiple to increase flow magnitude, useful in 3D
16transparency = True
17
18nimg = len(imgs)
19masks_om, flows_om = [[]]*nimg,[[]]*nimg
20
21# splitting the images into batches helps manage VRAM use so that memory can get properly released 
22# here we have just one image, but most people will have several to process
23for k in range(nimg):
24    masks_om[k], flows_om[k], _ = model.eval(imgs[k],
25                                             channels=chans,
26                                             rescale=rescale,
27                                             mask_threshold=mask_threshold,
28                                             net_avg=net_avg,
29                                             transparency=transparency,
30                                             flow_threshold=flow_threshold,
31                                             omni=omni,
32                                             resample=resample,
33                                             verbose=verbose,
34                                             diam_threshold=diam_threshold,
35                                             cluster=cluster,
36                                             niter=10,
37                                             tile=tile,
38                                             compute_masks=compute_masks,
39                                             flow_factor=flow_factor) 
2025-02-01 03:02:36,988	[INFO]     cellpose_omni/models.py       eval........()	 line 696	is_grey False, slice_ndim 3, dim 3, nchan 1, is_list False
2025-02-01 03:02:36,988	[INFO]                                               	 line 707	is_image True, is_stack False, is_list False
2025-02-01 03:02:36,989	[INFO]                                               	 line 729	Evaluating with flow_threshold 0.00, mask_threshold -5.00
2025-02-01 03:02:36,989	[INFO]                                               	 line 731	using omni model, cluster False
2025-02-01 03:02:36,989	[INFO]                                               	 line 1130	not using dataparallel
2025-02-01 03:02:36,990	[INFO]                                               	 line 1133	network initialized.
2025-02-01 03:02:37,137	[INFO]                                               	 line 1140	shape before transforms.convert_image(): (162, 207, 443)
2025-02-01 03:02:37,137	[INFO]                                               	 line 1141	model dim: 3
2025-02-01 03:02:37,138	[INFO]     cellpose_omni/transforms.py   conv...image()	 line 506	multi-stack tiff read in as having 162 planes 1 channels
2025-02-01 03:02:37,138	[INFO]     cellpose_omni/models.py       eval........()	 line 1151	shape after transforms.convert_image(): (162, 207, 443, 1)
2025-02-01 03:02:37,138	[INFO]                                               	 line 1157	shape now (1, 162, 207, 443, 1)
2025-02-01 03:03:26,491	[INFO]     omnipose/core.py              comp...masks()	 line 1334	mask_threshold is -5
2025-02-01 03:03:26,491	[INFO]                                               	 line 1346	Using hysteresis threshold.
dP_ times 10 for >2d, still experimenting
2025-02-01 03:03:27,288	[INFO]                                   follow_flows()	 line 2049	niter: 10, interp: True, suppress: True, calc_trace: False
2025-02-01 03:03:27,446	[INFO]                                   steps_batch.()	 line 1941	interp is False, interpolation mode is nearest
2025-02-01 03:03:29,501	[INFO]                                   follow_flows()	 line 2093	done follow_flows
2025-02-01 03:03:29,534	[INFO]                                   get_masks...()	 line 1781	Mean diameter is 25.689877
2025-02-01 03:03:29,663	[INFO]                                               	 line 1795	cluster: False, SKLEARN_ENABLED: True
2025-02-01 03:03:29,949	[INFO]                                               	 line 1881	nclasses: 5, mask.ndim: 3
2025-02-01 03:03:29,958	[INFO]                                               	 line 1885	Using boundary output to split edge defects.
2025-02-01 03:03:30,156	[INFO]                                               	 line 1899	Done finding masks.
2025-02-01 03:03:31,092	[INFO]                                   comp...masks()	 line 1644	compute_masks() execution time: 4.6 sec
2025-02-01 03:03:31,092	[INFO]                                               	 line 1646		execution time per pixel: 3.09738e-07 sec/px
2025-02-01 03:03:31,098	[INFO]                                               	 line 1647		execution time per cell pixel: 8.05205e-07 sec/px

Plot results

3D segmentation is a lot harder to show than 2D. If anyone figures out a good way to use one of the many tools out there (ipyvolume, K3D-Jupyter, itkwidgets, ipygany) for label visualization (not image volumes), please let me know. Few of these are in active development, and my own 3D work requires robust label editing tools anyway, which I do not think any available tools offer. Hence I shall just load in Napari and show you an auto-captured screenshot.

 1%%capture
 2import ncolor
 3mask = masks_om[0]
 4mask_nc = ncolor.label(mask,max_depth=20)
 5
 6import napari
 7viewer = napari.view_labels(mask_nc);
 8viewer.dims.ndisplay = 3
 9viewer.camera.center = [s//2 for s in mask.shape]
10viewer.camera.zoom=1
11viewer.camera.angles=(10.90517458968619, -20.777067798396835, 58.04311170773853)
12viewer.camera.perspective=0.0
13viewer.camera.interactive=True
14
15img = viewer.screenshot(size=(1000,1000),scale=1,canvas_only=True,flash=False)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
1plt.figure(figsize=(3,3),frameon=False)
2plt.imshow(img)
3plt.axis('off')
4plt.show()
../_images/94e244360822632da171677d5e160f2732625af534ef1aaff06676a69b7274fd.png
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)

Plot orthogonal slices

 1from cellpose_omni import plot
 2from omnipose.plot import apply_ncolor
 3
 4mu = flows_om[0][1]
 5T = flows_om[0][2]
 6bd = flows_om[0][4]
 7# mu.shape,T.shape,bd.shape
 8
 9d = mu.shape[0]
10
11from omnipose.utils import rescale
12c = np.array([1]*2+[0]*(d-2))
13# c = np.arange(d)
14def cyclic_perm(a):
15    n = len(a)
16    b = [[a[i - j] for i in range(n)] for j in range(n)]
17    return b
18slices = []
19idx = np.arange(d)
20cmap = mpl.colormaps['magma']
21cmap2 = mpl.colormaps['viridis']
22
23for inds in cyclic_perm(c):
24    slc = tuple([slice(-1) if i else mu.shape[k+1]//2 for i,k in zip(inds,idx)])
25    flow = plot.dx_to_circ(mu[np.where(inds)+slc],transparency=1)/255
26    dist = cmap(rescale(T)[slc])
27    bnds = cmap2(rescale(bd)[slc])
28    msks = apply_ncolor(masks_om[0][slc])
29
30    fig = plt.figure(figsize=[5]*2,frameon=False)
31    plt.imshow(np.hstack((flow,dist,bnds,msks)),interpolation='none')
32    plt.axis('off')
33    plt.show()
34    
../_images/5510d26a9364044c8f19dbae8e5684c76617a303a251b65c6fb6889f0c916d1b.png ../_images/094671c091886973cf5998c41509a260cd44c5d01811e9111cce04904c73aa9d.png ../_images/671b168aa1b7528f22efb89204694e742692348bbc2803e1218c0e72e7544440.png
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)

Notes on the above

Slices do not always look crisp because we are cutting though boundaries. At these locations, the flow and distance fields darken and the boundary field brightens. This can result in flat and muddled regions that are hard to interpret. Again, interactive 3D visualization tools are needed to properly evaluate the results of the segmentation. In this case, we have cut through the middle of enough cells to confirm that the output looks reasonable.

This small dataset with problematic annotations was sufficient for demonstrating that Omnipose can be used on 3D data, but I again emphasize that any algorithm will only work well after training on well-annotated, representative examples. In this case, small cell clusters were neither well-annotated nor well-represented in the training set, and you can see the negative impact of that in this example.

These 3D models are incredibly VRAM-hungry, so all results in the paper were actually run on an AWS instance. Here I ran them on CPU, which is much slower but necessary to do even with a 24GB Titan RTX.

Running Omnipose with do_3D

do_3D is not something you want to use with any Omnipose model, but you might want to use it with a 2D Cellpose model for 3D cells with extended shapes. This is because do_3D computes 2D flow fields from every yx, yz, and xz slice of the image and composites these components into a 3D field. It turns out that the center-seeking flow slices of Cellpose end up pointing roughly toward the local 3D skeleton, i.e. the do_3D Cellpose composite field approximates the true 3D flows of Omnipose. The 2D Omnipose field, on the other hand, cannot be composited into a useful 3D field.

Althought the do_3D Cellpose field directs pixels toward the skeleton, the stock Cellpose mask reconstruction algorithm tends to oversegment pixels into clusters along the skeleton. To avoid this, you can use a Cellpose model but with Omnipose mask reconstruciton by usin omni=True. Here is how to do this.

 1from cellpose_omni import models, core
 2
 3# define cellpose model
 4model_name = 'plant_cp'
 5
 6# this model was trained on 2D slices 
 7dim = 2 
 8nclasses = 2 # cellpose models have no boundary field, just flow and distance 
 9
10# Cellpose defaults to 2 channels; 
11# this is the setup for grayscale in that case
12nchan = 2
13chans = [0,0]
14
15# no rescaling for this model
16diam_mean = 0
17
18
19use_GPU = core.use_gpu()
20model = models.CellposeModel(gpu=use_GPU, model_type=model_name, net_avg=False, 
21                             diam_mean=diam_mean, nclasses=nclasses, dim=dim, nchan=nchan)
22
23
24# segmentation parameters 
25omni = 1
26rescale = False
27mask_threshold = 0 
28net_avg = 0
29verbose = 1
30tile = 0
31compute_masks = 1
32rescale = None
33flow_threshold=0.
34do_3D=True
35flow_factor=10
36
37masks_cp, flows_cp, _ = model.eval(imgs,
38                                   channels=chans,
39                                   rescale=rescale,
40                                   mask_threshold=mask_threshold,
41                                   net_avg=net_avg,
42                                   transparency=True, 
43                                   flow_threshold=flow_threshold,
44                                   verbose=verbose, 
45                                   tile=tile,
46                                   compute_masks=compute_masks, 
47                                   do_3D=True, 
48                                   omni=omni,
49                                   flow_factor=flow_factor)
2024-06-28 18:56:21,718	[INFO]	core    _use...torch()	 line 74	** TORCH GPU version installed and working. **
2024-06-28 18:56:21,718	[INFO]	models  __init__....()	 line 431	>>plant_cp<< model set to be used
2024-06-28 18:56:21,719	[INFO]	core    _use...torch()	 line 74	** TORCH GPU version installed and working. **
2024-06-28 18:56:21,719	[INFO]	        assi...evice()	 line 85	>>>> using GPU
2024-06-28 18:56:21,749	[INFO]	        __init__....()	 line 183	u-net config: ([2, 32, 64, 128, 256], 3, 2)
2024-06-28 18:56:21,773	[INFO]	models  eval........()	 line 691	is_grey True, slice_ndim 3, dim 2, nchan 2, is_list True
2024-06-28 18:56:21,774	[INFO]	                    	 line 702	is_image False, is_stack False, is_list True
2024-06-28 18:56:21,774	[INFO]	                    	 line 1050	Evaluating one image at a time
2024-06-28 18:56:21,774	[INFO]	                    	 line 691	is_grey True, slice_ndim 3, dim 2, nchan 2, is_list False
2024-06-28 18:56:21,775	[INFO]	                    	 line 702	is_image True, is_stack False, is_list False
2024-06-28 18:56:21,775	[INFO]	                    	 line 724	Evaluating with flow_threshold 0.00, mask_threshold 0.00
2024-06-28 18:56:21,775	[INFO]	                    	 line 726	using omni model, cluster False
2024-06-28 18:56:21,775	[INFO]	                    	 line 1110	using dataparallel
2024-06-28 18:56:21,775	[INFO]	                    	 line 1122	network initialized.
2024-06-28 18:56:21,798	[INFO]	                    	 line 1129	shape before transforms.convert_image(): (162, 207, 443)
2024-06-28 18:56:21,798	[INFO]	                    	 line 1130	model dim: 2
multi-stack tiff read in as having 162 planes 1 channels
2024-06-28 18:56:21,825	[INFO]	models  eval........()	 line 1140	shape after transforms.convert_image(): (162, 207, 443, 2)
2024-06-28 18:56:22,234	[INFO]	core    _run_3D.....()	 line 712	running YX: 162 planes of size (207, 443)
0%|          | 0/15 [00:00<?, ?it/s]
7%|6         | 1/15 [00:00<00:01,  9.84it/s]
20%|##        | 3/15 [00:00<00:01, 11.48it/s]
33%|###3      | 5/15 [00:00<00:00, 11.93it/s]
47%|####6     | 7/15 [00:00<00:00, 12.16it/s]
60%|######    | 9/15 [00:00<00:00, 12.24it/s]
73%|#######3  | 11/15 [00:00<00:00, 12.29it/s]
87%|########6 | 13/15 [00:01<00:00, 12.27it/s]
100%|##########| 15/15 [00:01<00:00, 12.39it/s]
100%|##########| 15/15 [00:01<00:00, 12.17it/s]
2024-06-28 18:56:23,780	[INFO]	core    _run_3D.....()	 line 712	running ZY: 207 planes of size (162, 443)
0%|          | 0/19 [00:00<?, ?it/s]
11%|#         | 2/19 [00:00<00:01, 14.53it/s]
21%|##1       | 4/19 [00:00<00:01, 14.62it/s]
32%|###1      | 6/19 [00:00<00:00, 14.53it/s]
42%|####2     | 8/19 [00:00<00:00, 14.49it/s]
53%|#####2    | 10/19 [00:00<00:00, 14.49it/s]
63%|######3   | 12/19 [00:00<00:00, 14.50it/s]
74%|#######3  | 14/19 [00:00<00:00, 14.49it/s]
84%|########4 | 16/19 [00:01<00:00, 14.49it/s]
95%|#########4| 18/19 [00:01<00:00, 14.52it/s]
100%|##########| 19/19 [00:01<00:00, 14.53it/s]
2024-06-28 18:56:25,502	[INFO]	core    _run_3D.....()	 line 712	running ZX: 443 planes of size (162, 207)
0%|          | 0/14 [00:00<?, ?it/s]
14%|#4        | 2/14 [00:00<00:00, 14.49it/s]
29%|##8       | 4/14 [00:00<00:00, 14.34it/s]
43%|####2     | 6/14 [00:00<00:00, 14.34it/s]
57%|#####7    | 8/14 [00:00<00:00, 14.36it/s]
71%|#######1  | 10/14 [00:00<00:00, 14.33it/s]
86%|########5 | 12/14 [00:00<00:00, 14.36it/s]
100%|##########| 14/14 [00:00<00:00, 14.42it/s]
100%|##########| 14/14 [00:00<00:00, 14.37it/s]
2024-06-28 18:56:27,606	[INFO]	models  _run_cp.....()	 line 1317	network run in 5.78s
2024-06-28 18:56:27,607	[INFO]	core    comp...masks()	 line 1322	mask_threshold is 0
2024-06-28 18:56:27,607	[INFO]	                    	 line 1334	Using hysteresis threshold.
dP_ times 10 for >2d, still experimenting
2024-06-28 18:56:28,163	[INFO]	        follow_flows()	 line 2026	niter: None, interp: True, suppress: True, calc_trace: False
2024-06-28 18:56:28,203	[INFO]	        steps_batch.()	 line 1918	interp is False, interpolation mode is nearest
2024-06-28 18:56:28,539	[INFO]	        follow_flows()	 line 2070	done follow_flows
2024-06-28 18:56:28,548	[INFO]	        get_masks...()	 line 1768	Mean diameter is 60.140644
2024-06-28 18:56:28,643	[INFO]	                    	 line 1782	cluster: False, SKLEARN_ENABLED: True
2024-06-28 18:56:28,822	[INFO]	                    	 line 1858	nclasses: 3, mask.ndim: 3
2024-06-28 18:56:28,943	[INFO]	                    	 line 1876	Done finding masks.
2024-06-28 18:56:29,409	[INFO]	        comp...masks()	 line 1631	compute_masks() execution time: 1.8 sec
2024-06-28 18:56:29,409	[INFO]	                    	 line 1633		execution time per pixel: 1.21324e-07 sec/px
2024-06-28 18:56:29,413	[INFO]	                    	 line 1634		execution time per cell pixel: 3.02984e-07 sec/px
2024-06-28 18:56:29,417	[INFO]	models  _run_cp.....()	 line 1443	masks created in 1.81s

Compare masks to ground truth

1from fastremap import unique
2mgt = io.imread(files[0][:-4]+'_masks.tif')
3print('Cellpose do_3D + omni=True: {} masks. \nOmnipose 3D: {} masks. \nGround truth: {} masks'.format(len(unique(masks_cp[0])),
4                                                                                                    len(unique(masks_om[0])), 
5                                                                                                    len(unique(mgt))))
Cellpose do_3D + omni=True: 83 masks. 
Omnipose 3D: 113 masks. 
Ground truth: 67 masks

For what it's worth, pure Cellpose gives ~550 masks in this volume, pure Omnipose gives ~200, and Cellpose model + Omnipose mask reconstruction gives ~50. I'm sorry to say that the ground truth for this dataset is quite bad, containing some undersegmented cells, but more importantly, an entire "ignore" region where there are many, many cells that are unlabeled. So the count of 67 cells in the ground truth refers only to the long cells on the outside of the root. Thus, 55 cells is a severe under-segmentation of the volume. Let's see why.

Plot results

 1from cellpose_omni import plot
 2from omnipose.plot import apply_ncolor
 3
 4mu = flows_cp[0][1]
 5T = flows_cp[0][2]
 6bd = flows_cp[0][4]
 7# mu.shape,T.shape,bd.shape
 8
 9d = mu.shape[0]
10
11from omnipose.utils import rescale
12c = np.array([1]*2+[0]*(d-2))
13# c = np.arange(d)
14def cyclic_perm(a):
15    n = len(a)
16    b = [[a[i - j] for i in range(n)] for j in range(n)]
17    return b
18slices = []
19idx = np.arange(d)
20cmap = mpl.colormaps['magma']
21cmap2 = mpl.colormaps['viridis']
22
23for inds in cyclic_perm(c):
24    slc = tuple([slice(-1) if i else mu.shape[k+1]//2 for i,k in zip(inds,idx)])
25    flow = plot.dx_to_circ(mu[np.where(inds)+slc],transparency=1)/255
26    dist = cmap(rescale(T)[slc])
27    msks = apply_ncolor(masks_cp[0][slc])
28
29    fig = plt.figure(figsize=[5]*2,frameon=False)
30    plt.imshow(np.hstack((flow,dist,msks)),interpolation='none')
31    
32    plt.axis('off')
33    plt.show()
../_images/65b57d8ff51a430fc73525e3e8acc7250f5d1682bd612eb9343749f0459b731d.png ../_images/d56102821d1841694bb061cd72da521aa3392b1c4d8ae9018c13e2acacec84f4.png ../_images/c7cddc6791502abc91e7f5dbb27b9adfac68e9d91bf60659d1abc98b65f60820.png
 1%%capture
 2import ncolor
 3mask = masks_cp[0]
 4mask_nc = ncolor.label(mask,max_depth=20)
 5
 6import napari
 7viewer = napari.view_labels(mask_nc);
 8viewer.dims.ndisplay = 3
 9viewer.camera.center = [s//2 for s in mask.shape]
10viewer.camera.zoom=1
11viewer.camera.angles=(10.90517458968619, -20.777067798396835, 58.04311170773853)
12viewer.camera.perspective=0.0
13viewer.camera.interactive=True
14
15img = viewer.screenshot(size=(1000,1000),scale=1,canvas_only=True,flash=False)
1plt.figure(figsize=(3,3),frameon=False)
2plt.imshow(img)
3plt.axis('off')
4plt.show()
../_images/7e2047daaa4a03d930bf28b47a37001b536268d7cf7441cb55dcd129cdb36568.png

It appears that omni=True does allow 2D Cellpose models to work in 3D, but the prediction quality - worsened by artifacts introduced by the compoisiting into 3D - is a limiting factor.