Source code for sphot.parallel

"""Parallel execution of run_scalefit across filters.

Architecture
------------
- One worker process per filter (multiprocessing.spawn).
- Each worker reads the galaxy from a parent-saved temp h5, calls
  `sphot.core.run_scalefit` on its assigned filter, and returns the
  modified CutoutData. The parent merges results back into the galaxy.
- Per-filter progress (Rich) is rendered in the parent via:
    * one PRE-ALLOCATED `Progress` row per filter (the "pre-spaced"
      layout the user asked for — rows do not jump as filters start),
    * dynamically-added child rows for nested tasks (NNLS refit,
      iPSF ladder, calibrate_psf_step, ...).
- Workers communicate progress events via a `multiprocessing.Manager`
  Queue. A listener thread in the parent translates events into real
  `Progress.update / add_task / remove_task` calls.
- Workers redirect stdout/stderr to per-filter log files so the parent's
  Live display stays clean.

Public API
----------
`parallel_scalefit(galaxy, base_params, filters, blur_psf, ...)`
"""
from __future__ import annotations

import logging
import multiprocessing as mp
import os
import sys
import tempfile
import threading
import time
import warnings
from typing import Sequence

from rich.progress import (BarColumn, Progress, TextColumn,
                           TimeElapsedColumn, TimeRemainingColumn)
from rich.live import Live
from rich.console import Group
from rich.text import Text


# ----- pickleable Rich proxies (workers use these) --------------------

[docs] class QueueProgressProxy: """Pickleable proxy for rich.Progress. Mirrors the subset of Progress's API that sphot's core / psf code uses: add_task, update, remove_task. """ def __init__(self, queue, filter_label): self.queue = queue self.filter_label = filter_label self._counter = 0
[docs] def add_task(self, description, total=None, **kw): self._counter += 1 child_id = self._counter try: self.queue.put(('add_task', self.filter_label, child_id, description, total)) except Exception: pass return child_id
[docs] def update(self, task_id, advance=None, completed=None, total=None, refresh=None, description=None, **kw): try: self.queue.put(('update', self.filter_label, task_id, advance, completed, total, description)) except Exception: pass
[docs] def remove_task(self, task_id): try: self.queue.put(('remove_task', self.filter_label, task_id)) except Exception: pass
# ----- worker entry point -------------------------------------------- # Module-global queue handle. Set by _init_worker via Pool's # initializer (which IS allowed to pass an mp.Queue at spawn time). # We then read it from worker tasks instead of including it in # pool.imap args (which goes through pickle and would raise # "Queue objects should only be shared between processes through # inheritance"). _WORKER_QUEUE = None def _init_worker(queue): global _WORKER_QUEUE _WORKER_QUEUE = queue def _scalefit_worker(args): """Spawn-mode worker. Reads galaxy from disk, runs scalefit on one filter, returns (filter, modified CutoutData) or (filter, None) on failure. """ (galaxy_path, base_params, filtername, blur_psf_filt, allow_refit, fit_complex_model, n_mainloop_iter, log_file, working_dir) = args queue = _WORKER_QUEUE if log_file: try: f = open(log_file, 'w', buffering=1) sys.stdout = f sys.stderr = f # The parent set SPHOT_QUIET_IMPORT=1 around the Pool, so # bootstrap imports in this worker (sphot.config.load_config # etc.) installed a NullHandler and were silenced — that # prevents log lines from leaking to the parent's terminal and # clobbering its Live region before this redirect happens. # Now that stdout/stderr are pointed at our per-filter log # file, restore real logging into that file. import logging sphot_logger = logging.getLogger('sphot') for h in list(sphot_logger.handlers): sphot_logger.removeHandler(h) sphot_logger.setLevel(logging.INFO) sphot_logger.propagate = True sh = logging.StreamHandler(f) sh.setFormatter(logging.Formatter( "[sphot %(levelname)s] (%(asctime)s): %(message)s " "(%(module)s.%(funcName)s)", "%m/%d/%y %H:%M:%S")) sphot_logger.addHandler(sh) except Exception: pass if working_dir: try: os.chdir(working_dir) except Exception: pass warnings.filterwarnings('ignore') import matplotlib matplotlib.use('Agg') from sphot.data import read_sphot_h5 from sphot.core import run_scalefit try: galaxy = read_sphot_h5(galaxy_path) except Exception as e: import traceback traceback.print_exc() try: queue.put(('filter_failed', filtername, f'read_sphot_h5: {e}')) except Exception: pass return filtername, None proxy = QueueProgressProxy(queue, filtername) try: run_scalefit(galaxy, filtername, base_params, allow_refit=allow_refit, fit_complex_model=fit_complex_model, N_mainloop_iter=n_mainloop_iter, blur_psf=blur_psf_filt, progress=proxy) try: queue.put(('filter_done', filtername)) except Exception: pass return filtername, galaxy.images[filtername] except Exception as e: import traceback traceback.print_exc() try: queue.put(('filter_failed', filtername, str(e))) except Exception: pass return filtername, None # ----- public API ----------------------------------------------------
[docs] def parallel_scalefit(galaxy, base_params, filters: Sequence[str], blur_psf, *, allow_refit: bool = False, fit_complex_model: bool = False, N_mainloop_iter: int = 5, n_workers: int | None = None, working_dir: str | None = None, log_dir: str | None = None, console=None): """Run run_scalefit on each filter in `filters` in parallel. Returns the list of (filter, CutoutData) tuples returned by workers. Modifies `galaxy.images[filter]` in-place with the worker results. `working_dir` should be the directory where sphot_config.toml lives (workers chdir to it before importing sphot.config). `log_dir` is where each filter's stdout/stderr is captured. """ if n_workers is None: n_workers = min(len(filters), os.cpu_count() or 1, 6) if log_dir is None: log_dir = working_dir or os.getcwd() os.makedirs(log_dir, exist_ok=True) if working_dir is None: working_dir = os.getcwd() # Save the galaxy once for workers to read. tmp = tempfile.NamedTemporaryFile(suffix='.h5', delete=False, dir=working_dir) tmp.close() galaxy.save(tmp.name) # Suppress sphot logger output during spawn-bootstrap imports in any # child process (Manager + Pool workers) so log lines emitted at # import time don't leak to the parent's terminal and clobber the # Live region. Spawn children inherit env vars; sphot/logging.py # checks this flag at module load time. _prev_quiet = os.environ.get('SPHOT_QUIET_IMPORT') os.environ['SPHOT_QUIET_IMPORT'] = '1' ctx = mp.get_context('spawn') # `ctx.Queue()` instead of `ctx.Manager().Queue()`: avoids spawning # a separate SyncManager subprocess (which has been flaky on macOS # spawn-mode — BrokenPipeError during the manager's address-send # handshake). A regular mp.Queue is a pipe + lock pair, fully # picklable across the Pool boundary. queue = ctx.Queue() # Share a single Console between the sphot RichHandler and the Live # display so log records emitted during the run render above the live # region instead of clobbering it. Without this, RichHandler writes to # its own Console while Live owns a different one — both target the # same stdout but neither knows the other is there, so a log line # printed mid-render breaks the bars (this was the "queued" rows being # truncated). if console is None: for h in logging.getLogger('sphot').handlers: h_console = getattr(h, 'console', None) if h_console is not None: console = h_console break # One renderable block per filter: headline + (child OR blank line). # The blank-line placeholder keeps each block at exactly 2 lines so that # children appearing/disappearing don't reflow the layout (which was # making a re-appearing child clobber the next filter's headline below). def _make_progress(): return Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeElapsedColumn(), TimeRemainingColumn(), transient=False, console=console, ) class _FilterBlock: def __init__(self, filter_name): self.filter_name = filter_name self.headline_progress = _make_progress() self.headline_id = self.headline_progress.add_task( f'[bold cyan]{filter_name}[/]: scalefit (queued)', total=None, start=False) self.child_progress = _make_progress() self.child_real_id = None # real task id when a child is active def __rich_console__(self, console_, options): yield self.headline_progress if self.child_real_id is not None: yield self.child_progress else: yield Text("") # placeholder so block height stays constant blocks = {f: _FilterBlock(f) for f in filters} started = set() live = Live(Group(*[blocks[f] for f in filters]), console=console, refresh_per_second=10, transient=False) # (filter, worker child_id) -> ('headline'|'child', real task id) task_map = {} stop_listener = threading.Event() def listener(): while not stop_listener.is_set(): try: msg = queue.get(timeout=0.2) except Exception: continue try: cmd = msg[0] if cmd == 'add_task': _, filt, child_id, desc, total = msg blk = blocks[filt] if child_id == 1: # First add_task per worker = the main loop. Promote # it onto the pre-allocated headline so the headline # gets a real bar/total and the layout stays stable. blk.headline_progress.update( blk.headline_id, description=f'[bold cyan]{filt}[/]: {desc}', total=total) blk.headline_progress.start_task(blk.headline_id) task_map[(filt, child_id)] = ('headline', blk.headline_id) started.add(filt) else: # If a previous child is still in the slot (worker # didn't remove_task before adding a new one), clear # it first so the slot only ever holds one task. if blk.child_real_id is not None: try: blk.child_progress.remove_task(blk.child_real_id) except Exception: pass blk.child_real_id = None real = blk.child_progress.add_task(f' {desc}', total=total) blk.child_real_id = real task_map[(filt, child_id)] = ('child', real) # Force an immediate render so very short-lived child # tasks (e.g. Sersic iNM, which can finish in <250ms) # are visible at least once instead of being skipped # over by the Live's 4Hz auto-refresh. try: live.refresh() except Exception: pass elif cmd == 'update': _, filt, child_id, advance, completed, total, descr = msg blk = blocks[filt] target = task_map.get((filt, child_id)) if target is None: continue kind, real = target p = (blk.headline_progress if kind == 'headline' else blk.child_progress) kw = {} if advance is not None: kw['advance'] = advance if completed is not None: kw['completed'] = completed if total is not None: kw['total'] = total if descr is not None: if kind == 'headline': kw['description'] = f'[bold cyan]{filt}[/]: {descr}' else: kw['description'] = f' {descr}' if kw: try: p.update(real, **kw) except Exception: pass elif cmd == 'remove_task': _, filt, child_id = msg blk = blocks[filt] target = task_map.pop((filt, child_id), None) if target is None: continue kind, real = target if kind == 'headline': # Never remove the headline — it stays for the whole # run and gets a final 'done'/'FAILED' state below. continue # Don't remove the bar from the slot immediately. iNM # (and other short-lived child tasks) finish in <1s, # which is below the user's perception threshold; if we # remove instantly, the bar would never visibly settle # at "complete". Instead mark the task 100% done and # leave it in the slot. The next add_task overwrites # this slot (existing logic) when the next nested step # begins. try: task = next((t for t in blk.child_progress.tasks if t.id == real), None) if task is not None and task.total is not None: blk.child_progress.update(real, completed=task.total) except Exception: pass elif cmd == 'filter_done': _, filt = msg started.add(filt) blk = blocks[filt] h = blk.headline_id total = next((t.total for t in blk.headline_progress.tasks if t.id == h), None) kw = {'description': f'[bold green]{filt}[/]: done'} if total is not None: kw['completed'] = total blk.headline_progress.update(h, **kw) elif cmd == 'filter_failed': _, filt, err = msg started.add(filt) short_err = str(err)[:60] blk = blocks[filt] blk.headline_progress.update( blk.headline_id, description=f'[bold red]{filt}[/]: FAILED: {short_err}') except Exception: pass listener_thread = threading.Thread(target=listener, daemon=True) listener_thread.start() args_list = [ (tmp.name, base_params, f, blur_psf[f], allow_refit, fit_complex_model, N_mainloop_iter, os.path.join(log_dir, f'parallel_{f}.log'), working_dir) for f in filters ] results = [] try: with live: # `initializer` + `initargs` is the standard way to share an # mp.Queue with workers in spawn mode (the queue is hooked # up at process bootstrap, not via pickle of imap args). with ctx.Pool(processes=n_workers, maxtasksperchild=1, initializer=_init_worker, initargs=(queue,)) as pool: for r in pool.imap_unordered(_scalefit_worker, args_list): results.append(r) # let listener drain trailing messages time.sleep(0.5) finally: if _prev_quiet is None: os.environ.pop('SPHOT_QUIET_IMPORT', None) else: os.environ['SPHOT_QUIET_IMPORT'] = _prev_quiet stop_listener.set() listener_thread.join(timeout=2.0) try: os.remove(tmp.name) except Exception: pass # merge results back into the parent galaxy. # NOTE: `galaxy.images` is a @property that builds a fresh dict each # access, so `galaxy.images[f] = cd` mutates a throwaway view and # the real CutoutData (stored as galaxy.<filtername>) is unchanged. # Use setattr so the modification persists. for f, cd in results: if cd is not None: setattr(galaxy, f, cd) return results