#!/usr/bin/env python
# -*- coding: utf-8 -*-
# *************************************************************************** #
# Copyright © 2022, UChicago Argonne, LLC #
# All Rights Reserved #
# Software Name: Tomocupy #
# By: Argonne National Laboratory #
# #
# OPEN SOURCE LICENSE #
# #
# Redistribution and use in source and binary forms, with or without #
# modification, are permitted provided that the following conditions are met: #
# #
# 1. Redistributions of source code must retain the above copyright notice, #
# this list of conditions and the following disclaimer. #
# 2. Redistributions in binary form must reproduce the above copyright #
# notice, this list of conditions and the following disclaimer in the #
# documentation and/or other materials provided with the distribution. #
# 3. Neither the name of the copyright holder nor the names of its #
# contributors may be used to endorse or promote products derived #
# from this software without specific prior written permission. #
# #
# #
# *************************************************************************** #
# DISCLAIMER #
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS #
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT #
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS #
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT #
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, #
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED #
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR #
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF #
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING #
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS #
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
# *************************************************************************** #
from tomocupy import utils
from tomocupy import logging
from tomocupy.processing import proc_functions
from tomocupy.global_vars import args, params
from ast import literal_eval
from queue import Queue
import cupyx.scipy.ndimage as ndimage
import cupy as cp
import numpy as np
import signal
import cv2
__author__ = "Viktor Nikitin"
__copyright__ = "Copyright (c) 2022, UChicago Argonne, LLC."
__docformat__ = 'restructuredtext en'
__all__ = ['FindCenter', ]
log = logging.getLogger(__name__)
[docs]class FindCenter():
'''
Find rotation axis by comapring 0 and 180 degrees projection with using SIFT
'''
def __init__(self, cl_reader):
# Set ^C interrupt to abort and deallocate memory on GPU
signal.signal(signal.SIGINT, utils.signal_handler)
signal.signal(signal.SIGTERM, utils.signal_handler)
# init tomo functions
self.cl_proc_func = proc_functions.ProcFunctions()
# additional refs
self.cl_reader = cl_reader
[docs] def find_center(self):
if args.rotation_axis_method == 'sift':
center = self.find_center_sift()
elif args.rotation_axis_method == 'vo':
center = self.find_center_vo()
return center*2**args.binning
def find_center_sift(self):
pairs = literal_eval(args.rotation_axis_pairs)
flat, dark = self.cl_reader.read_flat_dark(
params.st_n, params.end_n)
if pairs[0] == pairs[1]:
pairs[0] = 0
pairs[1] = params.nproj-1
data = self.cl_reader.read_pairs(
pairs, args.start_row, args.end_row, params.st_n, params.end_n)
data = cp.array(data)
flat = cp.array(flat)
dark = cp.array(dark)
data = self.cl_proc_func.darkflat_correction(data, dark, flat)
data = self.cl_proc_func.minus_log(data)
data = data.get()
shifts, nmatches = _register_shift_sift(
data[::2], data[1::2, :, ::-1], args.rotation_axis_sift_threshold)
centers = params.n//2-shifts[:, 1]/2+params.st_n
log.info(f'Number of matched features {nmatches}')
log.info(
f'Found centers for projection pairs {centers}, mean: {np.mean(centers)}')
log.info(
f'Vertical misalignment {shifts[:, 0]}, mean: {np.mean(shifts[:, 0])}')
return np.mean(centers)
[docs] def read_data_try(self, data_queue, id_slice):
st_z = id_slice
end_z = id_slice + 2**args.binning
self.cl_reader.read_data_chunk_to_queue(
data_queue, params.ids_proj, st_z, end_z,
params.st_n, params.end_n, 0, params.in_dtype)
[docs] def find_center_sift(self):
from ast import literal_eval
pairs = literal_eval(args.rotation_axis_pairs)
flat, dark = self.cl_reader.read_flat_dark(
params.st_n, params.end_n)
st_row = args.find_center_start_row
end_row = args.find_center_end_row
if end_row == -1:
end_row = args.end_row
flat = flat[:, st_row:end_row]
dark = dark[:, st_row:end_row]
if pairs[0] == pairs[1]:
pairs[0] = 0
pairs[1] = params.nproj-1
data = self.cl_reader.read_pairs(
pairs, st_row, end_row, params.st_n, params.end_n)
data = cp.array(data)
flat = cp.array(flat)
dark = cp.array(dark)
data = self.cl_proc_func.darkflat_correction(data, dark, flat)
data = self.cl_proc_func.minus_log(data)
data = data.get()
shifts, nmatches = _register_shift_sift(
data[::2], data[1::2, :, ::-1], args.rotation_axis_sift_threshold)
centers = params.n//2-shifts[:, 1]/2+params.st_n
log.info(f'Number of matched features {nmatches}')
log.info(
f'Found centers for projection pairs {centers}, mean: {np.mean(centers)}')
log.info(
f'Vertical misalignment {shifts[:, 0]}, mean: {np.mean(shifts[:, 0])}')
return np.mean(centers)*2**args.binning
[docs] def find_center_vo(self, ind=None, smin=-50, smax=50, srad=6, step=0.25, ratio=0.5, drop=20):
"""
Find rotation axis location using Nghia Vo's method. :cite:`Vo:14`.
Parameters
----------
ind : int, optional
Index of the slice to be used for reconstruction.
smin, smax : int, optional
Coarse search radius. Reference to the horizontal center of
the sinogram.
srad : float, optional
Fine search radius.
step : float, optional
Step of fine searching.
ratio : float, optional
The ratio between the FOV of the camera and the size of object.
It's used to generate the mask.
drop : int, optional
Drop lines around vertical center of the mask.
Returns
-------
float
Rotation axis location.
"""
# defaults
srad = 6 # Fine search radius.
# The ratio between the FOV of the camera and the size of object. It's used to generate the mask.
ratio = 0.5
drop = 20 # Drop lines around vertical center of the mask.
step = args.center_search_step
smin = -args.center_search_width
smax = args.center_search_width
data_queue = Queue(1)
self.read_data_try(data_queue, params.id_slices[0])
item = data_queue.get()
# copy to gpu
data = cp.array(item['data'])
dark = cp.array(item['dark'])
flat = cp.array(item['flat'])
data = cp.array(data)
flat = cp.array(flat)
dark = cp.array(dark)
data = self.cl_proc_func.darkflat_correction(data, dark, flat)
data = self.cl_proc_func.minus_log(data)
_tomo = data.swapaxes(0, 1)[0]
# Denoising
# There's a critical reason to use different window sizes
# between coarse and fine search.
_tomo_cs = ndimage.gaussian_filter(_tomo, (3, 1), mode='reflect')
_tomo_fs = ndimage.gaussian_filter(_tomo, (2, 2), mode='reflect')
# Coarse and fine searches for finding the rotation center.
init_cen = _search_coarse(_tomo_cs, smin, smax, ratio, drop)
fine_cen = _search_fine(_tomo_fs, srad, step,
init_cen, ratio, drop)
log.debug('Rotation center search finished: %i', fine_cen)
return fine_cen
def _find_min_max(data):
"""Find min and max values according to histogram"""
mmin = np.zeros(data.shape[0], dtype='float32')
mmax = np.zeros(data.shape[0], dtype='float32')
for k in range(data.shape[0]):
h, e = np.histogram(data[k][:], 1000)
stend = np.where(h > np.max(h)*0.005)
st = stend[0][0]
end = stend[0][-1]
mmin[k] = e[st]
mmax[k] = e[end+1]
return mmin, mmax
def _register_shift_sift(datap1, datap2, th=0.5):
"""Find shifts via SIFT detecting features"""
mmin, mmax = _find_min_max(datap1)
sift = cv2.SIFT_create()
shifts = np.zeros([datap1.shape[0], 2], dtype='float32')
for id in range(datap1.shape[0]):
tmp1 = ((datap2[id]-mmin[id]) /
(mmax[id]-mmin[id])*255)
tmp1[tmp1 > 255] = 255
tmp1[tmp1 < 0] = 0
tmp2 = ((datap1[id]-mmin[id]) /
(mmax[id]-mmin[id])*255)
tmp2[tmp2 > 255] = 255
tmp2[tmp2 < 0] = 0
# find key points
tmp1 = tmp1.astype('uint8')
tmp2 = tmp2.astype('uint8')
kp1, des1 = sift.detectAndCompute(tmp1, None)
kp2, des2 = sift.detectAndCompute(tmp2, None)
# cv2.imwrite('/data/Fister_rec/original_image_right_keypoints.png',cv2.drawKeypoints(tmp1,kp1,None))
# cv2.imwrite('/data/Fister_rec/original_image_left_keypoints.png',cv2.drawKeypoints(tmp2,kp2,None))
match = cv2.BFMatcher()
matches = match.knnMatch(des1, des2, k=2)
good = []
for m, n in matches:
if m.distance < th*n.distance:
good.append(m)
if len(good) == 0:
print('no features found')
exit()
draw_params = dict(matchColor=(0, 255, 0),
singlePointColor=None,
flags=2)
tmp3 = cv2.drawMatches(tmp1, kp1, tmp2, kp2,
good, None, **draw_params)
# cv2.imwrite("/data/Fister_rec/original_image_drawMatches.jpg", tmp3)
src_pts = np.float32(
[kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
dst_pts = np.float32(
[kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
shift = (src_pts-dst_pts)[:, 0, :]
shifts[id] = np.mean(shift, axis=0)[::-1]
return shifts, len(good)
def _calculate_metric(shift_col, sino1, sino2, sino3, mask):
"""
Metric calculation.
"""
shift_col = 1.0 * np.squeeze(shift_col)
if np.abs(shift_col - np.floor(shift_col)) == 0.0:
shift_col = int(shift_col)
sino_shift = cp.roll(sino2, shift_col, axis=1)
if shift_col >= 0:
sino_shift[:, :shift_col] = sino3[:, :shift_col]
else:
sino_shift[:, shift_col:] = sino3[:, shift_col:]
mat = cp.vstack((sino1, sino_shift))
else:
sino_shift = ndimage.shift(
sino2, (0, shift_col), order=3, prefilter=True)
if shift_col >= 0:
shift_int = int(np.ceil(shift_col))
sino_shift[:, :shift_int] = sino3[:, :shift_int]
else:
shift_int = int(np.floor(shift_col))
sino_shift[:, shift_int:] = sino3[:, shift_int:]
mat = cp.vstack((sino1, sino_shift))
metric = cp.mean(
cp.abs(cp.fft.fftshift(cp.fft.fft2(mat))) * mask)
return metric
def _search_coarse(sino, smin, smax, ratio, drop):
"""
Coarse search for finding the rotation center.
"""
(nrow, ncol) = sino.shape
cen_fliplr = (ncol - 1.0) / 2.0
smin = np.int16(np.clip(smin + cen_fliplr, 0, ncol - 1) - cen_fliplr)
smax = np.int16(np.clip(smax + cen_fliplr, 0, ncol - 1) - cen_fliplr)
start_cor = ncol // 2 + smin
stop_cor = ncol // 2 + smax
flip_sino = cp.fliplr(sino)
comp_sino = cp.flipud(sino) # Used to avoid local minima
list_cor = np.arange(start_cor, stop_cor + 0.5, 0.5)
list_metric = np.zeros(len(list_cor), dtype=np.float32)
mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
list_shift = 2.0 * (list_cor - cen_fliplr)
for k, s in enumerate(list_shift):
list_metric[k] = _calculate_metric(s, sino, flip_sino, comp_sino, mask)
minpos = np.argmin(list_metric)
if minpos == 0:
log.debug('WARNING!!!Global minimum is out of searching range')
log.debug('Please extend smin: %i', smin)
if minpos == len(list_metric) - 1:
log.debug('WARNING!!!Global minimum is out of searching range')
log.debug('Please extend smax: %i', smax)
cor = list_cor[minpos]
return cor
def _search_fine(sino, srad, step, init_cen, ratio, drop):
"""
Fine search for finding the rotation center.
"""
(nrow, ncol) = sino.shape
cen_fliplr = (ncol - 1.0) / 2.0
srad = np.clip(np.abs(srad), 1.0, ncol / 4.0)
step = np.clip(np.abs(step), 0.1, srad)
init_cen = np.clip(init_cen, srad, ncol - srad - 1)
list_cor = init_cen + np.arange(-srad, srad + step, step)
flip_sino = cp.fliplr(sino)
comp_sino = cp.flipud(sino)
mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
list_shift = 2.0 * (list_cor - cen_fliplr)
list_metric = np.zeros(len(list_cor), dtype=np.float32)
for k, s in enumerate(list_shift):
list_metric[k] = _calculate_metric(s, sino, flip_sino, comp_sino, mask)
cor = list_cor[np.argmin(list_metric)]
return cor
def _create_mask(nrow, ncol, radius, drop):
"""
Make a binary mask to select coefficients outside the double-wedge region.
Eq.(3) in https://doi.org/10.1364/OE.22.019078
Parameters
----------
nrow : int
Image height.
ncol : int
Image width.
radius: int
Radius of an object, in pixel unit.
drop : int
Drop lines around vertical center of the mask.
Returns
-------
2D binary mask.
"""
du = 1.0 / ncol
dv = (nrow - 1.0) / (nrow * 2.0 * np.pi)
cen_row = np.int16(np.ceil(nrow / 2.0) - 1)
cen_col = np.int16(np.ceil(ncol / 2.0) - 1)
drop = min(drop, np.int16(np.ceil(0.05 * nrow)))
mask = cp.zeros((nrow, ncol), dtype='float32')
for i in range(nrow):
pos = np.int16(np.ceil(((i - cen_row) * dv / radius) / du))
(pos1, pos2) = np.clip(np.sort(
(-pos + cen_col, pos + cen_col)), 0, ncol - 1)
mask[i, pos1:pos2 + 1] = 1.0
mask[cen_row - drop:cen_row + drop + 1, :] = 0.0
mask[:, cen_col - 1:cen_col + 2] = 0.0
return mask