import numpy as np
from scipy.ndimage import binary_dilation, binary_erosion, gaussian_filter
from scipy.ndimage import convolve1d, convolve, affine_transform
from skimage.morphology import remove_small_holes
from skimage.registration import phase_cross_correlation
from scipy.ndimage import shift as im_shift
from skimage import color
import dask
from .plot import sinebow
from skimage import measure
import fastremap
import mahotas as mh
import math
from ncolor import format_labels # just in case I forgot to switch it out elsewhere
from pathlib import Path
import os
import re
import mgen
import fastremap
from numba import njit
import functools
import itertools
# import logging, sys
# LOGGER_FORMAT = "%(asctime)-20s\t[%(levelname)-5s]\t[%(filename)-10s %(lineno)-5d%(funcName)-18s]\t%(message)s"
# logging.basicConfig(
# level=logging.INFO,
# format=LOGGER_FORMAT,
# handlers=[
# logging.StreamHandler(sys.stdout)
# ]
# )
# omnipose_logger = logging.getLogger(__name__)
# logging.getLogger('xmlschema').setLevel(logging.WARNING) # get rid of that annoying xmlschema warning
# # logging.getLogger('qdarktheme').setLevel(logging.WARNING)
import sys
from .logger import setup_logger
omnipose_logger = setup_logger('utils')
# No reason to support anything but pytorch for omnipose
# I want it to fail here otherwise, much easier to debug
import torch
TORCH_ENABLED = True
# the following is duplicated but I cannot import cellpose, circular import issue
import platform
ARM = 'arm' in platform.processor() # the backend chack for apple silicon does not work on intel macs
try: #backends not available in order versions of torch
ARM = torch.backends.mps.is_available() and ARM
except:
ARM = False
torch_GPU = torch.device('mps') if ARM else torch.device('cuda')
torch_CPU = torch.device('cpu')
[docs]def find_files(directory, suffix, exclude_suffixes=[]):
for root, dirs, files in os.walk(directory):
for basename in files:
name, ext = os.path.splitext(basename)
if name.endswith(suffix) and not any(name.endswith(exclude) for exclude in exclude_suffixes):
filename = os.path.join(root, basename)
yield filename
[docs]def findbetween(s,string1='[',string2=']'):
"""Find text between string1 and string2."""
return re.findall(str(re.escape(string1))+"(.*)"+str(re.escape(string2)),s)[0]
[docs]def getname(path,prefix='',suffix='',padding=0):
"""Extract the file name."""
return os.path.splitext(Path(path).name)[0].replace(prefix,'').replace(suffix,'').zfill(padding)
[docs]def to_16_bit(im):
"""Rescale image [0,2^16-1] and then cast to uint16."""
return np.uint16(rescale(im)*(2**16-1))
[docs]def to_8_bit(im):
"""Rescale image [0,2^8-1] and then cast to uint8."""
return np.uint8(rescale(im)*(2**8-1))
### This section defines the tiling functions
[docs]def get_module(x):
if isinstance(x, (np.ndarray, tuple, int, float, dask.array.Array)) or np.isscalar(x):
return np
elif torch.is_tensor(x):
return torch
else:
raise ValueError("Input must be a numpy array, a tuple, a torch tensor, an integer, or a float")
[docs]def get_flip(idx):
module = get_module(idx)
return tuple([slice(None,None,None) if i%2 else
slice(None,None,-1) for i in idx])
def _taper_mask_ND(shape=(224,224), sig=7.5):
dim = len(shape)
bsize = max(shape)
xm = np.arange(bsize)
xm = np.abs(xm - xm.mean())
# 1D distribution
mask = 1/(1 + np.exp((xm - (bsize/2-20)) / sig))
# extend to ND
for j in range(dim-1):
mask = mask * mask[..., np.newaxis]
slc = tuple([slice(bsize//2-s//2,bsize//2+s//2+s%2) for s in shape])
mask = mask[slc]
return mask
[docs]def unaugment_tiles_ND(y, inds, unet=False):
""" reverse test-time augmentations for averaging
Parameters
----------
y: float32
array of shape (ntiles, nchan, *DIMS)
where nchan = (*DP,distance) (and boundary if nlasses=3)
unet: bool (optional, False)
whether or not unet output or cellpose output
Returns
-------
y: float32
"""
module = get_module(y)
dim = len(inds[0])
for i,idx in enumerate(inds):
flip = get_flip(idx)
factor = module.array([1 if i%2 else -1 for i in idx])
y[i] = y[i][(Ellipsis,)+flip]
if not unet:
y[i][:dim] = [s*f for s,f in zip(y[i][:dim],factor)]
return y
[docs]def average_tiles_ND(y,subs,shape):
""" average results of network over tiles
Parameters
-------------
y: float, [ntiles x nclasses x bsize x bsize]
output of cellpose network for each tile
subs : list
list of slices for each subtile
shape : int, list or tuple
shape of pre-tiled image (may be larger than original image if
image size is less than bsize)
Returns
-------------
yf: float32, [nclasses x Ly x Lx]
network output averaged over tiles
"""
module = get_module(y)
is_torch = module.__name__ == 'torch'
if is_torch:
params = {'device':y.device,'dtype':torch.float32}
else:
params = {'dtype':np.float32}
Navg = module.zeros(shape,**params)
yf = module.zeros((y.shape[1],)+shape, **params)
mask = _taper_mask_ND(y.shape[-len(shape):])
if is_torch:
mask = torch.tensor(mask,device=y.device)
for j,slc in enumerate(subs):
yf[(Ellipsis,)+slc] += y[j] * mask
Navg[slc] += mask
yf /= Navg
return yf
[docs]def make_tiles_ND(imgi, bsize=224, augment=False, tile_overlap=0.1,
normalize=True, return_tiles=True):
""" make tiles of image to run at test-time
if augmented, tiles are flipped and tile_overlap=2.
* original
* flipped vertically
* flipped horizontally
* flipped vertically and horizontally
Parameters
----------
imgi : float32
array that's nchan x Ly x Lx
bsize : float (optional, default 224)
size of tiles
augment : bool (optional, default False)
flip tiles and set tile_overlap=2.
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles
Returns
-------
IMG : float32
tensor of shape ntiles,nchan,bsize,bsize
subs : list
list of slices for each subtile
shape : tuple
shape of original image
"""
module = get_module(imgi)
nchan = imgi.shape[0]
shape = imgi.shape[1:]
dim = len(shape)
inds = []
if augment:
bsize = int(bsize)
pad_seq = [(0,0)]+[(0,max(0,bsize-s))for s in shape]
imgi = module.pad(imgi,pad_seq)
shape = imgi.shape[-dim:]
ntyx = [max(2, int(module.ceil(2. * s / bsize))) for s in shape]
start = [module.linspace(0, s-bsize, n).astype(int) for s,n in zip(shape,ntyx)]
intervals = [[slice(si,si+bsize) for si in s] for s in start]
subs = list(itertools.product(*intervals))
indexes = [module.arange(len(s)) for s in start]
inds = list(itertools.product(*indexes))
IMG = []
for slc,idx in zip(subs,inds):
flip = get_flip(idx)
IMG.append(imgi[(Ellipsis,)+slc][(Ellipsis,)+flip])
IMG = module.stack(IMG)
else:
tile_overlap = min(0.5, max(0.05, tile_overlap))
bbox = tuple([int(min(bsize,s)) for s in shape])
ntyx = [1 if s<=bsize else int(np.ceil((1.+2*tile_overlap) * s / bsize))
for s in shape]
start = [np.linspace(0, s-b, n).astype(int) for s,b,n in zip(shape,bbox,ntyx)]
intervals = [[slice(si,si+bsize) for si in s] for s in start]
subs = list(itertools.product(*intervals))
if return_tiles:
IMG = module.stack([imgi[(Ellipsis,)+slc] for slc in subs])
if normalize:
omnipose_logger.info('Running on tiles. Now normalizing each tile separately.')
IMG = normalize99(IMG,dim=0)
else:
omnipose_logger.info('rescaling stack as a whole')
IMG = rescale(IMG)
else:
IMG = None
return IMG, subs, shape, inds
[docs]def generate_slices(image_shape, crop_size):
"""Generate slices for cropping an image into crops of size crop_size."""
num_crops = [math.ceil(s / crop_size) for s in image_shape]
I,J = range(num_crops[0]),range(num_crops[1])
slices = [[[] for j in J] for i in I]
for i in I:
row_start = i * crop_size
row_end = min((i + 1) * crop_size, image_shape[0])
for j in J:
col_start = j * crop_size
col_end = min((j + 1) * crop_size, image_shape[1])
# slices.append((slice(row_start, row_end), slice(col_start, col_end)))
slices[i][j] = (slice(row_start, row_end), slice(col_start, col_end))
return slices, num_crops
[docs]def shifts_to_slice(shifts,shape):
"""
Find the minimal crop box from time lapse registraton shifts.
"""
# max_shift = np.max(shifts,axis=0)
# min_shift = np.min(shifts,axis=0)
# slc = tuple([slice(np.maximum(0,0+int(mn)),np.minimum(s,s-int(mx))) for mx,mn,s in zip(np.flip(max_shift),np.flip(min_shift),shape)])
# slc = tuple([slice(np.maximum(0,0+int(mn)),np.minimum(s,s-int(mx))) for mx,mn,s in zip(max_shift,min_shift,shape)])
upper_shift = np.min(shifts,axis=0)
lower_shift = np.max(shifts,axis=0)
slc = tuple([slice(np.maximum(0,0+int(l)),np.minimum(s,s-int(u))) for u,l,s in zip(upper_shift,lower_shift,shape)])
return slc
[docs]def make_unique(masks):
"""Relabel stack of label matrices such that there is no repeated label across slices."""
masks = masks.copy().astype(np.uint32)
T = range(len(masks))
offset = 0
for t in T:
# f = format_labels(masks[t],clean=True)
fastremap.renumber(masks[t],in_place=True)
masks[t][masks[t]>0]+=offset
offset = masks[t].max()
return masks
# import imreg_dft
[docs]def cross_reg(imstack,upsample_factor=100,order=1,
target_image=None,
normalization=None,cval=None,
prefilter=True,reverse=True):
"""
Find the transformation matrices for all images in a time series to align to the beginning frame.
"""
# this is a super important preprocessing step for registration to work
im_to_reg = np.stack([i/gaussian_filter(i,5) for i in imstack])
dim = imstack.ndim - 1 # dim is spatial, assume first dimension is t
s = np.zeros(dim)
shape = imstack.shape[-dim:]
regstack = np.zeros_like(imstack)
shifts = np.zeros((len(imstack),dim))
for i,im in enumerate(im_to_reg[::-1] if reverse else im_to_reg):
if target_image is None:
ref = regstack[i-1] if i>0 else im
else:
ref = target_image
# reference_mask=~np.isnan(ref)
# moving_mask=~np.isnan(im)
# pad = 1
# shift = phase_cross_correlation(np.pad(ref,pad), np.pad(im,pad),
shift = phase_cross_correlation(ref,im,
upsample_factor=upsample_factor,
# return_error = False,
normalization=normalization)[0]
# reference_mask=reference_mask,
# moving_mask=moving_mask)
# shift = imreg_dft.imreg.translation(ref,im)['tvec']
shifts[i] = shift
regstack[i] = im_shift(imstack[i], shift, order=order, prefilter=prefilter,
mode='nearest' if cval is None else 'constant',
cval=np.nanmean(imstack[i]) if cval is None else cval)
if reverse:
return shifts[::-1], regstack[::-1]
else:
return shifts,regstack
[docs]def shift_stack(imstack, shifts, order=1, cval=None):
"""
Shift each time slice of imstack according to list of 2D shifts.
"""
print('TODO: parallize this on dask or GPU')
regstack = np.zeros_like(imstack)
for i in range(len(shifts)):
regstack[i] = im_shift(imstack[i],shifts[i],order=order,
mode='nearest' if cval is None else 'constant',
cval=np.nanmean(imstack[i]) if cval is None else cval)
return regstack
# GPU version
import torch
import torch.fft
# def phase_cross_correlation_GPU(target, moving_images):
# # Assuming target is a 2D tensor [height, width]
# # and moving_images is a 3D tensor [num_images, height, width]
# # Expand dims of target to match moving_images
# target = target.unsqueeze(0)
# # print(target.shape,moving_images.shape)
# # Compute FFT of images
# target_fft = torch.fft.fftn(target, dim=[-2, -1])
# moving_fft = torch.fft.fftn(moving_images, dim=[-2, -1])
# # print(target_fft.shape,moving_fft.shape)
# # Compute cross-correlation by multiplying with complex conjugate
# cross_corr = torch.fft.ifftn(target_fft * moving_fft.conj(), dim=[-2, -1]).real
# # Find peak in cross-correlation
# max_indices = torch.argmax(cross_corr.view(cross_corr.shape[0], -1), dim=1)
# # Convert flat indices to 2D indices
# height = cross_corr.shape[-2]
# width = cross_corr.shape[-1]
# shifts_y = max_indices // width
# shifts_x = max_indices % width
# # Adjust shifts to fall within the correct range
# # make sure shift vector points in the right direction
# shifts_y = height // 2 - (shifts_y + height // 2) % height
# shifts_x = width // 2 - (shifts_x + width // 2) % width
# # Combine shifts along both dimensions into a single tensor
# shifts = torch.stack([shifts_y, shifts_x], dim=-1)
# return shifts
# def phase_cross_correlation_GPU(image_stack, target_index):
# # Assuming image_stack is a 3D tensor [num_images, height, width]
# # and target_index is an integer
# target_image = image_stack[target_index].unsqueeze(0)
# moving_images = torch.cat([image_stack[:target_index], image_stack[target_index+1:]])
# target_fft = torch.fft.fftn(target_image, dim=[-2, -1])
# moving_fft = torch.fft.fftn(moving_images, dim=[-2, -1])
# cross_corr = torch.fft.ifftn(target_fft * moving_fft.conj(), dim=[-2, -1]).real
# max_indices = torch.argmax(cross_corr.view(cross_corr.shape[0], -1), dim=1)
# height = cross_corr.shape[-2]
# width = cross_corr.shape[-1]
# shifts_y = max_indices // width
# shifts_x = max_indices % width
# shifts_y = height // 2 - (shifts_y + height // 2) % height
# shifts_x = width // 2 - (shifts_x + width // 2) % width
# shifts = torch.stack([shifts_y, shifts_x], dim=-1)
# # Insert a zero shift at the target index
# zero_shift = torch.zeros(1, 2, device=image_stack.device)
# shifts = torch.cat([shifts[:target_index], zero_shift, shifts[target_index:]])
# return shifts.long()
# import torch.nn.functional as F
# def phase_cross_correlation_GPU(image_stack, target_index, upsample_factor=1):
# # Assuming image_stack is a 3D tensor [num_images, height, width]
# # and target_index is an integer
# target_image = image_stack[target_index].unsqueeze(0)
# moving_images = torch.cat([image_stack[:target_index], image_stack[target_index+1:]])
# target_fft = torch.fft.fftn(target_image, dim=[-2, -1])
# moving_fft = torch.fft.fftn(moving_images, dim=[-2, -1])
# cross_corr = torch.fft.ifftn(target_fft * moving_fft.conj(), dim=[-2, -1]).real
# # Upsample cross correlation to achieve subpixel precision
# if upsample_factor > 1:
# cross_corr = cross_corr.unsqueeze(1)
# print('cc',cross_corr.shape)
# cross_corr = F.interpolate(cross_corr, scale_factor=upsample_factor,
# mode='bilinear', align_corners=False)
# print('cc',cross_corr.shape)
# cross_corr = cross_corr.squeeze(1)
# max_indices = torch.argmax(cross_corr.view(cross_corr.shape[0], -1), dim=1)
# height = cross_corr.shape[-2]
# width = cross_corr.shape[-1]
# shifts_y = max_indices // width
# shifts_x = max_indices % width
# shifts_y = height // 2 - (shifts_y + height // 2) % height
# shifts_x = width // 2 - (shifts_x + width // 2) % width
# # Convert shifts back to original pixel grid
# shifts_y = shifts_y / upsample_factor
# shifts_x = shifts_x / upsample_factor
# shifts = torch.stack([shifts_y, shifts_x], dim=-1)
# # Insert a zero shift at the target index
# zero_shift = torch.zeros(1, 2, device=image_stack.device)
# shifts = torch.cat([shifts[:target_index], zero_shift, shifts[target_index:]])
# return shifts
# def phase_cross_correlation_GPU(image_stack, target_index, upsample_factor=10):
# # Assuming image_stack is a 3D tensor [num_images, height, width]
# # and target_index is an integer
# # Upsample the images
# image_stack = F.interpolate(image_stack.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear', align_corners=False).squeeze(1)
# target_image = image_stack[target_index].unsqueeze(0)
# moving_images = torch.cat([image_stack[:target_index], image_stack[target_index+1:]])
# target_fft = torch.fft.fftn(target_image, dim=[-2, -1])
# moving_fft = torch.fft.fftn(moving_images, dim=[-2, -1])
# cross_corr = torch.fft.ifftn(target_fft * moving_fft.conj(), dim=[-2, -1]).real
# max_indices = torch.argmax(cross_corr.view(cross_corr.shape[0], -1), dim=1)
# height = cross_corr.shape[-2]
# width = cross_corr.shape[-1]
# shifts_y = max_indices // width
# shifts_x = max_indices % width
# shifts_y = height // 2 - (shifts_y + height // 2) % height
# shifts_x = width // 2 - (shifts_x + width // 2) % width
# # Convert shifts back to original pixel grid
# shifts_y = shifts_y / upsample_factor
# shifts_x = shifts_x / upsample_factor
# shifts = torch.stack([shifts_y, shifts_x], dim=-1)
# # Insert a zero shift at the target index
# zero_shift = torch.zeros(1, 2, device=image_stack.device)
# shifts = torch.cat([shifts[:target_index], zero_shift, shifts[target_index:]])
# return shifts.float()
[docs]def gaussian_kernel(size: int, sigma: float):
"""Creates a 2D Gaussian kernel with mean 0.
Args:
size (int): The size of the kernel. Should be an odd number.
sigma (float): The standard deviation of the Gaussian distribution.
Returns:
torch.Tensor: The Gaussian kernel.
"""
coords = torch.arange(size).float() - size // 2
g = torch.exp(-(coords**2) / (2 * sigma**2))
g /= g.sum()
return g.outer(g)
[docs]def apply_gaussian_blur(image, kernel_size, sigma):
"""Applies a Gaussian blur to the image.
Args:
image (torch.Tensor): The image to blur.
kernel_size (int): The size of the Gaussian kernel.
sigma (float): The standard deviation of the Gaussian distribution.
Returns:
torch.Tensor: The blurred image.
"""
kernel = gaussian_kernel(kernel_size, sigma).unsqueeze(0).unsqueeze(0)
image = image.unsqueeze(0).unsqueeze(0)
# Apply 'reflect' padding to the image
padding_size = kernel_size // 2
image = F.pad(image, (padding_size, padding_size, padding_size, padding_size), mode='reflect')
# Perform the convolution without additional padding
blurred = F.conv2d(image, kernel, padding=0)
return blurred.squeeze(0).squeeze(0)
[docs]def phase_cross_correlation_GPU_old(image_stack, target_index=None, upsample_factor=10,
reverse=False,normalize=False):
# Assuming image_stack is a 3D tensor [num_images, height, width]
# and target_index is an integer or None for sequential registration
im_to_reg = torch.stack([i/apply_gaussian_blur(i, 5, 1) for i in image_stack])
# Upsample the images
image_stack = F.interpolate(im_to_reg.unsqueeze(1).float(),
scale_factor=upsample_factor, mode='bilinear',
align_corners=False).squeeze(1)
# Initialize shifts with a zero shift for the first image
# shifts = [[0, 0]]
shifts = []
for i in range(1, len(image_stack)):
if target_index is None:
# Sequential registration
# target_image = image_stack[i-1]
if reverse:
# Reverse registration
target_image = image_stack[i+1] if i < len(image_stack) - 1 else image_stack[i]
else:
# Sequential registration
target_image = image_stack[i-1] if i > 0 else image_stack[i]
else:
# Target registration
target_image = image_stack[target_index]
moving_image = image_stack[i]
# target_fft = torch.fft.fftn(target_image.unsqueeze(0), dim=[-2, -1])
# moving_fft = torch.fft.fftn(moving_image.unsqueeze(0), dim=[-2, -1])
target_fft = torch.fft.fftn(target_image, dim=[-2, -1])
moving_fft = torch.fft.fftn(moving_image, dim=[-2, -1])
# Compute the cross-power spectrum
cross_power_spectrum = target_fft * moving_fft.conj()
# Normalize the cross-power spectrum if the normalize option is True
if normalize:
cross_power_spectrum /= torch.abs(cross_power_spectrum)
cross_corr = torch.abs(torch.fft.ifftn(cross_power_spectrum, dim=[-2, -1]))
print('cc',cross_corr.shape)
max_index = torch.argmax(cross_corr.view(-1))
height = cross_corr.shape[-2]
width = cross_corr.shape[-1]
shift_y = max_index // width
shift_x = max_index % width
shift_y = height // 2 - (shift_y + height // 2) % height
shift_x = width // 2 - (shift_x + width // 2) % width
# Convert shifts back to original pixel grid
shift_y = shift_y / upsample_factor
shift_x = shift_x / upsample_factor
shifts.append([shift_y, shift_x])
shifts.append([0,0])
shifts = torch.tensor(shifts, device=image_stack.device)*(-2)
# Subtract the average shift from all shifts to minimize the total shift
# avg_shift = shifts.mean(dim=0)
# shifts -= avg_shift
# shifts = torch.cumsum(shifts,dim=0)
return shifts.float()
# return accumulated_shifts
[docs]def phase_cross_correlation_GPU(image_stack,
upsample_factor=10,
# normalization='phase'
normalization=None,
):
# Assuming image_stack is a 3D tensor [num_images, height, width]
# Upsample the images
# image_stack = F.interpolate(image_stack.unsqueeze(1).float(),
# scale_factor=upsample_factor, mode='bilinear',
# align_corners=False).squeeze(1)
# m = torch.nn.Upsample(scale_factor=tuple([upsample_factor,upsample_factor]),mode='bilinear')
# image_stack = m(image_stack.float().unsqueeze(1)).squeeze(1)
device = image_stack.device
im_to_reg = torch.stack([i/apply_gaussian_blur(i, 9, 3) for i in image_stack.float()])
# im_to_reg = image_stack
# Compute the FFT of the images
norm='backward'
image_fft = torch.fft.fft2(im_to_reg,norm=norm)#, dim=[-2, -1])
# Compute the cross-power spectrum for each pair of images
cross_power_spectrum = image_fft[:-1] * image_fft[1:].conj()
# Normalize the cross-power spectrum
if normalization == 'phase':
cross_power_spectrum /= torch.abs(cross_power_spectrum)#+1e-6
# Compute the cross-correlation by taking the inverse FFT
cross_corr = torch.abs(torch.fft.ifft2(cross_power_spectrum,norm=norm)) #, dim=[-2, -1])
m = torch.nn.Upsample(scale_factor=upsample_factor,mode='bilinear')
cross_corr = m(cross_corr.unsqueeze(1)).squeeze(1)
# Find the shift for each pair of images
max_indices = torch.argmax(cross_corr.view(cross_corr.shape[0], -1), dim=-1).float()
shifts_y, shifts_x = (max_indices / cross_corr.shape[-1]).long(), (max_indices % cross_corr.shape[-1]).long()
# Stack the shifts and append a [0, 0] shift at the beginning
# shifts = torch.stack([shifts_y, shifts_x]).T
shifts = 2*torch.stack([shifts_y, shifts_x]).T
zero_shift = torch.zeros(1, 2, dtype=shifts.dtype, device=shifts.device)
shifts = torch.cat([shifts,zero_shift], dim=0) / upsample_factor
# Accumulate the shifts - SUPER important and was the cause of the bug
shifts = torch.cumsum(shifts.flip(dims=[0]),dim=0).flip(dims=[0])
# Subtract the average shift from all shifts to minimize the total shift
avg_shift = shifts.mean(dim=0)
shifts -= avg_shift
# should replace shift by making it so that the shifts are closest to pixel shifts?
return shifts
[docs]def pairwise_registration(image_stack, upsample_factor=10):
im_to_reg = torch.stack([i/apply_gaussian_blur(i, 5, 5) for i in image_stack])
# Upsample the images
image_stack = F.interpolate(im_to_reg.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear', align_corners=False).squeeze(1)
num_images = len(image_stack)
shifts = torch.zeros((num_images, num_images, 2), device=image_stack.device)
for i in range(num_images):
for j in range(i+1, num_images):
target_image = image_stack[i]
moving_image = image_stack[j]
target_fft = torch.fft.fftn(target_image.unsqueeze(0), dim=[-2, -1])
moving_fft = torch.fft.fftn(moving_image.unsqueeze(0), dim=[-2, -1])
cross_corr = torch.fft.ifftn(target_fft * moving_fft.conj(), dim=[-2, -1]).real
max_index = torch.argmax(cross_corr.view(-1))
height = cross_corr.shape[-2]
width = cross_corr.shape[-1]
shift_y = max_index // width
shift_x = max_index % width
shift_y = height // 2 - (shift_y + height // 2) % height
shift_x = width // 2 - (shift_x + width // 2) % width
# Convert shifts back to original pixel grid
shift_y = shift_y / upsample_factor
shift_x = shift_x / upsample_factor
shifts[i, j] = torch.tensor([shift_y, shift_x])
shifts[j, i] = torch.tensor([-shift_y, -shift_x]) # Reverse shift for the opposite direction
# return shifts
# Compute final shifts
final_shifts = compute_final_shifts(shifts)
final_shifts = torch.cumsum(final_shifts, dim=0)
return final_shifts
import networkx as nx
[docs]def compute_final_shifts(pairwise_shifts):
# Create a graph where each node is an image and each edge is a shift
G = nx.Graph()
num_images = pairwise_shifts.shape[0]
for i in range(num_images):
for j in range(i+1, num_images):
shift = pairwise_shifts[i, j]
# Add an edge between image i and image j with weight equal to the magnitude of the shift
G.add_edge(i, j, weight=torch.norm(shift), shift=shift)
# Compute the minimum spanning tree of the graph
mst = nx.minimum_spanning_tree(G)
# Initialize final shifts with zeros
final_shifts = torch.zeros((num_images, 2), device=pairwise_shifts.device)
# Use a DFS to compute the shifts of all images relative to the reference
for edge in nx.dfs_edges(mst, source=0):
i, j = edge
shift = mst.edges[i, j]['shift']
final_shifts[j] = final_shifts[i] + shift
return final_shifts
[docs]def apply_shifts(moving_images, shifts):
# Assuming moving_images is a 3D tensor [num_images, height, width]
# and shifts is a 2D tensor [num_images, 2] (y, x)
N, H, W = moving_images.shape
# Normalize the shifts to be in the range [-1, 1]
shifts = shifts / torch.tensor([H, W]).to(shifts.device)
# Create a grid of indices
grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W))
grid_y = grid_y.to(shifts.device).float()
grid_x = grid_x.to(shifts.device).float()
# Normalize the grid to be in the range [-1, 1]
grid_y = 2.0 * grid_y / (H - 1) - 1.0
grid_x = 2.0 * grid_x / (W - 1) - 1.0
# Apply the shifts to the grid of indices
grid_y = grid_y[None] + shifts[:, 0][:, None, None]
grid_x = grid_x[None] + shifts[:, 1][:, None, None]
# Stack the grids to create a [N, H, W, 2] grid
grid = torch.stack([grid_x, grid_y], dim=-1)
# Use the shifted grid of indices to index into moving_images
intersection = F.grid_sample(moving_images.unsqueeze(1), grid, align_corners=False)
return intersection.squeeze(1)
[docs]def unravel_index(index, shape):
out = []
for dim in reversed(shape):
out.append(index % dim)
index = index // dim
return tuple(reversed(out))
[docs]def normalize_field(mu,use_torch=False,cutoff=0):
""" normalize all nonzero field vectors to magnitude 1
Parameters
----------
mu: ndarray, float
Component array of lenth N by L1 by L2 by ... by LN.
Returns
--------------
normalized component array of identical size.
"""
if use_torch:
mag = torch_norm(mu,dim=0)
# out = torch.zeros_like(mu)
# sel = mag>cutoff
# out[:,sel] = torch.div(mu[:,sel],mag[sel])
# return out
# return torch.where(mag>cutoff,mu/mag,torch.zeros_like(mu))
return torch.where(mag>cutoff,mu/mag,mu)
else:
mag = np.sqrt(np.nansum(mu**2,axis=0))
return safe_divide(mu,mag,cutoff)
# @torch.jit.script
[docs]def torch_norm(a,dim=0,keepdim=False):
if ARM:
#torch.linalg.norm not implemented on MPS yet
# this is the fastest I have tested but still slow in comparison
return a.square().sum(dim=dim,keepdim=keepdim).sqrt()
else:
return torch.linalg.norm(a,dim=dim,keepdim=keepdim)
# @njit
# def safe_divide(num,den,cutoff=0):
# """ Division ignoring zeros and NaNs in the denominator."""
# # module = get_module(num) # assume num and den are the same type
# # return np.divide(num, den, out=np.zeros_like(num),
# # where=np.logical_and(den>cutoff,~np.isnan(den)))
# if isinstance(num, np.ndarray):
# return np.divide(num, den, out=np.zeros_like(num),
# where=np.logical_and(den>cutoff,~np.isnan(den)))
# elif isinstance(num, torch.Tensor):
# return torch.where((den > cutoff) & torch.isfinite(den), num / den, torch.zeros_like(num))
# else:
# raise TypeError("num must be a numpy array or a PyTorch tensor")
# def safe_divide(num, den, cutoff=0):
# """ Division ignoring zeros and NaNs in the denominator."""
# module = get_module(num)
# return module.where((den > cutoff) & module.isfinite(den) & (den>0),
# module.divide(num, den),
# module.zeros_like(num))
# def safe_divide(num, den, cutoff=0):
# """ Division ignoring zeros and NaNs in the denominator."""
# module = get_module(num)
# valid_den = (den > cutoff) & module.isfinite(den) #isfinite catches both nan and inf
# return module.divide(num, den, out=module.zeros_like(num,dtype=module.float64), where=valid_den, rounding_mode='trunc')
[docs]def safe_divide(num, den, cutoff=0):
""" Division ignoring zeros and NaNs in the denominator."""
module = get_module(num)
valid_den = (den > cutoff) & module.isfinite(den) #isfinite catches both nan and inf
if module == np:
return np.divide(num, den, out=np.zeros_like(num, dtype=np.float32), where=valid_den)
elif module == torch:
return torch.where(valid_den, num / den, torch.zeros_like(num, dtype=torch.float32))
else:
raise TypeError("num must be a numpy array or a PyTorch tensor")
[docs]def normalize99(Y, lower=0.01, upper=99.99, contrast_limits=None, dim=None):
""" normalize array/tensor so 0.0 is 0.01st percentile and 1.0 is 99.99th percentile
Upper and lower percentile ranges configurable.
Parameters
----------
Y: ndarray/tensor, float
Input array/tensor.
upper: float
upper percentile above which pixels are sent to 1.0
lower: float
lower percentile below which pixels are sent to 0.0
contrast_limits: list, float (optional, override computation)
list of two floats, lower and upper contrast limits
Returns
--------------
normalized array/tensor with a minimum of 0 and maximum of 1
"""
module = get_module(Y)
if contrast_limits is None:
quantiles = np.array([lower, upper]) / 100
if module == torch:
quantiles = torch.tensor(quantiles, dtype=Y.dtype, device=Y.device)
if dim is not None:
# Reshape Y into a 2D tensor for quantile computation
Y_flattened = Y.reshape(Y.shape[dim], -1)
lower_val, upper_val = module.quantile(Y_flattened, quantiles, axis=-1)
# Reshape back into original shape for broadcasting
if dim == 0:
lower_val = lower_val.reshape(Y.shape[dim], *([1] * (len(Y.shape) - 1)))
upper_val = upper_val.reshape(Y.shape[dim], *([1] * (len(Y.shape) - 1)))
else:
lower_val = lower_val.reshape(*Y.shape[:dim], *([1] * (len(Y.shape) - dim - 1)))
upper_val = upper_val.reshape(*Y.shape[:dim], *([1] * (len(Y.shape) - dim - 1)))
else:
# lower_val, upper_val = module.quantile(Y, quantiles)
try:
lower_val, upper_val = module.quantile(Y, quantiles)
except RuntimeError:
lower_val, upper_val = auto_chunked_quantile(Y, quantiles)
else:
if module == np:
contrast_limits = np.array(contrast_limits)
elif module == torch:
contrast_limits = torch.tensor(contrast_limits)
lower_val, upper_val = contrast_limits
# Y = module.clip(Y, lower_val, upper_val) # is this needed?
# Y -= lower_val
# Y /= (upper_val - lower_val)
# return Y
# return (Y-lower_val)/(upper_val-lower_val)
# return module.clip((Y-lower_val)/(upper_val-lower_val),0,1)
# return module.clip((Y-lower_val)/(upper_val-lower_val),0,1)
# in this case, since lower_val is not the absolute minimum, but the lowerr quanitle,
# Y-lower_val can be less than zero. Likewise for the upward scalimg being slightly >1.
return module.clip(safe_divide(Y-lower_val,upper_val-lower_val),0,1)
[docs]def auto_chunked_quantile(tensor, q):
# Determine the maximum number of elements that can be handled by PyTorch's quantile function
max_elements = 16e6 - 1
# Determine the number of elements in the tensor
num_elements = tensor.nelement()
# Determine the chunk size
chunk_size = math.ceil(num_elements / max_elements)
# Split the tensor into chunks
chunks = torch.chunk(tensor, chunk_size)
# Compute the quantile for each chunk
return torch.stack([torch.quantile(chunk, q) for chunk in chunks]).mean(dim=0)
# # Usage:
# large_tensor = torch.rand(int(2000e4))
# q = torch.tensor([0.2, 0.9])
# quantiles = auto_chunked_quantile(large_tensor, q)
[docs]def normalize_image(im, mask, target=0.5, foreground=False, iterations=1, scale=1, channel_axis=0):
"""
Normalize image by rescaling from 0 to 1 and then adjusting gamma to bring
average background to specified value (0.5 by default).
Parameters
----------
im: ndarray, float
input image or volume
mask: ndarray, int or bool
input labels or foreground mask
target: float
target background/foreground value in the range 0-1
channel_axis: int
the axis that contains the channels
Returns
--------------
gamma-normalized array with a minimum of 0 and maximum of 1
"""
im = rescale(im) * scale
if im.ndim > 2: # assume last axis is channel axis
im = np.moveaxis(im, channel_axis, -1) # move channels to last axis
im = [im[..., i] for i in range(im.shape[-1])] # break into a list of channels
else:
im = [im]
if not isinstance(mask, list):
mask = [mask] * len(im)
for k in range(len(mask)):
bin0 = binary_erosion(mask[k]>0 if foreground else mask[k] == 0, iterations=iterations)
source_target = np.mean(im[k][bin0])
im[k] = im[k] ** (np.log(target) / np.log(source_target))
return np.stack(im, axis=channel_axis).squeeze()
import ncolor
[docs]def mask_outline_overlay(img,masks,outlines,mono=None):
"""
Apply a color overlay to a grayscale image based on a label matrix.
mono is a single color to use. Otherwise, N sinebow colors are used.
"""
if mono is None:
m,n = ncolor.label(masks,max_depth=20,return_n=True)
c = sinebow(n)
colors = np.array(list(c.values()))[1:]
else:
colors = mono
m = masks>0
if img.ndim == 3:
im = rescale(color.rgb2gray(img))
else:
im = img
overlay = color.label2rgb(m,im,colors,
bg_label=0,
alpha=np.stack([((m>0)*1.+outlines*0.75)/3]*3,axis=-1))
return overlay
[docs]def mono_mask_bd(masks,outlines,color=[1,0,0],a=0.25):
m = masks>0
alpha = (m>0)*a+outlines*(1-a)
return np.stack([m*c for c in color]+[alpha],axis=-1)
[docs]def moving_average(x, w):
return convolve1d(x,np.ones(w)/w,axis=0)
# def rescale(T,floor=None,ceiling=None):
# """Rescale array between 0 and 1"""
# if ceiling is None:
# ceiling = T[:].max()
# if floor is None:
# floor = T[:].min()
# T = np.interp(T, (floor, ceiling), (0, 1))
# return T
[docs]def rescale(T, floor=None, ceiling=None, dim=None):
"""Rescale data between 0 and 1"""
# module = torch if isinstance(T, torch.Tensor) else np
module = get_module(T)
if dim is not None:
axes = tuple(i for i in range(T.ndim) if i != dim)
else:
axes = None
if ceiling is None:
ceiling = module.amax(T, axis=axes)
if dim is not None:
ceiling = ceiling.reshape(*[1 if i != dim else -1 for i in range(T.ndim)])
if floor is None:
floor = module.amin(T, axis=axes)
if dim is not None:
floor = floor.reshape(*[1 if i != dim else -1 for i in range(T.ndim)])
# T = (T - floor) / (ceiling - floor)
T = safe_divide(T - floor,ceiling - floor)
# T = module.clip((T - floor)/(ceiling - floor),0,1)
return T
[docs]def normalize_stack(vol,mask,bg=0.5,bright_foreground=None,
subtractive=False,iterations=1,equalize_foreground=1,quantiles=[0.01,0.99]):
"""
Adjust image stacks so that background is
(1) consistent in brightness and
(2) brought to an even average via semantic gamma normalization.
"""
# vol = rescale(vol)
vol = vol.copy()
# binarize background mask, recede from foreground, slice-wise to not erode in time
kwargs = {'iterations':iterations} if iterations>1 else {}
bg_mask = [binary_erosion(m==0,**kwargs) for m in mask]
# find mean backgroud for each slice
bg_real = [np.nanmean(v[m]) for v,m in zip(vol,bg_mask)]
# automatically determine if foreground objects are bright or dark
if bright_foreground is None:
bright_foreground = np.mean(vol[bg_mask]) < np.mean(vol[mask>0])
# if smooth:
# bg_real = moving_average(bg_real,5)
# some weird fluctuations happening with background being close to zero, but just on fluoresnece... might need to invert or go by foreground
bg_min = np.min(bg_real) # get the minimum one, want to normalize by lowest one
# normalize wrt background
if subtractive:
vol = np.stack([safe_divide(v-bg_r,bg_min) for v,bg_r in zip(vol,bg_real)])
else:
vol = np.stack([v*safe_divide(bg_min,bg_r) for v,bg_r in zip(vol,bg_real)])
# print('mm',vol.min(),vol.max(),bright_foreground)
# equalize foreground signal
if equalize_foreground:
q1,q2 = quantiles
if bright_foreground:
fg_real = [np.percentile(v[m>0],99.99) for v,m in zip(vol,mask)]
# fg_real = [v.max() for v,m in zip(vol,bg_mask)]
floor = np.percentile(vol[bg_mask],0.01)
vol = [rescale(v,ceiling=f, floor=floor) for v,f in zip(vol,fg_real)]
else:
fg_real = [np.quantile(v[m>0],q1) for v,m in zip(vol,mask)]
# fg_real = [.5]*(len(vol))
# ceiling = np.percentile(vol[bg_mask],99.99)
# print('hh',np.any(np.stack(fg_real)<0),np.any(np.stack(fg_real)>ceiling),ceiling,np.mean(fg_real))
# vol = [rescale(v,ceiling=ceiling,floor=f) for v,f in zip(vol,fg_real)]
# ceiling = [np.percentile(v[m],99.99) for v,m in zip(vol,mask==0)]#bg_mask
ceiling = np.quantile(vol,q2,axis=(-2,-1))
vol = [np.interp(v,(f, c), (0, 1)) for v,f,c in zip(vol,fg_real,ceiling)]
# print([(np.max(v),np.min(v)) for v,bg_m in zip(vol,bg_mask)])
vol = np.stack(vol)
# vol = rescale(vol) # now rescale by overall min and max
vol = np.stack([v**(np.log(bg)/np.log(np.mean(v[bg_m]))) for v,bg_m in zip(vol,bg_mask)]) # now can gamma normalize
return vol
[docs]def is_integer(var):
return isinstance(var, int) or isinstance(var, np.integer) or (isinstance(var, torch.Tensor) and var.is_integer())
[docs]def bbox_to_slice(bbox,shape,pad=0,im_pad=0):
"""
return the tuple of slices for cropping an image based on the skimage.measure bounding box
optional padding allows for the bounding box to be expanded, but not outside the original image dimensions
Parameters
----------
bbox: ndarray, float
input bounding box, e.g. [y0,x0,y1,x1]
shape: array, tuple, or list, int
shape of corresponding array to be sliced
pad: array, tuple, or list, int
padding to be applied to each axis of the bounding box
can be a common padding (5 means 5 on every side)
or a list of each axis padding ([3,4] means 3 on y and 4 on x).
N-volume requires an N-tuple.
im_pad: int
region around the edges to avoid (pull back coordinate limits)
Returns
--------------
tuple of slices
"""
dim = len(shape)
# if type(pad) is int:
if is_integer(pad):
pad = [pad]*dim
# if type(im_pad) is int:
if is_integer(im_pad):
im_pad = [im_pad]*dim
# return tuple([slice(int(max(0,bbox[n]-pad[n])),int(min(bbox[n+dim]+pad[n],shape[n]))) for n in range(len(bbox)//2)])
# added a +1 to stop, might be a necessary fix but not sure yet
# print('im_pad',im_pad, bbox, pad, shape)
one = 0
return tuple([slice(int(max(im_pad[n],bbox[n]-pad[n])),
int(min(bbox[n+dim]+pad[n]+one,shape[n]-im_pad[n])))
for n in range(len(bbox)//2)])
[docs]def crop_bbox(mask, pad=10, iterations=3, im_pad=0, area_cutoff=0,
max_dim=np.inf, get_biggest=False, binary=False):
"""Take a label matrix and return a list of bounding boxes identifying clusters of labels.
Parameters
--------------
mask: matrix of integer labels
pad: amount of space in pixels to add around the label (does not extend beyond image edges, will shrink for consistency)
iterations: number of dilation iterations to merge labels separated by this number of pixel or less
im_pad: amount of space to subtract off the label matrix edges
area_cutoff: label clusters below this area in square pixels will be ignored
max_dim: if a cluster is above this cutoff, quit and return the original image bounding box
Returns
---------------
slices: list of bounding box slices with padding
"""
bw = binary_dilation(mask>0,iterations=iterations) if iterations> 0 else mask>0
clusters = measure.label(bw)
regions = measure.regionprops(clusters)
sz = mask.shape
d = mask.ndim
# ylim = [im_pad,sz[0]-im_pad]
# xlim = [im_pad,sz[1]-im_pad]
slices = []
if get_biggest:
w = np.argmax([props.area for props in regions])
bbx = regions[w].bbox
minpad = min(pad,bbx[0],bbx[1],sz[0]-bbx[2],sz[1]-bbx[3])
# print(pad,bbx[0],bbx[1],sz[0]-bbx[2],sz[1]-bbx[3])
# print(minpad,sz,bbx)
slices.append(bbox_to_slice(bbx,sz,pad=minpad,im_pad=im_pad))
else:
for props in regions:
if props.area>area_cutoff:
bbx = props.bbox
minpad = min(pad,bbx[0],bbx[1],sz[0]-bbx[2],sz[1]-bbx[3])
# print(minpad,'m',im_pad)
slices.append(bbox_to_slice(bbx,sz,pad=minpad,im_pad=im_pad))
# merge into a single slice
if binary:
start_xy = np.min([[slc[i].start for i in range(d)] for slc in slices],axis=0)
stop_xy = np.max([[slc[i].stop for i in range(d)] for slc in slices],axis=0)
slices = tuple([slice(start,stop) for start,stop in zip(start_xy,stop_xy)])
return slices
[docs]def get_boundary(mask):
"""ND binary mask boundary using mahotas.
Parameters
----------
mask: ND array, bool
binary mask
Returns
--------------
Binary boundary map
"""
return np.logical_xor(mask,mh.morph.erode(mask))
# Omnipose version of remove_edge_masks, need to merge (this one is more flexible)
[docs]def clean_boundary(labels, boundary_thickness=3, area_thresh=30, cutoff=0.5):
"""Delete boundary masks below a given size threshold within a certain distance from the boundary.
Parameters
----------
boundary_thickness: int
labels within a stripe of this thickness along the boundary will be candidates for removal.
area_thresh: int
labels with area below this value will be removed.
cutoff: float
Fraction from 0 to 1 of the overlap with the boundary before the mask is removed. Default 0.5.
Set to 0 if you want any mask touching the boundary to be removed.
Returns
--------------
label matrix with small edge labels removed.
"""
border_mask = np.zeros(labels.shape, dtype=bool)
border_mask = binary_dilation(border_mask, border_value=1, iterations=boundary_thickness)
clean_labels = np.copy(labels)
for cell_ID in fastremap.unique(labels[border_mask])[1:]:
mask = labels==cell_ID
area = np.count_nonzero(mask)
overlap = np.count_nonzero(np.logical_and(mask, border_mask))
if overlap > 0 and area<area_thresh and overlap/area >= cutoff: #only remove cells that are X% or more edge px
clean_labels[mask] = 0
return clean_labels
# This function takes a few milliseconds for a typical image
[docs]def get_edge_masks(labels,dists):
"""Finds and returns masks that are largely cut off by the edge of the image.
This function loops over all masks touching the image boundary and compares the
maximum value of the distance field along the boundary to the top quartile of distance
within the mask. Regions whose edges just skim the image edge will not be classified as
an "edge mask" by this criteria, whereas masks cut off in their center (where distance is high)
will be returned as part of this output.
Parameters
----------
labels: ND array, int
label matrix
dists: ND array, float
distance field (calculated with reflection padding of labels)
Returns
--------------
clean_labels: ND array, int
label matrix of all cells qualifying as 'edge masks'
"""
border_mask = np.zeros(labels.shape, dtype=bool)
border_mask = binary_dilation(border_mask, border_value=1, iterations=1)
clean_labels = np.zeros_like(labels)
for cell_ID in fastremap.unique(labels[border_mask])[1:]:
mask = labels==cell_ID
max_dist = np.max(dists[np.logical_and(mask, border_mask)])
# mean_dist = np.mean(dists[mask])
dist_thresh = np.percentile(dists[mask],75)
# sort of a way to say the skeleton isn't touching the boundary
# top 25%
if max_dist>=dist_thresh: # we only want to keep cells whose distance at the boundary is not too small
clean_labels[mask] = cell_ID
return clean_labels
[docs]def border_indices(tyx):
"""Return flat indices of border values in ND. Use via A.flat[border_indices]."""
dim_indices = [np.arange(dim_size) for dim_size in tyx]
dim_indices = np.meshgrid(*dim_indices, indexing='ij')
dim_indices = [indices.ravel() for indices in dim_indices]
indices = []
for i in range(len(tyx)):
for j in [0, tyx[i] - 1]:
mask = (dim_indices[i] == j)
indices.append(np.where(mask)[0])
return np.concatenate(indices)
# @njit
# def get_neighbors(coords, steps, dim, shape, edges=None, pad=0):
# print('this version actually a lot slower than below ')
# if edges is None:
# edges = [np.array([-1,s]) for s in shape]
# npix = coords[0].shape[-1]
# neighbors = np.empty((dim, len(steps), npix), dtype=np.int64)
# for d in range(dim):
# for i, s in enumerate(steps):
# for j in range(npix):
# if ((coords[d][j] + s[d]) in edges[d]) and ((coords[d][j] + 2*s[d]) not in edges[d]):
# neighbors[d,i,j] = coords[d][j]
# else:
# neighbors[d,i,j] = coords[d][j] + s[d]
# return neighbors
# much faster
# @njit
# def isin_numba(x, y):
# result = np.zeros(x.shape, dtype=np.bool_)
# for i in range(x.size):
# result[i] = x[i] in y
# return result
# @njit
# def get_neighbors(coords, steps, dim, shape, edges=None):
# if edges is None:
# edges = [np.array([-1,s]) for s in shape]
# npix = coords[0].shape[-1]
# neighbors = np.empty((dim, len(steps), npix), dtype=np.int64)
# for d in range(dim):
# for i, s in enumerate(steps):
# X = coords[d] + s[d]
# mask = np.logical_and(isin_numba(X, edges[d]), ~isin_numba(X+s[d], edges[d]))
# neighbors[d,i] = np.where(mask, coords[d], X)
# return neighbors
# slightly faster than the jit code!
[docs]def get_neighbors(coords, steps, dim, shape, edges=None, pad=0):
"""
Get the coordinates of all neighbor pixels.
Coordinates of pixels that are out-of-bounds get clipped.
"""
if edges is None:
edges = [np.array([-1+pad,s-pad]) for s in shape]
npix = coords[0].shape[-1]
neighbors = np.empty((dim, len(steps), npix), dtype=np.int64)
for d in range(dim):
S = steps[:,d].reshape(-1, 1)
X = coords[d] + S
# mask = np.logical_and(np.isin(X, edges[d]), ~np.isin(X+S, edges[d]))
# out of bounds is where the shifted coordinate X is in the edges list
# that second criterion might have been for my batched stuff
# oob = np.logical_and(np.isin(X, edges[d]), ~np.isin(X+S, edges[d]))
oob = np.isin(X, edges[d])
# print('checkme f', pad,np.sum(oob))
C = np.broadcast_to(coords[d], X.shape)
neighbors[d] = np.where(oob, C, X)
# neighbors[d] = X
return neighbors
[docs]def get_neighbors_torch(input, steps):
"""This version not yet used/tested."""
# Get dimensions
B, D, *DIMS = input.shape
nsteps = steps.shape[0]
# Compute coordinates
coordinates = torch.stack(torch.meshgrid([torch.arange(dim) for dim in DIMS]), dim=0)
coordinates = coordinates.unsqueeze(0).expand(B, *[-1]*(D+1)) # Add batch dimension and repeat for batch
# Compute shifted coordinates
steps = steps.unsqueeze(-1).unsqueeze(-1).expand(nsteps, D, *DIMS).to(input.device)
shifted_coordinates = (coordinates.unsqueeze(1) + steps.unsqueeze(0))
# Clamp shifted_coordinates in-place
for d in range(D):
shifted_coordinates[:, :, d].clamp_(min=0, max=DIMS[d]-1)
return shifted_coordinates
# this version works without padding, should ultimately replace the other one in core
# @njit
[docs]def get_neigh_inds(neighbors,coords,shape,background_reflect=False):
"""
For L pixels and S steps, find the neighboring pixel indexes
0,1,...,L for each step. Background index is -1. Returns:
Parameters
----------
coords: tuple, int
coordinates of nonzero pixels, <dim>x<npix>
shape: tuple, int
shape of the image array
Returns
-------
indexes: 1D array
list of pixel indexes 0,1,...L-1
neigh_inds: 2D array
SxL array corresponding to affinity graph
ind_matrix: ND array
indexes inserted into the ND image volume
"""
neighbors = tuple(neighbors) # just in case I pass it as ndarray
npix = neighbors[0].shape[-1]
indexes = np.arange(npix)
ind_matrix = -np.ones(shape,int)
ind_matrix[tuple(coords)] = indexes
neigh_inds = ind_matrix[neighbors]
# If needed, we can do a similar thing I do at boundaries and make neighbor
# references to background redirect back to the edge pixel. However, this should
# not be default, since I rely on accurate neighbor indices later to test for background
# So, probably better to do this sort of thing while contructing the affinity graph itself
if background_reflect:
oob = np.nonzero(neigh_inds==-1) # 2 x nbad , pos 0 is the 0-step inds and pos 1 is the npix inds
neigh_inds[oob] = indexes[oob[1]] # reflect back to itself
ind_matrix[neighbors] = neigh_inds # update ind matrix as well
# should I also update neighbor coordinate array? No, that's more fixed.
# index points to the correct coordinate.
# not sure if -1 is general enough, probbaly should be since other adjacent masks will be unlinked
# can test it by adding some padding to the concatenation...
# also, the reflections should be happening at edges of the image, but it is not?
return indexes, neigh_inds, ind_matrix
# This might need some reflection added in for it to work
# also might need generalization to include cleaned mask pixels getting dropped
[docs]def subsample_affinity(augmented_affinity,slc,mask):
"""
Helper function to subsample an affinity graph according to an image crop slice
and a foreground selection mask.
Parameters
----------
augmented_affinity: NDarray, int64
Stacked neighbor coordinate array and affinity graph. For dimension d,
augmented_affinity[:d] are the neighbor coordinates of shape (d,3**d,npix)
and augmented_affinity[d] is the affinity graph of shape (3**d,npix).
slc: tuple, slice
tuple of slices along each dimension defining the crop window
mask: NDarray, bool
foreground selection mask, in the image space of the original graph
(i.e., not already sliced)
Returns
--------
Augmented affinity graph corresponding to the cropped/masked region.
"""
# From the augmented affinity graph we can extract a lot
nstep = augmented_affinity.shape[1]
dim = len(slc) # dimension
neighbors = augmented_affinity[:dim]
affinity_graph = augmented_affinity[dim]
idx = nstep//2
coords = neighbors[:,idx]
in_bounds = np.all(np.vstack([[c<s.stop, c>=s.start] for c,s in zip(coords,slc)]),axis=0)
in_mask = mask[tuple(coords)]>0
in_mask_and_bounds = np.logical_and(in_bounds,in_mask)
inds_crop = np.nonzero(in_mask_and_bounds)[0]
# print('y',len(inds_crop),np.sum(in_mask_and_bounds), np.sum(in_bounds), np.sum(in_mask))
if len(inds_crop):
crop_neighbors = neighbors[:,:,inds_crop]
affinity_crop = affinity_graph[:,inds_crop]
# shift coordinates back acording to the lower bound of the slice
# also refect at edges of the new bounding box
edges = [np.array([-1,s.stop-s.start]) for s in slc]
steps = get_steps(dim)
# I should see if I can get this batched somehow...
for d in range(dim):
crop_coords = coords[d,inds_crop] - slc[d].start
S = steps[:,d].reshape(-1, 1)
X = crop_coords + S # cropped coordinates
# edgemask = np.logical_and(np.isin(X, edges[d]), ~np.isin(X+S, edges[d]))
edgemask = np.isin(X, edges[d])
# print('checkthisttoo')
C = np.broadcast_to(crop_coords, X.shape)
crop_neighbors[d] = np.where(edgemask, C, X)
#return augmented affinity
return np.vstack((crop_neighbors,affinity_crop[np.newaxis]))
else:
e = np.empty((dim+1,nstep,0),dtype=augmented_affinity.dtype)
return e, []
[docs]@functools.lru_cache(maxsize=None)
def get_steps(dim):
"""
Get a symmetrical list of all 3**N points in a hypercube represented
by a list of all possible sequences of -1, 0, and 1 in ND.
1D: [[-1],[0],[1]]
2D: [[-1, -1],
[-1, 0],
[-1, 1],
[ 0, -1],
[ 0, 0],
[ 0, 1],
[ 1, -1],
[ 1, 0],
[ 1, 1]]
The opposite pixel at index i is always found at index -(i+1). The number
of possible face, edge, vertex, etc. connections grows exponentially with
dimension: 3 steps in 1D, 9 steps in 3D, 3**N in ND.
"""
neigh = [[-1,0,1] for i in range(dim)]
steps = cartesian(neigh) # all the possible step sequences in ND
return steps
# @functools.lru_cache(maxsize=None)
[docs]def steps_to_indices(steps):
"""
Get indices of the hupercubes sharing m-faces on the central n-cube. These
are sorted by the connectivity (by center, face, edge, vertex, ...). I.e.,
the central point index is first, followed by cardinal directions, ordinals,
and so on.
"""
# each kind of m-face can be categorized by the number of steps to get there
sign = np.sum(np.abs(steps),axis=1)
# we want to bin them into groups
# E.g., in 2D: [4] (central), [1,3,5,7] (cardinal), [0,2,6,8] (ordinal)
uniq = fastremap.unique(sign)
inds = [np.where(sign==i)[0] for i in uniq]
# weighting factor for each hypercube group (distance from central point)
fact = np.sqrt(uniq)
return inds, fact, sign
# [steps[:idx],steps[idx+1:]] can give the other steps
[docs]@functools.lru_cache(maxsize=None)
def kernel_setup(dim):
"""
Get relevant kernel information for the hypercube of interest.
Calls get_steps(), steps_to_indices().
Parameters
----------
dim: int
dimension (usually 2 or 3, but can be any positive integer)
Returns
-------
steps: ndarray, int
list of steps to each kernal point
see get_steps()
idx: int
index of the central point within the step list
this is always (3**dim)//2
inds: ndarray, int
list of kernel points sorted by type
see steps_to_indices()
fact: float
list of face/edge/vertex/... distances
see steps_to_indices()
sign: 1D array, int
signature distinguishing each kind of m-face via the number of steps
see steps_to_indices()
"""
steps = get_steps(dim)
inds, fact, sign = steps_to_indices(steps)
idx = inds[0][0] # the central point is always first
return steps,inds,idx,fact,sign
# not acutally used in the code, typically use steps_to_indices etc.
[docs]def cubestats(n):
"""
Gets the number of m-dimensional hypercubes connected to the n-cube, including itself.
Parameters
----------
n: int
dimension of hypercube
Returns
-------
List whose length tells us how many hypercube types there are (point/edge/pixel/voxel...)
connected to the central hypercube and whose entries denote many there in each group.
E.g., a square would be n=2, so cubestats returns [4, 4, 1] for four points (m=0),
four edges (m=1), and one face (the original square,m=n=2).
"""
faces = []
for m in range(n+1):
faces.append((2**(n-m))*math.comb(n,m))
return faces
[docs]def curve_filter(im,filterWidth=1.5):
"""
curveFilter : calculates the curvatures of an image.
INPUT :
im : image to be filtered
filterWidth : filter width
OUTPUT :
M_ : Mean curvature of the image without negative values
G_ : Gaussian curvature of the image without negative values
C1_ : Principal curvature 1 of the image without negative values
C2_ : Principal curvature 2 of the image without negative values
M : Mean curvature of the ima ge
G : Gaussian curvature of the image
C1 : Principal curvature 1 of the image
C2 : Principal curvature 2 of the image
im_xx : \del^2 x / \del x^2
im_yy : \del^2 x / \del y^2
im_xy : \del^2 x / \del x \del y
"""
shape = [np.floor(7*filterWidth) //2 *2 +1]*2 # minor modification is to make this odd
m,n = [(s-1.)/2. for s in shape]
y,x = np.ogrid[-m:m+1,-n:n+1]
v = filterWidth**2
gau = 1/(2*np.pi*v) * np.exp( -(x**2 + y**2) / (2.*v) )
f_xx = ((x/v)**2-1/v)*gau
f_yy = ((y/v)**2-1/v)*gau
f_xy = y*x*gau/v**2
im_xx = convolve(im, f_xx, mode='nearest')
im_yy = convolve(im, f_yy, mode='nearest')
im_xy = convolve(im, f_xy, mode='nearest')
# gaussian curvature
G = im_xx*im_yy-im_xy**2
# mean curvature
M = -(im_xx+im_yy)/2
# compute principal curvatures
C1 = (M-np.sqrt(np.abs(M**2-G)));
C2 = (M+np.sqrt(np.abs(M**2-G)));
# remove negative values
G_ = G.copy()
G_[G<0] = 0;
M_ = M.copy()
M_[M<0] = 0
C1_ = C1.copy()
C1_[C1<0] = 0
C2_ = C2.copy()
C2_[C2<0] = 0
return M_, G_, C1_, C2_, M, G, C1, C2, im_xx, im_yy, im_xy
[docs]def rotate(V,theta,order=1,output_shape=None,center=None):
dim = V.ndim
v1 = np.array([0]*(dim-1)+[1])
v2 = np.array([0]*(dim-2)+[1,0])
s_in = V.shape
if output_shape is None:
s_out = s_in
else:
s_out = output_shape
M = mgen.rotation_from_angle_and_plane(np.pi/2-theta,v2,v1)
if center is None:
c_in = 0.5 * np.array(s_in)
else:
c_in = center
c_out = 0.5 * np.array(s_out)
offset = c_in - np.dot(np.linalg.inv(M), c_out)
V_rot = affine_transform(V, np.linalg.inv(M), offset=offset,
order=order, output_shape=output_shape)
return V_rot
# make a list of all sprues
from sklearn.utils.extmath import cartesian
from scipy.ndimage import binary_hit_or_miss
import mahotas as mh
[docs]def get_spruepoints(bw):
d = bw.ndim
idx = (3**d)//2 # the index of the center pixel is placed here when considering the neighbor kernel
neigh = [[-1,0,1] for i in range(d)]
steps = cartesian(neigh) # all the possible step sequences in ND
sign = np.sum(np.abs(steps),axis=1) # signature distinguishing each kind of m-face via the number of steps
hits = np.zeros_like(bw)
mid = tuple([1]*d) # kernel 3 wide in every axis, so middle is 1
# alt
substeps = np.array(list(set([tuple(s) for s in steps])-set([(0,)*d]))) # remove zero shift element
# substeps = steps.copy()
for step in substeps:
oppose = np.array([np.dot(step,s) for s in steps])
sprue = np.zeros([3]*d,dtype=int) # allocate matrix
sprue[tuple(mid-step)] = 1
sprue[mid] = 1
miss = np.zeros([3]*d,dtype=int)
for idx in np.argwhere(np.logical_and(oppose>=0,sign!=0)).flatten():
c = tuple(steps[idx]+1)
miss[c] = 1
hitmiss = 2 - 2*miss - sprue
# mahotas hitmiss is far faster than ndimage
hm = mh.morph.hitmiss(bw,hitmiss)
hits = hits+hm
return hits>0
[docs]def localnormalize(im,sigma1=2,sigma2=20):
im = normalize99(im)
blur1 = gaussian_filter(im,sigma=sigma1)
num = im - blur1
blur2 = gaussian_filter(num*num, sigma=sigma2)
den = np.sqrt(blur2)
return normalize99(num/den+1e-8)
import torchvision.transforms.functional as TF
[docs]def localnormalize_GPU(im, sigma1=2, sigma2=20):
im = normalize99(im)
kernel_size1 = round(sigma1 * 6)
kernel_size1 += kernel_size1 % 2 == 0
blur1 = TF.gaussian_blur(im, kernel_size1, sigma1)
num = im - blur1
kernel_size2 = round(sigma2 * 6)
kernel_size2 += kernel_size2 % 2 == 0
blur2 = TF.gaussian_blur(num*num, kernel_size2, sigma2)
den = torch.sqrt(blur2)
return normalize99(num/den+1e-8)
# from https://stackoverflow.com/questions/47370718/indexing-numpy-array-by-a-numpy-array-of-coordinates
[docs]def ravel_index(b, shp):
return np.concatenate((np.asarray(shp[1:])[::-1].cumprod()[::-1],[1])).dot(b)
# https://stackoverflow.com/questions/31544129/extract-separate-non-zero-blocks-from-array
[docs]def find_nonzero_runs(a):
# Create an array that is 1 where a is nonzero, and pad each end with an extra 0.
isnonzero = np.concatenate(([0], (np.asarray(a) != 0).view(np.int8), [0]))
absdiff = np.abs(np.diff(isnonzero))
# Runs start and end where absdiff is 1.
ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
return ranges
# @njit
# def remap_pairs(pairs: set[tuple[int, int]], mapping: dict[int, int]) -> set[tuple[int, int]]:
# remapped_pairs = set()
# for x, y in pairs:
# remapped_x = mapping.get(x, x)
# remapped_y = mapping.get(y, y)
# remapped_pairs.add((remapped_x, remapped_y))
# return remapped_pairs
# from numba import jit
# @jit(nopython=True)
# def remap_pairs(pairs, mapping):
# remapped_pairs = set()
# for x, y in pairs:
# remapped_x = mapping.get(x, x)
# remapped_y = mapping.get(y, y)
# remapped_pairs.add((remapped_x, remapped_y))
# return remapped_pairs
from numba import njit
from numba import types
[docs]@njit
def remap_pairs(pairs, replacements):
remapped_pairs = set()
for x, y in pairs:
for a, b in replacements:
if x == a:
x = b
if y == a:
y = b
remapped_pairs.add((x, y))
return remapped_pairs
# need to comment out any njit code that I do not use...
[docs]@njit
def add_gaussian_noise(image, mean=0, var=0.01):
shape = image.shape
noise = np.random.normal(mean, var**0.5, shape)
noisy_image = image + noise
noisy_image = np.clip(noisy_image, 0, 1) # Clip values to [0, 1] range
return noisy_image
[docs]def add_poisson_noise(image):
noisy_image = np.random.poisson(image)
noisy_image = np.clip(noisy_image, 0, 1) # Clip values to [0, 1] range
return noisy_image
from scipy.ndimage import gaussian_filter, convolve
[docs]def thin_skeleton(image):
# DTS thinning algorithm
dimensions = len(image.shape)
neighbors = np.ones((3,) * dimensions, dtype=bool)
neighbors[tuple([1] * dimensions)] = False
while True:
marker = np.zeros_like(image)
# Convolve the image with the neighbors template
convolution = convolve(image, neighbors, mode='constant')
# Find the pixels where the convolution equals the number of neighbors
marker[np.where(convolution == np.sum(neighbors))] = 1
if np.sum(marker) == 0:
break
image = np.logical_and(image, np.logical_not(marker))
return image
[docs]def save_nested_list(file_path, nested_list):
"""Helper function to save affinity graphs."""
np.savez_compressed(file_path, *nested_list)
[docs]def load_nested_list(file_path):
"""Helper function to load affinity graphs."""
loaded_data = np.load(file_path,allow_pickle=True)
loaded_nested_list = []
for key in loaded_data.keys():
loaded_nested_list.append(loaded_data[key])
return loaded_nested_list
import torch.nn.functional as F
[docs]def hysteresis_threshold(image, low, high):
"""
Pytorch implementation of skimage.filters.apply_hysteresis_threshold().
Discprepencies occur for very high thresholds/thin objects.
"""
# Ensure the image is a torch tensor
if not isinstance(image, torch.Tensor):
image = torch.tensor(image)
# Create masks for values greater than low and high thresholds
mask_low = image > low
mask_high = image > high
# Initialize thresholded tensor
thresholded = mask_low.clone()
# Create hysteresis kernel
spatial_dims = len(image.shape) - 2
kernel_size = [3] * spatial_dims
hysteresis_kernel = torch.ones([1, 1] + kernel_size, device=image.device, dtype=image.dtype)
# Hysteresis thresholding
thresholded_old = torch.zeros_like(thresholded)
while (thresholded_old != thresholded).any():
if spatial_dims == 2:
hysteresis_magnitude = F.conv2d(thresholded.float(), hysteresis_kernel, padding=1)
elif spatial_dims == 3:
hysteresis_magnitude = F.conv3d(thresholded.float(), hysteresis_kernel, padding=1)
else:
raise ValueError(f'Unsupported number of spatial dimensions: {spatial_dims}')
# thresholded_old = thresholded.clone()
thresholded_old.copy_(thresholded)
thresholded = ((hysteresis_magnitude > 0) & mask_low) | mask_high
# sum_old = thresholded.sum()
# while True:
# if spatial_dims == 2:
# hysteresis_magnitude = F.conv2d(thresholded.float(), hysteresis_kernel, padding=1)
# elif spatial_dims == 3:
# hysteresis_magnitude = F.conv3d(thresholded.float(), hysteresis_kernel, padding=1)
# else:
# raise ValueError(f'Unsupported number of spatial dimensions: {spatial_dims}')
# thresholded = ((hysteresis_magnitude > 0) & mask_low) | mask_high
# sum_new = thresholded.sum()
# if sum_new == sum_old:
# break
# sum_old = sum_new
return thresholded.bool()#, mask_low, mask_high
[docs]def correct_illumination(img,sigma=5):
# Apply a Gaussian blur to the image
blurred = gaussian_filter(img, sigma=sigma)
# Normalize the image
return (img - blurred) / np.std(blurred)