Omnipose in 3D¶
You can use the dim (dimension) argument to tell Omnipose to segment your images using a 3D model. This means that an image stack or 3D array is treated as a 3D volume given to a network trained on 3D volumes. This is very different from do_3D in Cellpose, which cleverly leveraged 2D predictions on all 2D slices of a 3D volume to construct a 3D flow field for segmentation. It turns out that the pseudo-ND Cellpose flows are an approximation to the true 3D flows of Omnipose, because the flows in each slice point to a local center of the cell, a.k.a. the cell skeleton to which the Omnipose field points. Thus, is not recommended to use Omnipose 2D slice predictions with do_3D. Instead, this notebook assumes you have trained a 3D model such as the plant_omni model.
1# Import dependencies
2import numpy as np
3from cellpose_omni import models, core
4
5# This checks to see if you have set up your GPU properly.
6# CPU performance is a lot slower, but not a problem if you
7# are only processing a few images.
8use_GPU = core.use_gpu()
9
10# for plotting
11import matplotlib as mpl
12import matplotlib.pyplot as plt
13mpl.rcParams['figure.dpi'] = 300
14plt.style.use('dark_background')
15%matplotlib inline
Read in data¶
Here I am choosing one of the scaled-down volumes of the plant Arabidopsis thaliana dataset we used in the Omnipose paper.
1from pathlib import Path
2import os
3from cellpose_omni import io
4import omnipose
5omnidir = Path(omnipose.__file__).parent.parent
6basedir = os.path.join(omnidir,'docs','test_files_3D')
7files = io.get_image_files(basedir)
8files # this displays the variable if it the last thing in the code block
['/home/kcutler/DataDrive/omnipose/docs/test_files_3D/Movie1_t00004_crop_gt.tif']
1from cellpose_omni import io, transforms
2from omnipose.utils import normalize99
3
4imgs = [io.imread(f) for f in files]
5
6# print some info about the images.
7for i in imgs:
8 print('Original image shape:',i.shape)
9 print('data type:',i.dtype)
10 print('data range:', i.min(),i.max())
11nimg = len(imgs)
12print('number of images:',nimg)
Original image shape: (162, 207, 443)
data type: uint8
data range: 0 247
number of images: 1
Initialize model¶
plant_omni is the model trained on these plant cell images. (The image we loaded is from the test set, of course.)
1from cellpose_omni import models
2model_name = 'plant_omni'
3
4dim = 3
5nclasses = 3 # flow + dist + boundary
6nchan = 1
7omni = 1
8rescale = False
9diam_mean = 0
10use_GPU = 0 # Most people do not have enough VRAM to run on GPU... 24GB not enough for this image, need nearly 48GB
11model = models.CellposeModel(gpu=use_GPU, model_type=model_name, net_avg=False,
12 diam_mean=diam_mean, nclasses=nclasses, dim=dim, nchan=nchan)
2025-02-01 03:02:36,537 [INFO] cellpose_omni/models.py __init__....() line 436 >>plant_omni<< model set to be used
2025-02-01 03:02:36,538 [INFO] cellpose_omni/core.py assi...evice() line 72 Using CPU.
2025-02-01 03:02:36,640 [INFO] __init__....() line 173 u-net config: ([1, 32, 64, 128, 256], 5, 3)
Run segmentation¶
1import torch
2torch.cuda.empty_cache()
3mask_threshold = -5 #usually this is -1
4flow_threshold = 0.
5diam_threshold = 12
6net_avg = False
7cluster = False
8verbose = 1
9tile = 0
10chans = None
11compute_masks = 1
12resample=False
13rescale=None
14omni=True
15flow_factor = 10 # multiple to increase flow magnitude, useful in 3D
16transparency = True
17
18nimg = len(imgs)
19masks_om, flows_om = [[]]*nimg,[[]]*nimg
20
21# splitting the images into batches helps manage VRAM use so that memory can get properly released
22# here we have just one image, but most people will have several to process
23for k in range(nimg):
24 masks_om[k], flows_om[k], _ = model.eval(imgs[k],
25 channels=chans,
26 rescale=rescale,
27 mask_threshold=mask_threshold,
28 net_avg=net_avg,
29 transparency=transparency,
30 flow_threshold=flow_threshold,
31 omni=omni,
32 resample=resample,
33 verbose=verbose,
34 diam_threshold=diam_threshold,
35 cluster=cluster,
36 niter=10,
37 tile=tile,
38 compute_masks=compute_masks,
39 flow_factor=flow_factor)
2025-02-01 03:02:36,988 [INFO] cellpose_omni/models.py eval........() line 696 is_grey False, slice_ndim 3, dim 3, nchan 1, is_list False
2025-02-01 03:02:36,988 [INFO] line 707 is_image True, is_stack False, is_list False
2025-02-01 03:02:36,989 [INFO] line 729 Evaluating with flow_threshold 0.00, mask_threshold -5.00
2025-02-01 03:02:36,989 [INFO] line 731 using omni model, cluster False
2025-02-01 03:02:36,989 [INFO] line 1130 not using dataparallel
2025-02-01 03:02:36,990 [INFO] line 1133 network initialized.
2025-02-01 03:02:37,137 [INFO] line 1140 shape before transforms.convert_image(): (162, 207, 443)
2025-02-01 03:02:37,137 [INFO] line 1141 model dim: 3
2025-02-01 03:02:37,138 [INFO] cellpose_omni/transforms.py conv...image() line 506 multi-stack tiff read in as having 162 planes 1 channels
2025-02-01 03:02:37,138 [INFO] cellpose_omni/models.py eval........() line 1151 shape after transforms.convert_image(): (162, 207, 443, 1)
2025-02-01 03:02:37,138 [INFO] line 1157 shape now (1, 162, 207, 443, 1)
2025-02-01 03:03:26,491 [INFO] omnipose/core.py comp...masks() line 1334 mask_threshold is -5
2025-02-01 03:03:26,491 [INFO] line 1346 Using hysteresis threshold.
dP_ times 10 for >2d, still experimenting
2025-02-01 03:03:27,288 [INFO] follow_flows() line 2049 niter: 10, interp: True, suppress: True, calc_trace: False
2025-02-01 03:03:27,446 [INFO] steps_batch.() line 1941 interp is False, interpolation mode is nearest
2025-02-01 03:03:29,501 [INFO] follow_flows() line 2093 done follow_flows
2025-02-01 03:03:29,534 [INFO] get_masks...() line 1781 Mean diameter is 25.689877
2025-02-01 03:03:29,663 [INFO] line 1795 cluster: False, SKLEARN_ENABLED: True
2025-02-01 03:03:29,949 [INFO] line 1881 nclasses: 5, mask.ndim: 3
2025-02-01 03:03:29,958 [INFO] line 1885 Using boundary output to split edge defects.
2025-02-01 03:03:30,156 [INFO] line 1899 Done finding masks.
2025-02-01 03:03:31,092 [INFO] comp...masks() line 1644 compute_masks() execution time: 4.6 sec
2025-02-01 03:03:31,092 [INFO] line 1646 execution time per pixel: 3.09738e-07 sec/px
2025-02-01 03:03:31,098 [INFO] line 1647 execution time per cell pixel: 8.05205e-07 sec/px
Plot results¶
3D segmentation is a lot harder to show than 2D. If anyone figures out a good way to use one of the many tools out there (ipyvolume, K3D-Jupyter, itkwidgets, ipygany) for label visualization (not image volumes), please let me know. Few of these are in active development, and my own 3D work requires robust label editing tools anyway, which I do not think any available tools offer. Hence I shall just load in Napari and show you an auto-captured screenshot.
1%%capture
2import ncolor
3mask = masks_om[0]
4mask_nc = ncolor.label(mask,max_depth=20)
5
6import napari
7viewer = napari.view_labels(mask_nc);
8viewer.dims.ndisplay = 3
9viewer.camera.center = [s//2 for s in mask.shape]
10viewer.camera.zoom=1
11viewer.camera.angles=(10.90517458968619, -20.777067798396835, 58.04311170773853)
12viewer.camera.perspective=0.0
13viewer.camera.interactive=True
14
15img = viewer.screenshot(size=(1000,1000),scale=1,canvas_only=True,flash=False)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
1plt.figure(figsize=(3,3),frameon=False)
2plt.imshow(img)
3plt.axis('off')
4plt.show()
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
Plot orthogonal slices¶
1from cellpose_omni import plot
2from omnipose.plot import apply_ncolor
3
4mu = flows_om[0][1]
5T = flows_om[0][2]
6bd = flows_om[0][4]
7# mu.shape,T.shape,bd.shape
8
9d = mu.shape[0]
10
11from omnipose.utils import rescale
12c = np.array([1]*2+[0]*(d-2))
13# c = np.arange(d)
14def cyclic_perm(a):
15 n = len(a)
16 b = [[a[i - j] for i in range(n)] for j in range(n)]
17 return b
18slices = []
19idx = np.arange(d)
20cmap = mpl.colormaps['magma']
21cmap2 = mpl.colormaps['viridis']
22
23for inds in cyclic_perm(c):
24 slc = tuple([slice(-1) if i else mu.shape[k+1]//2 for i,k in zip(inds,idx)])
25 flow = plot.dx_to_circ(mu[np.where(inds)+slc],transparency=1)/255
26 dist = cmap(rescale(T)[slc])
27 bnds = cmap2(rescale(bd)[slc])
28 msks = apply_ncolor(masks_om[0][slc])
29
30 fig = plt.figure(figsize=[5]*2,frameon=False)
31 plt.imshow(np.hstack((flow,dist,bnds,msks)),interpolation='none')
32 plt.axis('off')
33 plt.show()
34
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
'-avx512er' is not a recognized feature for this target (ignoring feature)
'-avx512pf' is not a recognized feature for this target (ignoring feature)
Notes on the above¶
Slices do not always look crisp because we are cutting though boundaries. At these locations, the flow and distance fields darken and the boundary field brightens. This can result in flat and muddled regions that are hard to interpret. Again, interactive 3D visualization tools are needed to properly evaluate the results of the segmentation. In this case, we have cut through the middle of enough cells to confirm that the output looks reasonable.
This small dataset with problematic annotations was sufficient for demonstrating that Omnipose can be used on 3D data, but I again emphasize that any algorithm will only work well after training on well-annotated, representative examples. In this case, small cell clusters were neither well-annotated nor well-represented in the training set, and you can see the negative impact of that in this example.
These 3D models are incredibly VRAM-hungry, so all results in the paper were actually run on an AWS instance. Here I ran them on CPU, which is much slower but necessary to do even with a 24GB Titan RTX.
Running Omnipose with do_3D¶
do_3D is not something you want to use with any Omnipose model, but you might want to use it with a 2D Cellpose model for 3D cells with extended shapes. This is because do_3D computes 2D flow fields from every yx, yz, and xz slice of the image and composites these components into a 3D field. It turns out that the center-seeking flow slices of Cellpose end up pointing roughly toward the local 3D skeleton, i.e. the do_3D Cellpose composite field approximates the true 3D flows of Omnipose. The 2D Omnipose field, on the other hand, cannot be composited into a useful 3D field.
Althought the do_3D Cellpose field directs pixels toward the skeleton, the stock Cellpose mask reconstruction algorithm tends to oversegment pixels into clusters along the skeleton. To avoid this, you can use a Cellpose model but with Omnipose mask reconstruciton by usin omni=True. Here is how to do this.
1from cellpose_omni import models, core
2
3# define cellpose model
4model_name = 'plant_cp'
5
6# this model was trained on 2D slices
7dim = 2
8nclasses = 2 # cellpose models have no boundary field, just flow and distance
9
10# Cellpose defaults to 2 channels;
11# this is the setup for grayscale in that case
12nchan = 2
13chans = [0,0]
14
15# no rescaling for this model
16diam_mean = 0
17
18
19use_GPU = core.use_gpu()
20model = models.CellposeModel(gpu=use_GPU, model_type=model_name, net_avg=False,
21 diam_mean=diam_mean, nclasses=nclasses, dim=dim, nchan=nchan)
22
23
24# segmentation parameters
25omni = 1
26rescale = False
27mask_threshold = 0
28net_avg = 0
29verbose = 1
30tile = 0
31compute_masks = 1
32rescale = None
33flow_threshold=0.
34do_3D=True
35flow_factor=10
36
37masks_cp, flows_cp, _ = model.eval(imgs,
38 channels=chans,
39 rescale=rescale,
40 mask_threshold=mask_threshold,
41 net_avg=net_avg,
42 transparency=True,
43 flow_threshold=flow_threshold,
44 verbose=verbose,
45 tile=tile,
46 compute_masks=compute_masks,
47 do_3D=True,
48 omni=omni,
49 flow_factor=flow_factor)
2024-06-28 18:56:21,718 [INFO] core _use...torch() line 74 ** TORCH GPU version installed and working. **
2024-06-28 18:56:21,718 [INFO] models __init__....() line 431 >>plant_cp<< model set to be used
2024-06-28 18:56:21,719 [INFO] core _use...torch() line 74 ** TORCH GPU version installed and working. **
2024-06-28 18:56:21,719 [INFO] assi...evice() line 85 >>>> using GPU
2024-06-28 18:56:21,749 [INFO] __init__....() line 183 u-net config: ([2, 32, 64, 128, 256], 3, 2)
2024-06-28 18:56:21,773 [INFO] models eval........() line 691 is_grey True, slice_ndim 3, dim 2, nchan 2, is_list True
2024-06-28 18:56:21,774 [INFO] line 702 is_image False, is_stack False, is_list True
2024-06-28 18:56:21,774 [INFO] line 1050 Evaluating one image at a time
2024-06-28 18:56:21,774 [INFO] line 691 is_grey True, slice_ndim 3, dim 2, nchan 2, is_list False
2024-06-28 18:56:21,775 [INFO] line 702 is_image True, is_stack False, is_list False
2024-06-28 18:56:21,775 [INFO] line 724 Evaluating with flow_threshold 0.00, mask_threshold 0.00
2024-06-28 18:56:21,775 [INFO] line 726 using omni model, cluster False
2024-06-28 18:56:21,775 [INFO] line 1110 using dataparallel
2024-06-28 18:56:21,775 [INFO] line 1122 network initialized.
2024-06-28 18:56:21,798 [INFO] line 1129 shape before transforms.convert_image(): (162, 207, 443)
2024-06-28 18:56:21,798 [INFO] line 1130 model dim: 2
multi-stack tiff read in as having 162 planes 1 channels
2024-06-28 18:56:21,825 [INFO] models eval........() line 1140 shape after transforms.convert_image(): (162, 207, 443, 2)
2024-06-28 18:56:22,234 [INFO] core _run_3D.....() line 712 running YX: 162 planes of size (207, 443)
0%| | 0/15 [00:00<?, ?it/s]
7%|6 | 1/15 [00:00<00:01, 9.84it/s]
20%|## | 3/15 [00:00<00:01, 11.48it/s]
33%|###3 | 5/15 [00:00<00:00, 11.93it/s]
47%|####6 | 7/15 [00:00<00:00, 12.16it/s]
60%|###### | 9/15 [00:00<00:00, 12.24it/s]
73%|#######3 | 11/15 [00:00<00:00, 12.29it/s]
87%|########6 | 13/15 [00:01<00:00, 12.27it/s]
100%|##########| 15/15 [00:01<00:00, 12.39it/s]
100%|##########| 15/15 [00:01<00:00, 12.17it/s]
2024-06-28 18:56:23,780 [INFO] core _run_3D.....() line 712 running ZY: 207 planes of size (162, 443)
0%| | 0/19 [00:00<?, ?it/s]
11%|# | 2/19 [00:00<00:01, 14.53it/s]
21%|##1 | 4/19 [00:00<00:01, 14.62it/s]
32%|###1 | 6/19 [00:00<00:00, 14.53it/s]
42%|####2 | 8/19 [00:00<00:00, 14.49it/s]
53%|#####2 | 10/19 [00:00<00:00, 14.49it/s]
63%|######3 | 12/19 [00:00<00:00, 14.50it/s]
74%|#######3 | 14/19 [00:00<00:00, 14.49it/s]
84%|########4 | 16/19 [00:01<00:00, 14.49it/s]
95%|#########4| 18/19 [00:01<00:00, 14.52it/s]
100%|##########| 19/19 [00:01<00:00, 14.53it/s]
2024-06-28 18:56:25,502 [INFO] core _run_3D.....() line 712 running ZX: 443 planes of size (162, 207)
0%| | 0/14 [00:00<?, ?it/s]
14%|#4 | 2/14 [00:00<00:00, 14.49it/s]
29%|##8 | 4/14 [00:00<00:00, 14.34it/s]
43%|####2 | 6/14 [00:00<00:00, 14.34it/s]
57%|#####7 | 8/14 [00:00<00:00, 14.36it/s]
71%|#######1 | 10/14 [00:00<00:00, 14.33it/s]
86%|########5 | 12/14 [00:00<00:00, 14.36it/s]
100%|##########| 14/14 [00:00<00:00, 14.42it/s]
100%|##########| 14/14 [00:00<00:00, 14.37it/s]
2024-06-28 18:56:27,606 [INFO] models _run_cp.....() line 1317 network run in 5.78s
2024-06-28 18:56:27,607 [INFO] core comp...masks() line 1322 mask_threshold is 0
2024-06-28 18:56:27,607 [INFO] line 1334 Using hysteresis threshold.
dP_ times 10 for >2d, still experimenting
2024-06-28 18:56:28,163 [INFO] follow_flows() line 2026 niter: None, interp: True, suppress: True, calc_trace: False
2024-06-28 18:56:28,203 [INFO] steps_batch.() line 1918 interp is False, interpolation mode is nearest
2024-06-28 18:56:28,539 [INFO] follow_flows() line 2070 done follow_flows
2024-06-28 18:56:28,548 [INFO] get_masks...() line 1768 Mean diameter is 60.140644
2024-06-28 18:56:28,643 [INFO] line 1782 cluster: False, SKLEARN_ENABLED: True
2024-06-28 18:56:28,822 [INFO] line 1858 nclasses: 3, mask.ndim: 3
2024-06-28 18:56:28,943 [INFO] line 1876 Done finding masks.
2024-06-28 18:56:29,409 [INFO] comp...masks() line 1631 compute_masks() execution time: 1.8 sec
2024-06-28 18:56:29,409 [INFO] line 1633 execution time per pixel: 1.21324e-07 sec/px
2024-06-28 18:56:29,413 [INFO] line 1634 execution time per cell pixel: 3.02984e-07 sec/px
2024-06-28 18:56:29,417 [INFO] models _run_cp.....() line 1443 masks created in 1.81s
Compare masks to ground truth¶
1from fastremap import unique
2mgt = io.imread(files[0][:-4]+'_masks.tif')
3print('Cellpose do_3D + omni=True: {} masks. \nOmnipose 3D: {} masks. \nGround truth: {} masks'.format(len(unique(masks_cp[0])),
4 len(unique(masks_om[0])),
5 len(unique(mgt))))
Cellpose do_3D + omni=True: 83 masks.
Omnipose 3D: 113 masks.
Ground truth: 67 masks
For what it's worth, pure Cellpose gives ~550 masks in this volume, pure Omnipose gives ~200, and Cellpose model + Omnipose mask reconstruction gives ~50. I'm sorry to say that the ground truth for this dataset is quite bad, containing some undersegmented cells, but more importantly, an entire "ignore" region where there are many, many cells that are unlabeled. So the count of 67 cells in the ground truth refers only to the long cells on the outside of the root. Thus, 55 cells is a severe under-segmentation of the volume. Let's see why.
Plot results¶
1from cellpose_omni import plot
2from omnipose.plot import apply_ncolor
3
4mu = flows_cp[0][1]
5T = flows_cp[0][2]
6bd = flows_cp[0][4]
7# mu.shape,T.shape,bd.shape
8
9d = mu.shape[0]
10
11from omnipose.utils import rescale
12c = np.array([1]*2+[0]*(d-2))
13# c = np.arange(d)
14def cyclic_perm(a):
15 n = len(a)
16 b = [[a[i - j] for i in range(n)] for j in range(n)]
17 return b
18slices = []
19idx = np.arange(d)
20cmap = mpl.colormaps['magma']
21cmap2 = mpl.colormaps['viridis']
22
23for inds in cyclic_perm(c):
24 slc = tuple([slice(-1) if i else mu.shape[k+1]//2 for i,k in zip(inds,idx)])
25 flow = plot.dx_to_circ(mu[np.where(inds)+slc],transparency=1)/255
26 dist = cmap(rescale(T)[slc])
27 msks = apply_ncolor(masks_cp[0][slc])
28
29 fig = plt.figure(figsize=[5]*2,frameon=False)
30 plt.imshow(np.hstack((flow,dist,msks)),interpolation='none')
31
32 plt.axis('off')
33 plt.show()
1%%capture
2import ncolor
3mask = masks_cp[0]
4mask_nc = ncolor.label(mask,max_depth=20)
5
6import napari
7viewer = napari.view_labels(mask_nc);
8viewer.dims.ndisplay = 3
9viewer.camera.center = [s//2 for s in mask.shape]
10viewer.camera.zoom=1
11viewer.camera.angles=(10.90517458968619, -20.777067798396835, 58.04311170773853)
12viewer.camera.perspective=0.0
13viewer.camera.interactive=True
14
15img = viewer.screenshot(size=(1000,1000),scale=1,canvas_only=True,flash=False)
1plt.figure(figsize=(3,3),frameon=False)
2plt.imshow(img)
3plt.axis('off')
4plt.show()
It appears that omni=True does allow 2D Cellpose models to work in 3D, but the prediction quality - worsened by artifacts introduced by the compoisiting into 3D - is a limiting factor.