0

Not sure if that is the correct terminology. Basically trying to take a black and white image and first transform it such that all the white pixels that border black-pixels remain white, else turn black. That part of the program works fine, and is done in find_edges. Next I need to calculate the distance from each element in the image to the closest white-pixel. Right now I am doing it by using a for-loop that is insanely slow. Is there a way to make the find_nearest_edge function written solely with numpy without the need for a for-loop to call it on each element? Thanks.

####

from PIL import Image
import numpy as np
from scipy.ndimage import binary_erosion

####

def find_nearest_edge(arr, point):
    w, h = arr.shape
    x, y = point
    xcoords, ycoords = np.meshgrid(np.arange(w), np.arange(h))

    target = np.sqrt((xcoords - x)**2 + (ycoords - y)**2)
    target[arr == 0] = np.inf

    shortest_distance = np.min(target[target > 0.0])

    return shortest_distance

def find_edges(img):
    img = img.convert('L')
    img_np = np.array(img)

    kernel = np.ones((3,3))
    edges = img_np - binary_erosion(img_np, kernel)*255

    return edges

a = Image.open('a.png')
x, y = a.size

edges = find_edges(a)

out = Image.fromarray(edges.astype('uint8'), 'L')
out.save('b.png')

dists =[]
for _x in range(x):
    for _y in range(y):
        dist = find_nearest_edge(edges,(_x,_y))
        dists.append(dist)

print(dists)

Images:

enter image description here

enter image description here

2 Answers 2

1

You can use KDTree to compute distances fast.

import numpy as np
import matplotlib.pyplot as plt

from scipy.ndimage import binary_erosion
from scipy.spatial import KDTree


def find_edges(img):
    img_np = np.array(img)

    kernel = np.ones((3,3))
    edges = img_np - binary_erosion(img_np, kernel)*255

    return edges


def find_closest_distance(img):
    # NOTE: assuming input is binary image and white is any non-zero value!
    white_pixel_points = np.array(np.where(img))
    tree = KDTree(white_pixel_points.T)
    img_meshgrid = np.array(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]))).T
    distances, _ = tree.query(img_meshgrid)
    return distances

test_image = np.zeros((200, 200))
rectangle = np.ones((30, 80))
test_image[20:50, 60:140] = rectangle
test_image[150:180, 60:140] = rectangle
test_image[60:140, 20:50] = rectangle.T
test_image[60:140, 150:180] = rectangle.T
test_image = test_image * 255
edge_image = find_edges(test_image)
distance_image = find_closest_distance(edge_image)


fig, axes = plt.subplots(1, 3, figsize=(12, 5))
axes[0].imshow(test_image, cmap='Greys_r')
axes[1].imshow(edge_image, cmap='Greys_r')
axes[2].imshow(distance_image, cmap='Greys_r')
plt.show()

enter image description here

Sign up to request clarification or add additional context in comments.

1 Comment

thanks this was just what I needed - took less than a minute to complete on my laptop.
0

You can make your code 25X faster by just changing find_nearest_edge as follows. Many other optimizations are possible, but this is the biggest bottleneck in your code.

from numba import njit
@njit
def find_nearest_edge(arr, point):
    x, y = point
    shortest_distance = np.inf
    for i in range(arr.shape[0]):
        for j in range(arr.shape[1]):
            if arr[i,j] == 0: continue
            shortest_distance = min(shortest_distance, (i-x)**2 + (j-y)**2)
    return np.sqrt(shortest_distance)

1 Comment

thanks for the response, but this still requires a for loop to iterate on each point, so is still prohibitively slow - the find_nearest_edge function itself was fast enuf, it was the iterating on each point and feeding it to find_nearest_edge that was slow.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.