Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/specify_cli/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ class StepBase(ABC):

Every step type — built-in or extension-provided — implements this
interface and registers in ``STEP_REGISTRY``.

Thread-safety: ``STEP_REGISTRY`` holds a single shared instance per type, so
a concurrent ``fan-out`` (``max_concurrency > 1``) can invoke ``execute`` on
the same instance from several threads at once. Implementations must be
stateless / thread-safe — derive all per-run state from the ``config`` and
``context`` arguments and never mutate ``self`` in ``execute``. The built-in
steps follow this rule.
"""

#: Matches the ``type:`` value in workflow YAML.
Expand Down
298 changes: 249 additions & 49 deletions src/specify_cli/workflows/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@

from __future__ import annotations

import dataclasses
import json
import os
import re
import tempfile
import threading
import uuid
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -378,6 +382,15 @@ def __init__(
self.current_step_index = 0
self.current_step_id: str | None = None
self.step_results: dict[str, dict[str, Any]] = {}
# Guards step_results mutation and save() so a concurrent fan-out cannot
# mutate the dict while save() is serializing it (which would raise
# "dictionary changed size during iteration").
self._lock = threading.Lock()
Comment on lines +385 to +388
# Serializes append_log's list append + log.jsonl write so concurrent
# fan-out workers cannot interleave or corrupt log lines. Kept separate
# from _lock so frequent logging never contends with state saves; since
# append_log is never called while _lock is held, the two never nest.
self._log_lock = threading.Lock()
self.inputs: dict[str, Any] = {}
self.created_at = datetime.now(timezone.utc).isoformat()
self.updated_at = self.created_at
Expand All @@ -387,28 +400,72 @@ def __init__(
def runs_dir(self) -> Path:
return self.project_root / ".specify" / "workflows" / "runs" / self.run_id

def record_step_result(self, step_id: str, data: dict[str, Any]) -> None:
"""Record one step's result under the run lock.

Routing the mutation through the lock keeps it from racing a concurrent
``save()`` that is iterating ``step_results`` (e.g. during a concurrent
fan-out). For a sequential run this is an uncontended lock.
"""
with self._lock:
self.step_results[step_id] = data

def set_step_output(self, step_id: str, output: Any) -> None:
"""Replace an already-recorded step's ``output`` under the run lock.

Fan-out updates its parent step's output after the items have run;
routing that nested mutation through the lock keeps it from racing a
``save()`` serializing ``step_results`` — the same invariant
``record_step_result`` provides for the top-level assignment.
"""
with self._lock:
if step_id in self.step_results:
self.step_results[step_id]["output"] = output

def save(self) -> None:
"""Persist current state to disk."""
self.updated_at = datetime.now(timezone.utc).isoformat()
"""Persist current state to disk.

Held under the run lock and written atomically (temp file + ``os.replace``)
so a concurrent fan-out can neither mutate ``step_results`` mid-serialization
nor leave a reader observing a half-written file. Racing writers only
contend to be last; they never corrupt.
"""
runs_dir = self.runs_dir
runs_dir.mkdir(parents=True, exist_ok=True)

state_data = {
"run_id": self.run_id,
"workflow_id": self.workflow_id,
"status": self.status.value,
"current_step_index": self.current_step_index,
"current_step_id": self.current_step_id,
"step_results": self.step_results,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
with open(runs_dir / "state.json", "w", encoding="utf-8") as f:
json.dump(state_data, f, indent=2)

inputs_data = {"inputs": self.inputs}
with open(runs_dir / "inputs.json", "w", encoding="utf-8") as f:
json.dump(inputs_data, f, indent=2)
with self._lock:
# Stamp updated_at inside the lock so the timestamp matches the
# snapshot this thread serializes (concurrent savers don't race it).
self.updated_at = datetime.now(timezone.utc).isoformat()
state_data = {
"run_id": self.run_id,
"workflow_id": self.workflow_id,
"status": self.status.value,
"current_step_index": self.current_step_index,
"current_step_id": self.current_step_id,
"step_results": self.step_results,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
self._atomic_write_json(runs_dir / "state.json", state_data)
self._atomic_write_json(runs_dir / "inputs.json", {"inputs": self.inputs})

@staticmethod
def _atomic_write_json(path: Path, data: dict[str, Any]) -> None:
"""Write *data* as indented JSON to *path* atomically (temp + ``os.replace``)."""
fd, tmp = tempfile.mkstemp(
dir=str(path.parent), prefix=f".{path.name}.", suffix=".tmp"
)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
os.replace(tmp, path)
except BaseException:
try:
os.unlink(tmp)
except OSError:
pass
raise

@classmethod
def load(cls, run_id: str, project_root: Path) -> RunState:
Expand Down Expand Up @@ -456,14 +513,18 @@ def load(cls, run_id: str, project_root: Path) -> RunState:
return state

def append_log(self, entry: dict[str, Any]) -> None:
"""Append a log entry to the run log."""
entry["timestamp"] = datetime.now(timezone.utc).isoformat()
self.log_entries.append(entry)
"""Append a log entry to the run log.

Held under ``_log_lock`` so concurrent fan-out workers serialize their
list append and ``log.jsonl`` write rather than interleaving lines.
"""
entry["timestamp"] = datetime.now(timezone.utc).isoformat()
runs_dir = self.runs_dir
runs_dir.mkdir(parents=True, exist_ok=True)
with open(runs_dir / "log.jsonl", "a", encoding="utf-8") as f:
f.write(json.dumps(entry) + "\n")
with self._log_lock:
self.log_entries.append(entry)
with open(runs_dir / "log.jsonl", "a", encoding="utf-8") as f:
f.write(json.dumps(entry) + "\n")


# -- Workflow Engine ------------------------------------------------------
Expand Down Expand Up @@ -739,7 +800,7 @@ def _execute_steps(
"status": result.status.value,
}
context.steps[step_id] = step_data
state.step_results[step_id] = step_data
state.record_step_result(step_id, step_data)

state.append_log(
{
Expand Down Expand Up @@ -867,39 +928,29 @@ def _execute_steps(
return
if orig and ns_copy["id"] in context.steps:
context.steps[orig] = context.steps[ns_copy["id"]]
state.step_results[orig] = context.steps[ns_copy["id"]]

# Fan-out: execute nested step template per item with unique IDs
state.record_step_result(
orig, context.steps[ns_copy["id"]]
)

# Fan-out: execute the nested step template once per item. Honors
# max_concurrency — <=1 runs sequentially (default, historical
# behavior); >1 runs up to that many items concurrently. Either way
# results are assembled in item order under the
# parentId:templateId:index id grammar.
if step_type == "fan-out":
items = result.output.get("items", [])
template = result.output.get("step_template", {})
if template and items:
fan_out_results = []
for item_idx, item_val in enumerate(result.output["items"]):
context.item = item_val
# Per-item ID: parentId:templateId:index
item_step = dict(template)
base_id = item_step.get("id", "item")
item_step["id"] = f"{step_id}:{base_id}:{item_idx}"
self._execute_steps(
[item_step], context, state, registry,
step_offset=-1,
)
# Collect per-item result for fan-in
item_result = context.steps.get(item_step["id"], {})
fan_out_results.append(item_result.get("output", {}))
if state.status in (
RunStatus.PAUSED,
RunStatus.FAILED,
RunStatus.ABORTED,
):
break
fan_out_results = self._run_fan_out(
items, template, step_id, context, state, registry,
result.output.get("max_concurrency", 1),
)
Comment on lines +944 to +947
context.item = None
# Preserve original output and add collected results
fan_out_output = dict(result.output)
fan_out_output["results"] = fan_out_results
context.steps[step_id]["output"] = fan_out_output
state.step_results[step_id]["output"] = fan_out_output
state.set_step_output(step_id, fan_out_output)
if state.status in (
RunStatus.PAUSED,
RunStatus.FAILED,
Expand All @@ -910,7 +961,156 @@ def _execute_steps(
# Empty items or no template — normalize output
result.output["results"] = []
context.steps[step_id]["output"] = result.output
state.step_results[step_id]["output"] = result.output
state.set_step_output(step_id, result.output)

def _run_fan_out(
self,
items: list[Any],
template: dict[str, Any],
step_id: str,
context: StepContext,
state: RunState,
registry: dict[str, Any],
max_concurrency: Any,
) -> list[Any]:
"""Run a fan-out template once per item; return per-item outputs in item order.

``max_concurrency`` <= 1 (the default) runs items sequentially, identical
to the historical fan-out behavior. ``max_concurrency`` > 1 runs items on a
bounded thread pool using a sliding submission window of that size: at most
that many items are ever in flight, and no new item is launched once the run
has reached a halting status, so a halt cannot keep starting queued work.

Results are always returned in item order (never completion order). On a
halt (PAUSED/FAILED/ABORTED) the returned prefix is the items up to and
including the first item *in item order* whose own execution halted the run
— identical to the sequential path. Later items that have not yet started
are cancelled; any already running are allowed to finish but their outputs
are ignored. Halt is attributed per item from that item's recorded result
(not the shared run status, which a concurrently-running later item may have
already flipped), so the prefix never drops the actual halting item.

``max_concurrency`` is coerced with ``int()``; a value that cannot be
coerced (``None``, a non-numeric string, …) or that coerces to <= 1 runs
sequentially, while a numeric string like ``"4"`` or a float like ``4.0``
is honored.
"""
halting = (RunStatus.PAUSED, RunStatus.FAILED, RunStatus.ABORTED)
try:
workers = max(1, int(max_concurrency))
except (TypeError, ValueError):
workers = 1
Comment on lines +999 to +1002

base_id = template.get("id", "item")

def item_id(idx: int) -> str:
# Per-item ID grammar: parentId:templateId:index.
return f"{step_id}:{base_id}:{idx}"

def run_item(idx: int, item_ctx: StepContext) -> Any:
item_step = dict(template)
item_step["id"] = item_id(idx)
self._execute_steps(
[item_step], item_ctx, state, registry, step_offset=-1,
)
# Read back through the context that was actually executed against,
# not the outer closure — clearer and robust if StepContext copying
# ever stops sharing the steps dict by reference.
return item_ctx.steps.get(item_step["id"], {}).get("output", {})

# Sequential path — identical to the historical behavior.
if workers <= 1:
results: list[Any] = []
for item_idx, item_val in enumerate(items):
context.item = item_val
results.append(run_item(item_idx, context))
if state.status in halting:
break
return results

# Concurrent path — bounded sliding window; results assembled in item order.
n = len(items)
slots: list[Any] = [None] * n

Comment on lines +1031 to +1034
def run_isolated(idx: int) -> Any:
# Each item runs against its own context copy so context.item is not
# clobbered across threads; the shared steps dict is written only on the
# disjoint parentId:templateId:index key (GIL-safe on distinct keys).
return run_item(idx, dataclasses.replace(context, item=items[idx]))

def item_halted(idx: int) -> bool:
# Did THIS item's own execution halt the run? Decide from the item's
# own recorded result, not the shared run status, so a later item's
# concurrent halt is never misattributed here. Mirrors the sequential
# break condition: PAUSED halts; FAILED halts unless continue_on_error
# routes around it; an abort always halts.
rec = context.steps.get(item_id(idx))
if rec is None:
# Ran but recorded nothing — only happens when the item failed
# before record_step_result (e.g. an unknown step type returns
# early). Every item runs the same template, so the run status is
# this item's own outcome; attribute the halt to it.
return state.status in halting
status = rec.get("status")
if status == StepStatus.PAUSED.value:
return True
if status == StepStatus.FAILED.value:
out = rec.get("output") or {}
if out.get("aborted") or template.get("continue_on_error") is not True:
return True
return False

first_exc: Exception | None = None
halted_at: int | None = None
collected = 0
with ThreadPoolExecutor(max_workers=workers) as pool:
futures: dict[int, Future] = {}
next_submit = 0
for idx in range(n):
Comment on lines +1066 to +1069
# Refill the window: keep <= workers in flight, and stop launching
# new items once the run is halting so a halt cannot keep starting
# queued work. Already-submitted futures are still collected in
# item order below.
while (
next_submit < n
and len(futures) < workers
and state.status not in halting
):
futures[next_submit] = pool.submit(run_isolated, next_submit)
next_submit += 1

fut = futures.pop(idx, None)
if fut is None:
# Safety net: the window submits indices in order and the loop
# breaks at the first halting item, so every collected index has
# an in-flight future. Stop cleanly rather than raise if a future
# change ever breaks that invariant.
break
try:
slots[idx] = fut.result()
except Exception as exc:
# A genuine exception escaping a step (not a normal step
# FAILED, which sets state.status) must not be masked: cancel
# outstanding work and re-raise so the engine marks the run
# failed instead of reporting a vacuous completion.
first_exc = exc
for other in futures.values():
other.cancel()
break
collected = idx + 1
if item_halted(idx):
# First halting item in item order: include it (slots[idx] is
# already set), then cancel everything still pending.
halted_at = idx
for other in futures.values():
other.cancel()
break

if first_exc is not None:
raise first_exc
if halted_at is not None:
return slots[: halted_at + 1]
return slots[:collected]

def _resolve_inputs(
self,
Expand Down
Loading