import numpy as np
import warnings
import cv2
import logging
transforms_logger = logging.getLogger(__name__)
from . import dynamics, utils
import itertools # ND tiling
# import omnipose, edt, fastremap
# OMNI_INSTALLED = True
try:
import omnipose, edt, fastremap
OMNI_INSTALLED = True
except:
OMNI_INSTALLED = False
print('OMNIPOSE NOT INSTALLED')
def _taper_mask(ly=224, lx=224, sig=7.5):
bsize = max(224, max(ly, lx))
xm = np.arange(bsize)
xm = np.abs(xm - xm.mean())
mask = 1/(1 + np.exp((xm - (bsize/2-20)) / sig))
mask = mask * mask[:, np.newaxis]
mask = mask[bsize//2-ly//2 : bsize//2+ly//2+ly%2,
bsize//2-lx//2 : bsize//2+lx//2+lx%2]
return mask
[docs]def unaugment_tiles(y, unet=False):
""" reverse test-time augmentations for averaging
Parameters
----------
y: float32
array that's ntiles_y x ntiles_x x chan x Ly x Lx where chan = (dY, dX, cell prob)
unet: bool (optional, False)
whether or not unet output or cellpose output
Returns
-------
y: float32
"""
for j in range(y.shape[0]):
for i in range(y.shape[1]):
if j%2==0 and i%2==1:
y[j,i] = y[j,i, :,::-1, :]
if not unet:
y[j,i,0] *= -1
elif j%2==1 and i%2==0:
y[j,i] = y[j,i, :,:, ::-1]
if not unet:
y[j,i,1] *= -1
elif j%2==1 and i%2==1:
y[j,i] = y[j,i, :,::-1, ::-1]
if not unet:
y[j,i,0] *= -1
y[j,i,1] *= -1
return y
[docs]def average_tiles(y, ysub, xsub, Ly, Lx):
""" average results of network over tiles
Parameters
-------------
y: float, [ntiles x nclasses x bsize x bsize]
output of cellpose network for each tile
ysub : list
list of arrays with start and end of tiles in Y of length ntiles
xsub : list
list of arrays with start and end of tiles in X of length ntiles
Ly : int
size of pre-tiled image in Y (may be larger than original image if
image size is less than bsize)
Lx : int
size of pre-tiled image in X (may be larger than original image if
image size is less than bsize)
Returns
-------------
yf: float32, [nclasses x Ly x Lx]
network output averaged over tiles
"""
Navg = np.zeros((Ly,Lx))
yf = np.zeros((y.shape[1], Ly, Lx), np.float32)
# taper edges of tiles
mask = _taper_mask(ly=y.shape[-2], lx=y.shape[-1])
for j in range(len(ysub)):
yf[:, ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += y[j] * mask
Navg[ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += mask
yf /= Navg
return yf
[docs]def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1):
""" make tiles of image to run at test-time
if augmented, tiles are flipped and tile_overlap=2.
* original
* flipped vertically
* flipped horizontally
* flipped vertically and horizontally
Parameters
----------
imgi : float32
array that's nchan x Ly x Lx
bsize : float (optional, default 224)
size of tiles
augment : bool (optional, default False)
flip tiles and set tile_overlap=2.
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles
Returns
-------
IMG : float32
array that's ntiles x nchan x bsize x bsize
ysub : list
list of arrays with start and end of tiles in Y of length ntiles
xsub : list
list of arrays with start and end of tiles in X of length ntiles
"""
nchan, Ly, Lx = imgi.shape
if augment:
bsize = np.int32(bsize)
# pad if image smaller than bsize
if Ly<bsize:
imgi = np.concatenate((imgi, np.zeros((nchan, bsize-Ly, Lx))), axis=1)
Ly = bsize
if Lx<bsize:
imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize-Lx))), axis=2)
Ly, Lx = imgi.shape[-2:]
# tiles overlap by half of tile size
ny = max(2, int(np.ceil(2. * Ly / bsize)))
nx = max(2, int(np.ceil(2. * Lx / bsize)))
ystart = np.linspace(0, Ly-bsize, ny).astype(int)
xstart = np.linspace(0, Lx-bsize, nx).astype(int)
ysub = []
xsub = []
# flip tiles so that overlapping segments are processed in rotation
IMG = np.zeros((len(ystart), len(xstart), nchan, bsize, bsize), np.float32)
for j in range(len(ystart)):
for i in range(len(xstart)):
ysub.append([ystart[j], ystart[j]+bsize])
xsub.append([xstart[i], xstart[i]+bsize])
IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]
# flip tiles to allow for augmentation of overlapping segments
if j%2==0 and i%2==1:
IMG[j,i] = IMG[j,i, :,::-1, :]
elif j%2==1 and i%2==0:
IMG[j,i] = IMG[j,i, :,:, ::-1]
elif j%2==1 and i%2==1:
IMG[j,i] = IMG[j,i,:, ::-1, ::-1]
else:
tile_overlap = min(0.5, max(0.05, tile_overlap))
bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx)
bsizeY = np.int32(bsizeY)
bsizeX = np.int32(bsizeX)
# tiles overlap by 10% tile size
ny = 1 if Ly<=bsize else int(np.ceil((1.+2*tile_overlap) * Ly / bsize))
nx = 1 if Lx<=bsize else int(np.ceil((1.+2*tile_overlap) * Lx / bsize))
ystart = np.linspace(0, Ly-bsizeY, ny).astype(int)
xstart = np.linspace(0, Lx-bsizeX, nx).astype(int)
ysub = []
xsub = []
IMG = np.zeros((len(ystart), len(xstart), nchan, bsizeY, bsizeX), np.float32)
for j in range(len(ystart)):
for i in range(len(xstart)):
ysub.append([ystart[j], ystart[j]+bsizeY])
xsub.append([xstart[i], xstart[i]+bsizeX])
IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]
return IMG, ysub, xsub, Ly, Lx
from omnipose.utils import get_flip, _taper_mask_ND, unaugment_tiles_ND, average_tiles_ND, make_tiles_ND
# def get_flip(idx):
# """
# ND slices for flipping arrays along particular axes
# based on the tile indices. Used in augment_tiles_ND()
# and unaugment_tiles_ND().
# """
# return tuple([slice(None,None,None) if i%2 else
# slice(None,None,-1) for i in idx])
# def _taper_mask_ND(shape=(224,224), sig=7.5):
# dim = len(shape)
# bsize = max(shape)
# xm = np.arange(bsize)
# xm = np.abs(xm - xm.mean())
# # 1D distribution
# mask = 1/(1 + np.exp((xm - (bsize/2-20)) / sig))
# # extend to ND
# for j in range(dim-1):
# mask = mask * mask[..., np.newaxis]
# slc = tuple([slice(bsize//2-s//2,bsize//2+s//2+s%2) for s in shape])
# mask = mask[slc]
# return mask
# def unaugment_tiles_ND(y, inds, unet=False):
# """ reverse test-time augmentations for averaging
# Parameters
# ----------
# y: float32
# array that's ntiles x chan x Ly x Lx where
# chan = (dY, dX, dist, boundary)
# unet: bool (optional, False)
# whether or not unet output or cellpose output
# Returns
# -------
# y: float32
# """
# dim = len(inds[0])
# for i,idx in enumerate(inds):
# # repeat the flip to undo it
# flip = get_flip(idx)
# # flow field componenets need to be flipped
# factor = np.array([1 if i%2 else -1 for i in idx])
# # apply the flip
# y[i] = y[i][(Ellipsis,)+flip]
# # apply the flow field flip
# if not unet:
# y[i][:dim] = [s*f for s,f in zip(y[i][:dim],factor)]
# return y
# def average_tiles_ND(y,subs,shape):
# """ average results of network over tiles
# Parameters
# -------------
# y: float, [ntiles x nclasses x bsize x bsize]
# output of cellpose network for each tile
# subs : list
# list of slices for each subtile
# shape : int, list or tuple
# shape of pre-tiled image (may be larger than original image if
# image size is less than bsize)
# Returns
# -------------
# yf: float32, [nclasses x Ly x Lx]
# network output averaged over tiles
# """
# Navg = np.zeros(shape)
# yf = np.zeros((y.shape[1],)+shape, np.float32)
# # taper edges of tiles
# mask = _taper_mask_ND(y.shape[-len(shape):])
# for j,slc in enumerate(subs):
# yf[(Ellipsis,)+slc] += y[j] * mask
# Navg[slc] += mask
# yf /= Navg
# return yf
# def make_tiles_ND(imgi, bsize=224, augment=False, tile_overlap=0.1):
# """ make tiles of image to run at test-time
# if augmented, tiles are flipped and tile_overlap=2.
# * original
# * flipped vertically
# * flipped horizontally
# * flipped vertically and horizontally
# Parameters
# ----------
# imgi : float32
# array that's nchan x Ly x Lx
# bsize : float (optional, default 224)
# size of tiles
# augment : bool (optional, default False)
# flip tiles and set tile_overlap=2.
# tile_overlap: float (optional, default 0.1)
# fraction of overlap of tiles
# Returns
# -------
# IMG : float32
# array that's ntiles x nchan x bsize x bsize
# ysub : list
# list of arrays with start and end of tiles in Y of length ntiles
# xsub : list
# list of arrays with start and end of tiles in X of length ntiles
# """
# nchan = imgi.shape[0]
# shape = imgi.shape[1:]
# dim = len(shape)
# inds = []
# if augment:
# bsize = np.int32(bsize)
# # pad if image smaller than bsize
# pad_seq = [(0,0)]+[(0,max(0,bsize-s))for s in shape]
# imgi = np.pad(imgi,pad_seq)
# shape = imgi.shape[-dim:]
# # tiles overlap by half of tile size
# ntyx = [max(2, int(np.ceil(2. * s / bsize))) for s in shape]
# start = [np.linspace(0, s-bsize, n).astype(int) for s,n in zip(shape,ntyx)]
# intervals = [[slice(si,si+bsize) for si in s] for s in start]
# subs = list(itertools.product(*intervals))
# indexes = [np.arange(len(s)) for s in start]
# inds = list(itertools.product(*indexes))
# IMG = []
# # here I flip if the index is odd
# for slc,idx in zip(subs,inds):
# flip = get_flip(idx) # avoid repetition with unaugment
# IMG.append(imgi[(Ellipsis,)+slc][(Ellipsis,)+flip])
# IMG = np.stack(IMG)
# else:
# tile_overlap = min(0.5, max(0.05, tile_overlap))
# # bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx)
# # B = [np.int32(min(b,s)) for s,b in zip(im.shape,bsize)] if bzise variable
# bbox = tuple([np.int32(min(bsize,s)) for s in shape])
# # tiles overlap by 10% tile size by default
# ntyx = [1 if s<=bsize else int(np.ceil((1.+2*tile_overlap) * s / bsize))
# for s in shape]
# start = [np.linspace(0, s-b, n).astype(int) for s,b,n in zip(shape,bbox,ntyx)]
# intervals = [[slice(si,si+bsize) for si in s] for s in start]
# subs = list(itertools.product(*intervals))
# # IMG = np.zeros((len(ystart), len(xstart), nchan, bsizeY, bsizeX), np.float32)
# # IMG = np.zeros(tuple([len(s) for s in start])+(nchan,)+bbox, np.float32)
# # IMG = np.stack([imgi[(Ellipsis,)+slc] for slc in subs])
# print('normalizing each tile')
# IMG = np.stack([normalize99(imgi[(Ellipsis,)+slc],omni=True) for slc in subs])
# return IMG, subs, shape, inds
# needs to have a wider range to avoid weird effects with few cells in frame
# also turns out previous formulation can give negative numbers, messes up log operations etc.
[docs]def normalize99(Y,lower=0.01,upper=99.99,omni=False):
""" normalize image so 0.0 is 0.01st percentile and 1.0 is 99.99th percentile """
if omni and OMNI_INSTALLED:
X = omnipose.utils.normalize99(Y)
else:
X = Y.copy()
x01 = np.percentile(X, 1)
x99 = np.percentile(X, 99)
X = (X - x01) / (x99 - x01)
return X
[docs]def move_axis(img, m_axis=-1, first=True):
""" move axis m_axis to first or last position """
if m_axis==-1:
m_axis = img.ndim-1
m_axis = min(img.ndim-1, m_axis)
axes = np.arange(0, img.ndim)
if first:
axes[1:m_axis+1] = axes[:m_axis]
axes[0] = m_axis
else:
axes[m_axis:-1] = axes[m_axis+1:]
axes[-1] = m_axis
img = img.transpose(tuple(axes))
return img
# more flexible replacement
[docs]def move_axis_new(a, axis, pos):
"""Move ndarray axis to new location, preserving order of other axes."""
# Get the current shape of the array
shape = a.shape
# Create the permutation order for numpy.transpose()
perm = list(range(len(shape)))
perm.pop(axis)
perm.insert(pos, axis)
# Transpose the array based on the permutation order
return np.transpose(a, perm)
# This was edited to fix a bug where single-channel images of shape (y,x) would be
# transposed to (x,y) if x<y, making the labels no longer correspond to the data.
[docs]def move_min_dim(img, force=False):
""" move minimum dimension last as channels if < 10, or force==True """
if len(img.shape) > 2: #only makes sense to do this if channel axis is already present, not best for 3D though!
min_dim = min(img.shape)
if min_dim < 10 or force:
if img.shape[-1]==min_dim:
channel_axis = -1
else:
channel_axis = (img.shape).index(min_dim)
img = move_axis(img, m_axis=channel_axis, first=False)
return img
[docs]def update_axis(m_axis, to_squeeze, ndim):
if m_axis==-1:
m_axis = ndim-1
if (to_squeeze==m_axis).sum() == 1:
m_axis = None
else:
inds = np.ones(ndim, bool)
inds[to_squeeze] = False
m_axis = np.nonzero(np.arange(0, ndim)[inds]==m_axis)[0]
if len(m_axis) > 0:
m_axis = m_axis[0]
else:
m_axis = None
return m_axis
[docs]def convert_image(x, channels, channel_axis=None, z_axis=None,
do_3D=False, normalize=True, invert=False,
nchan=2, dim=2, omni=False):
""" return image with z first, channels last and normalized intensities """
# squeeze image, and if channel_axis or z_axis given, transpose image
if x.ndim > 3:
to_squeeze = np.array([int(isq) for isq,s in enumerate(x.shape) if s==1])
# remove channel axis if number of channels is 1
if len(to_squeeze) > 0:
channel_axis = update_axis(channel_axis, to_squeeze, x.ndim) if channel_axis is not None else channel_axis
z_axis = update_axis(z_axis, to_squeeze, x.ndim) if z_axis is not None else z_axis
x = x.squeeze()
# print('shape00',x.shape)
# put z axis first
if z_axis is not None and x.ndim > 2:
x = move_axis(x, m_axis=z_axis, first=True)
if channel_axis is not None:
channel_axis += 1
if x.ndim==3:
x = x[...,np.newaxis]
# print('shape01',x.shape,x.ndim,channel_axis,dim)
# put channel axis last
if channel_axis is not None and x.ndim > 2:
x = move_axis(x, m_axis=channel_axis, first=False)
elif x.ndim == dim:
# x = x[...,np.newaxis]
x = x[np.newaxis]
# print('shape02',x.shape)
if do_3D :
if x.ndim < 3:
transforms_logger.critical('ERROR: cannot process 2D images in 3D mode')
raise ValueError('ERROR: cannot process 2D images in 3D mode')
elif x.ndim<4:
x = x[...,np.newaxis]
# print('shape03',x.shape)
# this one must be the cuplrit... no, in fact it is not
if channel_axis is None:
x = move_min_dim(x)
channel_axis = -1 # moves to last
# print('shape04',x.shape)
if x.ndim > 3:
transforms_logger.info('multi-stack tiff read in as having %d planes %d channels'%
(x.shape[0], x.shape[-1]))
if channels is not None:
channels = channels[0] if len(channels)==1 else channels
if len(channels) < 2:
transforms_logger.critical('ERROR: two channels not specified')
raise ValueError('ERROR: two channels not specified')
x = reshape(x, channels=channels, channel_axis=channel_axis)
# print('AAA',x.shape,channels)
else:
# print('BBB',do_3D,x.ndim,x.shape,nchan)
# code above put channels last, so its making sure nchan matches below
# not sure when this condition would be met, but it conflicts with 3D
if x.shape[-1] > nchan and x.ndim>dim:
transforms_logger.warning(('WARNING: more than %d channels given, use '
'"channels" input for specifying channels -'
'just using first %d channels to run processing')%(nchan,nchan))
x = x[...,:nchan]
if not do_3D and x.ndim>3 and dim==2: # error should only be thrown for 2D mode
transforms_logger.critical('ERROR: cannot process 4D images in 2D mode')
raise ValueError('ERROR: cannot process 4D images in 2D mode')
if x.shape[-1] < nchan:
x = np.concatenate((x,
np.tile(np.zeros_like(x), (1,1,nchan-1))),
axis=-1)
if normalize or invert:
x = normalize_img(x, invert=invert, omni=omni)
return x
[docs]def reshape(data, channels=[0,0], chan_first=False, channel_axis=0):
""" reshape data using channels
Parameters
----------
data : numpy array that's (Z x ) Ly x Lx x nchan
if data.ndim==8 and data.shape[0]<8, assumed to be nchan x Ly x Lx
channels : list of int of length 2 (optional, default [0,0])
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
For instance, to train on grayscale images, input [0,0]. To train on images with cells
in green and nuclei in blue, input [2,3].
channel_axis : int, default 0
the axis that corresponds to channels (usually 0 or -1)
Returns
-------
data : numpy array that's (Z x ) Ly x Lx x nchan (if chan_first==False)
"""
data = data.astype(np.float32)
if data.ndim < 3: # plain 2D images get a new channel axis
data = data[...,np.newaxis]
elif data.shape[0]<8 and data.ndim==3: # Assume stack is nchan x Ly x Lx, so reorder to Ly x Lx x nchan
data = np.transpose(data, (1,2,0))
channel_axis = -1
# 8 is completely arbitrary and idk why we need to assume this, we should change to just using the channel axis
if data.shape[-1]==1:
# use grayscale image
# adds a second channel of zeros
data = np.concatenate((data, np.zeros_like(data)), axis=-1)
else:
if channels[0]==0:
# [0,0] does a mean of all channels, pads with 0 for second channel
data = data.mean(axis=channel_axis, keepdims=True) # also had a big bug: 3D volumes get squashed to 2D along x axis!!! Assumptions bad.
data = np.concatenate((data, np.zeros_like(data)), axis=-1) # forces images to always have 2 channels, possibly bad for multidimensional
else:
chanid = [channels[0]-1] # [0,0] would do a mean, [1,0] would actually take the first channel
if channels[1] > 0:
chanid.append(channels[1]-1)
data = data[...,chanid]
for i in range(data.shape[-1]):
if np.ptp(data[...,i]) == 0.0:
if i==0:
warnings.warn("chan to seg' has value range of ZERO")
else:
warnings.warn("'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0")
if data.shape[-1]==1:
data = np.concatenate((data, np.zeros_like(data)), axis=-1)
if chan_first:
if data.ndim==4:
data = np.transpose(data, (3,0,1,2))
else:
data = np.transpose(data, (2,0,1))
return data
[docs]def normalize_img(img, axis=-1, invert=False, omni=False):
""" normalize each channel of the image so that so that 0.0=1st percentile
and 1.0=99th percentile of image intensities
optional inversion
Parameters
------------
img: ND-array (at least 3 dimensions)
axis: channel axis to loop over for normalization
Returns
---------------
img: ND-array, float32
normalized image of same size
"""
if img.ndim<3:
error_message = 'Image needs to have at least 3 dimensions'
transforms_logger.critical(error_message)
raise ValueError(error_message)
img = img.astype(np.float32)
img = np.moveaxis(img, axis, 0)
for k in range(img.shape[0]):
# ptp can still give nan's with weird images
if np.percentile(img[k],99) > np.percentile(img[k],1)+1e-3: #np.ptp(img[k]) > 1e-3:
img[k] = normalize99(img[k],omni=omni)
if invert:
img[k] = -1*img[k] + 1
img = np.moveaxis(img, 0, axis)
return img
[docs]def reshape_train_test(train_data, train_labels, test_data, test_labels, channels, channel_axis=0, normalize=True, dim=2, omni=False):
""" check sizes and reshape train and test data for training """
nimg = len(train_data)
# check that arrays are correct size
if nimg != len(train_labels):
error_message = 'train data and labels not same length'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return
if train_labels[0].ndim < 2 or train_data[0].ndim < 2:
error_message = 'training data or labels are not at least two-dimensional'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return
if train_data[0].ndim > 3:
error_message = 'training data is more than three-dimensional (should be 2D or 3D array)'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return
# check if test_data correct length
if not (test_data is not None and test_labels is not None and
len(test_data) > 0 and len(test_data)==len(test_labels)):
test_data = None
print('reshape_train_test',train_data[0].shape,channels,channel_axis,normalize,omni)
# make data correct shape and normalize it so that 0 and 1 are 1st and 99th percentile of data
# reshape_and_normalize_data pads the train_data with an empty channel axis if it doesn't have one (single channel images/volumes).
train_data, test_data, run_test = reshape_and_normalize_data(train_data,
test_data=test_data,
channels=channels,
channel_axis=channel_axis,
normalize=normalize,
omni=omni,
dim=dim)
print('reshape_train_test_2',train_data[0].shape)
if train_data is None:
error_message = 'training data do not all have the same number of channels'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return
if not run_test:
test_data, test_labels = None, None
if not np.all([dta.shape[-dim:] == lbl.shape[-dim:] for dta, lbl in zip(train_data,train_labels)]):
error_message = 'training data and labels are not the same shape, must be something wrong with preprocessing assumptions'
transforms_logger.critical(error_message)
raise ValueError(error_message)
return
return train_data, train_labels, test_data, test_labels, run_test
[docs]def reshape_and_normalize_data(train_data, test_data=None, channels=None, channel_axis=0, normalize=True, omni=False, dim=2):
""" inputs converted to correct shapes for *training* and rescaled so that 0.0=1st percentile
and 1.0=99th percentile of image intensities in each channel
Parameters
--------------
train_data: list of ND-arrays, float
list of training images of size [Ly x Lx], [nchan x Ly x Lx], or [Ly x Lx x nchan]
test_data: list of ND-arrays, float (optional, default None)
list of testing images of size [Ly x Lx], [nchan x Ly x Lx], or [Ly x Lx x nchan]
channels: list of int of length 2 (optional, default None)
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
For instance, to train on grayscale images, input [0,0]. To train on images with cells
in green and nuclei in blue, input [2,3].
normalize: bool (optional, True)
normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
Returns
-------------
train_data: list of ND-arrays, float
list of training images of size [2 x Ly x Lx]
test_data: list of ND-arrays, float (optional, default None)
list of testing images of size [2 x Ly x Lx]
run_test: bool
whether or not test_data was correct size and is useable during training
"""
for test, data in enumerate([train_data, test_data]):
if data is None:
return train_data, test_data, False
nimg = len(data)
for i in range(nimg):
if channels is None:
if channel_axis is not None:
data[i] = move_axis_new(data[i], axis=channel_axis, pos=0)
else:
m = f'No channel axis specified. Image shape is {data[i].shape}. Supply channel_axis if incorrect.'
transforms_logger.warning(m)
if channels is not None:
data[i] = reshape(data[i], channels=channels, chan_first=True, channel_axis=channel_axis) # the cuplrit with 3D
# if data[i].ndim < 3:
# data[i] = data[i][np.newaxis,:,:]
# we actually want this padding for single-channel volumes too
# data with multiple channels will have channels defined and have an axis already; could also pass in nchan to avoid this assumption
# instead of this, we could just make the other parts of the code not rely on a channel axis and slice smarter
if channels is None and data[i].ndim==dim:
data[i] = data[i][np.newaxis]
if normalize:
data[i] = normalize_img(data[i], axis=0, omni=omni)
return train_data, test_data, True
[docs]def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR, no_channels=False):
""" resize image for computing flows / unresize for computing dynamics
Parameters
-------------
img0: ND-array
image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X]
Ly: int, optional
Lx: int, optional
rsz: float, optional
resize coefficient(s) for image; if Ly is None then rsz is used
interpolation: cv2 interp method (optional, default cv2.INTER_LINEAR)
Returns
--------------
imgs: ND-array
image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]
"""
if Ly is None and rsz is None:
error_message = 'must give size to resize to or factor to use for resizing'
transforms_logger.critical(error_message)
raise ValueError(error_message)
if Ly is None:
# determine Ly and Lx using rsz
if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
rsz = [rsz, rsz]
if no_channels:
Ly = int(img0.shape[-2] * rsz[-2])
Lx = int(img0.shape[-1] * rsz[-1])
else:
Ly = int(img0.shape[-3] * rsz[-2])
Lx = int(img0.shape[-2] * rsz[-1])
# no_channels useful for z-stacks, so the third dimension is not treated as a channel
# but if this is called for grayscale images, they first become [Ly,Lx,2] so ndim=3 but
if (img0.ndim>2 and no_channels) or (img0.ndim==4 and not no_channels):
if no_channels:
imgs = np.zeros((img0.shape[0], Ly, Lx), np.float32)
else:
imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), np.float32)
for i,img in enumerate(img0):
imgs[i] = cv2.resize(img, (Lx, Ly), interpolation=interpolation)
# imgs[i] = scipy.ndimage.zoom(img, resize/np.array(img.shape), order=order)
else:
imgs = cv2.resize(img0, (Lx, Ly), interpolation=interpolation)
return imgs
[docs]def pad_image_ND(img0, div=16, extra=1, dim=2):
""" pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D)
Parameters
-------------
img0: ND-array
image of size [nchan (x Lz) x Ly x Lx]
div: int (optional, default 16)
Returns
--------------
I: ND-array
padded image
ysub: array, int
yrange of pixels in I corresponding to img0
xsub: array, int
xrange of pixels in I corresponding to img0
"""
inds = [k for k in range(-dim,0)]
Lpad = [int(div * np.ceil(img0.shape[i]/div) - img0.shape[i]) for i in inds]
pad1 = [extra*div//2 + Lpad[k]//2 for k in range(dim)]
pad2 = [extra*div//2 + Lpad[k] - Lpad[k]//2 for k in range(dim)]
emptypad = tuple([[0,0]]*(img0.ndim-dim))
pads = emptypad+tuple(np.stack((pad1,pad2),axis=1))
# changed from 'constant' - avoids a lot of edge artifacts!!!
# any option that extends the data naturally will do... reflect seems to be the best
mode = 'reflect'
I = np.pad(img0,pads,mode=mode)
shape = img0.shape[-dim:]
subs = [np.arange(pad1[k],pad1[k]+shape[k]) for k in range(dim)]
return I, subs
[docs]def random_rotate_and_resize(X, Y=None, scale_range=1., gamma_range=[.5,4], tyx=None,
do_flip=True, rescale=None, unet=False,
inds=None, omni=False, dim=2, nchan=1, nclasses=3, device=None):
""" augmentation by random rotation and resizing
X and Y are lists or arrays of length nimg, with dims channels x Ly x Lx (channels optional)
Parameters
----------
X: LIST of ND-arrays, float
list of image arrays of size [nchan x Ly x Lx] or [Ly x Lx]
Y: LIST of ND-arrays, float (optional, default None)
list of image labels of size [nlabels x Ly x Lx] or [Ly x Lx]. The 1st channel
of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation).
If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow].
If unet, second channel is dist_to_bound.
scale_range: float (optional, default 1.0)
Range of resizing of images for augmentation. Images are resized by
(1-scale_range/2) + scale_range * np.random.rand()
gamma_range: float (optional, default 0.5)
Images are gamma-adjusted im**gamma for gamma in (1-gamma_range,1+gamma_range)
xy: tuple, int (optional, default (224,224))
size of transformed images to return
do_flip: bool (optional, default True)
whether or not to flip images horizontally
rescale: array, float (optional, default None)
how much to resize images by before performing augmentations
unet: bool (optional, default False)
Returns
-------
imgi: ND-array, float
transformed images in array [nimg x nchan x xy[0] x xy[1]]
lbl: ND-array, float
transformed labels in array [nimg x nchan x xy[0] x xy[1]]
scale: array, float
amount each image was resized by
"""
scale_range = max(0, min(2, float(scale_range))) # limit overall range to [0,2] i.e. 1+-1
if inds is None: # only relevant when debugging
nimg = len(X)
inds = np.arange(nimg)
return omnipose.core.random_rotate_and_resize(X, Y=Y, scale_range=scale_range, gamma_range=gamma_range,
tyx=tyx, do_flip=do_flip, rescale=rescale, inds=inds,
nchan=nchan)
# if omni and OMNI_INSTALLED:
# return omnipose.core.random_rotate_and_resize(X, Y=Y, scale_range=scale_range, gamma_range=gamma_range,
# tyx=tyx, do_flip=do_flip, rescale=rescale, inds=inds,
# nchan=nchan)
# else:
# # backwards compatibility; completely 'stock', no gamma augmentation or any other extra frills.
# # [Y[i][1:] for i in inds] is necessary because the original transform function does not use masks (entry 0).
# # This used to be done in the original function call.
# if tyx is None:
# tyx = (224,)*dim
# print('yoyo',X[0].shape,Y[0].shape)
# return original_random_rotate_and_resize(X, Y=[y[1:] for y in Y] if Y is not None else None,
# scale_range=scale_range, xy=tyx,
# do_flip=do_flip, rescale=rescale, unet=unet)
# I have the omni flag here just in case, but it actually does not affect the tests
[docs]def normalize_field(mu,omni=False):
if omni and OMNI_INSTALLED:
mu = omnipose.utils.normalize_field(mu)
else:
mu /= (1e-20 + (mu**2).sum(axis=0)**0.5)
return mu
def _X2zoom(img, X2=1):
""" zoom in image
Parameters
----------
img : numpy array that's Ly x Lx
Returns
-------
img : numpy array that's Ly x Lx
"""
ny,nx = img.shape[:2]
img = cv2.resize(img, (int(nx * (2**X2)), int(ny * (2**X2))))
return img
def _image_resizer(img, resize=512, to_uint8=False):
""" resize image
Parameters
----------
img : numpy array that's Ly x Lx
resize : int
max size of image returned
to_uint8 : bool
convert image to uint8
Returns
-------
img : numpy array that's Ly x Lx, Ly,Lx<resize
"""
ny,nx = img.shape[:2]
if to_uint8:
if img.max()<=255 and img.min()>=0 and img.max()>1:
img = img.astype(np.uint8)
else:
img = img.astype(np.float32)
img -= img.min()
img /= img.max()
img *= 255
img = img.astype(np.uint8)
if np.array(img.shape).max() > resize:
if ny>nx:
nx = int(nx/ny * resize)
ny = resize
else:
ny = int(ny/nx * resize)
nx = resize
shape = (nx,ny)
img = cv2.resize(img, shape)
img = img.astype(np.uint8)
return img
[docs]def original_random_rotate_and_resize(X, Y=None, scale_range=1., xy = (224,224),
do_flip=True, rescale=None, unet=False):
""" augmentation by random rotation and resizing
X and Y are lists or arrays of length nimg, with dims channels x Ly x Lx (channels optional)
Parameters
----------
X: LIST of ND-arrays, float
list of image arrays of size [nchan x Ly x Lx] or [Ly x Lx]
Y: LIST of ND-arrays, float (optional, default None)
list of image labels of size [nlabels x Ly x Lx] or [Ly x Lx]. The 1st channel
of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation).
If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow].
If unet, second channel is dist_to_bound.
scale_range: float (optional, default 1.0)
Range of resizing of images for augmentation. Images are resized by
(1-scale_range/2) + scale_range * np.random.rand()
xy: tuple, int (optional, default (224,224))
size of transformed images to return
do_flip: bool (optional, default True)
whether or not to flip images horizontally
rescale: array, float (optional, default None)
how much to resize images by before performing augmentations
unet: bool (optional, default False)
Returns
-------
imgi: ND-array, float
transformed images in array [nimg x nchan x xy[0] x xy[1]]
lbl: ND-array, float
transformed labels in array [nimg x nchan x xy[0] x xy[1]]
scale: array, float
amount by which each image was resized
"""
print('this',X[0].shape)
scale_range = max(0, min(2, float(scale_range)))
nimg = len(X)
if X[0].ndim>2:
nchan = X[0].shape[0]
else:
nchan = 1
imgi = np.zeros((nimg, nchan, xy[0], xy[1]), np.float32)
lbl = []
if Y is not None:
if Y[0].ndim>2:
nt = Y[0].shape[0]
else:
nt = 1
lbl = np.zeros((nimg, nt, xy[0], xy[1]), np.float32)
scale = np.zeros(nimg, np.float32)
for n in range(nimg):
Ly, Lx = X[n].shape[-2:]
# generate random augmentation parameters
flip = np.random.rand()>.5
theta = np.random.rand() * np.pi * 2
scale[n] = (1-scale_range/2) + scale_range * np.random.rand()
if rescale is not None:
scale[n] *= 1. / rescale[n]
dxy = np.maximum(0, np.array([Lx*scale[n]-xy[1],Ly*scale[n]-xy[0]]))
dxy = (np.random.rand(2,) - .5) * dxy
# create affine transform
cc = np.array([Lx/2, Ly/2])
cc1 = cc - np.array([Lx-xy[1], Ly-xy[0]])/2 + dxy
pts1 = np.float32([cc,cc + np.array([1,0]), cc + np.array([0,1])])
pts2 = np.float32([cc1,
cc1 + scale[n]*np.array([np.cos(theta), np.sin(theta)]),
cc1 + scale[n]*np.array([np.cos(np.pi/2+theta), np.sin(np.pi/2+theta)])])
M = cv2.getAffineTransform(pts1,pts2)
img = X[n].copy()
if Y is not None:
labels = Y[n].copy()
if labels.ndim<3:
labels = labels[np.newaxis,:,:]
if flip and do_flip:
img = img[..., ::-1]
if Y is not None:
labels = labels[..., ::-1]
if nt > 1 and not unet:
labels[2] = -labels[2]
for k in range(nchan):
I = cv2.warpAffine(img[k], M, (xy[1],xy[0]), flags=cv2.INTER_LINEAR)
imgi[n,k] = I
if Y is not None:
for k in range(nt):
if k==0:
lbl[n,k] = cv2.warpAffine(labels[k], M, (xy[1],xy[0]), flags=cv2.INTER_NEAREST)
else:
lbl[n,k] = cv2.warpAffine(labels[k], M, (xy[1],xy[0]), flags=cv2.INTER_LINEAR)
if nt > 1 and not unet:
v1 = lbl[n,2].copy()
v2 = lbl[n,1].copy()
lbl[n,1] = (-v1 * np.sin(-theta) + v2*np.cos(-theta))
lbl[n,2] = (v1 * np.cos(-theta) + v2*np.sin(-theta))
return imgi, lbl, scale