first comit
This commit is contained in:
@@ -0,0 +1,242 @@
|
||||
"""Kernel Provisioner Classes"""
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from ..connect import KernelConnectionInfo, LocalPortCache
|
||||
from ..launcher import launch_kernel
|
||||
from ..localinterfaces import is_local_ip, local_ips
|
||||
from .provisioner_base import KernelProvisionerBase
|
||||
|
||||
|
||||
class LocalProvisioner(KernelProvisionerBase): # type:ignore[misc]
|
||||
"""
|
||||
:class:`LocalProvisioner` is a concrete class of ABC :py:class:`KernelProvisionerBase`
|
||||
and is the out-of-box default implementation used when no kernel provisioner is
|
||||
specified in the kernel specification (``kernel.json``). It provides functional
|
||||
parity to existing applications by launching the kernel locally and using
|
||||
:class:`subprocess.Popen` to manage its lifecycle.
|
||||
|
||||
This class is intended to be subclassed for customizing local kernel environments
|
||||
and serve as a reference implementation for other custom provisioners.
|
||||
"""
|
||||
|
||||
process = None
|
||||
_exit_future = None
|
||||
pid = None
|
||||
pgid = None
|
||||
ip = None
|
||||
ports_cached = False
|
||||
|
||||
@property
|
||||
def has_process(self) -> bool:
|
||||
return self.process is not None
|
||||
|
||||
async def poll(self) -> Optional[int]:
|
||||
"""Poll the provisioner."""
|
||||
ret = 0
|
||||
if self.process:
|
||||
ret = self.process.poll() # type:ignore[unreachable]
|
||||
return ret
|
||||
|
||||
async def wait(self) -> Optional[int]:
|
||||
"""Wait for the provisioner process."""
|
||||
ret = 0
|
||||
if self.process:
|
||||
# Use busy loop at 100ms intervals, polling until the process is
|
||||
# not alive. If we find the process is no longer alive, complete
|
||||
# its cleanup via the blocking wait(). Callers are responsible for
|
||||
# issuing calls to wait() using a timeout (see kill()).
|
||||
while await self.poll() is None: # type:ignore[unreachable]
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Process is no longer alive, wait and clear
|
||||
ret = self.process.wait()
|
||||
# Make sure all the fds get closed.
|
||||
for attr in ["stdout", "stderr", "stdin"]:
|
||||
fid = getattr(self.process, attr)
|
||||
if fid:
|
||||
fid.close()
|
||||
self.process = None # allow has_process to now return False
|
||||
return ret
|
||||
|
||||
async def send_signal(self, signum: int) -> None:
|
||||
"""Sends a signal to the process group of the kernel (this
|
||||
usually includes the kernel and any subprocesses spawned by
|
||||
the kernel).
|
||||
|
||||
Note that since only SIGTERM is supported on Windows, we will
|
||||
check if the desired signal is for interrupt and apply the
|
||||
applicable code on Windows in that case.
|
||||
"""
|
||||
if self.process:
|
||||
if signum == signal.SIGINT and sys.platform == "win32": # type:ignore[unreachable]
|
||||
from ..win_interrupt import send_interrupt
|
||||
|
||||
send_interrupt(self.process.win32_interrupt_event)
|
||||
return
|
||||
|
||||
# Prefer process-group over process
|
||||
if self.pgid and hasattr(os, "killpg"):
|
||||
try:
|
||||
os.killpg(self.pgid, signum)
|
||||
return
|
||||
except OSError:
|
||||
pass # We'll retry sending the signal to only the process below
|
||||
|
||||
# If we're here, send the signal to the process and let caller handle exceptions
|
||||
self.process.send_signal(signum)
|
||||
return
|
||||
|
||||
async def kill(self, restart: bool = False) -> None:
|
||||
"""Kill the provisioner and optionally restart."""
|
||||
if self.process:
|
||||
if hasattr(signal, "SIGKILL"): # type:ignore[unreachable]
|
||||
# If available, give preference to signalling the process-group over `kill()`.
|
||||
try:
|
||||
await self.send_signal(signal.SIGKILL)
|
||||
return
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
self.process.kill()
|
||||
except OSError as e:
|
||||
LocalProvisioner._tolerate_no_process(e)
|
||||
|
||||
async def terminate(self, restart: bool = False) -> None:
|
||||
"""Terminate the provisioner and optionally restart."""
|
||||
if self.process:
|
||||
if hasattr(signal, "SIGTERM"): # type:ignore[unreachable]
|
||||
# If available, give preference to signalling the process group over `terminate()`.
|
||||
try:
|
||||
await self.send_signal(signal.SIGTERM)
|
||||
return
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
self.process.terminate()
|
||||
except OSError as e:
|
||||
LocalProvisioner._tolerate_no_process(e)
|
||||
|
||||
@staticmethod
|
||||
def _tolerate_no_process(os_error: OSError) -> None:
|
||||
# In Windows, we will get an Access Denied error if the process
|
||||
# has already terminated. Ignore it.
|
||||
if sys.platform == "win32":
|
||||
if os_error.winerror != 5:
|
||||
raise
|
||||
# On Unix, we may get an ESRCH error (or ProcessLookupError instance) if
|
||||
# the process has already terminated. Ignore it.
|
||||
else:
|
||||
from errno import ESRCH
|
||||
|
||||
if not isinstance(os_error, ProcessLookupError) or os_error.errno != ESRCH:
|
||||
raise
|
||||
|
||||
async def cleanup(self, restart: bool = False) -> None:
|
||||
"""Clean up the resources used by the provisioner and optionally restart."""
|
||||
if self.ports_cached and not restart:
|
||||
# provisioner is about to be destroyed, return cached ports
|
||||
lpc = LocalPortCache.instance()
|
||||
ports = (
|
||||
self.connection_info["shell_port"],
|
||||
self.connection_info["iopub_port"],
|
||||
self.connection_info["stdin_port"],
|
||||
self.connection_info["hb_port"],
|
||||
self.connection_info["control_port"],
|
||||
)
|
||||
for port in ports:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(port, int)
|
||||
lpc.return_port(port)
|
||||
|
||||
async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Perform any steps in preparation for kernel process launch.
|
||||
|
||||
This includes applying additional substitutions to the kernel launch command and env.
|
||||
It also includes preparation of launch parameters.
|
||||
|
||||
Returns the updated kwargs.
|
||||
"""
|
||||
|
||||
# This should be considered temporary until a better division of labor can be defined.
|
||||
km = self.parent
|
||||
if km:
|
||||
if km.transport == "tcp" and not is_local_ip(km.ip):
|
||||
msg = (
|
||||
"Can only launch a kernel on a local interface. "
|
||||
f"This one is not: {km.ip}."
|
||||
"Make sure that the '*_address' attributes are "
|
||||
"configured properly. "
|
||||
f"Currently valid addresses are: {local_ips()}"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
# build the Popen cmd
|
||||
extra_arguments = kwargs.pop("extra_arguments", [])
|
||||
|
||||
# write connection file / get default ports
|
||||
# TODO - change when handshake pattern is adopted
|
||||
if km.cache_ports and not self.ports_cached:
|
||||
lpc = LocalPortCache.instance()
|
||||
km.shell_port = lpc.find_available_port(km.ip)
|
||||
km.iopub_port = lpc.find_available_port(km.ip)
|
||||
km.stdin_port = lpc.find_available_port(km.ip)
|
||||
km.hb_port = lpc.find_available_port(km.ip)
|
||||
km.control_port = lpc.find_available_port(km.ip)
|
||||
self.ports_cached = True
|
||||
if "env" in kwargs:
|
||||
jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "")
|
||||
km.write_connection_file(jupyter_session=jupyter_session)
|
||||
else:
|
||||
km.write_connection_file()
|
||||
self.connection_info = km.get_connection_info()
|
||||
|
||||
kernel_cmd = km.format_kernel_cmd(
|
||||
extra_arguments=extra_arguments
|
||||
) # This needs to remain here for b/c
|
||||
else:
|
||||
extra_arguments = kwargs.pop("extra_arguments", [])
|
||||
kernel_cmd = self.kernel_spec.argv + extra_arguments
|
||||
|
||||
return await super().pre_launch(cmd=kernel_cmd, **kwargs)
|
||||
|
||||
async def launch_kernel(self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
|
||||
"""Launch a kernel with a command."""
|
||||
scrubbed_kwargs = LocalProvisioner._scrub_kwargs(kwargs)
|
||||
self.process = launch_kernel(cmd, **scrubbed_kwargs)
|
||||
pgid = None
|
||||
if hasattr(os, "getpgid"):
|
||||
try:
|
||||
pgid = os.getpgid(self.process.pid)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
self.pid = self.process.pid
|
||||
self.pgid = pgid
|
||||
return self.connection_info
|
||||
|
||||
@staticmethod
|
||||
def _scrub_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Remove any keyword arguments that Popen does not tolerate."""
|
||||
keywords_to_scrub: List[str] = ["extra_arguments", "kernel_id"]
|
||||
scrubbed_kwargs = kwargs.copy()
|
||||
for kw in keywords_to_scrub:
|
||||
scrubbed_kwargs.pop(kw, None)
|
||||
return scrubbed_kwargs
|
||||
|
||||
async def get_provisioner_info(self) -> Dict:
|
||||
"""Captures the base information necessary for persistence relative to this instance."""
|
||||
provisioner_info = await super().get_provisioner_info()
|
||||
provisioner_info.update({"pid": self.pid, "pgid": self.pgid, "ip": self.ip})
|
||||
return provisioner_info
|
||||
|
||||
async def load_provisioner_info(self, provisioner_info: Dict) -> None:
|
||||
"""Loads the base information necessary for persistence relative to this instance."""
|
||||
await super().load_provisioner_info(provisioner_info)
|
||||
self.pid = provisioner_info["pid"]
|
||||
self.pgid = provisioner_info["pgid"]
|
||||
self.ip = provisioner_info["ip"]
|
||||
Reference in New Issue
Block a user