"""Tools to analyse images
"""
import logging
import numpy as np
import matplotlib.pyplot as plt
import ecpi.common.sky.catalog as cat
from ecpi.process.imag.core.source_finder import MaxSourceFinder
#pylint:disable=W0102
[docs]class SkyImageStatistics(object):
"""Compute statistical analysis of a sky image
"""
def __init__(self, sky_count, sky_snr, filter_array=0, reg=[(0, 0), (0, 0)]):
"""
**builder**
:param sky_count: sky image
:type sky_count: 2d array
:param sky_snr: sky snr image
:type sky_snr: 2d array
:param filter_array: fov size array set to nan for positions not to be analysed
(1 elsewhere) (default=0, not used)
:type filter_array: 2d array
:param reg: image coordinates delimiting a the rectangular region of sky to
considered [(y_a,z_a),(y_b,z_b)](inclusive)
:type reg: [(int, int), (int, int)]
"""
self.sky_count = sky_count
self.sky_snr = sky_snr
if isinstance(filter_array, (int)):
self.filter_array = np.ones(sky_count.shape)
else:
self.filter_array = filter_array
if reg == [(0, 0), (0, 0)]:
self.reg = [(0, 0), (len(sky_count), len(sky_count[0]))]
else:
self.reg = reg
# declare logger
self.logger = logging.getLogger(__name__)
self.logger.info('creating an instance of SkyImageStatistics')
[docs] def run(self, nb_src=0, snr_limit=0, rad=3, show=False):
"""
run the statistical routine
:param nb_src: number of peak to be found. (default=0, no search)
:type nb_src: int
:param snr_limit: only fov positions of snr >= snr_limit is considered
:type snr_limit: float
:param rad: avoiding radius for each peak (default=3)
:type rad: float
:param show: flag to show the different sky images along the process (default=False)
:type show: bool
:return: image_stat [mean, median, std], list_src [position, src_count, max_snr]
:rtype: [float, float, float], list
"""
list_src = []
if nb_src == 0:
# TODO: COVERAGE
return [np.nanmean(self.sky_count),
np.nanmedian(self.sky_count),
np.nanstd(self.sky_count)\
], list_src
# import IPython; IPython.embed()
tmp_filter = self.filter_array[self.reg[0][0]:self.reg[1][0] + 1,
self.reg[0][1]:self.reg[1][1] + 1].copy()
self.used_sky_count = self.sky_count[self.reg[0][0]:self.reg[1][0] + 1,
self.reg[0][1]:self.reg[1][1] + 1] * tmp_filter
self.used_sky_snr = self.sky_snr[self.reg[0][0]:self.reg[1][0] + 1,
self.reg[0][1]:self.reg[1][1] + 1] * tmp_filter
if show:
plt.title('initial image')
plt.imshow(self.used_sky_count)
plt.show()
nb_found_src = 0
if np.isnan(self.used_sky_snr).all():
self.logger.warning("max_snr is full of NaNs !")
max_snr = np.nanmax(self.used_sky_snr)
while(nb_found_src < nb_src and max_snr > snr_limit):
max_snr_indices = np.where(self.used_sky_snr == max_snr)
if len(max_snr_indices[0]) > 0:
max_snr_indices = (np.array(max_snr_indices[0][0]), np.array(max_snr_indices[1][0]))
max_snr_indices_save = (np.array(max_snr_indices[0] + self.reg[0][0]),
np.array(max_snr_indices[1] + self.reg[0][1]))
src_count = self.used_sky_count[max_snr_indices]
success, fit_params = self._fit_peak(max_snr_indices, snr_limit)
if success:
fit_params[1] += self.reg[0][0]
fit_params[2] += self.reg[0][1]
list_src.append([max_snr_indices_save, src_count, max_snr, *fit_params])
else:
break
row, col = np.indices(tmp_filter.shape)
indices = np.where((row <= max_snr_indices[0] + rad) *
(row >= max_snr_indices[0] - rad) *
(col <= max_snr_indices[1] + rad) *
(col >= max_snr_indices[1] - rad))
tmp_filter[indices] = np.nan
self.used_sky_count *= tmp_filter
self.used_sky_snr *= tmp_filter
nb_found_src += 1
max_snr = np.nanmax(self.used_sky_snr)
if show:
plt.title(f'{nb_found_src} found source')
plt.imshow(self.used_sky_count)
plt.show()
return [np.nanmean(self.used_sky_count), \
np.nanmedian(self.used_sky_count), \
np.nanstd(self.used_sky_count)], list_src
def _fit_peak(self, peak_pos, snr_min):
"""fit the peak with 2d gaussian
[height_fit, x_pos_fit, y_pos_fit, sigma_x_fit, sigma_y_fit, offset_fit]
:param peak_pos: position of the peak to be fitted, if already known. [pix_y, pix_z].
:type peak_position: [float, float]
:param snr_min: snr minimum to do the fit
:type snr_min: float
"""
fit_source = MaxSourceFinder(self.used_sky_count, self.used_sky_snr)
fit_source.fit_source(snr_min, peak_pos, verbose=False)
return fit_source.success, fit_source.params
# TODO: called in main_imag: must not be a "protected" method.
def _build_catalog(self, list_srcs):
"""build a catalog of sources from a list of sources
:param list_srcs: list of sources
:type list_srcs: list
:return: catalog of sources
:rtype: CatalogIdentifiedSources
"""
catalog = cat.CatalogIdentifiedSources()
src_nb = 1
for src in list_srcs:
# ('sourceID', 'Y', 'Z', 'Y_fit', 'error_Y', 'Z_fit', 'error_Z', 'ra', 'dec', 'errorrad', 'flux', 'pflux', 'errflux', 'snr', 'name', 'dist_src', 'class')
info_src = ['f' + str(src_nb), src[0][0] - 99, src[0][1] - 99, src[4] - 99,
0, src[5] - 99, 0, 404, 404, 0, src[3], src[1], 0, src[2], '', 0, 0]
catalog.add_src(info_src)
src_nb += 1
return catalog