Grism Modeling#

class Grism(im_shape, im_scale=0.031, icenter=5, jcenter=5, wavelength=4.2, wave_space=None, index_min=None, index_max=None, grism_filter='F444W', grism_module='A', grism_pupil='R', PSF=None)#

JWST NIRCam grism spectroscopy forward modeling class.

Models the dispersion and convolution of 3D data cubes through a grism, including wavelength-dependent trace, LSF, and PSF effects.

Parameters:
  • im_shape (int) – Size of the image (assumes square)

  • im_scale (float, optional) – Pixel scale of the model image in arcsec (default: 0.031)

  • icenter (int, optional) – i-coordinate of galaxy center in pixels (default: 5)

  • jcenter (int, optional) – j-coordinate of galaxy center in pixels (default: 5)

  • wavelength (float, optional) – Central wavelength in microns (default: 4.2)

  • wave_space (numpy.ndarray, optional) – Wavelength array in microns

  • index_min (int, optional) – Minimum wavelength index

  • index_max (int, optional) – Maximum wavelength index

  • grism_filter (str, optional) – Grism filter name (default: ‘F444W’)

  • grism_module (str, optional) – JWST module ‘A’ or ‘B’ (default: ‘A’)

  • grism_pupil (str, optional) – Grism pupil ‘R’ or ‘C’ (default: ‘R’)

  • PSF (numpy.ndarray, optional) – Point spread function array

Variables:
  • im_shape (int) – Image size

  • im_scale (float) – Model pixel scale

  • detector_scale (float) – Detector pixel scale (0.0629” for JWST)

  • factor (int) – Oversampling factor between model and detector

__init__(im_shape, im_scale=0.031, icenter=5, jcenter=5, wavelength=4.2, wave_space=None, index_min=None, index_max=None, grism_filter='F444W', grism_module='A', grism_pupil='R', PSF=None)#
__str__()#

String representation of the Grism object.

Returns:

Formatted string with grism configuration details

Return type:

str

init_detector()#

Initialize detector coordinate system.

Creates 1D and 2D arrays mapping model pixels to detector coordinates, accounting for oversampling and galaxy center position.

Notes

Sets attributes: - detector_xmin, detector_xmax: Detector coordinate bounds - detector_space_1d: 1D detector coordinate array - detector_space_2d: 2D detector coordinate grid

load_coefficients()#

Load JWST NIRCam grism calibration coefficients.

Loads tracing and dispersion polynomial coefficients from the nircam_grism calibration files for the specified filter, module, and pupil configuration.

Notes

Sets attributes: - fit_opt: Trace polynomial coefficients - w_opt: Dispersion polynomial coefficients - WRANGE: Valid wavelength range for the filter

Calibration files are read from the nircam_grism/FS_grism_config directory.

get_trace()#

Compute grism dispersion trace for the central galaxy pixel.

Assuming the galaxy is at the detector center, computes where the central pixel appears on the detector when emitting at each wavelength in wave_space. Uses polynomial coefficients from NIRCam grism calibration.

Returns:

  • dxs (jax.numpy.ndarray) – Uniformly spaced dispersion positions (detector x-coordinates)

  • disp_space (jax.numpy.ndarray) – Dispersion positions for each wavelength in wave_space

Notes

Sets attributes: - disp_space: Dispersion positions from polynomial trace equation - dxs: Uniformly spaced array spanning min to max dispersion - wavs: Wavelengths corresponding to each position in dxs - inverse_wave_disp: Interpolator mapping dispersion to wavelength

set_wave_array()#

Compute effective wavelength for each model pixel.

Determines the central wavelength of each pixel on the plane of the central pixel by computing spatial separation in wavelength space. This accounts for the wavelength shift each pixel experiences due to its position.

Notes

Algorithm: 1. Disperse each pixel at self.wavelength with zero velocity 2. Find corresponding wavelength in central pixel’s reference frame 3. Store result in self.wave_array

Sets attribute: - wave_array: 2D array of effective wavelengths for each model pixel

load_poly_factors(a01, a02, a03, a04, a05, a06, b01, b02, b03, b04, b05, b06, c01, c02, c03, d01)#

Load polynomial dispersion coefficients.

Parameters:
  • a01-a06 (float) – Constant term polynomial coefficients

  • b01-b06 (float) – Linear (wavelength) term coefficients

  • c01-c03 (float) – Quadratic (wavelength^2) term coefficients

  • d01 (float) – Cubic (wavelength^3) term coefficient

Notes

These coefficients define the NIRCam grism dispersion polynomial: dx = (a_coeffs) + (b_coeffs)*wave + (c_coeffs)*wave^2 + d01*wave^3 where each set of coefficients also depends on x,y detector position.

load_poly_coefficients()#

Pre-compute position-dependent polynomial coefficients.

Evaluates the polynomial coefficients at each detector position in the 2D grid. This precomputation speeds up dispersion calculations by avoiding repeated polynomial evaluations.

Notes

Assumes horizontal dispersion only (all pixels on same detector row). Sets attributes: - coef1: Constant term coefficients (2D array) - coef2: Linear wavelength term coefficients (2D array) - coef3: Quadratic wavelength term coefficients (2D array) - coef4: Cubic wavelength term coefficients (2D array)

grism_dispersion(wave)#

Compute grism dispersion offset for given wavelength.

Calculates the x-axis offset (dx) in the grism image for a pixel at wavelength wave using the NIRCam grism dispersion polynomial.

Parameters:

wave (float or array_like) – Wavelength in microns

Returns:

dx – Dispersion offset in detector pixels

Return type:

jax.numpy.ndarray

Notes

Uses pre-computed position-dependent coefficients (coef1-coef4) from load_poly_coefficients(). Wavelength is normalized by subtracting 3.95 μm.

set_detector_scale(scale)#

Set detector pixel scale.

Parameters:

scale (float) – Detector pixel scale in arcsec/pixel

compute_lsf()#

Compute Line Spread Function (LSF) for NIRCam grism.

Calculates spectral resolution and LSF width at self.wavelength using empirical polynomial fit to NIRCam grism resolving power.

Returns:

R – Spectral resolving power (R = λ/Δλ)

Return type:

float

Notes

Sets attributes: - sigma_lsf: LSF width in wavelength units (microns) - sigma_v_lsf: LSF width in velocity units (km/s)

The resolving power R is modeled as a 4th-order polynomial in wavelength. Conversion to sigma assumes Gaussian profile (FWHM = 2.36 * sigma).

compute_lsf_new()#

Compute improved LSF using double-Gaussian model.

Uses empirical two-component Gaussian model derived from NIRCam flight data. The LSF is modeled as a weighted sum of two Gaussians with module-dependent parameters (Module A vs B).

Returns:

lsf_kernel – Normalized 1D LSF convolution kernel

Return type:

jax.numpy.ndarray

Notes

Model parameters (fraction, FWHM) depend on NIRCam module and wavelength. The kernel is constructed with 6-sigma width and normalized to unit sum.

References

LSF model from Fengwu Sun based on updated NIRCam grism calibration data.

compute_PSF(PSF)#

Prepare Point Spread Function for grism modeling.

Oversamples the input PSF to match model resolution, crops to 11x11 pixels (9x9 for factor>1 after oversampling), normalizes, and reshapes for 3D convolution with spectral dimension.

Parameters:

PSF (numpy.ndarray or jax.numpy.ndarray) – 2D Point Spread Function at detector resolution

Notes

Sets attributes: - oversampled_PSF: PSF oversampled by self.factor and normalized - PSF: 3D array with shape (spatial_y, spatial_x, 1) for broadcasting

If factor=1, uses input PSF directly. Otherwise oversamples using bilinear interpolation and crops to central 11x11 pixels.

disperse(F, V, D)#

Disperse a 3D data cube (flux, velocity, dispersion) through the grism.

Forward models the grism spectroscopy by: 1. Shifting wavelengths based on velocity field 2. Broadening by velocity dispersion and LSF 3. Convolving with spatial PSF 4. Collapsing to 2D grism spectrum

Parameters:
  • F (jax.numpy.ndarray) – 2D flux map (spatial y, spatial x)

  • V (jax.numpy.ndarray) – 2D velocity field in km/s (spatial y, spatial x)

  • D (jax.numpy.ndarray) – 2D velocity dispersion field in km/s (spatial y, spatial x)

Returns:

2D dispersed grism spectrum (spatial y, wavelength)

Return type:

jax.numpy.ndarray

Notes

Uses Gaussian profile convolution for spectral dispersion and FFT convolution for spatial PSF. Velocity is converted to wavelength shift via Doppler formula.