|
| 1 | +"""Async subprocess execution with progress callbacks. |
| 2 | +
|
| 3 | +Async equivalent of :mod:`libvcs._internal.run`. |
| 4 | +
|
| 5 | +Note |
| 6 | +---- |
| 7 | +This is an internal API not covered by versioning policy. |
| 8 | +
|
| 9 | +Examples |
| 10 | +-------- |
| 11 | +- :func:`~async_run`: Async command execution with progress callback. |
| 12 | +
|
| 13 | + Before (sync): |
| 14 | +
|
| 15 | + >>> from libvcs._internal.run import run |
| 16 | + >>> output = run(['echo', 'hello'], check_returncode=True) |
| 17 | +
|
| 18 | + With this (async): |
| 19 | +
|
| 20 | + >>> from libvcs._internal.async_run import async_run |
| 21 | + >>> async def example(): |
| 22 | + ... output = await async_run(['echo', 'hello']) |
| 23 | + ... return output.strip() |
| 24 | + >>> asyncio.run(example()) |
| 25 | + 'hello' |
| 26 | +""" |
| 27 | + |
| 28 | +from __future__ import annotations |
| 29 | + |
| 30 | +import asyncio |
| 31 | +import asyncio.subprocess |
| 32 | +import datetime |
| 33 | +import logging |
| 34 | +import sys |
| 35 | +import typing as t |
| 36 | +from collections.abc import Mapping, Sequence |
| 37 | + |
| 38 | +from libvcs import exc |
| 39 | +from libvcs._internal.types import StrOrBytesPath |
| 40 | + |
| 41 | +from .run import console_to_str |
| 42 | + |
| 43 | +logger = logging.getLogger(__name__) |
| 44 | + |
| 45 | + |
| 46 | +class AsyncProgressCallbackProtocol(t.Protocol): |
| 47 | + """Async callback to report subprocess communication. |
| 48 | +
|
| 49 | + Async equivalent of :class:`~libvcs._internal.run.ProgressCallbackProtocol`. |
| 50 | +
|
| 51 | + Examples |
| 52 | + -------- |
| 53 | + >>> async def my_progress(output: str, timestamp: datetime.datetime) -> None: |
| 54 | + ... print(f"[{timestamp}] {output}", end="") |
| 55 | +
|
| 56 | + See Also |
| 57 | + -------- |
| 58 | + libvcs._internal.run.ProgressCallbackProtocol : Sync equivalent |
| 59 | + wrap_sync_callback : Helper to wrap sync callbacks for async use |
| 60 | + """ |
| 61 | + |
| 62 | + async def __call__(self, output: str, timestamp: datetime.datetime) -> None: |
| 63 | + """Process progress for subprocess communication.""" |
| 64 | + ... |
| 65 | + |
| 66 | + |
| 67 | +def wrap_sync_callback( |
| 68 | + sync_cb: t.Callable[[str, datetime.datetime], None], |
| 69 | +) -> AsyncProgressCallbackProtocol: |
| 70 | + """Wrap a sync callback for use with async APIs. |
| 71 | +
|
| 72 | + This helper allows users with existing sync callbacks to use them |
| 73 | + with async APIs without modification. |
| 74 | +
|
| 75 | + Parameters |
| 76 | + ---------- |
| 77 | + sync_cb : Callable[[str, datetime.datetime], None] |
| 78 | + Synchronous callback function |
| 79 | +
|
| 80 | + Returns |
| 81 | + ------- |
| 82 | + AsyncProgressCallbackProtocol |
| 83 | + Async wrapper that calls the sync callback |
| 84 | +
|
| 85 | + Examples |
| 86 | + -------- |
| 87 | + >>> def my_sync_progress(output: str, timestamp: datetime.datetime) -> None: |
| 88 | + ... print(output, end="") |
| 89 | + >>> async_cb = wrap_sync_callback(my_sync_progress) |
| 90 | + >>> # Now use async_cb with async_run() |
| 91 | + """ |
| 92 | + |
| 93 | + async def wrapper(output: str, timestamp: datetime.datetime) -> None: |
| 94 | + sync_cb(output, timestamp) |
| 95 | + |
| 96 | + return wrapper |
| 97 | + |
| 98 | + |
| 99 | +if sys.platform == "win32": |
| 100 | + _ENV: t.TypeAlias = Mapping[str, str] |
| 101 | +else: |
| 102 | + _ENV: t.TypeAlias = Mapping[bytes, StrOrBytesPath] | Mapping[str, StrOrBytesPath] |
| 103 | + |
| 104 | +_CMD: t.TypeAlias = StrOrBytesPath | Sequence[StrOrBytesPath] |
| 105 | + |
| 106 | + |
| 107 | +def _args_to_list(args: _CMD) -> list[str]: |
| 108 | + """Convert command args to list of strings. |
| 109 | +
|
| 110 | + Parameters |
| 111 | + ---------- |
| 112 | + args : str | bytes | Path | Sequence[str | bytes | Path] |
| 113 | + Command arguments in various forms |
| 114 | +
|
| 115 | + Returns |
| 116 | + ------- |
| 117 | + list[str] |
| 118 | + Normalized list of string arguments |
| 119 | + """ |
| 120 | + from os import PathLike |
| 121 | + |
| 122 | + if isinstance(args, (str, bytes, PathLike)): |
| 123 | + if isinstance(args, bytes): |
| 124 | + return [args.decode()] |
| 125 | + return [str(args)] |
| 126 | + return [arg.decode() if isinstance(arg, bytes) else str(arg) for arg in args] |
| 127 | + |
| 128 | + |
| 129 | +async def async_run( |
| 130 | + args: _CMD, |
| 131 | + *, |
| 132 | + cwd: StrOrBytesPath | None = None, |
| 133 | + env: _ENV | None = None, |
| 134 | + check_returncode: bool = True, |
| 135 | + callback: AsyncProgressCallbackProtocol | None = None, |
| 136 | + timeout: float | None = None, |
| 137 | +) -> str: |
| 138 | + """Run a command asynchronously. |
| 139 | +
|
| 140 | + Run 'args' and return stdout content (non-blocking). Optionally stream |
| 141 | + stderr to a callback for progress reporting. Raises an exception if |
| 142 | + the command exits non-zero (when check_returncode=True). |
| 143 | +
|
| 144 | + This is the async equivalent of :func:`~libvcs._internal.run.run`. |
| 145 | +
|
| 146 | + Parameters |
| 147 | + ---------- |
| 148 | + args : list[str] | str |
| 149 | + The command to run |
| 150 | + cwd : str | Path, optional |
| 151 | + Working directory for the command |
| 152 | + env : Mapping[str, str], optional |
| 153 | + Environment variables for the command |
| 154 | + check_returncode : bool, default True |
| 155 | + If True, raise :class:`~libvcs.exc.CommandError` on non-zero exit |
| 156 | + callback : AsyncProgressCallbackProtocol, optional |
| 157 | + Async callback to receive stderr output in real-time. |
| 158 | + Signature: ``async def callback(output: str, timestamp: datetime) -> None`` |
| 159 | + timeout : float, optional |
| 160 | + Timeout in seconds. Raises :class:`~libvcs.exc.CommandTimeoutError` |
| 161 | + if exceeded. |
| 162 | +
|
| 163 | + Returns |
| 164 | + ------- |
| 165 | + str |
| 166 | + Combined stdout output |
| 167 | +
|
| 168 | + Raises |
| 169 | + ------ |
| 170 | + libvcs.exc.CommandError |
| 171 | + If check_returncode=True and process exits with non-zero code |
| 172 | + libvcs.exc.CommandTimeoutError |
| 173 | + If timeout is exceeded |
| 174 | +
|
| 175 | + Examples |
| 176 | + -------- |
| 177 | + Basic usage: |
| 178 | +
|
| 179 | + >>> async def example(): |
| 180 | + ... output = await async_run(['echo', 'hello']) |
| 181 | + ... return output.strip() |
| 182 | + >>> asyncio.run(example()) |
| 183 | + 'hello' |
| 184 | +
|
| 185 | + With progress callback: |
| 186 | +
|
| 187 | + >>> async def progress(output: str, timestamp: datetime.datetime) -> None: |
| 188 | + ... pass # Handle progress output |
| 189 | + >>> async def clone_example(): |
| 190 | + ... url = f'file://{create_git_remote_repo()}' |
| 191 | + ... output = await async_run(['git', 'clone', url, str(tmp_path / 'cb_repo')]) |
| 192 | + ... return 'Cloning' in output or output == '' |
| 193 | + >>> asyncio.run(clone_example()) |
| 194 | + True |
| 195 | +
|
| 196 | + See Also |
| 197 | + -------- |
| 198 | + libvcs._internal.run.run : Synchronous equivalent |
| 199 | + AsyncSubprocessCommand : Lower-level async subprocess wrapper |
| 200 | + """ |
| 201 | + args_list = _args_to_list(args) |
| 202 | + |
| 203 | + # Create subprocess with pipes (using non-shell exec for security) |
| 204 | + proc = await asyncio.subprocess.create_subprocess_exec( |
| 205 | + *args_list, |
| 206 | + stdout=asyncio.subprocess.PIPE, |
| 207 | + stderr=asyncio.subprocess.PIPE, |
| 208 | + cwd=cwd, |
| 209 | + env=env, |
| 210 | + ) |
| 211 | + |
| 212 | + async def _run_with_callback() -> tuple[bytes, bytes, int]: |
| 213 | + """Run subprocess, streaming stderr to callback.""" |
| 214 | + stdout_data = b"" |
| 215 | + stderr_data = b"" |
| 216 | + |
| 217 | + assert proc.stdout is not None |
| 218 | + assert proc.stderr is not None |
| 219 | + |
| 220 | + # Read stderr line-by-line for progress callback |
| 221 | + if callback is not None: |
| 222 | + # Stream stderr to callback while collecting stdout |
| 223 | + async def read_stderr() -> bytes: |
| 224 | + collected = b"" |
| 225 | + assert proc.stderr is not None |
| 226 | + while True: |
| 227 | + line = await proc.stderr.readline() |
| 228 | + if not line: |
| 229 | + break |
| 230 | + collected += line |
| 231 | + # Call progress callback with decoded line |
| 232 | + await callback( |
| 233 | + output=console_to_str(line), |
| 234 | + timestamp=datetime.datetime.now(), |
| 235 | + ) |
| 236 | + return collected |
| 237 | + |
| 238 | + # Run stdout collection and stderr streaming concurrently |
| 239 | + stdout_task = asyncio.create_task(proc.stdout.read()) |
| 240 | + stderr_task = asyncio.create_task(read_stderr()) |
| 241 | + |
| 242 | + stdout_data, stderr_data = await asyncio.gather(stdout_task, stderr_task) |
| 243 | + |
| 244 | + # Send final carriage return (matching sync behavior) |
| 245 | + await callback(output="\r", timestamp=datetime.datetime.now()) |
| 246 | + else: |
| 247 | + # No callback - just collect both streams |
| 248 | + stdout_data, stderr_data = await proc.communicate() |
| 249 | + |
| 250 | + # Wait for process to complete |
| 251 | + await proc.wait() |
| 252 | + returncode = proc.returncode |
| 253 | + assert returncode is not None |
| 254 | + |
| 255 | + return stdout_data, stderr_data, returncode |
| 256 | + |
| 257 | + try: |
| 258 | + if timeout is not None: |
| 259 | + stdout_bytes, stderr_bytes, returncode = await asyncio.wait_for( |
| 260 | + _run_with_callback(), |
| 261 | + timeout=timeout, |
| 262 | + ) |
| 263 | + else: |
| 264 | + stdout_bytes, stderr_bytes, returncode = await _run_with_callback() |
| 265 | + except asyncio.TimeoutError: |
| 266 | + # Kill process on timeout |
| 267 | + proc.kill() |
| 268 | + await proc.wait() |
| 269 | + raise exc.CommandTimeoutError( |
| 270 | + output="Command timed out", |
| 271 | + returncode=-1, |
| 272 | + cmd=args_list, |
| 273 | + ) from None |
| 274 | + |
| 275 | + # Process stdout: strip and join lines (matching sync behavior) |
| 276 | + if stdout_bytes: |
| 277 | + lines = filter( |
| 278 | + None, |
| 279 | + (line.strip() for line in stdout_bytes.splitlines()), |
| 280 | + ) |
| 281 | + output = console_to_str(b"\n".join(lines)) |
| 282 | + else: |
| 283 | + output = "" |
| 284 | + |
| 285 | + # On error, use stderr content |
| 286 | + if returncode != 0 and stderr_bytes: |
| 287 | + stderr_lines = filter( |
| 288 | + None, |
| 289 | + (line.strip() for line in stderr_bytes.splitlines()), |
| 290 | + ) |
| 291 | + output = console_to_str(b"".join(stderr_lines)) |
| 292 | + |
| 293 | + if returncode != 0 and check_returncode: |
| 294 | + raise exc.CommandError( |
| 295 | + output=output, |
| 296 | + returncode=returncode, |
| 297 | + cmd=args_list, |
| 298 | + ) |
| 299 | + |
| 300 | + return output |
0 commit comments