from . import core, utils
from numba import njit
import matplotlib as mpl
import matplotlib.pyplot as plt
import types
import numpy as np
from matplotlib.backend_bases import GraphicsContextBase, RendererBase
from matplotlib.collections import LineCollection
from matplotlib.patches import Rectangle
from matplotlib.transforms import Bbox
from matplotlib.path import Path
class GC(GraphicsContextBase):
def __init__(self):
super().__init__()
self._capstyle = 'round'
[docs]def custom_new_gc(self):
return GC()
[docs]def plot_edges(shape,affinity_graph,neighbors,coords,
figsize=1,fig=None,ax=None, extent=None, slc=None, pic=None,
edgecol=[.75]*3+[.5],linewidth=0.15,step_inds=None,
cmap='inferno',origin='lower',bounds=None):
nstep,npix = affinity_graph.shape
coords = tuple(coords)
indexes, neigh_inds, ind_matrix = utils.get_neigh_inds(tuple(neighbors),coords,shape)
if step_inds is None:
step_inds = np.arange(nstep)
px_inds = np.arange(npix)
edge_list = core.affinity_to_edges(affinity_graph.astype(bool),
neigh_inds,
step_inds,
px_inds)
aff_coords = np.array(coords).T
segments = np.stack([[aff_coords[:,::-1][e]+0.5 for e in edge] for edge in edge_list])
# segments = np.stack([[aff_coords[e]+0.5 for e in edge] for edge in edge_list])
RendererBase.new_gc = types.MethodType(custom_new_gc, RendererBase)
newfig = fig is None and ax is None
if newfig:
if type(figsize) is not (list or tuple):
figsize = (figsize,figsize)
fig, ax = plt.subplots(figsize=figsize)
# ax.invert_yaxis()
if extent is None:
extent = np.array([0,shape[1],0,shape[0]])
nopic = pic is None
if nopic:
summed_affinity = np.zeros(shape,dtype=int)
summed_affinity[coords] = np.sum(affinity_graph,axis=0)
# print(np.unique(summed_affinity))
# c = sinebow(8)
# colors = np.array(list(c.values()))
# affinity_cmap = mpl.colors.ListedColormap(colors)
# colors = mpl.colormaps.get_cmap(cmap).reversed()(np.linspace(-1,1,8))
colors = mpl.colormaps.get_cmap(cmap).reversed()(np.linspace(0,1,9))
# colors = mpl.colormaps.get_cmap(cmap)(np.linspace(0,1,8))
colors = np.vstack((np.array([0]*4),colors))
affinity_cmap = mpl.colors.ListedColormap(colors)
pic = affinity_cmap(summed_affinity)
# # Generate random values between 0.5 and 1
# random_values = np.random.uniform(.75, 1, size=(len(segments),))
# # Multiply base_color by random values
# colors = edgecol * random_values[:, np.newaxis]
colors = edgecol
ax.imshow(pic[slc] if slc is not None else pic, extent=extent,origin=origin)
line_segments = LineCollection(segments, color=colors,linewidths=linewidth)
# if bounds is None:
# line_segments = LineCollection(segments, color=colors,linewidths=linewidth)
# # if bounds is not None:
# # clip_rect = Rectangle((bounds[0], bounds[1]), bounds[2], bounds[3])
# # clip_rect.set_transform(ax.transData)
# # line_segments.set_clip_path(clip_rect)
# else:
# # Create a bounding box that defines the extent
# bbox = Bbox.from_extents(bounds[0], bounds[1], bounds[0]+bounds[2], bounds[1]+bounds[3])
# # Create a path for each line segment and clip it to the bounding box
# clipped_segments = [Path(seg).clip_to_bbox(bbox).to_polygons() for seg in segments]
# # Create a line collection with the clipped segments
# line_segments = LineCollection(clipped_segments)
ax.add_collection(line_segments)
if newfig:
plt.axis('off')
ax.invert_yaxis()
plt.show()
if nopic:
return summed_affinity, affinity_cmap
[docs]def sinebow(N,bg_color=[0,0,0,0], offset=0):
""" Generate a color dictionary for use in visualizing N-colored labels. Background color
defaults to transparent black.
Parameters
----------
N: int
number of distinct colors to generate (excluding background)
bg_color: ndarray, list, or tuple of length 4
RGBA values specifying the background color at the front of the dictionary.
Returns
--------------
Dictionary with entries {int:RGBA array} to map integer labels to RGBA colors.
"""
colordict = {0:bg_color}
for j in range(N):
k = j+offset
angle = k*2*np.pi / (N)
r = ((np.cos(angle)+1)/2)
g = ((np.cos(angle+2*np.pi/3)+1)/2)
b = ((np.cos(angle+4*np.pi/3)+1)/2)
colordict.update({j+1:[r,g,b,1]})
return colordict
# @njit
# def colorize(im,colors=None,color_weights=None,offset=0):
# N = len(im)
# if colors is None:
# angle = np.arange(0,1,1/N)*2*np.pi+offset
# angles = np.stack((angle,angle+2*np.pi/3,angle+4*np.pi/3),axis=-1)
# colors = (np.cos(angles)+1)/2
# if color_weights is not None:
# colors *= color_weights
# rgb = np.zeros((im.shape[1], im.shape[2], 3))
# for i in range(N):
# for j in range(3):
# rgb[..., j] += im[i] * colors[i, j]
# rgb /= N
# return rgb
# @njit
[docs]def colorize(im, colors=None, color_weights=None, offset=0, channel_axis=-1):
N = len(im)
if colors is None:
angle = np.arange(0, 1, 1/N) * 2 * np.pi + offset
angles = np.stack((angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3), axis=-1)
colors = (np.cos(angles) + 1) / 2
if color_weights is not None:
colors *= np.expand_dims(color_weights,-1)
rgb_shape = im.shape[1:] + (colors.shape[1],)
if channel_axis == 0:
rgb_shape = rgb_shape[::-1]
rgb = np.zeros(rgb_shape)
# Use broadcasting to multiply im and colors and sum along the 0th dimension
rgb = (np.expand_dims(im, axis=-1) * colors.reshape(colors.shape[0], 1, 1, colors.shape[1])).mean(axis=0)
return rgb
import torch
[docs]def colorize_GPU(im, colors=None, color_weights=None, offset=0,channel_axis=-1):
N = im.shape[0]
device = im.device
if colors is None:
angle = torch.linspace(0, 1, N, device=device) * 2 * np.pi + offset
angles = torch.stack((angle, angle + 2 * np.pi / 3, angle + 4 * np.pi / 3), dim=-1)
colors = (torch.cos(angles) + 1) / 2
if color_weights is not None:
colors *= color_weights.unsqueeze(-1)
rgb_shape = im.shape[1:]+(colors.shape[1],)
if channel_axis == 0:
rgb_shape = tuple(rgb_shape[::-1])
rgb = torch.zeros(rgb_shape, device=device)
# Use broadcasting to multiply im and colors and sum along the 0th dimension
rgb = (im.unsqueeze(-1) * colors.view(colors.shape[0], 1, 1, colors.shape[1])).mean(dim=0)
return rgb
import ncolor
[docs]def apply_ncolor(masks,offset=0,cmap=None,max_depth=20,expand=True):
m,n = ncolor.label(masks,
max_depth=max_depth,
return_n=True,
conn=2,
expand=expand)
if cmap is None:
c = sinebow(n,offset=offset)
colors = np.array(list(c.values()))
cmap = mpl.colors.ListedColormap(colors)
return cmap(m)
else:
return cmap(utils.rescale(m))
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.pyplot as plt
[docs]def imshow(imgs, figsize=2, ax=None, hold=False, titles=None, title_size=None, spacing=0.05,
textcolor=[0.5]*3, dpi=300, **kwargs):
text_scale = 10
if isinstance(imgs, list):
if titles is None:
titles = [None] * len(imgs)
if title_size is None:
title_size = figsize / len(imgs) * text_scale
fig = plt.figure(figsize=(figsize * len(imgs), figsize),frameon=False, facecolor = [0]*4)
grid = ImageGrid(fig, 111, nrows_ncols=(1, len(imgs)), axes_pad=spacing, share_all=True)
for ax, img, title in zip(grid, imgs, titles):
ax.imshow(img, **kwargs)
ax.axis("off")
ax.set_frame_on(False)
ax.set_facecolor([0]*4)
if title is not None:
ax.set_title(title, fontsize=title_size,color=textcolor)
else:
if type(figsize) is not (list or tuple):
figsize = (figsize, figsize)
if title_size is None:
title_size = figsize[0] * text_scale
if ax is None:
fig, ax = plt.subplots(frameon=False, figsize=figsize, facecolor =[0]*4,dpi=dpi)
else:
hold = True
ax.imshow(imgs, **kwargs)
ax.axis("off")
if titles is not None:
ax.set_title(titles, fontsize=title_size, color=textcolor)
if not hold:
plt.show()
# def get_cmap(masks):
# lut = ncolor.get_lut(masks)
# c = sinebow(lut.max())
# colors = [c[l] for l in lut]
# cmap = mpl.colors.ListedColormap(colors)
# return cmap
# @njit()
# def rgb_flow(dP,transparency=False,mask=None,norm=False):
# """ dP is 2 x Y x X => 'optic' flow representation
# Parameters
# -------------
# dP: NDarray, float
# Flow field component stack [B,dy,dx]
# transparency: bool, default False
# magnitude of flow controls opacity, not lightness (clear background)
# mask: 2D array
# Multiplies each RGB component to suppress noise
# """
# mag = np.sqrt(np.sum(dP**2,axis=1))
# if norm:
# mag = np.clip(utils.normalize99(mag), 0, 1.).astype(np.float32)
# angles = np.arctan2(dP[:,1], dP[:,0])+np.pi
# a = 2
# r = ((np.cos(angles)+1)/a)
# g = ((np.cos(angles+2*np.pi/3)+1)/a)
# b = ((np.cos(angles+4*np.pi/3)+1)/a)
# if transparency:
# im = np.stack((r,g,b,mag),axis=-1)
# else:
# im = np.stack((r*mag,g*mag,b*mag),axis=-1)
# if mask is not None and transparency and dP.shape[0]<3:
# im[...,-1] *= mask
# im = (np.clip(im, 0, 1) * 255).astype(np.uint8)
# return im
# from numba import jit
# @jit(nopython=True)
# @njit()
# def rgb_flow(dP, transparency=True, mask=None, norm=True):
# mag = np.sqrt(np.sum(dP**2,axis=1)).reshape(1, -1)
# vecs = dP[:,0] + dP[:,1]*1j
# roots = np.exp(1j * np.pi * (2 * np.arange(3) / 3 +1))
# rgb = (np.real(roots * vecs.reshape(-1, 1) / np.max(mag)).T + 1 ) / 2
# if norm:
# # mag = np.clip(utils.normalize99(mag), 0, 1.).astype(np.float32)
# mag -= np.min(mag)
# mag /= np.max(mag)
# shape = dP.shape
# newshape = (shape[0], shape[3], shape[2], 3+transparency)
# # newshape = (shape[0], shape[2], shape[3], 3+transparency)
# if transparency:
# im = np.concatenate((rgb, mag), axis=0)
# else:
# im = rgb * mag
# im = (np.clip(im.T.reshape(newshape), 0, 1) * 255).astype(np.uint8)
# # im = np.swapaxes(im,1,2)
# return im
# @njit()
# def rgb_flow(dP, transparency=True, mask=None, norm=True):
# mag = np.sqrt(np.sum(dP**2,axis=1))
# vecs = dP[:,0] + dP[:,1]*1j
# roots = np.exp(1j * np.pi * (2 * np.arange(3) / 3 +1)).reshape((1, 1, 1, -1))
# rgb = (np.real(vecs[...,None]*roots / np.max(mag)) + 1 ) / 2
# if norm:
# mag -= np.min(mag)
# mag /= np.max(mag)
# shape = dP.shape
# newshape = (shape[0], shape[2], shape[3], 3+transparency)
# print(rgb.shape,newshape, mag.shape, vecs.shape)
# if transparency:
# im = np.empty(newshape)
# im[..., :3] = rgb
# im[..., 3] = mag
# else:
# im = rgb * mag
# im = (np.clip(im, 0, 1) * 255).astype(np.uint8)
# return im
import torch
[docs]def rgb_flow(dP, transparency=True, mask=None, norm=True, device=torch.device('cpu')):
"""Meant for stacks of dP, unsqueeze if using on a single plane."""
if isinstance(dP,torch.Tensor):
device = dP.device
else:
dP = torch.from_numpy(dP).to(device)
mag = utils.torch_norm(dP,dim=1)
vecs = dP[:,0] + dP[:,1]*1j
roots = torch.exp(1j * np.pi * (2 * torch.arange(3, device=device) / 3 +1))
rgb = (torch.real(vecs.unsqueeze(-1)*roots.view(1, 1, 1, -1) / torch.max(mag)) + 1 ) / 2
# f = 1.5
# rgb /= f
# rgb += (1-1/f)/2
if norm:
mag -= torch.min(mag)
mag /= torch.max(mag)
if transparency:
im = torch.cat((rgb, mag[..., None]), dim=-1)
else:
im = rgb * mag[..., None]
im = (torch.clamp(im, 0, 1) * 255).type(torch.uint8)
return im
from skimage import img_as_ubyte
from skimage import color
from skimage.segmentation import find_boundaries
[docs]def create_colormap(image, labels):
"""
Create a colormap based on the average color of each label in the image.
Parameters
----------
image: ndarray
An RGB image.
labels: ndarray
A 2D array of labels corresponding to the image.
Returns
-------
colormap: ndarray
A colormap where each row is the RGB color for the corresponding label.
"""
# Ensure the image is in the range 0-255
image = img_as_ubyte(image)
# Initialize an array to hold the RGB color for each label
colormap = np.zeros((labels.max() + 1, 3), dtype=np.uint8)
# Calculate the average color for each label
for label in np.unique(labels):
mask = labels == label
colormap[label] = image[mask].mean(axis=0)
return colormap
[docs]def color_from_RGB(im,rgb,m,bd=None, mode='inner',connectivity=2):
if bd is None:
bd = find_boundaries(m,mode=mode,connectivity=connectivity)
alpha = (m>0)*.5
alpha[bd] = 1
alpha = np.stack([alpha]*3,axis=-1)
m = ncolor.format_labels(m)
cmap = create_colormap(rgb,m)
clrs = utils.rescale(cmap[1:])
overlay = color.label2rgb(m,im,clrs,
bg_label=0,
alpha=alpha
# saturation=1,
# kind='overlay',
# alpha=1
)
return overlay
[docs]def image_grid(images, column_titles=None, row_titles=None,
xticks=[], yticks=[],
outline=False, outline_color=[0.5]*3,
padding=0.05,
fontsize=10, fontcolor=[0.5]*3,
fig_scale=6, dpi=300,
order='ij',
**kwargs):
"""Display a grid of images with uniform spacing."""
# get the dimensions of the grid
grid_dims = [len(images), len(images[0])]
ij = order=='ij'
ji = order=='ji'
n,m = grid_dims
# Get the shapes of the images
image_shapes = np.stack([i[0].shape for i in images])
# Padding between images
p = padding
# normalize dimension along row or column
a = list(image_shapes[:,0] / image_shapes[:,1]) if ij else list(image_shapes[:,1] / image_shapes[:,0])
b = np.ones_like(a)
# Cumulative dimension
ca = np.cumsum(a)
start_a = np.array([[0]*m]+[[(ca[i]+(i+1)*p)]*m for i in range(n-1)]).flatten().astype(float)
start_b = np.array([[(bi+p)*i for i in range(m)] for bi in b]).flatten().astype(float)
# Calculate the positions and sizes of the images in the grid
da = np.array([[ai]*m for ai in a]).flatten().astype(float)
db = np.array([[bi]*m for bi in b]).flatten().astype(float)
# Map the variables to their values
variables = {'ji': (start_a, start_b, da, db), 'ij': (start_b, start_a, db, da)}
# Assign the values to the variables
left, bottom, width, height = variables[order]
# Normalize the positions and sizes
max_w = left[-1]+width[-1]
max_h = bottom[-1]+height[-1]
left /= max_w
bottom /= max_h
width /= max_w
height /= max_h
# Create the figure
fig = plt.figure(figsize=(fig_scale,fig_scale*max_h/max_w), frameon=False, dpi=dpi)
# here m and n need to represent the actual grid layout rather than indexing
if ij:
n,m = grid_dims
elif ji:
m,n = grid_dims
else:
raise ValueError('order must be "ij" or "ji"')
# Add the subplots
axes = []
for i in range(n*m):
# ax = fig.add_axes([left[i], bottom[i], width[i], height[i]])
ax = fig.add_axes([left[i], 1-bottom[i]-height[i], width[i], height[i]])
axes.append(ax)
# add outline around each image, remove ticks
for i,ax in enumerate(axes):
if outline:
for s in ax.spines.values():
s.set_color(outline_color)
s.set_linewidth(1)
else:
ax.axis('off')
ax.set_xticks(xticks)
ax.set_yticks(yticks)
ax.patch.set_alpha(0)
# Display the image
j,k = np.unravel_index(i,grid_dims)
ax.imshow(images[j][k],**kwargs)
# Set the column titles
if column_titles is not None:
if ij and i < m:
idx = i
elif ji and i % n == 0:
idx = i // n
else:
idx = None
if idx is not None:
# ax.set_title(column_titles[idx], fontsize=fontsize, c=fontcolor)
ax.text(0.5, 1+p, column_titles[idx], rotation=0, fontsize=fontsize, color=fontcolor,
va='bottom', ha='center', transform=ax.transAxes)
# Set the row titles
if row_titles is not None:
if ij and i % m == 0:
idx = i // m
elif ji and i < n:
idx = i
else:
idx = None
if idx is not None:
ax.text(-p, 0.5, row_titles[idx], rotation=0, fontsize=fontsize, color=fontcolor,
va='center', ha='right', transform=ax.transAxes)
return fig
# from https://stackoverflow.com/a/63530703/13326811
import numpy as np
from matplotlib.collections import LineCollection as lc
from mpl_toolkits.mplot3d.art3d import Line3DCollection as lc3d
from scipy.interpolate import interp1d
from matplotlib.colors import colorConverter
[docs]def colored_line_segments(xs,ys,zs=None,color='k',mid_colors=False):
if isinstance(color,str):
color = colorConverter.to_rgba(color)[:-1]
color = np.array([color for i in range(len(xs))])
segs = []
seg_colors = []
lastColor = [color[0][0],color[0][1],color[0][2]]
start = [xs[0],ys[0]]
end = [xs[0],ys[0]]
if not zs is None:
start.append(zs[0])
end.append(zs[0])
else:
zs = [zs]*len(xs)
for x,y,z,c in zip(xs,ys,zs,color):
if mid_colors:
seg_colors.append([(chan+lastChan)*.5 for chan,lastChan in zip(c,lastColor)])
else:
seg_colors.append(c)
lastColor = c[:-1]
if not z is None:
start = [end[0],end[1],end[2]]
end = [x,y,z]
else:
start = [end[0],end[1]]
end = [x,y]
segs.append([start,end])
colors = [(*color,1) for color in seg_colors]
return segs, colors
[docs]def segmented_resample(xs,ys,zs=None,color='k',n_resample=100,mid_colors=False):
n_points = len(xs)
if isinstance(color,str):
color = colorConverter.to_rgba(color)[:-1]
color = np.array([color for i in range(n_points)])
n_segs = (n_points-1)*(n_resample-1)
xsInterp = np.linspace(0,1,n_resample)
segs = []
seg_colors = []
hiResXs = [xs[0]]
hiResYs = [ys[0]]
if not zs is None:
hiResZs = [zs[0]]
RGB = color.swapaxes(0,1)
for i in range(n_points-1):
fit_xHiRes = interp1d([0,1],xs[i:i+2])
fit_yHiRes = interp1d([0,1],ys[i:i+2])
xHiRes = fit_xHiRes(xsInterp)
yHiRes = fit_yHiRes(xsInterp)
hiResXs = hiResXs+list(xHiRes[1:])
hiResYs = hiResYs+list(yHiRes[1:])
R_HiRes = interp1d([0,1],RGB[0][i:i+2])(xsInterp)
G_HiRes = interp1d([0,1],RGB[1][i:i+2])(xsInterp)
B_HiRes = interp1d([0,1],RGB[2][i:i+2])(xsInterp)
lastColor = [R_HiRes[0],G_HiRes[0],B_HiRes[0]]
start = [xHiRes[0],yHiRes[0]]
end = [xHiRes[0],yHiRes[0]]
if not zs is None:
fit_zHiRes = interp1d([0,1],zs[i:i+2])
zHiRes = fit_zHiRes(xsInterp)
hiResZs = hiResZs+list(zHiRes[1:])
start.append(zHiRes[0])
end.append(zHiRes[0])
else:
zHiRes = [zs]*len(xHiRes)
if mid_colors: seg_colors.append([R_HiRes[0],G_HiRes[0],B_HiRes[0]])
for x,y,z,r,g,b in zip(xHiRes[1:],yHiRes[1:],zHiRes[1:],R_HiRes[1:],G_HiRes[1:],B_HiRes[1:]):
if mid_colors:
seg_colors.append([(chan+lastChan)*.5 for chan,lastChan in zip((r,g,b),lastColor)])
else:
seg_colors.append([r,g,b])
lastColor = [r,g,b]
if not z is None:
start = [end[0],end[1],end[2]]
end = [x,y,z]
else:
start = [end[0],end[1]]
end = [x,y]
segs.append([start,end])
colors = [(*color,1) for color in seg_colors]
data = [hiResXs,hiResYs]
if not zs is None:
data = [hiResXs,hiResYs,hiResZs]
return segs, colors, data
[docs]def faded_segment_resample(xs,ys,zs=None,color='k',fade_len=20,n_resample=100,direction='Head'):
segs, colors, hiResData = segmented_resample(xs,ys,zs,color,n_resample)
n_segs = len(segs)
if fade_len>len(segs):
fade_len=n_segs
if direction=='Head':
#Head fade
alphas = np.concatenate((np.zeros(n_segs-fade_len),np.linspace(0,1,fade_len)))
else:
#Tail fade
alphas = np.concatenate((np.linspace(1,0,fade_len),np.zeros(n_segs-fade_len)))
colors = [(*color[:-1],alpha) for color,alpha in zip(colors,alphas)]
return segs, colors, hiResData
# https://stackoverflow.com/a/27537018/13326811
def _get_perp_line(current_seg, out_of_page, linewidth):
perp = np.cross(current_seg, out_of_page)[0:2]
perp_unit = _get_unit_vector(perp)
current_seg_perp_line = perp_unit*linewidth
return current_seg_perp_line
def _get_unit_vector(vector):
vector_size = (vector[0]**2 + vector[1]**2)**0.5
vector_unit = vector / vector_size
return vector_unit[0:2]
[docs]def colored_line(x, y, ax, z=None, line_width=1, MAP='jet'):
# use pcolormesh to make interpolated rectangles
num_pts = len(x)
[xs, ys, zs] = [
np.zeros((num_pts,2)),
np.zeros((num_pts,2)),
np.zeros((num_pts,2))
]
dist = 0
out_of_page = [0, 0, 1]
for i in range(num_pts):
# set the colors and the x,y locations of the source line
xs[i][0] = x[i]
ys[i][0] = y[i]
if i > 0:
x_delta = x[i] - x[i-1]
y_delta = y[i] - y[i-1]
seg_length = (x_delta**2 + y_delta**2)**0.5
dist += seg_length
zs[i] = [dist, dist]
# define the offset perpendicular points
if i == num_pts - 1:
current_seg = [x[i]-x[i-1], y[i]-y[i-1], 0]
else:
current_seg = [x[i+1]-x[i], y[i+1]-y[i], 0]
current_seg_perp = _get_perp_line(
current_seg, out_of_page, line_width)
if i == 0 or i == num_pts - 1:
xs[i][1] = xs[i][0] + current_seg_perp[0]
ys[i][1] = ys[i][0] + current_seg_perp[1]
continue
current_pt = [x[i], y[i]]
current_seg_unit = _get_unit_vector(current_seg)
previous_seg = [x[i]-x[i-1], y[i]-y[i-1], 0]
previous_seg_perp = _get_perp_line(
previous_seg, out_of_page, line_width)
previous_seg_unit = _get_unit_vector(previous_seg)
# current_pt + previous_seg_perp + scalar * previous_seg_unit =
# current_pt + current_seg_perp - scalar * current_seg_unit =
scalar = (
(current_seg_perp - previous_seg_perp) /
(previous_seg_unit + current_seg_unit)
)
new_pt = current_pt + previous_seg_perp + scalar[0] * previous_seg_unit
xs[i][1] = new_pt[0]
ys[i][1] = new_pt[1]
# fig, ax = plt.subplots()
cm = plt.get_cmap(MAP)
ax.pcolormesh(xs, ys, zs, shading='gouraud', cmap=cm)
[docs]def plot_color_swatches(colors,figsize=0.5,dpi=100):
# Convert colors to a numpy array
colors = np.array(colors)
# If colors is a 1D array (single color), reshape it to a 2D array
if colors.ndim == 1:
colors = colors.reshape(1, -1)
# Create a list of swatches
swatches = [np.full((1, 1, 3), color, dtype=np.float32) for color in colors]
# Display the swatches
imshow(swatches, figsize=figsize,dpi=dpi)