# fitter.py
import numpy as np
from scipy.optimize import minimize, dual_annealing
from scipy.ndimage import zoom
from petrofit import PSFConvolvedModel2D, model_to_image
[docs]
class SphotModel(PSFConvolvedModel2D):
def __init__(self,model,cutoutdata,resample_psf=True,**kwargs):
'''
A wrapper class for the petrofit model.
Args:
model (astropy FittableModel): model to fit.
cutoutdata (CutoutData): the data to fit.
'''
if resample_psf:
psf = cutoutdata.psf
psf_oversample = cutoutdata.psf_oversample
psf = zoom(psf,1/psf_oversample)
psf /= psf.sum() # normalize
psf_oversample = 1
else:
psf = cutoutdata.psf
psf_oversample = cutoutdata.psf_oversample
super().__init__(model,
psf=psf,
psf_oversample = int(psf_oversample),
oversample = int(psf_oversample))
self.data = cutoutdata.data
self.free_params = self._param_names
self.fixed_params = {}
[docs]
def parse_params(self,theta):
''' a helper function to parse coordinates to all relevant sub-models.
'''
full_params = {}
full_params.update(self.fixed_params)
full_params.update(dict(zip(self.free_params,theta)))
parsed_params = []
for key in self._param_names:
if 'x_0' in key:
parsed_params.append(full_params['x_0'])
elif 'y_0' in key:
parsed_params.append(full_params['y_0'])
else:
parsed_params.append(full_params[key])
return parsed_params
[docs]
def list_to_params(self,theta):
''' convert a list of free parameters to the full model parameter array '''
return self.parse_params(theta)
[docs]
def set_fixed_params(self,fixed_params):
free_params = [param for param in self._param_names if param not in fixed_params.keys()]
free_params = [param for param in free_params if ('x_0' not in param and 'y_0' not in param)]
if 'x_0' not in fixed_params.keys():
x_0 = [self.parameters[self._param_names.index(param)] for param in self._param_names if 'x_0' in param][0]
y_0 = [self.parameters[self._param_names.index(param)] for param in self._param_names if 'y_0' in param][0]
self.free_params = ['x_0','y_0',*free_params] # array of names
self.fixed_params = fixed_params # dict
self.x0_physical = [self.parameters[self._param_names.index(param)] for param in free_params]
self.x0_physical = [x_0,y_0,*self.x0_physical]
else:
self.free_params = free_params # array of names
self.fixed_params = fixed_params # dict
self.x0_physical = [self.parameters[self._param_names.index(param)] for param in free_params]
[docs]
def get_bounds(self):
data = self.data
bounds = []
for key in self.free_params:
if 'x_0' in key:
bounds.append([0,data.shape[1]])
elif 'y_0' in key:
bounds.append([0,data.shape[0]])
elif 'amplitude' in key:
bounds.append([0,np.nanmax(data)])
elif 'r_eff' in key:
# Strictly positive: downstream EllipticalAperture
# construction (sigma_clip_outside_aperture) requires
# r_eff > 0. Sub-pixel r_eff has no physical meaning
# for galaxy cutouts either way.
bounds.append([0.5,data.shape[0]])
elif 'n' in key:
bounds.append([0.1,10])
elif 'ellip' in key:
bounds.append([0.01,1])
elif 'theta' in key:
bounds.append([0,np.pi])
elif 'psf_pa' in key:
bounds.append([0,np.pi])
else:
raise ValueError(f'Unknown parameter {key}')
return np.array(bounds)
[docs]
def set_conditions(self,list_of_conditions):
'''Set condition functions.
Args:
list_of_conditions (list of 2-tuple): list of conditions, each as a 2-tuple. Each tuple ``(a, b)`` is evaluated as ``a >= b``. Either entry of the tuple can be a parameter name or a numerical value. For example, ``[('r_eff', 10), ('r_eff_1', 'r_eff_0')]`` returns True iff ``r_eff >= 10`` AND ``r_eff_1 >= r_eff_0``.
'''
def condition_func(theta):
free_params_dict = dict(zip(self.free_params,theta))
for condition in list_of_conditions:
if isinstance(condition[0],str):
a = free_params_dict[condition[0]]
else:
a = condition[0]
if isinstance(condition[1],str):
b = free_params_dict[condition[1]]
else:
b = condition[1]
if a < b:
return False
return True
self.condition_func = condition_func
[docs]
class ModelFitter():
'''
A fitter class to perform Sersic model fitting to data.
'''
def __init__(self,model,cutoutdata,**kwargs):
self.cutoutdata = cutoutdata
self.shape = self.cutoutdata.data.shape
self.model = model
self.param_names = model._param_names
[docs]
def standardize_params(self,params):
'''
normalize parameters to be between 0 and 1.
'''
lower_bounds,upper_bounds = self.bounds_physical.T
return (params - lower_bounds) / (upper_bounds - lower_bounds)
[docs]
def unstandardize_params(self,params):
''' convert back normalized parameters to the physical scale '''
lower_bounds,upper_bounds = self.bounds_physical.T
return params * (upper_bounds - lower_bounds) + lower_bounds
[docs]
def eval_model(self,standardized_params):
''' render the model image based on the given parameters '''
params = self.unstandardize_params(standardized_params)
parsed_params = self.model.parse_params(params)
self.model.parameters = parsed_params
# petrofit.model_to_image takes size as (x_size, y_size); numpy
# shape is (rows, cols) = (y, x), so reverse before passing.
ny, nx = self.data.shape
_img = model_to_image(self.model, size=(nx, ny))
return _img
[docs]
def calc_chi2(self,
standardized_params,
iterinfo='',print_val=False,chi2_min_allowed=1e-10):
# parameter sanity check
condition_func = getattr(self.model,'condition_func',False)
if condition_func:
if condition_func(standardized_params) == False:
return np.inf
# evaluate model
model_img = self.eval_model(standardized_params)
# sanity check
if np.isfinite(model_img).sum() == 0:
return np.inf
if self.err is None:
chi2 = np.nansum((self.data - model_img)**2)
else:
chi2 = np.nansum(((self.data - model_img)/self.err)**2)
chi2 /= np.isfinite(self.data).sum()
if chi2 <= chi2_min_allowed:
return np.inf
if print_val:
print(f'\r {chi2:.4e} {iterinfo} ',end='',flush=True)
return chi2
[docs]
def fit(self,method='iterative_NM',fit_to='data',**kwargs):
self.data = getattr(self.cutoutdata,fit_to)
self.err = getattr(self.cutoutdata,fit_to+'_error',None)
# pre-compute the bounds
self.bounds_physical = self.model.get_bounds()
param_shape = len(self.model.free_params)
self.bounds = np.vstack([np.ones(param_shape)*0,
np.ones(param_shape)*1]).T
if not hasattr(self.model,'x0'):
self.model.x0 = self.standardize_params(self.model.x0_physical)
if method == 'iterative_NM':
iNM_kwargs = dict(rtol_init=1e-3,rtol_iter=1e-3,
rtol_convergence=1e-6,xrtol=1,max_iter=10)
iNM_kwargs.update(kwargs)
result,success = iterative_NM(self.calc_chi2, (), self.model.x0,
self.bounds,**iNM_kwargs)
elif method == 'dual_annealing':
# Iter-1 global escape. Pre-bracket with a short iNM, then run a
# single scipy.dual_annealing pass. Output is routed through the
# provided Rich `progress` task so it stays inside the existing
# progress block instead of streaming raw prints.
progress = kwargs.pop('progress', None)
progress_text = kwargs.pop('progress_text',
'dual_annealing global escape')
da_maxiter = int(kwargs.pop('da_maxiter', 10))
pre_iNM_kwargs = dict(rtol_init=1e-3, rtol_iter=1e-3,
rtol_convergence=1e-6, xrtol=1, max_iter=5)
result, _ = iterative_NM(self.calc_chi2, (), self.model.x0,
self.bounds, **pre_iNM_kwargs)
if progress is not None:
da_task = progress.add_task(progress_text, total=da_maxiter)
def _da_callback(x, f, context):
progress.update(da_task, advance=1, refresh=True)
return False
else:
_da_callback = None
result = dual_annealing(
self.calc_chi2,
bounds=self.bounds,
x0=result.x,
maxiter=da_maxiter,
initial_temp=10.0,
seed=0,
callback=_da_callback,
minimizer_kwargs=dict(method='L-BFGS-B', bounds=self.bounds,
options=dict(eps=1e-4, maxfun=1000)),
)
if progress is not None:
progress.update(da_task, completed=da_maxiter, refresh=True)
progress.remove_task(da_task)
success = True
elif method == 'BFGS':
result = minimize(self.calc_chi2,self.model.x0,
bounds=self.bounds,
method='L-BFGS-B',
options=dict(eps=1e-4,maxfun=1000))
success = result.success
elif method == 'lbfgsb_polish':
# Tight L-BFGS-B polish meant to run after iNM/dual_annealing
# has placed us inside the basin; tightens chi^2 to gradient ~ 0.
result = minimize(self.calc_chi2,self.model.x0,
bounds=self.bounds,
method='L-BFGS-B',
options=dict(eps=1e-5,ftol=1e-12,gtol=1e-10,
maxfun=2000))
success = result.success
else:
raise ValueError('method not recognized')
# results
bestfit_sersic_params_physical = dict(zip(self.model.free_params,
self.unstandardize_params(result.x)))
bestfit_img = self.eval_model(result.x)
sersic_residual = self.cutoutdata._rawdata - bestfit_img # always take residual from raw data
# update the total residual
psf_modelimg = getattr(self.cutoutdata,'psf_modelimg',0)
residual_img = self.cutoutdata._rawdata - psf_modelimg - bestfit_img
residual_masked = residual_img.copy() # no sigma clipping
# save all information in cutoutdata
self.cutoutdata.sersic_params_physical = bestfit_sersic_params_physical
self.cutoutdata.sersic_params = result.x
self.cutoutdata.sersic_modelimg = bestfit_img
self.cutoutdata.sersic_residual = sersic_residual
self.cutoutdata.residual = residual_img
self.cutoutdata.residual_masked = residual_masked
# Dynamic galaxy-size attribute consumed by the PSF calibrator's
# centre mask. Distinct from cd.galaxy_size (the prep-time
# Gaussian σ guess) so we don't break anything that depends on
# the static initial value.
self.cutoutdata.galaxy_size_sersic = float(
bestfit_sersic_params_physical.get('r_eff',
self.cutoutdata.galaxy_size))
self.model.x0 = result.x
save_bestfit_params(self.cutoutdata,bestfit_sersic_params_physical)
return self.cutoutdata
[docs]
class ModelScaleFitter(ModelFitter):
def __init__(self,model,cutoutdata,base_params=None,
**kwargs):
if base_params is None:
raise ValueWarning('ModelScaleFitter requires the base parameter to be initialized.')
super().__init__(model,cutoutdata,**kwargs)
self.base_params = base_params
self.bounds_physical = self.model.get_bounds()
[docs]
def scale_params(self,flux_scale):
''' a helper function to scale the parameters based on the flux scale '''
flux_scale = np.squeeze(flux_scale)
scaled_params = self.base_params.copy()
for param_name in self.model.free_params:
if 'amplitude' in param_name:
idx = self.model.free_params.index(param_name)
scaled_params[idx] *= flux_scale
return scaled_params
[docs]
def calc_chi2(self,flux_scale,
iterinfo='',print_val=False,chi2_min_allowed=1e-10):
scaled_modelparams = self.scale_params(flux_scale)
# parameter sanity check
condition_func = getattr(self.model,'condition_func',False)
if condition_func:
if condition_func(scaled_modelparams) == False:
return np.inf
# evaluate model
model_img = self.eval_model(scaled_modelparams)
# sanity check
if np.isfinite(model_img).sum() == 0:
return np.inf
if self.err is None:
chi2 = np.nansum((self.data - model_img)**2)
else:
chi2 = np.nansum(((self.data - model_img)/self.err)**2)
chi2 /= np.isfinite(self.data).sum()
if chi2 <= chi2_min_allowed:
return np.inf
if print_val:
print(f'\r {chi2:.4e} {iterinfo} ',end='',flush=True)
return chi2
[docs]
def fit(self,method='iterative_NM',fit_to='data',**kwargs):
self.data = getattr(self.cutoutdata,fit_to)
self.err = getattr(self.cutoutdata,fit_to+'_error',None)
# pre-compute the bounds
self.bounds = [[0,10]]
if not hasattr(self.model,'x0'):
self.model.x0 = [1]
# run fitting
if method == 'iterative_NM':
iNM_kwargs = dict(rtol_init=1e-3,rtol_iter=1e-3,
rtol_convergence=1e-6,xrtol=1,max_iter=10)
iNM_kwargs.update(kwargs)
result,success = iterative_NM(self.calc_chi2, (), self.model.x0,
self.bounds,**iNM_kwargs)
else:
raise ValueError('method not recognized')
# parse results
scaled_modelparams = self.scale_params(result.x)
bestfit_sersic_params_physical = dict(zip(self.model.free_params,
self.unstandardize_params(scaled_modelparams)))
bestfit_img = self.eval_model(scaled_modelparams)
sersic_residual = self.cutoutdata._rawdata - bestfit_img # always take residual from raw data
# update the total residual
psf_modelimg = getattr(self.cutoutdata,'psf_modelimg',0)
residual_img = self.cutoutdata._rawdata - psf_modelimg - bestfit_img
residual_masked = residual_img.copy() # no sigma clipping
# save all information in cutoutdata
self.cutoutdata.sersic_params_physical = bestfit_sersic_params_physical
self.cutoutdata.sersic_params = scaled_modelparams
self.cutoutdata.sersic_modelimg = bestfit_img
self.cutoutdata.sersic_residual = sersic_residual
self.cutoutdata.residual = residual_img
self.cutoutdata.residual_masked = residual_masked
self.cutoutdata.galaxy_size_sersic = float(
bestfit_sersic_params_physical.get('r_eff',
self.cutoutdata.galaxy_size))
self.model.x0 = result.x
save_bestfit_params(self.cutoutdata,bestfit_sersic_params_physical)
return self.cutoutdata
[docs]
def iterative_NM(func,args,x0,bounds,
rtol_init=1e-3,rtol_iter=1e-4,
rtol_convergence=1e-6,xrtol=1,max_iter=20,
maxfev_eachiter=100,
progress=None,
progress_text='Running iNM...',**kwargs):
''' Iterative Nelder-Mead minimization.
The original implementation by Scipy tends to miss the global minimum.
Rather than setting the tolerance to be small, the success rate tends to be higher
when the tolerance is set to be larger and the minimization is run multiple times.
'''
## initial fit
# set starting value
# print('',end='',flush=True)
chi2_init = func(x0,*args)
chi2_vals = [chi2_init]
convergence = False
# set tolerance
xatol = xrtol* max(np.abs(x0))
fatol = rtol_init * func(x0,*args)
# run minimization
result = minimize(func,x0,bounds=bounds,
method='Nelder-Mead',
args = (*args,'(iter=0)'),
options = dict(maxfev=maxfev_eachiter,fatol=fatol,xatol=xatol))
chi2_vals.append(result.fun)
# run minimization multiple times
if progress is not None:
progress_task = progress.add_task(progress_text, total=max_iter)
for i in range(max_iter):
x0 = result.x
xatol = xrtol* max(np.abs(x0))
fatol = rtol_iter * func(x0,*args)
result = minimize(func,x0,bounds=bounds,
method='Nelder-Mead',
args = (*args,f'(iter={i+1}: fatol={fatol:.2e})'),
options = dict(maxfev=maxfev_eachiter,fatol=fatol,xatol=xatol))
chi2_vals.append(result.fun)
if progress is not None:
progress.update(progress_task, advance=1, refresh=True)
if np.allclose(chi2_vals[-2:],chi2_vals[-1],rtol=rtol_convergence):
if np.isfinite(chi2_vals[-1]):
convergence = True
# print('\nIterative Nelder-Mead method Converged')
break
if progress is not None:
progress.remove_task(progress_task)
return result, convergence
[docs]
def save_bestfit_params(cutoutdata,bestfit_sersic_params_physical,):
for key,val in bestfit_sersic_params_physical.items():
setattr(cutoutdata,key,val)