import fnmatch
import os
import subprocess
import sys
import time
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Dict, List, Optional
# ---------------------------------------------------------------------------
# ANSI color helpers — no dependencies, disabled when not writing to a TTY
# or when the NO_COLOR env-var is set (https://no-color.org).
# ---------------------------------------------------------------------------
_RESET = "\033[0m"
_BOLD = "\033[1m"
_DIM = "\033[2m"
_GREEN = "\033[92m" # bright green
_YELLOW = "\033[93m" # bright yellow
_RED = "\033[91m" # bright red
_CYAN = "\033[96m" # bright cyan
def _supports_color() -> bool:
return sys.stdout.isatty() and not os.environ.get("NO_COLOR")
def _c(text: str, *codes: str) -> str:
"""Wrap *text* in ANSI escape codes when color output is supported."""
if not _supports_color():
return text
return "".join(codes) + str(text) + _RESET
def _pad(plain: str, colored: str, width: int, align: str = "l") -> str:
"""Pad *colored* to *width* visible characters using *plain* for length."""
padding = " " * max(0, width - len(plain))
return (colored + padding) if align == "l" else (padding + colored)
# State-code -> color bucket
_GREEN_STATES = {"R", "CG"}
_YELLOW_STATES = {"PD", "CF", "RQ", "RS", "RH", "RF", "S", "ST", "SI", "SO"}
_RED_STATES = {"F", "BF", "NF", "OOM", "TO", "DL", "PR"}
def _color_state(state_name: str, state_code: str) -> str:
if state_code in _GREEN_STATES:
return _c(state_name, _GREEN)
if state_code in _YELLOW_STATES:
return _c(state_name, _YELLOW)
if state_code in _RED_STATES:
return _c(state_name, _RED)
return state_name
# SLURM job state codes
JOB_STATES = {
"BF": "Boot Fail",
"CA": "Cancelled",
"CD": "Completed",
"CF": "Configuring",
"CG": "Completing",
"DL": "Deadline",
"F": "Failed",
"NF": "Node Fail",
"OOM": "Out of Memory",
"PD": "Pending",
"PR": "Preempted",
"R": "Running",
"RD": "Resv Del Hold",
"RF": "Requeue Fed",
"RH": "Requeue Hold",
"RQ": "Requeued",
"RS": "Resizing",
"RV": "Revoked",
"SI": "Signaling",
"SE": "Special Exit",
"SO": "Stage Out",
"ST": "Stopped",
"S": "Suspended",
"TO": "Timeout",
}
# States that mean the job is still alive in the queue
ACTIVE_STATES = {"R", "PD", "CG", "CF", "RQ", "RS", "SI", "SO", "ST", "S", "RH", "RF"}
# squeue --format codes and matching field names
_SEPARATOR = "\x1f" # ASCII unit separator — won't appear in job fields
_FORMAT_CODES = ["%i", "%u", "%j", "%t", "%P", "%D", "%C", "%M", "%l", "%r", "%Q"]
_FORMAT_STR = _SEPARATOR.join(_FORMAT_CODES)
[docs]
@dataclass
class SQueueJob:
"""A single job entry from the SLURM queue."""
job_id: int
user: str
name: str
state: str
partition: str
num_nodes: int
num_cpus: int
time_used: str
time_limit: str
reason: str
priority: int
@property
def is_running(self) -> bool:
return self.state == "R"
@property
def is_pending(self) -> bool:
return self.state == "PD"
@property
def is_active(self) -> bool:
return self.state in ACTIVE_STATES
@property
def state_name(self) -> str:
return JOB_STATES.get(self.state, self.state)
[docs]
def wait_until_done(
self,
poll_interval: float = 30.0,
timeout: Optional[float] = None,
verbose: bool = True,
) -> None:
"""Block until this specific job leaves the active queue.
Parameters
----------
poll_interval : float
Seconds between queue polls. Defaults to 30.
timeout : float, optional
Maximum seconds to wait before raising ``TimeoutError``.
verbose : bool
Print progress messages. Defaults to True.
"""
SQueue().wait_until_done(
job_id=self.job_id,
poll_interval=poll_interval,
timeout=timeout,
verbose=verbose,
)
def __repr__(self) -> str:
return (
f"SQueueJob(job_id={self.job_id}, user={self.user!r}, "
f"name={self.name!r}, state={self.state!r}({self.state_name}), "
f"partition={self.partition!r})"
)
def _parse_int(s: str, default: int = 0) -> int:
try:
return int(s.strip())
except ValueError:
return default
[docs]
class SQueue:
"""Interface to the SLURM job queue via ``squeue``.
Parameters
----------
user : str, optional
If given, only fetch jobs belonging to this user by default.
Examples
--------
>>> q = SQueue()
>>> q.summary()
{'total_jobs': 42, 'running': 30, 'pending': 12, 'users': {...}, 'by_state': {...}}
>>> q.wait_until_done(job_name='training_*')
>>> q.wait_until_done(job_id=12345)
>>> q.wait_until_done(user='alice')
"""
def __init__(
self, user: Optional[str] = None, partition: Optional[str] = None
) -> None:
self._default_user = user
self._default_partition = partition
self._jobs: List[SQueueJob] = []
self.refresh()
# ------------------------------------------------------------------
# Fetching
# ------------------------------------------------------------------
[docs]
def refresh(self) -> "SQueue":
"""Re-run ``squeue`` and update the cached job list.
Returns
-------
SQueue
self, for chaining.
"""
cmd = ["squeue", f"--format={_FORMAT_STR}", "--noheader"]
if self._default_user:
cmd += ["--user", self._default_user]
if self._default_partition:
cmd += ["--partition", self._default_partition]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"squeue failed: {result.stderr.strip()}")
self._jobs = []
for line in result.stdout.splitlines():
line = line.strip()
if not line:
continue
parts = line.split(_SEPARATOR)
if len(parts) < len(_FORMAT_CODES):
continue
try:
job = SQueueJob(
job_id=_parse_int(parts[0]),
user=parts[1].strip(),
name=parts[2].strip(),
state=parts[3].strip(),
partition=parts[4].strip(),
num_nodes=_parse_int(parts[5]),
num_cpus=_parse_int(parts[6]),
time_used=parts[7].strip(),
time_limit=parts[8].strip(),
reason=parts[9].strip(),
priority=_parse_int(parts[10]),
)
self._jobs.append(job)
except (ValueError, IndexError):
continue
return self
# ------------------------------------------------------------------
# Filtering
# ------------------------------------------------------------------
[docs]
def jobs(
self,
job_name: Optional[str] = None,
job_id: Optional[int | str] = None,
user: Optional[str] = None,
state: Optional[str] = None,
partition: Optional[str] = None,
) -> List[SQueueJob]:
"""Return jobs matching the given criteria.
Parameters
----------
job_name : str, optional
Job name or glob pattern (e.g. ``'train_*'``).
job_id : int or str, optional
Exact job ID.
user : str, optional
Username to filter by.
state : str, optional
SLURM state code, e.g. ``'R'`` or ``'PD'``.
partition : str, optional
Partition name to filter by.
Returns
-------
list of SQueueJob
"""
result = list(self._jobs)
if job_id is not None:
result = [j for j in result if j.job_id == int(job_id)]
if user is not None:
result = [j for j in result if j.user == user]
if state is not None:
result = [j for j in result if j.state == state]
if partition is not None:
result = [j for j in result if j.partition == partition]
if job_name is not None:
result = [j for j in result if fnmatch.fnmatch(j.name, job_name)]
return result
[docs]
def running_jobs(self) -> List[SQueueJob]:
"""Return all jobs currently in the R (Running) state."""
return [j for j in self._jobs if j.is_running]
[docs]
def pending_jobs(self) -> List[SQueueJob]:
"""Return all jobs currently in the PD (Pending) state."""
return [j for j in self._jobs if j.is_pending]
# ------------------------------------------------------------------
# Waiting
# ------------------------------------------------------------------
[docs]
def wait_until_done(
self,
job_name: Optional[str] = None,
job_id: Optional[int | str] = None,
user: Optional[str] = None,
poll_interval: float = 30.0,
timeout: Optional[float] = None,
verbose: bool = True,
) -> None:
"""Block until all matching jobs leave the active queue.
Supports glob patterns in *job_name* (``*`` and ``?`` wildcards).
At least one filter argument must be provided.
Parameters
----------
job_name : str, optional
Job name or glob pattern, e.g. ``'train_*'``.
job_id : int or str, optional
A specific job ID to wait for.
user : str, optional
Wait for all jobs belonging to this user to finish.
poll_interval : float
Seconds between queue polls. Defaults to 30.
timeout : float, optional
Maximum seconds to wait before raising ``TimeoutError``.
verbose : bool
Print progress messages. Defaults to True.
Raises
------
ValueError
If no filter is specified.
TimeoutError
If *timeout* is exceeded before all jobs finish.
"""
if job_name is None and job_id is None and user is None:
raise ValueError("Specify at least one of: job_name, job_id, user")
start = time.monotonic()
while True:
self.refresh()
active = [
j
for j in self.jobs(job_name=job_name, job_id=job_id, user=user)
if j.is_active
]
if not active:
if verbose:
print(_c("✓", _GREEN) + " All matching jobs have finished.")
return
if timeout is not None and (time.monotonic() - start) > timeout:
ids = [j.job_id for j in active]
raise TimeoutError(
f"Timed out after {timeout}s. Still active job IDs: {ids}"
)
if verbose:
ids = [j.job_id for j in active]
print(
_c("~", _YELLOW)
+ f" Waiting — {_c(str(len(active)), _YELLOW)} job(s) still active {ids}."
f" Polling again in {poll_interval}s."
)
time.sleep(poll_interval)
# ------------------------------------------------------------------
# Statistics
# ------------------------------------------------------------------
[docs]
def users(self) -> List[str]:
"""Return a sorted list of unique users with jobs in the queue."""
return sorted(set(j.user for j in self._jobs))
[docs]
def jobs_by_user(self) -> Dict[str, List[SQueueJob]]:
"""Return a mapping of username -> list of their jobs."""
result: Dict[str, List[SQueueJob]] = {}
for job in self._jobs:
result.setdefault(job.user, []).append(job)
return result
[docs]
def jobs_by_state(self) -> Dict[str, List[SQueueJob]]:
"""Return a mapping of state code -> list of jobs in that state."""
result: Dict[str, List[SQueueJob]] = {}
for job in self._jobs:
result.setdefault(job.state, []).append(job)
return result
[docs]
def jobs_by_partition(self) -> Dict[str, List[SQueueJob]]:
"""Return a mapping of partition name -> list of jobs in that partition."""
result: Dict[str, List[SQueueJob]] = {}
for job in self._jobs:
result.setdefault(job.partition, []).append(job)
return result
[docs]
def summary(self) -> dict:
"""Return a summary dict with total counts, per-user counts, and per-state counts.
Returns
-------
dict
Keys: ``total_jobs``, ``running``, ``pending``,
``users`` (dict of user -> job count),
``by_state`` (dict of state code -> job count).
"""
by_state = self.jobs_by_state()
by_user = self.jobs_by_user()
return {
"total_jobs": len(self._jobs),
"running": len(by_state.get("R", [])),
"pending": len(by_state.get("PD", [])),
"users": {u: len(jobs) for u, jobs in sorted(by_user.items())},
"by_state": {s: len(jobs) for s, jobs in sorted(by_state.items())},
}
# ------------------------------------------------------------------
# Dunder helpers
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self._jobs)
def __iter__(self):
return iter(self._jobs)
def __str__(self) -> str:
if not self._jobs:
return _c("SLURM Queue", _BOLD, _CYAN) + " · " + _c("empty", _DIM)
total_running = sum(1 for j in self._jobs if j.is_running)
total_pending = sum(1 for j in self._jobs if j.is_pending)
total_nodes = sum(j.num_nodes for j in self._jobs if j.is_running)
total_cpus = sum(j.num_cpus for j in self._jobs if j.is_running)
# Build per-user stats
rows = []
for user, jobs in self.jobs_by_user().items():
running = [j for j in jobs if j.is_running]
pending = [j for j in jobs if j.is_pending]
nodes = sum(j.num_nodes for j in running)
cpus = sum(j.num_cpus for j in running)
rows.append((user, len(jobs), len(running), len(pending), nodes, cpus))
# Heaviest users (by running nodes, then running jobs) first
rows.sort(key=lambda r: (-r[4], -r[2], -r[1]))
headers = ["User", "Jobs", "Running", "Pending", "Nodes (R)", "CPUs (R)"]
totals_plain = [
"TOTAL",
str(len(self._jobs)),
str(total_running),
str(total_pending),
str(total_nodes),
str(total_cpus),
]
# Column widths computed on plain text so ANSI codes don't shift columns
str_rows = [
[r[0], str(r[1]), str(r[2]), str(r[3]), str(r[4]), str(r[5])] for r in rows
]
widths = [
max(
len(headers[i]),
len(totals_plain[i]),
max((len(r[i]) for r in str_rows), default=0),
)
for i in range(len(headers))
]
def fmt_header() -> str:
cells = [_pad(headers[0], _c(headers[0], _BOLD), widths[0], "l")]
for i in range(1, len(headers)):
cells.append(_pad(headers[i], _c(headers[i], _BOLD), widths[i], "r"))
return " " + " ".join(cells)
def fmt_data_row(r: list) -> str:
cells = [_pad(r[0], r[0], widths[0], "l")]
cells.append(_pad(r[1], r[1], widths[1], "r"))
run_c = _c(r[2], _GREEN) if r[2] != "0" else r[2]
cells.append(_pad(r[2], run_c, widths[2], "r"))
pend_c = _c(r[3], _YELLOW) if r[3] != "0" else r[3]
cells.append(_pad(r[3], pend_c, widths[3], "r"))
cells.append(_pad(r[4], r[4], widths[4], "r"))
cells.append(_pad(r[5], r[5], widths[5], "r"))
return " " + " ".join(cells)
def fmt_totals() -> str:
p = totals_plain
cells = [_pad(p[0], _c(p[0], _BOLD), widths[0], "l")]
cells.append(_pad(p[1], _c(p[1], _BOLD), widths[1], "r"))
cells.append(_pad(p[2], _c(p[2], _BOLD, _GREEN), widths[2], "r"))
cells.append(_pad(p[3], _c(p[3], _BOLD, _YELLOW), widths[3], "r"))
cells.append(_pad(p[4], _c(p[4], _BOLD), widths[4], "r"))
cells.append(_pad(p[5], _c(p[5], _BOLD), widths[5], "r"))
return " " + " ".join(cells)
table_width = sum(widths) + 3 * (len(widths) - 1) + 2
title_plain = (
f"SLURM Queue \u00b7 {len(self._jobs)} jobs total"
f" \u00b7 {total_running} running"
f" \u00b7 {total_pending} pending"
)
title = (
_c("SLURM Queue", _BOLD, _CYAN)
+ " \u00b7 "
+ f"{len(self._jobs)} jobs total"
+ " \u00b7 "
+ _c(f"{total_running} running", _GREEN)
+ " \u00b7 "
+ _c(f"{total_pending} pending", _YELLOW)
)
width = max(table_width, len(title_plain))
bar_heavy = _c("\u2550" * width, _DIM)
bar_light = _c("\u2500" * width, _DIM)
lines = [
title,
bar_heavy,
fmt_header(),
bar_light,
*[fmt_data_row(r) for r in str_rows],
bar_light,
fmt_totals(),
bar_heavy,
]
return "\n".join(lines)
def __repr__(self) -> str:
s = self.summary()
return (
f"SQueue(total={s['total_jobs']}, running={s['running']}, "
f"pending={s['pending']}, users={list(s['users'].keys())})"
)
_REASON_MAX = 32 # truncate long scheduling-reason strings to this many characters
def _fmt_job_table(jobs: List[SQueueJob], show_reason: bool = False) -> str:
"""Format a list of jobs as an aligned table string."""
if not jobs:
return " (no jobs)"
headers = [
"JobID",
"User",
"Job Name",
"State",
"Partition",
"Nodes",
"CPUs",
"Used",
"Limit",
]
if show_reason:
headers.append("Reason")
def _trunc(s: str) -> str:
return s[:_REASON_MAX] + "\u2026" if len(s) > _REASON_MAX else s
# Plain rows for width calculation; colored rows for display
rows_plain = [
[
str(j.job_id),
j.user,
j.name,
j.state_name,
j.partition,
str(j.num_nodes),
str(j.num_cpus),
j.time_used,
j.time_limit,
]
+ ([_trunc(j.reason)] if show_reason else [])
for j in jobs
]
rows_colored = [
list(plain[:3]) + [_color_state(plain[3], j.state)] + list(plain[4:])
for j, plain in zip(jobs, rows_plain)
]
widths = [
max(len(headers[i]), max(len(r[i]) for r in rows_plain))
for i in range(len(headers))
]
right = {0, 5, 6}
def fmt_header() -> str:
cells = []
for i, h in enumerate(headers):
align = "r" if i in right else "l"
cells.append(_pad(h, _c(h, _BOLD), widths[i], align))
return " " + " ".join(cells)
def fmt_row(plain: list, colored: list) -> str:
cells = []
for i in range(len(plain)):
align = "r" if i in right else "l"
cells.append(_pad(plain[i], colored[i], widths[i], align))
return " " + " ".join(cells)
bar = _c("─" * (sum(widths) + 3 * (len(widths) - 1) + 2), _DIM)
lines = [
fmt_header(),
bar,
*[fmt_row(p, c) for p, c in zip(rows_plain, rows_colored)],
]
return "\n".join(lines)
def _fmt_stats_table(q: SQueue) -> str:
"""Format a partition-breakdown and state-breakdown view for the stats subcommand."""
lines: List[str] = []
# --- By Partition --------------------------------------------------------
by_part = q.jobs_by_partition()
total_jobs = len(q)
total_running = sum(1 for j in q if j.is_running)
total_pending = sum(1 for j in q if j.is_pending)
total_nodes = sum(j.num_nodes for j in q if j.is_running)
total_cpus = sum(j.num_cpus for j in q if j.is_running)
p_rows: list = []
for part, jobs in by_part.items():
r = [j for j in jobs if j.is_running]
p = [j for j in jobs if j.is_pending]
p_rows.append(
(
part,
len(jobs),
len(r),
len(p),
sum(j.num_nodes for j in r),
sum(j.num_cpus for j in r),
)
)
p_rows.sort(key=lambda r: (-r[4], -r[2], r[0]))
p_headers = ["Partition", "Jobs", "Running", "Pending", "Nodes (R)", "CPUs (R)"]
p_tot = [
"TOTAL",
str(total_jobs),
str(total_running),
str(total_pending),
str(total_nodes),
str(total_cpus),
]
p_str_rows = [
[r[0], str(r[1]), str(r[2]), str(r[3]), str(r[4]), str(r[5])] for r in p_rows
]
p_widths = [
max(
len(p_headers[i]),
len(p_tot[i]),
max((len(r[i]) for r in p_str_rows), default=0),
)
for i in range(len(p_headers))
]
def _ph() -> str:
cells = [_pad(p_headers[0], _c(p_headers[0], _BOLD), p_widths[0], "l")]
for i in range(1, len(p_headers)):
cells.append(_pad(p_headers[i], _c(p_headers[i], _BOLD), p_widths[i], "r"))
return " " + " ".join(cells)
def _pr(r: list) -> str:
cells = [_pad(r[0], r[0], p_widths[0], "l")]
cells.append(_pad(r[1], r[1], p_widths[1], "r"))
cells.append(
_pad(r[2], _c(r[2], _GREEN) if r[2] != "0" else r[2], p_widths[2], "r")
)
cells.append(
_pad(r[3], _c(r[3], _YELLOW) if r[3] != "0" else r[3], p_widths[3], "r")
)
cells.append(_pad(r[4], r[4], p_widths[4], "r"))
cells.append(_pad(r[5], r[5], p_widths[5], "r"))
return " " + " ".join(cells)
def _pt() -> str:
cells = [_pad(p_tot[0], _c(p_tot[0], _BOLD), p_widths[0], "l")]
cells.append(_pad(p_tot[1], _c(p_tot[1], _BOLD), p_widths[1], "r"))
cells.append(_pad(p_tot[2], _c(p_tot[2], _BOLD, _GREEN), p_widths[2], "r"))
cells.append(_pad(p_tot[3], _c(p_tot[3], _BOLD, _YELLOW), p_widths[3], "r"))
cells.append(_pad(p_tot[4], _c(p_tot[4], _BOLD), p_widths[4], "r"))
cells.append(_pad(p_tot[5], _c(p_tot[5], _BOLD), p_widths[5], "r"))
return " " + " ".join(cells)
p_bar = _c("\u2500" * (sum(p_widths) + 3 * (len(p_widths) - 1) + 2), _DIM)
lines += [
_c("By Partition", _BOLD),
p_bar,
_ph(),
p_bar,
*[_pr(r) for r in p_str_rows],
p_bar,
_pt(),
p_bar,
]
# --- By State ------------------------------------------------------------
by_state = q.jobs_by_state()
s_rows = sorted(
[
(JOB_STATES.get(code, code), code, len(jobs))
for code, jobs in by_state.items()
],
key=lambda r: -r[2],
)
s_headers = ["State", "Count"]
s_str_rows = [[r[0], str(r[2])] for r in s_rows]
s_widths = [
max(len(s_headers[i]), max((len(r[i]) for r in s_str_rows), default=0))
for i in range(len(s_headers))
]
def _sh() -> str:
return (
" "
+ _pad(s_headers[0], _c(s_headers[0], _BOLD), s_widths[0], "l")
+ " "
+ _pad(s_headers[1], _c(s_headers[1], _BOLD), s_widths[1], "r")
)
def _sr(state_name: str, state_code: str, count: str) -> str:
name_c = _color_state(state_name, state_code)
if state_code in _GREEN_STATES:
count_c = _c(count, _GREEN)
elif state_code in _YELLOW_STATES:
count_c = _c(count, _YELLOW)
elif state_code in _RED_STATES:
count_c = _c(count, _RED)
else:
count_c = count
return (
" "
+ _pad(state_name, name_c, s_widths[0], "l")
+ " "
+ _pad(count, count_c, s_widths[1], "r")
)
s_bar = _c("\u2500" * (s_widths[0] + s_widths[1] + 5), _DIM)
lines += [
"",
_c("By State", _BOLD),
s_bar,
_sh(),
s_bar,
*[_sr(s_rows[i][0], s_rows[i][1], r[1]) for i, r in enumerate(s_str_rows)],
s_bar,
]
return "\n".join(lines)
# ---------------------------------------------------------------------------
# SLURM accounting (sacct)
# ---------------------------------------------------------------------------
_SACCT_FIELDS = [
"JobID",
"User",
"JobName",
"State",
"Partition",
"AllocNodes",
"AllocCPUS",
"Elapsed",
"CPUTimeRAW",
"ExitCode",
]
_SACCT_FORMAT = ",".join(_SACCT_FIELDS)
# sacct state -> color bucket (separate from squeue states)
_SACCT_GREEN = {"COMPLETED"}
_SACCT_YELLOW = {"TIMEOUT", "PREEMPTED", "CANCELLED"}
_SACCT_RED = {"FAILED", "NODE_FAIL", "OUT_OF_MEMORY"}
def _normalize_sacct_state(state: str) -> str:
"""Normalize sacct state strings — e.g. 'CANCELLED by 1234' -> 'CANCELLED'."""
state = state.strip()
if state.startswith("CANCELLED"):
return "CANCELLED"
return state
def _fmt_cpu_hours(hours: float) -> str:
"""Format a CPU-hour value for display."""
h = int(hours)
if h >= 1_000_000:
return f"{h / 1_000_000:.1f}M"
if h >= 10_000:
s = str(h)
# insert thousands separators manually for portability
parts = []
while len(s) > 3:
parts.append(s[-3:])
s = s[:-3]
parts.append(s)
return ",".join(reversed(parts))
return str(h)
def _color_sacct_state(state: str, text: str) -> str:
if state in _SACCT_GREEN:
return _c(text, _GREEN)
if state in _SACCT_YELLOW:
return _c(text, _YELLOW)
if state in _SACCT_RED:
return _c(text, _RED)
return text
[docs]
@dataclass
class SAcctJob:
"""A single job record from SLURM accounting (``sacct``)."""
job_id: int
user: str
name: str
state: str # normalized, e.g. "COMPLETED", "FAILED", "CANCELLED"
partition: str
num_nodes: int
num_cpus: int
elapsed: str # wall-clock time as HH:MM:SS
cpu_time_raw: int # CPU-seconds = AllocCPUS * elapsed_seconds
exit_code: str # e.g. "0:0" or "1:0"
@property
def cpu_hours(self) -> float:
return self.cpu_time_raw / 3600.0
@property
def is_completed(self) -> bool:
return self.state == "COMPLETED"
@property
def is_failed(self) -> bool:
return self.state in {"FAILED", "NODE_FAIL", "OUT_OF_MEMORY"}
@property
def is_cancelled(self) -> bool:
return self.state == "CANCELLED"
@property
def is_timeout(self) -> bool:
return self.state == "TIMEOUT"
def __repr__(self) -> str:
return (
f"SAcctJob(job_id={self.job_id}, user={self.user!r}, "
f"name={self.name!r}, state={self.state!r}, elapsed={self.elapsed!r})"
)
[docs]
class SAcct:
"""Interface to SLURM job accounting via ``sacct``.
Parameters
----------
user : str, optional
If given, fetch only jobs for this user.
days : int
Number of days of history to look back (default: 7).
partition : str, optional
If given, filter to this partition.
Examples
--------
>>> a = SAcct(user='alice', days=30)
>>> a.summary()
{'total': 42, 'completed': 30, 'failed': 5, ...}
"""
def __init__(
self,
user: Optional[str] = None,
days: int = 7,
partition: Optional[str] = None,
) -> None:
self._user = user
self._days = days
self._partition = partition
self._jobs: List[SAcctJob] = []
self.refresh()
[docs]
def refresh(self) -> "SAcct":
"""Re-run ``sacct`` and update the cached job list."""
start = (datetime.now() - timedelta(days=self._days)).strftime(
"%Y-%m-%dT00:00:00"
)
cmd = [
"sacct",
f"--format={_SACCT_FORMAT}",
"--noheader",
"--parsable2",
f"--starttime={start}",
"--allocations", # main job entries only, no sub-steps
]
if self._user:
cmd += ["--user", self._user]
if self._partition:
cmd += ["--partition", self._partition]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"sacct failed: {result.stderr.strip()}")
self._jobs = []
for line in result.stdout.splitlines():
line = line.strip()
if not line:
continue
parts = line.split("|")
if len(parts) < len(_SACCT_FIELDS):
continue
job_id_str = parts[0].strip()
if not job_id_str or "." in job_id_str:
continue # skip job steps (e.g. 12345.batch)
try:
self._jobs.append(
SAcctJob(
job_id=_parse_int(job_id_str),
user=parts[1].strip(),
name=parts[2].strip(),
state=_normalize_sacct_state(parts[3]),
partition=parts[4].strip(),
num_nodes=_parse_int(parts[5]),
num_cpus=_parse_int(parts[6]),
elapsed=parts[7].strip(),
cpu_time_raw=_parse_int(parts[8]),
exit_code=parts[9].strip(),
)
)
except (ValueError, IndexError):
continue
return self
[docs]
def jobs(
self,
user: Optional[str] = None,
state: Optional[str] = None,
partition: Optional[str] = None,
) -> List[SAcctJob]:
"""Return accounting records matching the given criteria."""
result = list(self._jobs)
if user is not None:
result = [j for j in result if j.user == user]
if state is not None:
result = [j for j in result if j.state == state]
if partition is not None:
result = [j for j in result if j.partition == partition]
return result
[docs]
def jobs_by_user(self) -> Dict[str, List[SAcctJob]]:
"""Return a mapping of username -> list of their historical jobs."""
result: Dict[str, List[SAcctJob]] = {}
for job in self._jobs:
result.setdefault(job.user, []).append(job)
return result
[docs]
def jobs_by_state(self) -> Dict[str, List[SAcctJob]]:
"""Return a mapping of state -> list of jobs in that state."""
result: Dict[str, List[SAcctJob]] = {}
for job in self._jobs:
result.setdefault(job.state, []).append(job)
return result
[docs]
def jobs_by_partition(self) -> Dict[str, List[SAcctJob]]:
"""Return a mapping of partition -> list of jobs in that partition."""
result: Dict[str, List[SAcctJob]] = {}
for job in self._jobs:
result.setdefault(job.partition, []).append(job)
return result
[docs]
def summary(self) -> dict:
"""Return a summary dict of job counts and CPU usage.
Returns
-------
dict
Keys: ``total``, ``completed``, ``failed``, ``cancelled``,
``timeout``, ``cpu_hours``, ``by_state``, ``users``.
"""
by_state = self.jobs_by_state()
by_user = self.jobs_by_user()
return {
"total": len(self._jobs),
"completed": len(by_state.get("COMPLETED", [])),
"failed": sum(1 for j in self._jobs if j.is_failed),
"cancelled": len(by_state.get("CANCELLED", [])),
"timeout": len(by_state.get("TIMEOUT", [])),
"cpu_hours": sum(j.cpu_hours for j in self._jobs),
"by_state": {s: len(jobs) for s, jobs in sorted(by_state.items())},
"users": {u: len(jobs) for u, jobs in sorted(by_user.items())},
}
def __len__(self) -> int:
return len(self._jobs)
def __iter__(self):
return iter(self._jobs)
def __repr__(self) -> str:
s = self.summary()
return (
f"SAcct(total={s['total']}, completed={s['completed']}, "
f"failed={s['failed']}, cpu_hours={s['cpu_hours']:.1f})"
)
def _fmt_history_summary(acct: SAcct) -> str:
"""Per-user summary table — shown when no specific user is requested."""
by_user = acct.jobs_by_user()
if not by_user:
return " (no jobs found in the requested time window)"
headers = ["User", "Jobs", "Done", "Failed", "Timeout", "Cancelled", "CPU-hours"]
rows = []
for user, jobs in by_user.items():
done = sum(1 for j in jobs if j.is_completed)
failed = sum(1 for j in jobs if j.is_failed)
timeout = sum(1 for j in jobs if j.is_timeout)
cancelled = sum(1 for j in jobs if j.is_cancelled)
cpu_h = _fmt_cpu_hours(sum(j.cpu_hours for j in jobs))
rows.append((user, len(jobs), done, failed, timeout, cancelled, cpu_h))
rows.sort(key=lambda r: -r[1]) # heaviest users first
total_done = sum(1 for j in acct if j.is_completed)
total_failed = sum(1 for j in acct if j.is_failed)
total_timeout = sum(1 for j in acct if j.is_timeout)
total_cancelled = sum(1 for j in acct if j.is_cancelled)
totals_plain = [
"TOTAL",
str(len(acct)),
str(total_done),
str(total_failed),
str(total_timeout),
str(total_cancelled),
_fmt_cpu_hours(sum(j.cpu_hours for j in acct)),
]
str_rows = [
[r[0], str(r[1]), str(r[2]), str(r[3]), str(r[4]), str(r[5]), r[6]]
for r in rows
]
widths = [
max(
len(headers[i]),
len(totals_plain[i]),
max((len(r[i]) for r in str_rows), default=0),
)
for i in range(len(headers))
]
def fmt_header() -> str:
cells = [_pad(headers[0], _c(headers[0], _BOLD), widths[0], "l")]
for i in range(1, len(headers)):
cells.append(_pad(headers[i], _c(headers[i], _BOLD), widths[i], "r"))
return " " + " ".join(cells)
def fmt_row(r: list) -> str:
cells = [_pad(r[0], r[0], widths[0], "l")]
cells.append(_pad(r[1], r[1], widths[1], "r"))
cells.append(
_pad(r[2], _c(r[2], _GREEN) if r[2] != "0" else r[2], widths[2], "r")
)
cells.append(
_pad(r[3], _c(r[3], _RED) if r[3] != "0" else r[3], widths[3], "r")
)
cells.append(
_pad(r[4], _c(r[4], _YELLOW) if r[4] != "0" else r[4], widths[4], "r")
)
cells.append(_pad(r[5], r[5], widths[5], "r"))
cells.append(_pad(r[6], r[6], widths[6], "r"))
return " " + " ".join(cells)
def fmt_totals() -> str:
p = totals_plain
cells = [_pad(p[0], _c(p[0], _BOLD), widths[0], "l")]
cells.append(_pad(p[1], _c(p[1], _BOLD), widths[1], "r"))
cells.append(_pad(p[2], _c(p[2], _BOLD, _GREEN), widths[2], "r"))
f_c = _BOLD + _RED if p[3] != "0" else _BOLD
cells.append(_pad(p[3], _c(p[3], f_c), widths[3], "r"))
t_c = _BOLD + _YELLOW if p[4] != "0" else _BOLD
cells.append(_pad(p[4], _c(p[4], t_c), widths[4], "r"))
cells.append(_pad(p[5], _c(p[5], _BOLD), widths[5], "r"))
cells.append(_pad(p[6], _c(p[6], _BOLD), widths[6], "r"))
return " " + " ".join(cells)
bar = _c("\u2500" * (sum(widths) + 3 * (len(widths) - 1) + 2), _DIM)
return "\n".join(
[fmt_header(), bar, *[fmt_row(r) for r in str_rows], bar, fmt_totals(), bar]
)
def _fmt_history_detail(acct: SAcct) -> str:
"""Detailed breakdown for a single user — shown when --user is given."""
if not len(acct):
return " (no jobs found in the requested time window)"
total = len(acct)
total_cpu = sum(j.cpu_hours for j in acct)
lines: List[str] = []
# --- By State ------------------------------------------------------------
by_state = acct.jobs_by_state()
s_rows = sorted(
[
(s, len(jobs), sum(j.cpu_hours for j in jobs))
for s, jobs in by_state.items()
],
key=lambda r: -r[1],
)
s_tot_plain = ["TOTAL", str(total), "100%", _fmt_cpu_hours(total_cpu)]
s_str_rows = [
[r[0], str(r[1]), f"{100 * r[1] // total}%", _fmt_cpu_hours(r[2])]
for r in s_rows
]
s_headers = ["State", "Jobs", "%", "CPU-hours"]
s_widths = [
max(
len(s_headers[i]),
len(s_tot_plain[i]),
max((len(r[i]) for r in s_str_rows), default=0),
)
for i in range(len(s_headers))
]
def fmt_s_header() -> str:
cells = [_pad(s_headers[0], _c(s_headers[0], _BOLD), s_widths[0], "l")]
for i in range(1, len(s_headers)):
cells.append(_pad(s_headers[i], _c(s_headers[i], _BOLD), s_widths[i], "r"))
return " " + " ".join(cells)
def fmt_s_row(state: str, r: list) -> str:
name_c = _color_sacct_state(state, r[0])
count_c = _color_sacct_state(state, r[1])
pct_c = _color_sacct_state(state, r[2])
cpu_c = _color_sacct_state(state, r[3])
cells = [
_pad(r[0], name_c, s_widths[0], "l"),
_pad(r[1], count_c, s_widths[1], "r"),
_pad(r[2], pct_c, s_widths[2], "r"),
_pad(r[3], cpu_c, s_widths[3], "r"),
]
return " " + " ".join(cells)
def fmt_s_totals() -> str:
p = s_tot_plain
cells = [_pad(p[0], _c(p[0], _BOLD), s_widths[0], "l")]
for i in range(1, len(p)):
cells.append(_pad(p[i], _c(p[i], _BOLD), s_widths[i], "r"))
return " " + " ".join(cells)
s_bar = _c("\u2500" * (sum(s_widths) + 3 * (len(s_widths) - 1) + 2), _DIM)
lines += [
_c("By State", _BOLD),
s_bar,
fmt_s_header(),
s_bar,
*[fmt_s_row(s_rows[i][0], r) for i, r in enumerate(s_str_rows)],
s_bar,
fmt_s_totals(),
s_bar,
]
# --- By Partition --------------------------------------------------------
by_part = acct.jobs_by_partition()
if len(by_part) > 1 or list(by_part.keys()) != [""]:
p_rows = sorted(
[
(p, len(jobs), sum(j.cpu_hours for j in jobs))
for p, jobs in by_part.items()
],
key=lambda r: -r[2],
)
p_str_rows = [[r[0], str(r[1]), _fmt_cpu_hours(r[2])] for r in p_rows]
p_headers = ["Partition", "Jobs", "CPU-hours"]
p_widths = [
max(len(p_headers[i]), max((len(r[i]) for r in p_str_rows), default=0))
for i in range(len(p_headers))
]
def fmt_p_header() -> str:
cells = [_pad(p_headers[0], _c(p_headers[0], _BOLD), p_widths[0], "l")]
for i in range(1, len(p_headers)):
cells.append(
_pad(p_headers[i], _c(p_headers[i], _BOLD), p_widths[i], "r")
)
return " " + " ".join(cells)
def fmt_p_row(r: list) -> str:
cells = [_pad(r[0], r[0], p_widths[0], "l")]
for i in range(1, len(r)):
cells.append(_pad(r[i], r[i], p_widths[i], "r"))
return " " + " ".join(cells)
p_bar = _c("\u2500" * (sum(p_widths) + 3 * (len(p_widths) - 1) + 2), _DIM)
lines += [
"",
_c("By Partition", _BOLD),
p_bar,
fmt_p_header(),
p_bar,
*[fmt_p_row(r) for r in p_str_rows],
p_bar,
]
return "\n".join(lines)
# Sort-key functions for the `list --sort` option
_SORT_KEYS = {
"id": lambda j: j.job_id,
"user": lambda j: j.user,
"name": lambda j: j.name,
"state": lambda j: j.state,
"partition": lambda j: j.partition,
"nodes": lambda j: j.num_nodes,
"cpus": lambda j: j.num_cpus,
"time": lambda j: j.time_used,
"priority": lambda j: j.priority,
}
[docs]
def main() -> None:
"""Entry point for the ``slurm-queue`` command-line tool.
Sub-commands
------------
show (default)
Print a per-user queue summary table.
list
Print individual jobs, optionally filtered and sorted.
stats
Print partition and state breakdown statistics.
wait
Block until matching jobs leave the active queue.
"""
import argparse
import sys
parser = argparse.ArgumentParser(
prog="slurm-queue",
description="Inspect and wait on the SLURM job queue.",
)
sub = parser.add_subparsers(dest="cmd")
# ---- show ---------------------------------------------------------------
p_show = sub.add_parser("show", help="Print per-user queue summary (default).")
p_show.add_argument(
"--user", "-u", metavar="USER", default=None, help="Filter to this user."
)
p_show.add_argument(
"--partition",
"-p",
metavar="PARTITION",
default=None,
help="Filter to this partition.",
)
# ---- list ---------------------------------------------------------------
p_list = sub.add_parser("list", help="List individual jobs.")
p_list.add_argument("--user", "-u", metavar="USER", default=None)
p_list.add_argument(
"--partition",
"-p",
metavar="PARTITION",
default=None,
help="Filter to this partition.",
)
p_list.add_argument(
"--job-name",
"-n",
metavar="PATTERN",
default=None,
help="Filter by job name (glob patterns supported, e.g. 'train_*').",
)
p_list.add_argument(
"--job-id",
"-j",
metavar="ID",
type=int,
default=None,
help="Filter to a specific job ID.",
)
p_list.add_argument(
"--state",
"-s",
metavar="STATE",
default=None,
help="Filter by state code, e.g. R, PD, CG.",
)
p_list.add_argument(
"--sort",
"-S",
metavar="KEY",
default=None,
choices=list(_SORT_KEYS),
help="Sort by: id, user, name, state, partition, nodes, cpus, time, priority.",
)
p_list.add_argument(
"--reverse", "-r", action="store_true", help="Reverse the sort order."
)
p_list.add_argument(
"--reason",
action="store_true",
help="Show the scheduling/pending reason column.",
)
# ---- stats --------------------------------------------------------------
p_stats = sub.add_parser(
"stats", help="Print partition and state breakdown statistics."
)
p_stats.add_argument(
"--user", "-u", metavar="USER", default=None, help="Filter to this user."
)
p_stats.add_argument(
"--partition",
"-p",
metavar="PARTITION",
default=None,
help="Filter to this partition.",
)
# ---- history ------------------------------------------------------------
p_hist = sub.add_parser(
"history", help="Show job submission history from accounting records (sacct)."
)
p_hist.add_argument(
"--user",
"-u",
metavar="USER",
default=None,
help="Show detailed per-state breakdown for this user; omit for all-users summary.",
)
p_hist.add_argument(
"--days",
"-d",
metavar="N",
type=int,
default=7,
help="Number of days to look back (default: 7).",
)
p_hist.add_argument(
"--partition",
"-p",
metavar="PARTITION",
default=None,
help="Filter to this partition.",
)
# ---- wait ---------------------------------------------------------------
p_wait = sub.add_parser(
"wait", help="Wait until matching jobs leave the active queue."
)
p_wait.add_argument(
"--job-name",
"-n",
metavar="PATTERN",
default=None,
help="Job name or glob pattern to wait for (e.g. 'train_*').",
)
p_wait.add_argument(
"--job-id",
"-j",
metavar="ID",
type=int,
default=None,
help="Wait for a specific job ID.",
)
p_wait.add_argument(
"--user",
"-u",
metavar="USER",
default=None,
help="Wait for all jobs belonging to this user.",
)
p_wait.add_argument(
"--poll-interval",
"-i",
metavar="SECONDS",
type=float,
default=30.0,
help="Seconds between queue polls (default: 30).",
)
p_wait.add_argument(
"--timeout",
"-t",
metavar="SECONDS",
type=float,
default=None,
help="Raise an error if jobs are still running after this many seconds.",
)
p_wait.add_argument(
"--quiet", "-q", action="store_true", help="Suppress progress messages."
)
args = parser.parse_args()
# Default sub-command: show
if args.cmd is None or args.cmd == "show":
user = getattr(args, "user", None)
partition = getattr(args, "partition", None)
try:
q = SQueue(user=user, partition=partition)
print(q)
except RuntimeError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
elif args.cmd == "list":
try:
q = SQueue(user=args.user, partition=args.partition)
jobs = q.jobs(
job_name=args.job_name,
job_id=args.job_id,
state=args.state,
)
if args.sort:
jobs = sorted(jobs, key=_SORT_KEYS[args.sort], reverse=args.reverse)
print(_fmt_job_table(jobs, show_reason=args.reason))
except RuntimeError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
elif args.cmd == "stats":
try:
q = SQueue(user=args.user, partition=args.partition)
n_running = sum(1 for j in q if j.is_running)
n_pending = sum(1 for j in q if j.is_pending)
title_plain = (
f"SLURM Queue \u00b7 {len(q)} jobs total"
f" \u00b7 {n_running} running"
f" \u00b7 {n_pending} pending"
)
title = (
_c("SLURM Queue", _BOLD, _CYAN)
+ " \u00b7 "
+ f"{len(q)} jobs total"
+ " \u00b7 "
+ _c(f"{n_running} running", _GREEN)
+ " \u00b7 "
+ _c(f"{n_pending} pending", _YELLOW)
)
print(title)
print(_c("\u2550" * len(title_plain), _DIM))
print(_fmt_stats_table(q))
except RuntimeError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
elif args.cmd == "history":
try:
acct = SAcct(user=args.user, days=args.days, partition=args.partition)
n = args.days
day_s = "day" if n == 1 else "days"
part_s = f" \u00b7 {args.partition}" if args.partition else ""
user_s = f" \u00b7 {args.user}" if args.user else ""
title_plain = f"Job History \u00b7 last {n} {day_s} \u00b7 {len(acct)} jobs{part_s}{user_s}"
title = (
_c("Job History", _BOLD, _CYAN)
+ " \u00b7 "
+ _c(f"last {n} {day_s}", _DIM)
+ " \u00b7 "
+ f"{len(acct)} jobs"
+ (f" \u00b7 {args.partition}" if args.partition else "")
+ (f" \u00b7 " + _c(args.user, _BOLD) if args.user else "")
)
print(title)
print(_c("\u2550" * len(title_plain), _DIM))
if args.user:
print(_fmt_history_detail(acct))
else:
print(_fmt_history_summary(acct))
except RuntimeError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
elif args.cmd == "wait":
if args.job_name is None and args.job_id is None and args.user is None:
p_wait.error("Specify at least one of: --job-name, --job-id, --user")
try:
q = SQueue()
q.wait_until_done(
job_name=args.job_name,
job_id=args.job_id,
user=args.user,
poll_interval=args.poll_interval,
timeout=args.timeout,
verbose=not args.quiet,
)
except TimeoutError as e:
print(f"Timeout: {e}", file=sys.stderr)
sys.exit(1)
except RuntimeError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()