first comit

This commit is contained in:
2024-02-23 10:30:02 +00:00
commit ddeb07d0ba
12482 changed files with 1857507 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
"""Client-side implementations of the Jupyter protocol"""
from ._version import __version__, protocol_version, protocol_version_info, version_info
from .asynchronous import AsyncKernelClient
from .blocking import BlockingKernelClient
from .client import KernelClient
from .connect import * # noqa
from .launcher import * # noqa
from .manager import AsyncKernelManager, KernelManager, run_kernel
from .multikernelmanager import AsyncMultiKernelManager, MultiKernelManager
from .provisioning import KernelProvisionerBase, LocalProvisioner

View File

@@ -0,0 +1,20 @@
"""The version information for jupyter client."""
import re
from typing import List, Union
__version__ = "8.6.0"
# Build up version_info tuple for backwards compatibility
pattern = r"(?P<major>\d+).(?P<minor>\d+).(?P<patch>\d+)(?P<rest>.*)"
match = re.match(pattern, __version__)
if match:
parts: List[Union[int, str]] = [int(match[part]) for part in ["major", "minor", "patch"]]
if match["rest"]:
parts.append(match["rest"])
else:
parts = []
version_info = tuple(parts)
protocol_version_info = (5, 3)
protocol_version = "%i.%i" % protocol_version_info

View File

@@ -0,0 +1,431 @@
"""Adapters for Jupyter msg spec versions."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
import re
from typing import Any, Dict, List, Tuple
from ._version import protocol_version_info
def code_to_line(code: str, cursor_pos: int) -> Tuple[str, int]:
"""Turn a multiline code block and cursor position into a single line
and new cursor position.
For adapting ``complete_`` and ``object_info_request``.
"""
if not code:
return "", 0
for line in code.splitlines(True):
n = len(line)
if cursor_pos > n:
cursor_pos -= n
else:
break
return line, cursor_pos
_match_bracket = re.compile(r"\([^\(\)]+\)", re.UNICODE)
_end_bracket = re.compile(r"\([^\(]*$", re.UNICODE)
_identifier = re.compile(r"[a-z_][0-9a-z._]*", re.I | re.UNICODE)
def extract_oname_v4(code: str, cursor_pos: int) -> str:
"""Reimplement token-finding logic from IPython 2.x javascript
for adapting object_info_request from v5 to v4
"""
line, _ = code_to_line(code, cursor_pos)
oldline = line
line = _match_bracket.sub("", line)
while oldline != line:
oldline = line
line = _match_bracket.sub("", line)
# remove everything after last open bracket
line = _end_bracket.sub("", line)
matches = _identifier.findall(line)
if matches:
return matches[-1]
else:
return ""
class Adapter:
"""Base class for adapting messages
Override message_type(msg) methods to create adapters.
"""
msg_type_map: Dict[str, str] = {}
def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Update the header."""
return msg
def update_metadata(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Update the metadata."""
return msg
def update_msg_type(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Update the message type."""
header = msg["header"]
msg_type = header["msg_type"]
if msg_type in self.msg_type_map:
msg["msg_type"] = header["msg_type"] = self.msg_type_map[msg_type]
return msg
def handle_reply_status_error(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""This will be called *instead of* the regular handler
on any reply with status != ok
"""
return msg
def __call__(self, msg: Dict[str, Any]) -> Dict[str, Any]:
msg = self.update_header(msg)
msg = self.update_metadata(msg)
msg = self.update_msg_type(msg)
header = msg["header"]
handler = getattr(self, header["msg_type"], None)
if handler is None:
return msg
# handle status=error replies separately (no change, at present)
if msg["content"].get("status", None) in {"error", "aborted"}:
return self.handle_reply_status_error(msg)
return handler(msg)
def _version_str_to_list(version: str) -> List[int]:
"""convert a version string to a list of ints
non-int segments are excluded
"""
v = []
for part in version.split("."):
try:
v.append(int(part))
except ValueError:
pass
return v
class V5toV4(Adapter):
"""Adapt msg protocol v5 to v4"""
version = "4.1"
msg_type_map = {
"execute_result": "pyout",
"execute_input": "pyin",
"error": "pyerr",
"inspect_request": "object_info_request",
"inspect_reply": "object_info_reply",
}
def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Update the header."""
msg["header"].pop("version", None)
msg["parent_header"].pop("version", None)
return msg
# shell channel
def kernel_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a kernel info reply."""
v4c = {}
content = msg["content"]
for key in ("language_version", "protocol_version"):
if key in content:
v4c[key] = _version_str_to_list(content[key])
if content.get("implementation", "") == "ipython" and "implementation_version" in content:
v4c["ipython_version"] = _version_str_to_list(content["implementation_version"])
language_info = content.get("language_info", {})
language = language_info.get("name", "")
v4c.setdefault("language", language)
if "version" in language_info:
v4c.setdefault("language_version", _version_str_to_list(language_info["version"]))
msg["content"] = v4c
return msg
def execute_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an execute request."""
content = msg["content"]
content.setdefault("user_variables", [])
return msg
def execute_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an execute reply."""
content = msg["content"]
content.setdefault("user_variables", {})
# TODO: handle payloads
return msg
def complete_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a complete request."""
content = msg["content"]
code = content["code"]
cursor_pos = content["cursor_pos"]
line, cursor_pos = code_to_line(code, cursor_pos)
new_content = msg["content"] = {}
new_content["text"] = ""
new_content["line"] = line
new_content["block"] = None
new_content["cursor_pos"] = cursor_pos
return msg
def complete_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a complete reply."""
content = msg["content"]
cursor_start = content.pop("cursor_start")
cursor_end = content.pop("cursor_end")
match_len = cursor_end - cursor_start
content["matched_text"] = content["matches"][0][:match_len]
content.pop("metadata", None)
return msg
def object_info_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an object info request."""
content = msg["content"]
code = content["code"]
cursor_pos = content["cursor_pos"]
line, _ = code_to_line(code, cursor_pos)
new_content = msg["content"] = {}
new_content["oname"] = extract_oname_v4(code, cursor_pos)
new_content["detail_level"] = content["detail_level"]
return msg
def object_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""inspect_reply can't be easily backward compatible"""
msg["content"] = {"found": False, "oname": "unknown"}
return msg
# iopub channel
def stream(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a stream message."""
content = msg["content"]
content["data"] = content.pop("text")
return msg
def display_data(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a display data message."""
content = msg["content"]
content.setdefault("source", "display")
data = content["data"]
if "application/json" in data:
try:
data["application/json"] = json.dumps(data["application/json"])
except Exception:
# warn?
pass
return msg
# stdin channel
def input_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an input request."""
msg["content"].pop("password", None)
return msg
class V4toV5(Adapter):
"""Convert msg spec V4 to V5"""
version = "5.0"
# invert message renames above
msg_type_map = {v: k for k, v in V5toV4.msg_type_map.items()}
def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Update the header."""
msg["header"]["version"] = self.version
if msg["parent_header"]:
msg["parent_header"]["version"] = self.version
return msg
# shell channel
def kernel_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a kernel info reply."""
content = msg["content"]
for key in ("protocol_version", "ipython_version"):
if key in content:
content[key] = ".".join(map(str, content[key]))
content.setdefault("protocol_version", "4.1")
if content["language"].startswith("python") and "ipython_version" in content:
content["implementation"] = "ipython"
content["implementation_version"] = content.pop("ipython_version")
language = content.pop("language")
language_info = content.setdefault("language_info", {})
language_info.setdefault("name", language)
if "language_version" in content:
language_version = ".".join(map(str, content.pop("language_version")))
language_info.setdefault("version", language_version)
content["banner"] = ""
return msg
def execute_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an execute request."""
content = msg["content"]
user_variables = content.pop("user_variables", [])
user_expressions = content.setdefault("user_expressions", {})
for v in user_variables:
user_expressions[v] = v
return msg
def execute_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an execute reply."""
content = msg["content"]
user_expressions = content.setdefault("user_expressions", {})
user_variables = content.pop("user_variables", {})
if user_variables:
user_expressions.update(user_variables)
# Pager payloads became a mime bundle
for payload in content.get("payload", []):
if payload.get("source", None) == "page" and ("text" in payload):
if "data" not in payload:
payload["data"] = {}
payload["data"]["text/plain"] = payload.pop("text")
return msg
def complete_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a complete request."""
old_content = msg["content"]
new_content = msg["content"] = {}
new_content["code"] = old_content["line"]
new_content["cursor_pos"] = old_content["cursor_pos"]
return msg
def complete_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a complete reply."""
# complete_reply needs more context than we have to get cursor_start and end.
# use special end=null to indicate current cursor position and negative offset
# for start relative to the cursor.
# start=None indicates that start == end (accounts for no -0).
content = msg["content"]
new_content = msg["content"] = {"status": "ok"}
new_content["matches"] = content["matches"]
if content["matched_text"]:
new_content["cursor_start"] = -len(content["matched_text"])
else:
# no -0, use None to indicate that start == end
new_content["cursor_start"] = None
new_content["cursor_end"] = None
new_content["metadata"] = {}
return msg
def inspect_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an inspect request."""
content = msg["content"]
name = content["oname"]
new_content = msg["content"] = {}
new_content["code"] = name
new_content["cursor_pos"] = len(name)
new_content["detail_level"] = content["detail_level"]
return msg
def inspect_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""inspect_reply can't be easily backward compatible"""
content = msg["content"]
new_content = msg["content"] = {"status": "ok"}
found = new_content["found"] = content["found"]
new_content["data"] = data = {}
new_content["metadata"] = {}
if found:
lines = []
for key in ("call_def", "init_definition", "definition"):
if content.get(key, False):
lines.append(content[key])
break
for key in ("call_docstring", "init_docstring", "docstring"):
if content.get(key, False):
lines.append(content[key])
break
if not lines:
lines.append("<empty docstring>")
data["text/plain"] = "\n".join(lines)
return msg
# iopub channel
def stream(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a stream message."""
content = msg["content"]
content["text"] = content.pop("data")
return msg
def display_data(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle display data."""
content = msg["content"]
content.pop("source", None)
data = content["data"]
if "application/json" in data:
try:
data["application/json"] = json.loads(data["application/json"])
except Exception:
# warn?
pass
return msg
# stdin channel
def input_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an input request."""
msg["content"].setdefault("password", False)
return msg
def adapt(msg: Dict[str, Any], to_version: int = protocol_version_info[0]) -> Dict[str, Any]:
"""Adapt a single message to a target version
Parameters
----------
msg : dict
A Jupyter message.
to_version : int, optional
The target major version.
If unspecified, adapt to the current version.
Returns
-------
msg : dict
A Jupyter message appropriate in the new version.
"""
from .session import utcnow
header = msg["header"]
if "date" not in header:
header["date"] = utcnow()
if "version" in header:
from_version = int(header["version"].split(".")[0])
else:
# assume last version before adding the key to the header
from_version = 4
adapter = adapters.get((from_version, to_version), None)
if adapter is None:
return msg
return adapter(msg)
# one adapter per major version from,to
adapters = {
(5, 4): V5toV4(),
(4, 5): V4toV5(),
}

View File

@@ -0,0 +1 @@
from .client import AsyncKernelClient # noqa

View File

@@ -0,0 +1,75 @@
"""Implements an async kernel client"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import typing as t
import zmq.asyncio
from traitlets import Instance, Type
from ..channels import AsyncZMQSocketChannel, HBChannel
from ..client import KernelClient, reqrep
def wrapped(meth: t.Callable, channel: str) -> t.Callable:
"""Wrap a method on a channel and handle replies."""
def _(self: AsyncKernelClient, *args: t.Any, **kwargs: t.Any) -> t.Any:
reply = kwargs.pop("reply", False)
timeout = kwargs.pop("timeout", None)
msg_id = meth(self, *args, **kwargs)
if not reply:
return msg_id
return self._recv_reply(msg_id, timeout=timeout, channel=channel)
return _
class AsyncKernelClient(KernelClient):
"""A KernelClient with async APIs
``get_[channel]_msg()`` methods wait for and return messages on channels,
raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds.
"""
context = Instance(zmq.asyncio.Context)
def _context_default(self) -> zmq.asyncio.Context:
self._created_context = True
return zmq.asyncio.Context()
# --------------------------------------------------------------------------
# Channel proxy methods
# --------------------------------------------------------------------------
get_shell_msg = KernelClient._async_get_shell_msg
get_iopub_msg = KernelClient._async_get_iopub_msg
get_stdin_msg = KernelClient._async_get_stdin_msg
get_control_msg = KernelClient._async_get_control_msg
wait_for_ready = KernelClient._async_wait_for_ready
# The classes to use for the various channels
shell_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[arg-type]
iopub_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[arg-type]
stdin_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[arg-type]
hb_channel_class = Type(HBChannel) # type:ignore[arg-type]
control_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[arg-type]
_recv_reply = KernelClient._async_recv_reply
# replies come on the shell channel
execute = reqrep(wrapped, KernelClient.execute)
history = reqrep(wrapped, KernelClient.history)
complete = reqrep(wrapped, KernelClient.complete)
is_complete = reqrep(wrapped, KernelClient.is_complete)
inspect = reqrep(wrapped, KernelClient.inspect)
kernel_info = reqrep(wrapped, KernelClient.kernel_info)
comm_info = reqrep(wrapped, KernelClient.comm_info)
is_alive = KernelClient._async_is_alive
execute_interactive = KernelClient._async_execute_interactive
# replies come on the control channel
shutdown = reqrep(wrapped, KernelClient.shutdown, channel="control")

View File

@@ -0,0 +1 @@
from .client import BlockingKernelClient # noqa

View File

@@ -0,0 +1,71 @@
"""Implements a fully blocking kernel client.
Useful for test suites and blocking terminal interfaces.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import typing as t
from traitlets import Type
from ..channels import HBChannel, ZMQSocketChannel
from ..client import KernelClient, reqrep
from ..utils import run_sync
def wrapped(meth: t.Callable, channel: str) -> t.Callable:
"""Wrap a method on a channel and handle replies."""
def _(self: BlockingKernelClient, *args: t.Any, **kwargs: t.Any) -> t.Any:
reply = kwargs.pop("reply", False)
timeout = kwargs.pop("timeout", None)
msg_id = meth(self, *args, **kwargs)
if not reply:
return msg_id
return self._recv_reply(msg_id, timeout=timeout, channel=channel)
return _
class BlockingKernelClient(KernelClient):
"""A KernelClient with blocking APIs
``get_[channel]_msg()`` methods wait for and return messages on channels,
raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds.
"""
# --------------------------------------------------------------------------
# Channel proxy methods
# --------------------------------------------------------------------------
get_shell_msg = run_sync(KernelClient._async_get_shell_msg)
get_iopub_msg = run_sync(KernelClient._async_get_iopub_msg)
get_stdin_msg = run_sync(KernelClient._async_get_stdin_msg)
get_control_msg = run_sync(KernelClient._async_get_control_msg)
wait_for_ready = run_sync(KernelClient._async_wait_for_ready)
# The classes to use for the various channels
shell_channel_class = Type(ZMQSocketChannel) # type:ignore[arg-type]
iopub_channel_class = Type(ZMQSocketChannel) # type:ignore[arg-type]
stdin_channel_class = Type(ZMQSocketChannel) # type:ignore[arg-type]
hb_channel_class = Type(HBChannel) # type:ignore[arg-type]
control_channel_class = Type(ZMQSocketChannel) # type:ignore[arg-type]
_recv_reply = run_sync(KernelClient._async_recv_reply)
# replies come on the shell channel
execute = reqrep(wrapped, KernelClient.execute)
history = reqrep(wrapped, KernelClient.history)
complete = reqrep(wrapped, KernelClient.complete)
inspect = reqrep(wrapped, KernelClient.inspect)
kernel_info = reqrep(wrapped, KernelClient.kernel_info)
comm_info = reqrep(wrapped, KernelClient.comm_info)
is_alive = run_sync(KernelClient._async_is_alive)
execute_interactive = run_sync(KernelClient._async_execute_interactive)
# replies come on the control channel
shutdown = reqrep(wrapped, KernelClient.shutdown, channel="control")

View File

@@ -0,0 +1,330 @@
"""Base classes to manage a Client's interaction with a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import atexit
import time
import typing as t
from queue import Empty
from threading import Event, Thread
import zmq.asyncio
from jupyter_core.utils import ensure_async
from ._version import protocol_version_info
from .channelsabc import HBChannelABC
from .session import Session
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
# during garbage collection of threads at exit
# -----------------------------------------------------------------------------
# Constants and exceptions
# -----------------------------------------------------------------------------
major_protocol_version = protocol_version_info[0]
class InvalidPortNumber(Exception): # noqa
"""An exception raised for an invalid port number."""
pass
class HBChannel(Thread):
"""The heartbeat channel which monitors the kernel heartbeat.
Note that the heartbeat channel is paused by default. As long as you start
this channel, the kernel manager will ensure that it is paused and un-paused
as appropriate.
"""
session = None
socket = None
address = None
_exiting = False
time_to_dead: float = 1.0
_running = None
_pause = None
_beating = None
def __init__(
self,
context: t.Optional[zmq.Context] = None,
session: t.Optional[Session] = None,
address: t.Union[t.Tuple[str, int], str] = "",
) -> None:
"""Create the heartbeat monitor thread.
Parameters
----------
context : :class:`zmq.Context`
The ZMQ context to use.
session : :class:`session.Session`
The session to use.
address : zmq url
Standard (ip, port) tuple that the kernel is listening on.
"""
super().__init__()
self.daemon = True
self.context = context
self.session = session
if isinstance(address, tuple):
if address[1] == 0:
message = "The port number for a channel cannot be 0."
raise InvalidPortNumber(message)
address_str = "tcp://%s:%i" % address
else:
address_str = address
self.address = address_str
# running is False until `.start()` is called
self._running = False
self._exit = Event()
# don't start paused
self._pause = False
self.poller = zmq.Poller()
@staticmethod
@atexit.register
def _notice_exit() -> None:
# Class definitions can be torn down during interpreter shutdown.
# We only need to set _exiting flag if this hasn't happened.
if HBChannel is not None:
HBChannel._exiting = True
def _create_socket(self) -> None:
if self.socket is not None:
# close previous socket, before opening a new one
self.poller.unregister(self.socket) # type:ignore[unreachable]
self.socket.close()
assert self.context is not None
self.socket = self.context.socket(zmq.REQ)
self.socket.linger = 1000
assert self.address is not None
self.socket.connect(self.address)
self.poller.register(self.socket, zmq.POLLIN)
async def _async_run(self) -> None:
"""The thread's main activity. Call start() instead."""
self._create_socket()
self._running = True
self._beating = True
assert self.socket is not None
while self._running:
if self._pause:
# just sleep, and skip the rest of the loop
self._exit.wait(self.time_to_dead)
continue
since_last_heartbeat = 0.0
# no need to catch EFSM here, because the previous event was
# either a recv or connect, which cannot be followed by EFSM)
await ensure_async(self.socket.send(b"ping"))
request_time = time.time()
# Wait until timeout
self._exit.wait(self.time_to_dead)
# poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
self._beating = bool(self.poller.poll(0))
if self._beating:
# the poll above guarantees we have something to recv
await ensure_async(self.socket.recv())
continue
elif self._running:
# nothing was received within the time limit, signal heart failure
since_last_heartbeat = time.time() - request_time
self.call_handlers(since_last_heartbeat)
# and close/reopen the socket, because the REQ/REP cycle has been broken
self._create_socket()
continue
def run(self) -> None:
"""Run the heartbeat thread."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._async_run())
loop.close()
def pause(self) -> None:
"""Pause the heartbeat."""
self._pause = True
def unpause(self) -> None:
"""Unpause the heartbeat."""
self._pause = False
def is_beating(self) -> bool:
"""Is the heartbeat running and responsive (and not paused)."""
if self.is_alive() and not self._pause and self._beating: # noqa
return True
else:
return False
def stop(self) -> None:
"""Stop the channel's event loop and join its thread."""
self._running = False
self._exit.set()
self.join()
self.close()
def close(self) -> None:
"""Close the heartbeat thread."""
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
def call_handlers(self, since_last_heartbeat: float) -> None:
"""This method is called in the ioloop thread when a message arrives.
Subclasses should override this method to handle incoming messages.
It is important to remember that this method is called in the thread
so that some logic must be done to ensure that the application level
handlers are called in the application thread.
"""
pass
HBChannelABC.register(HBChannel)
class ZMQSocketChannel:
"""A ZMQ socket wrapper"""
def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None:
"""Create a channel.
Parameters
----------
socket : :class:`zmq.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
Unused here, for other implementations
"""
super().__init__()
self.socket: t.Optional[zmq.Socket] = socket
self.session = session
def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
assert self.socket is not None
msg = self.socket.recv_multipart(**kwargs)
ident, smsg = self.session.feed_identities(msg)
return self.session.deserialize(smsg)
def get_msg(self, timeout: t.Optional[float] = None) -> t.Dict[str, t.Any]:
"""Gets a message if there is one that is ready."""
assert self.socket is not None
if timeout is not None:
timeout *= 1000 # seconds to ms
ready = self.socket.poll(timeout)
if ready:
res = self._recv()
return res
else:
raise Empty
def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
"""Get all messages that are currently ready."""
msgs = []
while True:
try:
msgs.append(self.get_msg())
except Empty:
break
return msgs
def msg_ready(self) -> bool:
"""Is there a message that has been received?"""
assert self.socket is not None
return bool(self.socket.poll(timeout=0))
def close(self) -> None:
"""Close the socket channel."""
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
stop = close
def is_alive(self) -> bool:
"""Test whether the channel is alive."""
return self.socket is not None
def send(self, msg: t.Dict[str, t.Any]) -> None:
"""Pass a message to the ZMQ socket to send"""
assert self.socket is not None
self.session.send(self.socket, msg)
def start(self) -> None:
"""Start the socket channel."""
pass
class AsyncZMQSocketChannel(ZMQSocketChannel):
"""A ZMQ socket in an async API"""
socket: zmq.asyncio.Socket
def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None:
"""Create a channel.
Parameters
----------
socket : :class:`zmq.asyncio.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
Unused here, for other implementations
"""
if not isinstance(socket, zmq.asyncio.Socket):
msg = "Socket must be asyncio" # type:ignore[unreachable]
raise ValueError(msg)
super().__init__(socket, session)
async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override]
assert self.socket is not None
msg = await self.socket.recv_multipart(**kwargs)
_, smsg = self.session.feed_identities(msg)
return self.session.deserialize(smsg)
async def get_msg( # type:ignore[override]
self, timeout: t.Optional[float] = None
) -> t.Dict[str, t.Any]:
"""Gets a message if there is one that is ready."""
assert self.socket is not None
if timeout is not None:
timeout *= 1000 # seconds to ms
ready = await self.socket.poll(timeout)
if ready:
res = await self._recv()
return res
else:
raise Empty
async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override]
"""Get all messages that are currently ready."""
msgs = []
while True:
try:
msgs.append(await self.get_msg())
except Empty:
break
return msgs
async def msg_ready(self) -> bool: # type:ignore[override]
"""Is there a message that has been received?"""
assert self.socket is not None
return bool(await self.socket.poll(timeout=0))

View File

@@ -0,0 +1,51 @@
"""Abstract base classes for kernel client channels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import abc
class ChannelABC(metaclass=abc.ABCMeta):
"""A base class for all channel ABCs."""
@abc.abstractmethod
def start(self) -> None:
"""Start the channel."""
pass
@abc.abstractmethod
def stop(self) -> None:
"""Stop the channel."""
pass
@abc.abstractmethod
def is_alive(self) -> bool:
"""Test whether the channel is alive."""
pass
class HBChannelABC(ChannelABC):
"""HBChannel ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.channels.HBChannel`
"""
@abc.abstractproperty
def time_to_dead(self) -> float:
pass
@abc.abstractmethod
def pause(self) -> None:
"""Pause the heartbeat channel."""
pass
@abc.abstractmethod
def unpause(self) -> None:
"""Unpause the heartbeat channel."""
pass
@abc.abstractmethod
def is_beating(self) -> bool:
"""Test whether the channel is beating."""
pass

View File

@@ -0,0 +1,827 @@
"""Base class to manage the interaction with a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import inspect
import sys
import time
import typing as t
from functools import partial
from getpass import getpass
from queue import Empty
import zmq.asyncio
from jupyter_core.utils import ensure_async
from traitlets import Any, Bool, Instance, Type
from .channels import major_protocol_version
from .channelsabc import ChannelABC, HBChannelABC
from .clientabc import KernelClientABC
from .connect import ConnectionFileMixin
from .session import Session
# some utilities to validate message structure, these might get moved elsewhere
# if they prove to have more generic utility
def validate_string_dict(dct: t.Dict[str, str]) -> None:
"""Validate that the input is a dict with string keys and values.
Raises ValueError if not."""
for k, v in dct.items():
if not isinstance(k, str):
raise ValueError("key %r in dict must be a string" % k)
if not isinstance(v, str):
raise ValueError("value %r in dict must be a string" % v)
def reqrep(wrapped: t.Callable, meth: t.Callable, channel: str = "shell") -> t.Callable:
wrapped = wrapped(meth, channel)
if not meth.__doc__:
# python -OO removes docstrings,
# so don't bother building the wrapped docstring
return wrapped
basedoc, _ = meth.__doc__.split("Returns\n", 1)
parts = [basedoc.strip()]
if "Parameters" not in basedoc:
parts.append(
"""
Parameters
----------
"""
)
parts.append(
"""
reply: bool (default: False)
Whether to wait for and return reply
timeout: float or None (default: None)
Timeout to use when waiting for a reply
Returns
-------
msg_id: str
The msg_id of the request sent, if reply=False (default)
reply: dict
The reply message for this request, if reply=True
"""
)
wrapped.__doc__ = "\n".join(parts)
return wrapped
class KernelClient(ConnectionFileMixin):
"""Communicates with a single kernel on any host via zmq channels.
There are five channels associated with each kernel:
* shell: for request/reply calls to the kernel.
* iopub: for the kernel to publish results to frontends.
* hb: for monitoring the kernel's heartbeat.
* stdin: for frontends to reply to raw_input calls in the kernel.
* control: for kernel management calls to the kernel.
The messages that can be sent on these channels are exposed as methods of the
client (KernelClient.execute, complete, history, etc.). These methods only
send the message, they don't wait for a reply. To get results, use e.g.
:meth:`get_shell_msg` to fetch messages from the shell channel.
"""
# The PyZMQ Context to use for communication with the kernel.
context = Instance(zmq.Context)
_created_context = Bool(False)
def _context_default(self) -> zmq.Context:
self._created_context = True
return zmq.Context()
# The classes to use for the various channels
shell_channel_class = Type(ChannelABC)
iopub_channel_class = Type(ChannelABC)
stdin_channel_class = Type(ChannelABC)
hb_channel_class = Type(HBChannelABC)
control_channel_class = Type(ChannelABC)
# Protected traits
_shell_channel = Any()
_iopub_channel = Any()
_stdin_channel = Any()
_hb_channel = Any()
_control_channel = Any()
# flag for whether execute requests should be allowed to call raw_input:
allow_stdin: bool = True
def __del__(self) -> None:
"""Handle garbage collection. Destroy context if applicable."""
if (
self._created_context
and self.context is not None # type:ignore[redundant-expr]
and not self.context.closed
):
if self.channels_running:
if self.log:
self.log.warning("Could not destroy zmq context for %s", self)
else:
if self.log:
self.log.debug("Destroying zmq context for %s", self)
self.context.destroy()
try:
super_del = super().__del__ # type:ignore[misc]
except AttributeError:
pass
else:
super_del()
# --------------------------------------------------------------------------
# Channel proxy methods
# --------------------------------------------------------------------------
async def _async_get_shell_msg(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
"""Get a message from the shell channel"""
return await ensure_async(self.shell_channel.get_msg(*args, **kwargs))
async def _async_get_iopub_msg(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
"""Get a message from the iopub channel"""
return await ensure_async(self.iopub_channel.get_msg(*args, **kwargs))
async def _async_get_stdin_msg(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
"""Get a message from the stdin channel"""
return await ensure_async(self.stdin_channel.get_msg(*args, **kwargs))
async def _async_get_control_msg(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
"""Get a message from the control channel"""
return await ensure_async(self.control_channel.get_msg(*args, **kwargs))
async def _async_wait_for_ready(self, timeout: t.Optional[float] = None) -> None:
"""Waits for a response when a client is blocked
- Sets future time for timeout
- Blocks on shell channel until a message is received
- Exit if the kernel has died
- If client times out before receiving a message from the kernel, send RuntimeError
- Flush the IOPub channel
"""
if timeout is None:
timeout = float("inf")
abs_timeout = time.time() + timeout
from .manager import KernelManager
if not isinstance(self.parent, KernelManager):
# This Client was not created by a KernelManager,
# so wait for kernel to become responsive to heartbeats
# before checking for kernel_info reply
while not await self._async_is_alive():
if time.time() > abs_timeout:
raise RuntimeError(
"Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout
)
await asyncio.sleep(0.2)
# Wait for kernel info reply on shell channel
while True:
self.kernel_info()
try:
msg = await ensure_async(self.shell_channel.get_msg(timeout=1))
except Empty:
pass
else:
if msg["msg_type"] == "kernel_info_reply":
# Checking that IOPub is connected. If it is not connected, start over.
try:
await ensure_async(self.iopub_channel.get_msg(timeout=0.2))
except Empty:
pass
else:
self._handle_kernel_info_reply(msg)
break
if not await self._async_is_alive():
msg = "Kernel died before replying to kernel_info"
raise RuntimeError(msg)
# Check if current time is ready check time plus timeout
if time.time() > abs_timeout:
raise RuntimeError("Kernel didn't respond in %d seconds" % timeout)
# Flush IOPub channel
while True:
try:
msg = await ensure_async(self.iopub_channel.get_msg(timeout=0.2))
except Empty:
break
async def _async_recv_reply(
self, msg_id: str, timeout: t.Optional[float] = None, channel: str = "shell"
) -> t.Dict[str, t.Any]:
"""Receive and return the reply for a given request"""
if timeout is not None:
deadline = time.monotonic() + timeout
while True:
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
try:
if channel == "control":
reply = await self._async_get_control_msg(timeout=timeout)
else:
reply = await self._async_get_shell_msg(timeout=timeout)
except Empty as e:
msg = "Timeout waiting for reply"
raise TimeoutError(msg) from e
if reply["parent_header"].get("msg_id") != msg_id:
# not my reply, someone may have forgotten to retrieve theirs
continue
return reply
async def _stdin_hook_default(self, msg: t.Dict[str, t.Any]) -> None:
"""Handle an input request"""
content = msg["content"]
prompt = getpass if content.get("password", False) else input
try:
raw_data = prompt(content["prompt"]) # type:ignore[operator]
except EOFError:
# turn EOFError into EOF character
raw_data = "\x04"
except KeyboardInterrupt:
sys.stdout.write("\n")
return
# only send stdin reply if there *was not* another request
# or execution finished while we were reading.
if not (await self.stdin_channel.msg_ready() or await self.shell_channel.msg_ready()):
self.input(raw_data)
def _output_hook_default(self, msg: t.Dict[str, t.Any]) -> None:
"""Default hook for redisplaying plain-text output"""
msg_type = msg["header"]["msg_type"]
content = msg["content"]
if msg_type == "stream":
stream = getattr(sys, content["name"])
stream.write(content["text"])
elif msg_type in ("display_data", "execute_result"):
sys.stdout.write(content["data"].get("text/plain", ""))
elif msg_type == "error":
sys.stderr.write("\n".join(content["traceback"]))
def _output_hook_kernel(
self,
session: Session,
socket: zmq.sugar.socket.Socket,
parent_header: t.Any,
msg: t.Dict[str, t.Any],
) -> None:
"""Output hook when running inside an IPython kernel
adds rich output support.
"""
msg_type = msg["header"]["msg_type"]
if msg_type in ("display_data", "execute_result", "error"):
session.send(socket, msg_type, msg["content"], parent=parent_header)
else:
self._output_hook_default(msg)
# --------------------------------------------------------------------------
# Channel management methods
# --------------------------------------------------------------------------
def start_channels(
self,
shell: bool = True,
iopub: bool = True,
stdin: bool = True,
hb: bool = True,
control: bool = True,
) -> None:
"""Starts the channels for this kernel.
This will create the channels if they do not exist and then start
them (their activity runs in a thread). If port numbers of 0 are
being used (random ports) then you must first call
:meth:`start_kernel`. If the channels have been stopped and you
call this, :class:`RuntimeError` will be raised.
"""
if iopub:
self.iopub_channel.start()
if shell:
self.shell_channel.start()
if stdin:
self.stdin_channel.start()
self.allow_stdin = True
else:
self.allow_stdin = False
if hb:
self.hb_channel.start()
if control:
self.control_channel.start()
def stop_channels(self) -> None:
"""Stops all the running channels for this kernel.
This stops their event loops and joins their threads.
"""
if self.shell_channel.is_alive():
self.shell_channel.stop()
if self.iopub_channel.is_alive():
self.iopub_channel.stop()
if self.stdin_channel.is_alive():
self.stdin_channel.stop()
if self.hb_channel.is_alive():
self.hb_channel.stop()
if self.control_channel.is_alive():
self.control_channel.stop()
@property
def channels_running(self) -> bool:
"""Are any of the channels created and running?"""
return (
(self._shell_channel and self.shell_channel.is_alive())
or (self._iopub_channel and self.iopub_channel.is_alive())
or (self._stdin_channel and self.stdin_channel.is_alive())
or (self._hb_channel and self.hb_channel.is_alive())
or (self._control_channel and self.control_channel.is_alive())
)
ioloop = None # Overridden in subclasses that use pyzmq event loop
@property
def shell_channel(self) -> t.Any:
"""Get the shell channel object for this kernel."""
if self._shell_channel is None:
url = self._make_url("shell")
self.log.debug("connecting shell channel to %s", url)
socket = self.connect_shell(identity=self.session.bsession)
self._shell_channel = self.shell_channel_class( # type:ignore[call-arg,abstract]
socket, self.session, self.ioloop
)
return self._shell_channel
@property
def iopub_channel(self) -> t.Any:
"""Get the iopub channel object for this kernel."""
if self._iopub_channel is None:
url = self._make_url("iopub")
self.log.debug("connecting iopub channel to %s", url)
socket = self.connect_iopub()
self._iopub_channel = self.iopub_channel_class( # type:ignore[call-arg,abstract]
socket, self.session, self.ioloop
)
return self._iopub_channel
@property
def stdin_channel(self) -> t.Any:
"""Get the stdin channel object for this kernel."""
if self._stdin_channel is None:
url = self._make_url("stdin")
self.log.debug("connecting stdin channel to %s", url)
socket = self.connect_stdin(identity=self.session.bsession)
self._stdin_channel = self.stdin_channel_class( # type:ignore[call-arg,abstract]
socket, self.session, self.ioloop
)
return self._stdin_channel
@property
def hb_channel(self) -> t.Any:
"""Get the hb channel object for this kernel."""
if self._hb_channel is None:
url = self._make_url("hb")
self.log.debug("connecting heartbeat channel to %s", url)
self._hb_channel = self.hb_channel_class( # type:ignore[call-arg,abstract]
self.context, self.session, url
)
return self._hb_channel
@property
def control_channel(self) -> t.Any:
"""Get the control channel object for this kernel."""
if self._control_channel is None:
url = self._make_url("control")
self.log.debug("connecting control channel to %s", url)
socket = self.connect_control(identity=self.session.bsession)
self._control_channel = self.control_channel_class( # type:ignore[call-arg,abstract]
socket, self.session, self.ioloop
)
return self._control_channel
async def _async_is_alive(self) -> bool:
"""Is the kernel process still running?"""
from .manager import KernelManager
if isinstance(self.parent, KernelManager):
# This KernelClient was created by a KernelManager,
# we can ask the parent KernelManager:
return await self.parent._async_is_alive()
if self._hb_channel is not None:
# We don't have access to the KernelManager,
# so we use the heartbeat.
return self._hb_channel.is_beating()
# no heartbeat and not local, we can't tell if it's running,
# so naively return True
return True
async def _async_execute_interactive(
self,
code: str,
silent: bool = False,
store_history: bool = True,
user_expressions: t.Optional[t.Dict[str, t.Any]] = None,
allow_stdin: t.Optional[bool] = None,
stop_on_error: bool = True,
timeout: t.Optional[float] = None,
output_hook: t.Optional[t.Callable] = None,
stdin_hook: t.Optional[t.Callable] = None,
) -> t.Dict[str, t.Any]:
"""Execute code in the kernel interactively
Output will be redisplayed, and stdin prompts will be relayed as well.
If an IPython kernel is detected, rich output will be displayed.
You can pass a custom output_hook callable that will be called
with every IOPub message that is produced instead of the default redisplay.
.. versionadded:: 5.0
Parameters
----------
code : str
A string of code in the kernel's language.
silent : bool, optional (default False)
If set, the kernel will execute the code as quietly possible, and
will force store_history to be False.
store_history : bool, optional (default True)
If set, the kernel will store command history. This is forced
to be False if silent is True.
user_expressions : dict, optional
A dict mapping names to expressions to be evaluated in the user's
dict. The expression values are returned as strings formatted using
:func:`repr`.
allow_stdin : bool, optional (default self.allow_stdin)
Flag for whether the kernel can send stdin requests to frontends.
Some frontends (e.g. the Notebook) do not support stdin requests.
If raw_input is called from code executed from such a frontend, a
StdinNotImplementedError will be raised.
stop_on_error: bool, optional (default True)
Flag whether to abort the execution queue, if an exception is encountered.
timeout: float or None (default: None)
Timeout to use when waiting for a reply
output_hook: callable(msg)
Function to be called with output messages.
If not specified, output will be redisplayed.
stdin_hook: callable(msg)
Function or awaitable to be called with stdin_request messages.
If not specified, input/getpass will be called.
Returns
-------
reply: dict
The reply message for this request
"""
if not self.iopub_channel.is_alive():
emsg = "IOPub channel must be running to receive output"
raise RuntimeError(emsg)
if allow_stdin is None:
allow_stdin = self.allow_stdin
if allow_stdin and not self.stdin_channel.is_alive():
emsg = "stdin channel must be running to allow input"
raise RuntimeError(emsg)
msg_id = await ensure_async(
self.execute(
code,
silent=silent,
store_history=store_history,
user_expressions=user_expressions,
allow_stdin=allow_stdin,
stop_on_error=stop_on_error,
)
)
if stdin_hook is None:
stdin_hook = self._stdin_hook_default
# detect IPython kernel
if output_hook is None and "IPython" in sys.modules:
from IPython import get_ipython
ip = get_ipython() # type:ignore[no-untyped-call]
in_kernel = getattr(ip, "kernel", False)
if in_kernel:
output_hook = partial(
self._output_hook_kernel,
ip.display_pub.session,
ip.display_pub.pub_socket,
ip.display_pub.parent_header,
)
if output_hook is None:
# default: redisplay plain-text outputs
output_hook = self._output_hook_default
# set deadline based on timeout
if timeout is not None:
deadline = time.monotonic() + timeout
else:
timeout_ms = None
poller = zmq.Poller()
iopub_socket = self.iopub_channel.socket
poller.register(iopub_socket, zmq.POLLIN)
if allow_stdin:
stdin_socket = self.stdin_channel.socket
poller.register(stdin_socket, zmq.POLLIN)
else:
stdin_socket = None
# wait for output and redisplay it
while True:
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
timeout_ms = int(1000 * timeout)
events = dict(poller.poll(timeout_ms))
if not events:
emsg = "Timeout waiting for output"
raise TimeoutError(emsg)
if stdin_socket in events:
req = await ensure_async(self.stdin_channel.get_msg(timeout=0))
res = stdin_hook(req)
if inspect.isawaitable(res):
await res
continue
if iopub_socket not in events:
continue
msg = await ensure_async(self.iopub_channel.get_msg(timeout=0))
if msg["parent_header"].get("msg_id") != msg_id:
# not from my request
continue
output_hook(msg)
# stop on idle
if (
msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle"
):
break
# output is done, get the reply
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
return await self._async_recv_reply(msg_id, timeout=timeout)
# Methods to send specific messages on channels
def execute(
self,
code: str,
silent: bool = False,
store_history: bool = True,
user_expressions: t.Optional[t.Dict[str, t.Any]] = None,
allow_stdin: t.Optional[bool] = None,
stop_on_error: bool = True,
) -> str:
"""Execute code in the kernel.
Parameters
----------
code : str
A string of code in the kernel's language.
silent : bool, optional (default False)
If set, the kernel will execute the code as quietly possible, and
will force store_history to be False.
store_history : bool, optional (default True)
If set, the kernel will store command history. This is forced
to be False if silent is True.
user_expressions : dict, optional
A dict mapping names to expressions to be evaluated in the user's
dict. The expression values are returned as strings formatted using
:func:`repr`.
allow_stdin : bool, optional (default self.allow_stdin)
Flag for whether the kernel can send stdin requests to frontends.
Some frontends (e.g. the Notebook) do not support stdin requests.
If raw_input is called from code executed from such a frontend, a
StdinNotImplementedError will be raised.
stop_on_error: bool, optional (default True)
Flag whether to abort the execution queue, if an exception is encountered.
Returns
-------
The msg_id of the message sent.
"""
if user_expressions is None:
user_expressions = {}
if allow_stdin is None:
allow_stdin = self.allow_stdin
# Don't waste network traffic if inputs are invalid
if not isinstance(code, str):
raise ValueError("code %r must be a string" % code)
validate_string_dict(user_expressions)
# Create class for content/msg creation. Related to, but possibly
# not in Session.
content = {
"code": code,
"silent": silent,
"store_history": store_history,
"user_expressions": user_expressions,
"allow_stdin": allow_stdin,
"stop_on_error": stop_on_error,
}
msg = self.session.msg("execute_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def complete(self, code: str, cursor_pos: t.Optional[int] = None) -> str:
"""Tab complete text in the kernel's namespace.
Parameters
----------
code : str
The context in which completion is requested.
Can be anything between a variable name and an entire cell.
cursor_pos : int, optional
The position of the cursor in the block of code where the completion was requested.
Default: ``len(code)``
Returns
-------
The msg_id of the message sent.
"""
if cursor_pos is None:
cursor_pos = len(code)
content = {"code": code, "cursor_pos": cursor_pos}
msg = self.session.msg("complete_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def inspect(self, code: str, cursor_pos: t.Optional[int] = None, detail_level: int = 0) -> str:
"""Get metadata information about an object in the kernel's namespace.
It is up to the kernel to determine the appropriate object to inspect.
Parameters
----------
code : str
The context in which info is requested.
Can be anything between a variable name and an entire cell.
cursor_pos : int, optional
The position of the cursor in the block of code where the info was requested.
Default: ``len(code)``
detail_level : int, optional
The level of detail for the introspection (0-2)
Returns
-------
The msg_id of the message sent.
"""
if cursor_pos is None:
cursor_pos = len(code)
content = {
"code": code,
"cursor_pos": cursor_pos,
"detail_level": detail_level,
}
msg = self.session.msg("inspect_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def history(
self,
raw: bool = True,
output: bool = False,
hist_access_type: str = "range",
**kwargs: t.Any,
) -> str:
"""Get entries from the kernel's history list.
Parameters
----------
raw : bool
If True, return the raw input.
output : bool
If True, then return the output as well.
hist_access_type : str
'range' (fill in session, start and stop params), 'tail' (fill in n)
or 'search' (fill in pattern param).
session : int
For a range request, the session from which to get lines. Session
numbers are positive integers; negative ones count back from the
current session.
start : int
The first line number of a history range.
stop : int
The final (excluded) line number of a history range.
n : int
The number of lines of history to get for a tail request.
pattern : str
The glob-syntax pattern for a search request.
Returns
-------
The ID of the message sent.
"""
if hist_access_type == "range":
kwargs.setdefault("session", 0)
kwargs.setdefault("start", 0)
content = dict(raw=raw, output=output, hist_access_type=hist_access_type, **kwargs)
msg = self.session.msg("history_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def kernel_info(self) -> str:
"""Request kernel info
Returns
-------
The msg_id of the message sent
"""
msg = self.session.msg("kernel_info_request")
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def comm_info(self, target_name: t.Optional[str] = None) -> str:
"""Request comm info
Returns
-------
The msg_id of the message sent
"""
content = {} if target_name is None else {"target_name": target_name}
msg = self.session.msg("comm_info_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def _handle_kernel_info_reply(self, msg: t.Dict[str, t.Any]) -> None:
"""handle kernel info reply
sets protocol adaptation version. This might
be run from a separate thread.
"""
adapt_version = int(msg["content"]["protocol_version"].split(".")[0])
if adapt_version != major_protocol_version:
self.session.adapt_version = adapt_version
def is_complete(self, code: str) -> str:
"""Ask the kernel whether some code is complete and ready to execute.
Returns
-------
The ID of the message sent.
"""
msg = self.session.msg("is_complete_request", {"code": code})
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def input(self, string: str) -> None:
"""Send a string of raw input to the kernel.
This should only be called in response to the kernel sending an
``input_request`` message on the stdin channel.
Returns
-------
The ID of the message sent.
"""
content = {"value": string}
msg = self.session.msg("input_reply", content)
self.stdin_channel.send(msg)
def shutdown(self, restart: bool = False) -> str:
"""Request an immediate kernel shutdown on the control channel.
Upon receipt of the (empty) reply, client code can safely assume that
the kernel has shut down and it's safe to forcefully terminate it if
it's still alive.
The kernel will send the reply via a function registered with Python's
atexit module, ensuring it's truly done as the kernel is done with all
normal operation.
Returns
-------
The msg_id of the message sent
"""
# Send quit message to kernel. Once we implement kernel-side setattr,
# this should probably be done that way, but for now this will do.
msg = self.session.msg("shutdown_request", {"restart": restart})
self.control_channel.send(msg)
return msg["header"]["msg_id"]
KernelClientABC.register(KernelClient)

View File

@@ -0,0 +1,100 @@
"""Abstract base class for kernel clients"""
# -----------------------------------------------------------------------------
# Copyright (c) The Jupyter Development Team
#
# Distributed under the terms of the BSD License. The full license is in
# the file COPYING, distributed as part of this software.
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import abc
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from .channelsabc import ChannelABC
# -----------------------------------------------------------------------------
# Main kernel client class
# -----------------------------------------------------------------------------
class KernelClientABC(metaclass=abc.ABCMeta):
"""KernelManager ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.client.KernelClient`
"""
@abc.abstractproperty
def kernel(self) -> Any:
pass
@abc.abstractproperty
def shell_channel_class(self) -> type[ChannelABC]:
pass
@abc.abstractproperty
def iopub_channel_class(self) -> type[ChannelABC]:
pass
@abc.abstractproperty
def hb_channel_class(self) -> type[ChannelABC]:
pass
@abc.abstractproperty
def stdin_channel_class(self) -> type[ChannelABC]:
pass
@abc.abstractproperty
def control_channel_class(self) -> type[ChannelABC]:
pass
# --------------------------------------------------------------------------
# Channel management methods
# --------------------------------------------------------------------------
@abc.abstractmethod
def start_channels(
self,
shell: bool = True,
iopub: bool = True,
stdin: bool = True,
hb: bool = True,
control: bool = True,
) -> None:
"""Start the channels for the client."""
pass
@abc.abstractmethod
def stop_channels(self) -> None:
"""Stop the channels for the client."""
pass
@abc.abstractproperty
def channels_running(self) -> bool:
"""Get whether the channels are running."""
pass
@abc.abstractproperty
def shell_channel(self) -> ChannelABC:
pass
@abc.abstractproperty
def iopub_channel(self) -> ChannelABC:
pass
@abc.abstractproperty
def stdin_channel(self) -> ChannelABC:
pass
@abc.abstractproperty
def hb_channel(self) -> ChannelABC:
pass
@abc.abstractproperty
def control_channel(self) -> ChannelABC:
pass

View File

@@ -0,0 +1,725 @@
"""Utilities for connecting to jupyter kernels
The :class:`ConnectionFileMixin` class in this module encapsulates the logic
related to writing and reading connections files.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import errno
import glob
import json
import os
import socket
import stat
import tempfile
import warnings
from getpass import getpass
from typing import TYPE_CHECKING, Any, Dict, Union, cast
import zmq
from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
from traitlets.config import LoggingConfigurable, SingletonConfigurable
from .localinterfaces import localhost
from .utils import _filefind
if TYPE_CHECKING:
from jupyter_client import BlockingKernelClient
from .session import Session
# Define custom type for kernel connection info
KernelConnectionInfo = Dict[str, Union[int, str, bytes]]
def write_connection_file(
fname: str | None = None,
shell_port: int = 0,
iopub_port: int = 0,
stdin_port: int = 0,
hb_port: int = 0,
control_port: int = 0,
ip: str = "",
key: bytes = b"",
transport: str = "tcp",
signature_scheme: str = "hmac-sha256",
kernel_name: str = "",
**kwargs: Any,
) -> tuple[str, KernelConnectionInfo]:
"""Generates a JSON config file, including the selection of random ports.
Parameters
----------
fname : unicode
The path to the file to write
shell_port : int, optional
The port to use for ROUTER (shell) channel.
iopub_port : int, optional
The port to use for the SUB channel.
stdin_port : int, optional
The port to use for the ROUTER (raw input) channel.
control_port : int, optional
The port to use for the ROUTER (control) channel.
hb_port : int, optional
The port to use for the heartbeat REP channel.
ip : str, optional
The ip address the kernel will bind to.
key : str, optional
The Session key used for message authentication.
signature_scheme : str, optional
The scheme used for message authentication.
This has the form 'digest-hash', where 'digest'
is the scheme used for digests, and 'hash' is the name of the hash function
used by the digest scheme.
Currently, 'hmac' is the only supported digest scheme,
and 'sha256' is the default hash function.
kernel_name : str, optional
The name of the kernel currently connected to.
"""
if not ip:
ip = localhost()
# default to temporary connector file
if not fname:
fd, fname = tempfile.mkstemp(".json")
os.close(fd)
# Find open ports as necessary.
ports: list[int] = []
sockets: list[socket.socket] = []
ports_needed = (
int(shell_port <= 0)
+ int(iopub_port <= 0)
+ int(stdin_port <= 0)
+ int(control_port <= 0)
+ int(hb_port <= 0)
)
if transport == "tcp":
for _ in range(ports_needed):
sock = socket.socket()
# struct.pack('ii', (0,0)) is 8 null bytes
sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
sock.bind((ip, 0))
sockets.append(sock)
for sock in sockets:
port = sock.getsockname()[1]
sock.close()
ports.append(port)
else:
N = 1
for _ in range(ports_needed):
while os.path.exists(f"{ip}-{N!s}"):
N += 1
ports.append(N)
N += 1
if shell_port <= 0:
shell_port = ports.pop(0)
if iopub_port <= 0:
iopub_port = ports.pop(0)
if stdin_port <= 0:
stdin_port = ports.pop(0)
if control_port <= 0:
control_port = ports.pop(0)
if hb_port <= 0:
hb_port = ports.pop(0)
cfg: KernelConnectionInfo = {
"shell_port": shell_port,
"iopub_port": iopub_port,
"stdin_port": stdin_port,
"control_port": control_port,
"hb_port": hb_port,
}
cfg["ip"] = ip
cfg["key"] = key.decode()
cfg["transport"] = transport
cfg["signature_scheme"] = signature_scheme
cfg["kernel_name"] = kernel_name
cfg.update(kwargs)
# Only ever write this file as user read/writeable
# This would otherwise introduce a vulnerability as a file has secrets
# which would let others execute arbitrary code as you
with secure_write(fname) as f:
f.write(json.dumps(cfg, indent=2))
if hasattr(stat, "S_ISVTX"):
# set the sticky bit on the parent directory of the file
# to ensure only owner can remove it
runtime_dir = os.path.dirname(fname)
if runtime_dir:
permissions = os.stat(runtime_dir).st_mode
new_permissions = permissions | stat.S_ISVTX
if new_permissions != permissions:
try:
os.chmod(runtime_dir, new_permissions)
except OSError as e:
if e.errno == errno.EPERM:
# suppress permission errors setting sticky bit on runtime_dir,
# which we may not own.
pass
return fname, cfg
def find_connection_file(
filename: str = "kernel-*.json",
path: str | list[str] | None = None,
profile: str | None = None,
) -> str:
"""find a connection file, and return its absolute path.
The current working directory and optional search path
will be searched for the file if it is not given by absolute path.
If the argument does not match an existing file, it will be interpreted as a
fileglob, and the matching file in the profile's security dir with
the latest access time will be used.
Parameters
----------
filename : str
The connection file or fileglob to search for.
path : str or list of strs[optional]
Paths in which to search for connection files.
Returns
-------
str : The absolute path of the connection file.
"""
if profile is not None:
warnings.warn(
"Jupyter has no profiles. profile=%s has been ignored." % profile, stacklevel=2
)
if path is None:
path = [".", jupyter_runtime_dir()]
if isinstance(path, str):
path = [path]
try:
# first, try explicit name
return _filefind(filename, path)
except OSError:
pass
# not found by full name
if "*" in filename:
# given as a glob already
pat = filename
else:
# accept any substring match
pat = "*%s*" % filename
matches = []
for p in path:
matches.extend(glob.glob(os.path.join(p, pat)))
matches = [os.path.abspath(m) for m in matches]
if not matches:
msg = f"Could not find {filename!r} in {path!r}"
raise OSError(msg)
elif len(matches) == 1:
return matches[0]
else:
# get most recent match, by access time:
return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
def tunnel_to_kernel(
connection_info: str | KernelConnectionInfo,
sshserver: str,
sshkey: str | None = None,
) -> tuple[Any, ...]:
"""tunnel connections to a kernel via ssh
This will open five SSH tunnels from localhost on this machine to the
ports associated with the kernel. They can be either direct
localhost-localhost tunnels, or if an intermediate server is necessary,
the kernel must be listening on a public IP.
Parameters
----------
connection_info : dict or str (path)
Either a connection dict, or the path to a JSON connection file
sshserver : str
The ssh sever to use to tunnel to the kernel. Can be a full
`user@server:port` string. ssh config aliases are respected.
sshkey : str [optional]
Path to file containing ssh key to use for authentication.
Only necessary if your ssh config does not already associate
a keyfile with the host.
Returns
-------
(shell, iopub, stdin, hb, control) : ints
The five ports on localhost that have been forwarded to the kernel.
"""
from .ssh import tunnel
if isinstance(connection_info, str):
# it's a path, unpack it
with open(connection_info) as f:
connection_info = json.loads(f.read())
cf = cast(Dict[str, Any], connection_info)
lports = tunnel.select_random_ports(5)
rports = (
cf["shell_port"],
cf["iopub_port"],
cf["stdin_port"],
cf["hb_port"],
cf["control_port"],
)
remote_ip = cf["ip"]
if tunnel.try_passwordless_ssh(sshserver, sshkey):
password: bool | str = False
else:
password = getpass("SSH Password for %s: " % sshserver)
for lp, rp in zip(lports, rports):
tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
return tuple(lports)
# -----------------------------------------------------------------------------
# Mixin for classes that work with connection files
# -----------------------------------------------------------------------------
channel_socket_types = {
"hb": zmq.REQ,
"shell": zmq.DEALER,
"iopub": zmq.SUB,
"stdin": zmq.DEALER,
"control": zmq.DEALER,
}
port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")]
class ConnectionFileMixin(LoggingConfigurable):
"""Mixin for configurable classes that work with connection files"""
data_dir: str | Unicode = Unicode()
def _data_dir_default(self) -> str:
return jupyter_data_dir()
# The addresses for the communication channels
connection_file = Unicode(
"",
config=True,
help="""JSON file in which to store connection info [default: kernel-<pid>.json]
This file will contain the IP, ports, and authentication key needed to connect
clients to this kernel. By default, this file will be created in the security dir
of the current profile, but can be specified by absolute path.
""",
)
_connection_file_written = Bool(False)
transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
kernel_name: str | Unicode = Unicode()
context = Instance(zmq.Context)
ip = Unicode(
config=True,
help="""Set the kernel\'s IP address [default localhost].
If the IP address is something other than localhost, then
Consoles on other machines will be able to connect
to the Kernel, so be careful!""",
)
def _ip_default(self) -> str:
if self.transport == "ipc":
if self.connection_file:
return os.path.splitext(self.connection_file)[0] + "-ipc"
else:
return "kernel-ipc"
else:
return localhost()
@observe("ip")
def _ip_changed(self, change: Any) -> None:
if change["new"] == "*":
self.ip = "0.0.0.0" # noqa
# protected traits
hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
# names of the ports with random assignment
_random_port_names: list[str] | None = None
@property
def ports(self) -> list[int]:
return [getattr(self, name) for name in port_names]
# The Session to use for communication with the kernel.
session = Instance("jupyter_client.session.Session")
def _session_default(self) -> Session:
from .session import Session
return Session(parent=self)
# --------------------------------------------------------------------------
# Connection and ipc file management
# --------------------------------------------------------------------------
def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
"""Return the connection info as a dict
Parameters
----------
session : bool [default: False]
If True, return our session object will be included in the connection info.
If False (default), the configuration parameters of our session object will be included,
rather than the session object itself.
Returns
-------
connect_info : dict
dictionary of connection information.
"""
info = {
"transport": self.transport,
"ip": self.ip,
"shell_port": self.shell_port,
"iopub_port": self.iopub_port,
"stdin_port": self.stdin_port,
"hb_port": self.hb_port,
"control_port": self.control_port,
}
if session:
# add *clone* of my session,
# so that state such as digest_history is not shared.
info["session"] = self.session.clone()
else:
# add session info
info.update(
{
"signature_scheme": self.session.signature_scheme,
"key": self.session.key,
}
)
return info
# factory for blocking clients
blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient")
def blocking_client(self) -> BlockingKernelClient:
"""Make a blocking client connected to my kernel"""
info = self.get_connection_info()
bc = self.blocking_class(parent=self) # type:ignore[operator]
bc.load_connection_info(info)
return bc
def cleanup_connection_file(self) -> None:
"""Cleanup connection file *if we wrote it*
Will not raise if the connection file was already removed somehow.
"""
if self._connection_file_written:
# cleanup connection files on full shutdown of kernel we started
self._connection_file_written = False
try:
os.remove(self.connection_file)
except (OSError, AttributeError):
pass
def cleanup_ipc_files(self) -> None:
"""Cleanup ipc files if we wrote them."""
if self.transport != "ipc":
return
for port in self.ports:
ipcfile = "%s-%i" % (self.ip, port)
try:
os.remove(ipcfile)
except OSError:
pass
def _record_random_port_names(self) -> None:
"""Records which of the ports are randomly assigned.
Records on first invocation, if the transport is tcp.
Does nothing on later invocations."""
if self.transport != "tcp":
return
if self._random_port_names is not None:
return
self._random_port_names = []
for name in port_names:
if getattr(self, name) <= 0:
self._random_port_names.append(name)
def cleanup_random_ports(self) -> None:
"""Forgets randomly assigned port numbers and cleans up the connection file.
Does nothing if no port numbers have been randomly assigned.
In particular, does nothing unless the transport is tcp.
"""
if not self._random_port_names:
return
for name in self._random_port_names:
setattr(self, name, 0)
self.cleanup_connection_file()
def write_connection_file(self, **kwargs: Any) -> None:
"""Write connection info to JSON dict in self.connection_file."""
if self._connection_file_written and os.path.exists(self.connection_file):
return
self.connection_file, cfg = write_connection_file(
self.connection_file,
transport=self.transport,
ip=self.ip,
key=self.session.key,
stdin_port=self.stdin_port,
iopub_port=self.iopub_port,
shell_port=self.shell_port,
hb_port=self.hb_port,
control_port=self.control_port,
signature_scheme=self.session.signature_scheme,
kernel_name=self.kernel_name,
**kwargs,
)
# write_connection_file also sets default ports:
self._record_random_port_names()
for name in port_names:
setattr(self, name, cfg[name])
self._connection_file_written = True
def load_connection_file(self, connection_file: str | None = None) -> None:
"""Load connection info from JSON dict in self.connection_file.
Parameters
----------
connection_file: unicode, optional
Path to connection file to load.
If unspecified, use self.connection_file
"""
if connection_file is None:
connection_file = self.connection_file
self.log.debug("Loading connection file %s", connection_file)
with open(connection_file) as f:
info = json.load(f)
self.load_connection_info(info)
def load_connection_info(self, info: KernelConnectionInfo) -> None:
"""Load connection info from a dict containing connection info.
Typically this data comes from a connection file
and is called by load_connection_file.
Parameters
----------
info: dict
Dictionary containing connection_info.
See the connection_file spec for details.
"""
self.transport = info.get("transport", self.transport)
self.ip = info.get("ip", self._ip_default()) # type:ignore[assignment]
self._record_random_port_names()
for name in port_names:
if getattr(self, name) == 0 and name in info:
# not overridden by config or cl_args
setattr(self, name, info[name])
if "key" in info:
key = info["key"]
if isinstance(key, str):
key = key.encode()
assert isinstance(key, bytes)
self.session.key = key
if "signature_scheme" in info:
self.session.signature_scheme = info["signature_scheme"]
def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
"""Reconciles the connection information returned from the Provisioner.
Because some provisioners (like derivations of LocalProvisioner) may have already
written the connection file, this method needs to ensure that, if the connection
file exists, its contents match that of what was returned by the provisioner. If
the file does exist and its contents do not match, the file will be replaced with
the provisioner information (which is considered the truth).
If the file does not exist, the connection information in 'info' is loaded into the
KernelManager and written to the file.
"""
# Prevent over-writing a file that has already been written with the same
# info. This is to prevent a race condition where the process has
# already been launched but has not yet read the connection file - as is
# the case with LocalProvisioners.
file_exists: bool = False
if os.path.exists(self.connection_file):
with open(self.connection_file) as f:
file_info = json.load(f)
# Prior to the following comparison, we need to adjust the value of "key" to
# be bytes, otherwise the comparison below will fail.
file_info["key"] = file_info["key"].encode()
if not self._equal_connections(info, file_info):
os.remove(self.connection_file) # Contents mismatch - remove the file
self._connection_file_written = False
else:
file_exists = True
if not file_exists:
# Load the connection info and write out file, clearing existing
# port-based attributes so they will be reloaded
for name in port_names:
setattr(self, name, 0)
self.load_connection_info(info)
self.write_connection_file()
# Ensure what is in KernelManager is what we expect.
km_info = self.get_connection_info()
if not self._equal_connections(info, km_info):
msg = (
"KernelManager's connection information already exists and does not match "
"the expected values returned from provisioner!"
)
raise ValueError(msg)
@staticmethod
def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
"""Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
pertinent_keys = [
"key",
"ip",
"stdin_port",
"iopub_port",
"shell_port",
"control_port",
"hb_port",
"transport",
"signature_scheme",
]
return all(conn1.get(key) == conn2.get(key) for key in pertinent_keys)
# --------------------------------------------------------------------------
# Creating connected sockets
# --------------------------------------------------------------------------
def _make_url(self, channel: str) -> str:
"""Make a ZeroMQ URL for a given channel."""
transport = self.transport
ip = self.ip
port = getattr(self, "%s_port" % channel)
if transport == "tcp":
return "tcp://%s:%i" % (ip, port)
else:
return f"{transport}://{ip}-{port}"
def _create_connected_socket(
self, channel: str, identity: bytes | None = None
) -> zmq.sugar.socket.Socket:
"""Create a zmq Socket and connect it to the kernel."""
url = self._make_url(channel)
socket_type = channel_socket_types[channel]
self.log.debug("Connecting to: %s", url)
sock = self.context.socket(socket_type)
# set linger to 1s to prevent hangs at exit
sock.linger = 1000
if identity:
sock.identity = identity
sock.connect(url)
return sock
def connect_iopub(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the IOPub channel"""
sock = self._create_connected_socket("iopub", identity=identity)
sock.setsockopt(zmq.SUBSCRIBE, b"")
return sock
def connect_shell(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the Shell channel"""
return self._create_connected_socket("shell", identity=identity)
def connect_stdin(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the StdIn channel"""
return self._create_connected_socket("stdin", identity=identity)
def connect_hb(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the Heartbeat channel"""
return self._create_connected_socket("hb", identity=identity)
def connect_control(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the Control channel"""
return self._create_connected_socket("control", identity=identity)
class LocalPortCache(SingletonConfigurable):
"""
Used to keep track of local ports in order to prevent race conditions that
can occur between port acquisition and usage by the kernel. All locally-
provisioned kernels should use this mechanism to limit the possibility of
race conditions. Note that this does not preclude other applications from
acquiring a cached but unused port, thereby re-introducing the issue this
class is attempting to resolve (minimize).
See: https://github.com/jupyter/jupyter_client/issues/487
"""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.currently_used_ports: set[int] = set()
def find_available_port(self, ip: str) -> int:
while True:
tmp_sock = socket.socket()
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
tmp_sock.bind((ip, 0))
port = tmp_sock.getsockname()[1]
tmp_sock.close()
# This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
# We prevent two kernels to have the same ports.
if port not in self.currently_used_ports:
self.currently_used_ports.add(port)
return port
def return_port(self, port: int) -> None:
if port in self.currently_used_ports: # Tolerate uncached ports
self.currently_used_ports.remove(port)
__all__ = [
"write_connection_file",
"find_connection_file",
"tunnel_to_kernel",
"KernelConnectionInfo",
"LocalPortCache",
]

View File

@@ -0,0 +1,374 @@
""" A minimal application base mixin for all ZMQ based IPython frontends.
This is not a complete console app, as subprocess will not be able to receive
input, there is no real readline support, among other limitations. This is a
refactoring of what used to be the IPython/qt/console/qtconsoleapp.py
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import atexit
import os
import signal
import sys
import typing as t
import uuid
import warnings
from jupyter_core.application import base_aliases, base_flags
from traitlets import CBool, CUnicode, Dict, List, Type, Unicode
from traitlets.config.application import boolean_flag
from . import KernelManager, connect, find_connection_file, tunnel_to_kernel
from .blocking import BlockingKernelClient
from .connect import KernelConnectionInfo
from .kernelspec import NoSuchKernel
from .localinterfaces import localhost
from .restarter import KernelRestarter
from .session import Session
from .utils import _filefind
ConnectionFileMixin = connect.ConnectionFileMixin
# -----------------------------------------------------------------------------
# Aliases and Flags
# -----------------------------------------------------------------------------
flags: dict = {}
flags.update(base_flags)
# the flags that are specific to the frontend
# these must be scrubbed before being passed to the kernel,
# or it will raise an error on unrecognized flags
app_flags: dict = {
"existing": (
{"JupyterConsoleApp": {"existing": "kernel*.json"}},
"Connect to an existing kernel. If no argument specified, guess most recent",
),
}
app_flags.update(
boolean_flag(
"confirm-exit",
"JupyterConsoleApp.confirm_exit",
"""Set to display confirmation dialog on exit. You can always use 'exit' or
'quit', to force a direct exit without any confirmation. This can also
be set in the config file by setting
`c.JupyterConsoleApp.confirm_exit`.
""",
"""Don't prompt the user when exiting. This will terminate the kernel
if it is owned by the frontend, and leave it alive if it is external.
This can also be set in the config file by setting
`c.JupyterConsoleApp.confirm_exit`.
""",
)
)
flags.update(app_flags)
aliases: dict = {}
aliases.update(base_aliases)
# also scrub aliases from the frontend
app_aliases: dict = {
"ip": "JupyterConsoleApp.ip",
"transport": "JupyterConsoleApp.transport",
"hb": "JupyterConsoleApp.hb_port",
"shell": "JupyterConsoleApp.shell_port",
"iopub": "JupyterConsoleApp.iopub_port",
"stdin": "JupyterConsoleApp.stdin_port",
"control": "JupyterConsoleApp.control_port",
"existing": "JupyterConsoleApp.existing",
"f": "JupyterConsoleApp.connection_file",
"kernel": "JupyterConsoleApp.kernel_name",
"ssh": "JupyterConsoleApp.sshserver",
"sshkey": "JupyterConsoleApp.sshkey",
}
aliases.update(app_aliases)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
classes: t.List[t.Type[t.Any]] = [KernelManager, KernelRestarter, Session]
class JupyterConsoleApp(ConnectionFileMixin):
"""The base Jupyter console application."""
name: t.Union[str, Unicode] = "jupyter-console-mixin"
description: t.Union[str, Unicode] = """
The Jupyter Console Mixin.
This class contains the common portions of console client (QtConsole,
ZMQ-based terminal console, etc). It is not a full console, in that
launched terminal subprocesses will not be able to accept input.
The Console using this mixing supports various extra features beyond
the single-process Terminal IPython shell, such as connecting to
existing kernel, via:
jupyter console <appname> --existing
as well as tunnel via SSH
"""
classes = classes
flags = Dict(flags)
aliases = Dict(aliases)
kernel_manager_class = Type(
default_value=KernelManager,
config=True,
help="The kernel manager class to use.",
)
kernel_client_class = BlockingKernelClient
kernel_argv = List(Unicode())
# connection info:
sshserver = Unicode("", config=True, help="""The SSH server to use to connect to the kernel.""")
sshkey = Unicode(
"",
config=True,
help="""Path to the ssh key to use for logging in to the ssh server.""",
)
def _connection_file_default(self) -> str:
return "kernel-%i.json" % os.getpid()
existing = CUnicode("", config=True, help="""Connect to an already running kernel""")
kernel_name = Unicode(
"python", config=True, help="""The name of the default kernel to start."""
)
confirm_exit = CBool(
True,
config=True,
help="""
Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
to force a direct exit without any confirmation.""",
)
def build_kernel_argv(self, argv: object = None) -> None:
"""build argv to be passed to kernel subprocess
Override in subclasses if any args should be passed to the kernel
"""
self.kernel_argv = self.extra_args # type:ignore[attr-defined]
def init_connection_file(self) -> None:
"""find the connection file, and load the info if found.
The current working directory and the current profile's security
directory will be searched for the file if it is not given by
absolute path.
When attempting to connect to an existing kernel and the `--existing`
argument does not match an existing file, it will be interpreted as a
fileglob, and the matching file in the current profile's security dir
with the latest access time will be used.
After this method is called, self.connection_file contains the *full path*
to the connection file, never just its name.
"""
runtime_dir = self.runtime_dir # type:ignore[attr-defined]
if self.existing:
try:
cf = find_connection_file(self.existing, [".", runtime_dir])
except Exception:
self.log.critical(
"Could not find existing kernel connection file %s", self.existing
)
self.exit(1) # type:ignore[attr-defined]
self.log.debug("Connecting to existing kernel: %s", cf)
self.connection_file = cf
else:
# not existing, check if we are going to write the file
# and ensure that self.connection_file is a full path, not just the shortname
try:
cf = find_connection_file(self.connection_file, [runtime_dir])
except Exception:
# file might not exist
if self.connection_file == os.path.basename(self.connection_file):
# just shortname, put it in security dir
cf = os.path.join(runtime_dir, self.connection_file)
else:
cf = self.connection_file
self.connection_file = cf
try:
self.connection_file = _filefind(self.connection_file, [".", runtime_dir])
except OSError:
self.log.debug("Connection File not found: %s", self.connection_file)
return
# should load_connection_file only be used for existing?
# as it is now, this allows reusing ports if an existing
# file is requested
try:
self.load_connection_file()
except Exception:
self.log.error(
"Failed to load connection file: %r",
self.connection_file,
exc_info=True,
)
self.exit(1) # type:ignore[attr-defined]
def init_ssh(self) -> None:
"""set up ssh tunnels, if needed."""
if not self.existing or (not self.sshserver and not self.sshkey):
return
self.load_connection_file()
transport = self.transport
ip = self.ip
if transport != "tcp":
self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport)
sys.exit(-1)
if self.sshkey and not self.sshserver:
# specifying just the key implies that we are connecting directly
self.sshserver = ip
ip = localhost()
# build connection dict for tunnels:
info: KernelConnectionInfo = {
"ip": ip,
"shell_port": self.shell_port,
"iopub_port": self.iopub_port,
"stdin_port": self.stdin_port,
"hb_port": self.hb_port,
"control_port": self.control_port,
}
self.log.info("Forwarding connections to %s via %s", ip, self.sshserver)
# tunnels return a new set of ports, which will be on localhost:
self.ip = localhost()
try:
newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
except: # noqa
# even catch KeyboardInterrupt
self.log.error("Could not setup tunnels", exc_info=True)
self.exit(1) # type:ignore[attr-defined]
(
self.shell_port,
self.iopub_port,
self.stdin_port,
self.hb_port,
self.control_port,
) = newports
cf = self.connection_file
root, ext = os.path.splitext(cf)
self.connection_file = root + "-ssh" + ext
self.write_connection_file() # write the new connection file
self.log.info("To connect another client via this tunnel, use:")
self.log.info("--existing %s", os.path.basename(self.connection_file))
def _new_connection_file(self) -> str:
cf = ""
while not cf:
# we don't need a 128b id to distinguish kernels, use more readable
# 48b node segment (12 hex chars). Users running more than 32k simultaneous
# kernels can subclass.
ident = str(uuid.uuid4()).split("-")[-1]
runtime_dir = self.runtime_dir # type:ignore[attr-defined]
cf = os.path.join(runtime_dir, "kernel-%s.json" % ident)
# only keep if it's actually new. Protect against unlikely collision
# in 48b random search space
cf = cf if not os.path.exists(cf) else ""
return cf
def init_kernel_manager(self) -> None:
"""Initialize the kernel manager."""
# Don't let Qt or ZMQ swallow KeyboardInterupts.
if self.existing:
self.kernel_manager = None
return
signal.signal(signal.SIGINT, signal.SIG_DFL)
# Create a KernelManager and start a kernel.
try:
self.kernel_manager = self.kernel_manager_class(
ip=self.ip,
session=self.session,
transport=self.transport,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
connection_file=self.connection_file,
kernel_name=self.kernel_name,
parent=self,
data_dir=self.data_dir,
)
except NoSuchKernel:
self.log.critical("Could not find kernel %s", self.kernel_name)
self.exit(1) # type:ignore[attr-defined]
self.kernel_manager = t.cast(KernelManager, self.kernel_manager)
self.kernel_manager.client_factory = self.kernel_client_class
kwargs = {}
kwargs["extra_arguments"] = self.kernel_argv
self.kernel_manager.start_kernel(**kwargs)
atexit.register(self.kernel_manager.cleanup_ipc_files)
if self.sshserver:
# ssh, write new connection file
self.kernel_manager.write_connection_file()
# in case KM defaults / ssh writing changes things:
km = self.kernel_manager
self.shell_port = km.shell_port
self.iopub_port = km.iopub_port
self.stdin_port = km.stdin_port
self.hb_port = km.hb_port
self.control_port = km.control_port
self.connection_file = km.connection_file
atexit.register(self.kernel_manager.cleanup_connection_file)
def init_kernel_client(self) -> None:
"""Initialize the kernel client."""
if self.kernel_manager is not None:
self.kernel_client = self.kernel_manager.client()
else:
self.kernel_client = self.kernel_client_class(
session=self.session,
ip=self.ip,
transport=self.transport,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
connection_file=self.connection_file,
parent=self,
)
self.kernel_client.start_channels()
def initialize(self, argv: object = None) -> None:
"""
Classes which mix this class in should call:
JupyterConsoleApp.initialize(self,argv)
"""
if getattr(self, "_dispatching", False):
return
self.init_connection_file()
self.init_ssh()
self.init_kernel_manager()
self.init_kernel_client()
class IPythonConsoleApp(JupyterConsoleApp):
"""An app to manage an ipython console."""
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
"""Initialize the app."""
warnings.warn("IPythonConsoleApp is deprecated. Use JupyterConsoleApp", stacklevel=2)
super().__init__(*args, **kwargs)

View File

@@ -0,0 +1,4 @@
from .manager import AsyncIOLoopKernelManager # noqa
from .manager import IOLoopKernelManager # noqa
from .restarter import AsyncIOLoopKernelRestarter # noqa
from .restarter import IOLoopKernelRestarter # noqa

View File

@@ -0,0 +1,116 @@
"""A kernel manager with a tornado IOLoop"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import typing as t
import zmq
from tornado import ioloop
from traitlets import Instance, Type
from zmq.eventloop.zmqstream import ZMQStream
from ..manager import AsyncKernelManager, KernelManager
from .restarter import AsyncIOLoopKernelRestarter, IOLoopKernelRestarter
def as_zmqstream(f: t.Any) -> t.Callable:
"""Convert a socket to a zmq stream."""
def wrapped(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any:
save_socket_class = None
# zmqstreams only support sync sockets
if self.context._socket_class is not zmq.Socket:
save_socket_class = self.context._socket_class
self.context._socket_class = zmq.Socket
try:
socket = f(self, *args, **kwargs)
finally:
if save_socket_class:
# restore default socket class
self.context._socket_class = save_socket_class
return ZMQStream(socket, self.loop)
return wrapped
class IOLoopKernelManager(KernelManager):
"""An io loop kernel manager."""
loop = Instance("tornado.ioloop.IOLoop")
def _loop_default(self) -> ioloop.IOLoop:
return ioloop.IOLoop.current()
restarter_class = Type(
default_value=IOLoopKernelRestarter,
klass=IOLoopKernelRestarter,
help=(
"Type of KernelRestarter to use. "
"Must be a subclass of IOLoopKernelRestarter.\n"
"Override this to customize how kernel restarts are managed."
),
config=True,
)
_restarter: t.Any = Instance("jupyter_client.ioloop.IOLoopKernelRestarter", allow_none=True)
def start_restarter(self) -> None:
"""Start the restarter."""
if self.autorestart and self.has_kernel:
if self._restarter is None:
self._restarter = self.restarter_class(
kernel_manager=self, loop=self.loop, parent=self, log=self.log
)
self._restarter.start()
def stop_restarter(self) -> None:
"""Stop the restarter."""
if self.autorestart and self._restarter is not None:
self._restarter.stop()
connect_shell = as_zmqstream(KernelManager.connect_shell)
connect_control = as_zmqstream(KernelManager.connect_control)
connect_iopub = as_zmqstream(KernelManager.connect_iopub)
connect_stdin = as_zmqstream(KernelManager.connect_stdin)
connect_hb = as_zmqstream(KernelManager.connect_hb)
class AsyncIOLoopKernelManager(AsyncKernelManager):
"""An async ioloop kernel manager."""
loop = Instance("tornado.ioloop.IOLoop")
def _loop_default(self) -> ioloop.IOLoop:
return ioloop.IOLoop.current()
restarter_class = Type(
default_value=AsyncIOLoopKernelRestarter,
klass=AsyncIOLoopKernelRestarter,
help=(
"Type of KernelRestarter to use. "
"Must be a subclass of AsyncIOLoopKernelManager.\n"
"Override this to customize how kernel restarts are managed."
),
config=True,
)
_restarter: t.Any = Instance(
"jupyter_client.ioloop.AsyncIOLoopKernelRestarter", allow_none=True
)
def start_restarter(self) -> None:
"""Start the restarter."""
if self.autorestart and self.has_kernel:
if self._restarter is None:
self._restarter = self.restarter_class(
kernel_manager=self, loop=self.loop, parent=self, log=self.log
)
self._restarter.start()
def stop_restarter(self) -> None:
"""Stop the restarter."""
if self.autorestart and self._restarter is not None:
self._restarter.stop()
connect_shell = as_zmqstream(AsyncKernelManager.connect_shell)
connect_control = as_zmqstream(AsyncKernelManager.connect_control)
connect_iopub = as_zmqstream(AsyncKernelManager.connect_iopub)
connect_stdin = as_zmqstream(AsyncKernelManager.connect_stdin)
connect_hb = as_zmqstream(AsyncKernelManager.connect_hb)

View File

@@ -0,0 +1,102 @@
"""A basic in process kernel monitor with autorestarting.
This watches a kernel's state using KernelManager.is_alive and auto
restarts the kernel if it dies.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import time
import warnings
from typing import Any
from traitlets import Instance
from ..restarter import KernelRestarter
class IOLoopKernelRestarter(KernelRestarter):
"""Monitor and autorestart a kernel."""
loop = Instance("tornado.ioloop.IOLoop")
def _loop_default(self) -> Any:
warnings.warn(
"IOLoopKernelRestarter.loop is deprecated in jupyter-client 5.2",
DeprecationWarning,
stacklevel=4,
)
from tornado import ioloop
return ioloop.IOLoop.current()
_pcallback = None
def start(self) -> None:
"""Start the polling of the kernel."""
if self._pcallback is None:
from tornado.ioloop import PeriodicCallback
self._pcallback = PeriodicCallback(
self.poll,
1000 * self.time_to_dead,
)
self._pcallback.start()
def stop(self) -> None:
"""Stop the kernel polling."""
if self._pcallback is not None:
self._pcallback.stop()
self._pcallback = None
class AsyncIOLoopKernelRestarter(IOLoopKernelRestarter):
"""An async io loop kernel restarter."""
async def poll(self) -> None: # type:ignore[override]
"""Poll the kernel."""
if self.debug:
self.log.debug("Polling kernel...")
is_alive = await self.kernel_manager.is_alive()
now = time.time()
if not is_alive:
self._last_dead = now
if self._restarting:
self._restart_count += 1
else:
self._restart_count = 1
if self._restart_count > self.restart_limit:
self.log.warning("AsyncIOLoopKernelRestarter: restart failed")
self._fire_callbacks("dead")
self._restarting = False
self._restart_count = 0
self.stop()
else:
newports = self.random_ports_until_alive and self._initial_startup
self.log.info(
"AsyncIOLoopKernelRestarter: restarting kernel (%i/%i), %s random ports",
self._restart_count,
self.restart_limit,
"new" if newports else "keep",
)
self._fire_callbacks("restart")
await self.kernel_manager.restart_kernel(now=True, newports=newports)
self._restarting = True
else:
# Since `is_alive` only tests that the kernel process is alive, it does not
# indicate that the kernel has successfully completed startup. To solve this
# correctly, we would need to wait for a kernel info reply, but it is not
# necessarily appropriate to start a kernel client + channels in the
# restarter. Therefore, we use "has been alive continuously for X time" as a
# heuristic for a stable start up.
# See https://github.com/jupyter/jupyter_client/pull/717 for details.
stable_start_time = self.stable_start_time
if self.kernel_manager.provisioner:
stable_start_time = self.kernel_manager.provisioner.get_stable_start_time(
recommended=stable_start_time
)
if self._initial_startup and now - self._last_dead >= stable_start_time:
self._initial_startup = False
if self._restarting and now - self._last_dead >= stable_start_time:
self.log.debug("AsyncIOLoopKernelRestarter: restart apparently succeeded")
self._restarting = False

View File

@@ -0,0 +1,192 @@
"""Utilities to manipulate JSON objects."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import math
import numbers
import re
import types
import warnings
from binascii import b2a_base64
from collections.abc import Iterable
from datetime import datetime
from typing import Any, Optional, Union
from dateutil.parser import parse as _dateutil_parse
from dateutil.tz import tzlocal
next_attr_name = "__next__" # Not sure what downstream library uses this, but left it to be safe
# -----------------------------------------------------------------------------
# Globals and constants
# -----------------------------------------------------------------------------
# timestamp formats
ISO8601 = "%Y-%m-%dT%H:%M:%S.%f"
ISO8601_PAT = re.compile(
r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d{1,6})?(Z|([\+\-]\d{2}:?\d{2}))?$"
)
# holy crap, strptime is not threadsafe.
# Calling it once at import seems to help.
datetime.strptime("1", "%d") # noqa
# -----------------------------------------------------------------------------
# Classes and functions
# -----------------------------------------------------------------------------
def _ensure_tzinfo(dt: datetime) -> datetime:
"""Ensure a datetime object has tzinfo
If no tzinfo is present, add tzlocal
"""
if not dt.tzinfo:
# No more naïve datetime objects!
warnings.warn(
"Interpreting naive datetime as local %s. Please add timezone info to timestamps." % dt,
DeprecationWarning,
stacklevel=4,
)
dt = dt.replace(tzinfo=tzlocal())
return dt
def parse_date(s: Optional[str]) -> Optional[Union[str, datetime]]:
"""parse an ISO8601 date string
If it is None or not a valid ISO8601 timestamp,
it will be returned unmodified.
Otherwise, it will return a datetime object.
"""
if s is None:
return s
m = ISO8601_PAT.match(s)
if m:
dt = _dateutil_parse(s)
return _ensure_tzinfo(dt)
return s
def extract_dates(obj: Any) -> Any:
"""extract ISO8601 dates from unpacked JSON"""
if isinstance(obj, dict):
new_obj = {} # don't clobber
for k, v in obj.items():
new_obj[k] = extract_dates(v)
obj = new_obj
elif isinstance(obj, (list, tuple)):
obj = [extract_dates(o) for o in obj]
elif isinstance(obj, str):
obj = parse_date(obj)
return obj
def squash_dates(obj: Any) -> Any:
"""squash datetime objects into ISO8601 strings"""
if isinstance(obj, dict):
obj = dict(obj) # don't clobber
for k, v in obj.items():
obj[k] = squash_dates(v)
elif isinstance(obj, (list, tuple)):
obj = [squash_dates(o) for o in obj]
elif isinstance(obj, datetime):
obj = obj.isoformat()
return obj
def date_default(obj: Any) -> Any:
"""DEPRECATED: Use jupyter_client.jsonutil.json_default"""
warnings.warn(
"date_default is deprecated since jupyter_client 7.0.0."
" Use jupyter_client.jsonutil.json_default.",
stacklevel=2,
)
return json_default(obj)
def json_default(obj: Any) -> Any:
"""default function for packing objects in JSON."""
if isinstance(obj, datetime):
obj = _ensure_tzinfo(obj)
return obj.isoformat().replace("+00:00", "Z")
if isinstance(obj, bytes):
return b2a_base64(obj, newline=False).decode("ascii")
if isinstance(obj, Iterable):
return list(obj)
if isinstance(obj, numbers.Integral):
return int(obj)
if isinstance(obj, numbers.Real):
return float(obj)
raise TypeError("%r is not JSON serializable" % obj)
# Copy of the old ipykernel's json_clean
# This is temporary, it should be removed when we deprecate support for
# non-valid JSON messages
def json_clean(obj: Any) -> Any:
# types that are 'atomic' and ok in json as-is.
atomic_ok = (str, type(None))
# containers that we need to convert into lists
container_to_list = (tuple, set, types.GeneratorType)
# Since bools are a subtype of Integrals, which are a subtype of Reals,
# we have to check them in that order.
if isinstance(obj, bool):
return obj
if isinstance(obj, numbers.Integral):
# cast int to int, in case subclasses override __str__ (e.g. boost enum, #4598)
return int(obj)
if isinstance(obj, numbers.Real):
# cast out-of-range floats to their reprs
if math.isnan(obj) or math.isinf(obj):
return repr(obj)
return float(obj)
if isinstance(obj, atomic_ok):
return obj
if isinstance(obj, bytes):
# unanmbiguous binary data is base64-encoded
# (this probably should have happened upstream)
return b2a_base64(obj, newline=False).decode("ascii")
if isinstance(obj, container_to_list) or (
hasattr(obj, "__iter__") and hasattr(obj, next_attr_name)
):
obj = list(obj)
if isinstance(obj, list):
return [json_clean(x) for x in obj]
if isinstance(obj, dict):
# First, validate that the dict won't lose data in conversion due to
# key collisions after stringification. This can happen with keys like
# True and 'true' or 1 and '1', which collide in JSON.
nkeys = len(obj)
nkeys_collapsed = len(set(map(str, obj)))
if nkeys != nkeys_collapsed:
msg = (
"dict cannot be safely converted to JSON: "
"key collision would lead to dropped values"
)
raise ValueError(msg)
# If all OK, proceed by making the new dict that will be json-safe
out = {}
for k, v in obj.items():
out[str(k)] = json_clean(v)
return out
if isinstance(obj, datetime):
return obj.strftime(ISO8601)
# we don't understand it, it's probably an unserializable object
raise ValueError("Can't clean for JSON: %r" % obj)

View File

@@ -0,0 +1,92 @@
"""An application to launch a kernel by name in a local subprocess."""
import os
import signal
import typing as t
import uuid
from jupyter_core.application import JupyterApp, base_flags
from tornado.ioloop import IOLoop
from traitlets import Unicode
from . import __version__
from .kernelspec import NATIVE_KERNEL_NAME, KernelSpecManager
from .manager import KernelManager
class KernelApp(JupyterApp):
"""Launch a kernel by name in a local subprocess."""
version = __version__
description = "Run a kernel locally in a subprocess"
classes = [KernelManager, KernelSpecManager]
aliases = {
"kernel": "KernelApp.kernel_name",
"ip": "KernelManager.ip",
}
flags = {"debug": base_flags["debug"]}
kernel_name = Unicode(NATIVE_KERNEL_NAME, help="The name of a kernel type to start").tag(
config=True
)
def initialize(self, argv: t.Union[str, t.Sequence[str], None] = None) -> None:
"""Initialize the application."""
super().initialize(argv)
cf_basename = "kernel-%s.json" % uuid.uuid4()
self.config.setdefault("KernelManager", {}).setdefault(
"connection_file", os.path.join(self.runtime_dir, cf_basename)
)
self.km = KernelManager(kernel_name=self.kernel_name, config=self.config)
self.loop = IOLoop.current()
self.loop.add_callback(self._record_started)
def setup_signals(self) -> None:
"""Shutdown on SIGTERM or SIGINT (Ctrl-C)"""
if os.name == "nt":
return
def shutdown_handler(signo: int, frame: t.Any) -> None:
self.loop.add_callback_from_signal(self.shutdown, signo)
for sig in [signal.SIGTERM, signal.SIGINT]:
signal.signal(sig, shutdown_handler)
def shutdown(self, signo: int) -> None:
"""Shut down the application."""
self.log.info("Shutting down on signal %d", signo)
self.km.shutdown_kernel()
self.loop.stop()
def log_connection_info(self) -> None:
"""Log the connection info for the kernel."""
cf = self.km.connection_file
self.log.info("Connection file: %s", cf)
self.log.info("To connect a client: --existing %s", os.path.basename(cf))
def _record_started(self) -> None:
"""For tests, create a file to indicate that we've started
Do not rely on this except in our own tests!
"""
fn = os.environ.get("JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE")
if fn is not None:
with open(fn, "wb"):
pass
def start(self) -> None:
"""Start the application."""
self.log.info("Starting kernel %r", self.kernel_name)
try:
self.km.start_kernel()
self.log_connection_info()
self.setup_signals()
self.loop.start()
finally:
self.km.cleanup_resources()
main = KernelApp.launch_instance

View File

@@ -0,0 +1,453 @@
"""Tools for managing kernel specs"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import json
import os
import re
import shutil
import typing as t
import warnings
from jupyter_core.paths import SYSTEM_JUPYTER_PATH, jupyter_data_dir, jupyter_path
from traitlets import Bool, CaselessStrEnum, Dict, HasTraits, List, Set, Type, Unicode, observe
from traitlets.config import LoggingConfigurable
from .provisioning import KernelProvisionerFactory as KPF # noqa
pjoin = os.path.join
NATIVE_KERNEL_NAME = "python3"
class KernelSpec(HasTraits):
"""A kernel spec model object."""
argv: List[str] = List()
name = Unicode()
mimetype = Unicode()
display_name = Unicode()
language = Unicode()
env = Dict()
resource_dir = Unicode()
interrupt_mode = CaselessStrEnum(["message", "signal"], default_value="signal")
metadata = Dict()
@classmethod
def from_resource_dir(cls: type[KernelSpec], resource_dir: str) -> KernelSpec:
"""Create a KernelSpec object by reading kernel.json
Pass the path to the *directory* containing kernel.json.
"""
kernel_file = pjoin(resource_dir, "kernel.json")
with open(kernel_file, encoding="utf-8") as f:
kernel_dict = json.load(f)
return cls(resource_dir=resource_dir, **kernel_dict)
def to_dict(self) -> dict[str, t.Any]:
"""Convert the kernel spec to a dict."""
d = {
"argv": self.argv,
"env": self.env,
"display_name": self.display_name,
"language": self.language,
"interrupt_mode": self.interrupt_mode,
"metadata": self.metadata,
}
return d
def to_json(self) -> str:
"""Serialise this kernelspec to a JSON object.
Returns a string.
"""
return json.dumps(self.to_dict())
_kernel_name_pat = re.compile(r"^[a-z0-9._\-]+$", re.IGNORECASE)
def _is_valid_kernel_name(name: str) -> t.Any:
"""Check that a kernel name is valid."""
# quote is not unicode-safe on Python 2
return _kernel_name_pat.match(name)
_kernel_name_description = (
"Kernel names can only contain ASCII letters and numbers and these separators:"
" - . _ (hyphen, period, and underscore)."
)
def _is_kernel_dir(path: str) -> bool:
"""Is ``path`` a kernel directory?"""
return os.path.isdir(path) and os.path.isfile(pjoin(path, "kernel.json"))
def _list_kernels_in(dir: str | None) -> dict[str, str]:
"""Return a mapping of kernel names to resource directories from dir.
If dir is None or does not exist, returns an empty dict.
"""
if dir is None or not os.path.isdir(dir):
return {}
kernels = {}
for f in os.listdir(dir):
path = pjoin(dir, f)
if not _is_kernel_dir(path):
continue
key = f.lower()
if not _is_valid_kernel_name(key):
warnings.warn(
f"Invalid kernelspec directory name ({_kernel_name_description}): {path}",
stacklevel=3,
)
kernels[key] = path
return kernels
class NoSuchKernel(KeyError): # noqa
"""An error raised when there is no kernel of a give name."""
def __init__(self, name: str) -> None:
"""Initialize the error."""
self.name = name
def __str__(self) -> str:
return f"No such kernel named {self.name}"
class KernelSpecManager(LoggingConfigurable):
"""A manager for kernel specs."""
kernel_spec_class = Type(
KernelSpec,
config=True,
help="""The kernel spec class. This is configurable to allow
subclassing of the KernelSpecManager for customized behavior.
""",
)
ensure_native_kernel = Bool(
True,
config=True,
help="""If there is no Python kernelspec registered and the IPython
kernel is available, ensure it is added to the spec list.
""",
)
data_dir = Unicode()
def _data_dir_default(self) -> str:
return jupyter_data_dir()
user_kernel_dir = Unicode()
def _user_kernel_dir_default(self) -> str:
return pjoin(self.data_dir, "kernels")
whitelist = Set(
config=True,
help="""Deprecated, use `KernelSpecManager.allowed_kernelspecs`
""",
)
allowed_kernelspecs = Set(
config=True,
help="""List of allowed kernel names.
By default, all installed kernels are allowed.
""",
)
kernel_dirs: List[str] = List(
help="List of kernel directories to search. Later ones take priority over earlier."
)
_deprecated_aliases = {
"whitelist": ("allowed_kernelspecs", "7.0"),
}
# Method copied from
# https://github.com/jupyterhub/jupyterhub/blob/d1a85e53dccfc7b1dd81b0c1985d158cc6b61820/jupyterhub/auth.py#L143-L161
@observe(*list(_deprecated_aliases))
def _deprecated_trait(self, change: t.Any) -> None:
"""observer for deprecated traits"""
old_attr = change.name
new_attr, version = self._deprecated_aliases[old_attr]
new_value = getattr(self, new_attr)
if new_value != change.new:
# only warn if different
# protects backward-compatible config from warnings
# if they set the same value under both names
self.log.warning(
f"{self.__class__.__name__}.{old_attr} is deprecated in jupyter_client "
f"{version}, use {self.__class__.__name__}.{new_attr} instead"
)
setattr(self, new_attr, change.new)
def _kernel_dirs_default(self) -> list[str]:
dirs = jupyter_path("kernels")
# At some point, we should stop adding .ipython/kernels to the path,
# but the cost to keeping it is very small.
try:
# this should always be valid on IPython 3+
from IPython.paths import get_ipython_dir
dirs.append(os.path.join(get_ipython_dir(), "kernels"))
except ModuleNotFoundError:
pass
return dirs
def find_kernel_specs(self) -> dict[str, str]:
"""Returns a dict mapping kernel names to resource directories."""
d = {}
for kernel_dir in self.kernel_dirs:
kernels = _list_kernels_in(kernel_dir)
for kname, spec in kernels.items():
if kname not in d:
self.log.debug("Found kernel %s in %s", kname, kernel_dir)
d[kname] = spec
if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d:
try:
from ipykernel.kernelspec import RESOURCES
self.log.debug(
"Native kernel (%s) available from %s",
NATIVE_KERNEL_NAME,
RESOURCES,
)
d[NATIVE_KERNEL_NAME] = RESOURCES
except ImportError:
self.log.warning("Native kernel (%s) is not available", NATIVE_KERNEL_NAME)
if self.allowed_kernelspecs:
# filter if there's an allow list
d = {name: spec for name, spec in d.items() if name in self.allowed_kernelspecs}
return d
# TODO: Caching?
def _get_kernel_spec_by_name(self, kernel_name: str, resource_dir: str) -> KernelSpec:
"""Returns a :class:`KernelSpec` instance for a given kernel_name
and resource_dir.
"""
kspec = None
if kernel_name == NATIVE_KERNEL_NAME:
try:
from ipykernel.kernelspec import RESOURCES, get_kernel_dict
except ImportError:
# It should be impossible to reach this, but let's play it safe
pass
else:
if resource_dir == RESOURCES:
kdict = get_kernel_dict()
kspec = self.kernel_spec_class(resource_dir=resource_dir, **kdict)
if not kspec:
kspec = self.kernel_spec_class.from_resource_dir(resource_dir)
if not KPF.instance(parent=self.parent).is_provisioner_available(kspec):
raise NoSuchKernel(kernel_name)
return kspec
def _find_spec_directory(self, kernel_name: str) -> str | None:
"""Find the resource directory of a named kernel spec"""
for kernel_dir in [kd for kd in self.kernel_dirs if os.path.isdir(kd)]:
files = os.listdir(kernel_dir)
for f in files:
path = pjoin(kernel_dir, f)
if f.lower() == kernel_name and _is_kernel_dir(path):
return path
if kernel_name == NATIVE_KERNEL_NAME:
try:
from ipykernel.kernelspec import RESOURCES
except ImportError:
pass
else:
return RESOURCES
return None
def get_kernel_spec(self, kernel_name: str) -> KernelSpec:
"""Returns a :class:`KernelSpec` instance for the given kernel_name.
Raises :exc:`NoSuchKernel` if the given kernel name is not found.
"""
if not _is_valid_kernel_name(kernel_name):
self.log.warning(
f"Kernelspec name {kernel_name} is invalid: {_kernel_name_description}"
)
resource_dir = self._find_spec_directory(kernel_name.lower())
if resource_dir is None:
self.log.warning("Kernelspec name %s cannot be found!", kernel_name)
raise NoSuchKernel(kernel_name)
return self._get_kernel_spec_by_name(kernel_name, resource_dir)
def get_all_specs(self) -> dict[str, t.Any]:
"""Returns a dict mapping kernel names to kernelspecs.
Returns a dict of the form::
{
'kernel_name': {
'resource_dir': '/path/to/kernel_name',
'spec': {"the spec itself": ...}
},
...
}
"""
d = self.find_kernel_specs()
res = {}
for kname, resource_dir in d.items():
try:
if self.__class__ is KernelSpecManager:
spec = self._get_kernel_spec_by_name(kname, resource_dir)
else:
# avoid calling private methods in subclasses,
# which may have overridden find_kernel_specs
# and get_kernel_spec, but not the newer get_all_specs
spec = self.get_kernel_spec(kname)
res[kname] = {"resource_dir": resource_dir, "spec": spec.to_dict()}
except NoSuchKernel:
pass # The appropriate warning has already been logged
except Exception:
self.log.warning("Error loading kernelspec %r", kname, exc_info=True)
return res
def remove_kernel_spec(self, name: str) -> str:
"""Remove a kernel spec directory by name.
Returns the path that was deleted.
"""
save_native = self.ensure_native_kernel
try:
self.ensure_native_kernel = False
specs = self.find_kernel_specs()
finally:
self.ensure_native_kernel = save_native
spec_dir = specs[name]
self.log.debug("Removing %s", spec_dir)
if os.path.islink(spec_dir):
os.remove(spec_dir)
else:
shutil.rmtree(spec_dir)
return spec_dir
def _get_destination_dir(
self, kernel_name: str, user: bool = False, prefix: str | None = None
) -> str:
if user:
return os.path.join(self.user_kernel_dir, kernel_name)
elif prefix:
return os.path.join(os.path.abspath(prefix), "share", "jupyter", "kernels", kernel_name)
else:
return os.path.join(SYSTEM_JUPYTER_PATH[0], "kernels", kernel_name)
def install_kernel_spec(
self,
source_dir: str,
kernel_name: str | None = None,
user: bool = False,
replace: bool | None = None,
prefix: str | None = None,
) -> str:
"""Install a kernel spec by copying its directory.
If ``kernel_name`` is not given, the basename of ``source_dir`` will
be used.
If ``user`` is False, it will attempt to install into the systemwide
kernel registry. If the process does not have appropriate permissions,
an :exc:`OSError` will be raised.
If ``prefix`` is given, the kernelspec will be installed to
PREFIX/share/jupyter/kernels/KERNEL_NAME. This can be sys.prefix
for installation inside virtual or conda envs.
"""
source_dir = source_dir.rstrip("/\\")
if not kernel_name:
kernel_name = os.path.basename(source_dir)
kernel_name = kernel_name.lower()
if not _is_valid_kernel_name(kernel_name):
msg = f"Invalid kernel name {kernel_name!r}. {_kernel_name_description}"
raise ValueError(msg)
if user and prefix:
msg = "Can't specify both user and prefix. Please choose one or the other."
raise ValueError(msg)
if replace is not None:
warnings.warn(
"replace is ignored. Installing a kernelspec always replaces an existing "
"installation",
DeprecationWarning,
stacklevel=2,
)
destination = self._get_destination_dir(kernel_name, user=user, prefix=prefix)
self.log.debug("Installing kernelspec in %s", destination)
kernel_dir = os.path.dirname(destination)
if kernel_dir not in self.kernel_dirs:
self.log.warning(
"Installing to %s, which is not in %s. The kernelspec may not be found.",
kernel_dir,
self.kernel_dirs,
)
if os.path.isdir(destination):
self.log.info("Removing existing kernelspec in %s", destination)
shutil.rmtree(destination)
shutil.copytree(source_dir, destination)
self.log.info("Installed kernelspec %s in %s", kernel_name, destination)
return destination
def install_native_kernel_spec(self, user: bool = False) -> None:
"""DEPRECATED: Use ipykernel.kernelspec.install"""
warnings.warn(
"install_native_kernel_spec is deprecated. Use ipykernel.kernelspec import install.",
stacklevel=2,
)
from ipykernel.kernelspec import install
install(self, user=user)
def find_kernel_specs() -> dict[str, str]:
"""Returns a dict mapping kernel names to resource directories."""
return KernelSpecManager().find_kernel_specs()
def get_kernel_spec(kernel_name: str) -> KernelSpec:
"""Returns a :class:`KernelSpec` instance for the given kernel_name.
Raises KeyError if the given kernel name is not found.
"""
return KernelSpecManager().get_kernel_spec(kernel_name)
def install_kernel_spec(
source_dir: str,
kernel_name: str | None = None,
user: bool = False,
replace: bool | None = False,
prefix: str | None = None,
) -> str:
"""Install a kernel spec in a given directory."""
return KernelSpecManager().install_kernel_spec(source_dir, kernel_name, user, replace, prefix)
install_kernel_spec.__doc__ = KernelSpecManager.install_kernel_spec.__doc__
def install_native_kernel_spec(user: bool = False) -> None:
"""Install the native kernel spec."""
KernelSpecManager().install_native_kernel_spec(user=user)
install_native_kernel_spec.__doc__ = KernelSpecManager.install_native_kernel_spec.__doc__

View File

@@ -0,0 +1,341 @@
"""Apps for managing kernel specs."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import errno
import json
import os.path
import sys
import typing as t
from jupyter_core.application import JupyterApp, base_aliases, base_flags
from traitlets import Bool, Dict, Instance, List, Unicode
from traitlets.config.application import Application
from . import __version__
from .kernelspec import KernelSpecManager
from .provisioning.factory import KernelProvisionerFactory
class ListKernelSpecs(JupyterApp):
"""An app to list kernel specs."""
version = __version__
description = """List installed kernel specifications."""
kernel_spec_manager = Instance(KernelSpecManager)
json_output = Bool(
False,
help="output spec name and location as machine-readable json.",
config=True,
)
flags = {
"json": (
{"ListKernelSpecs": {"json_output": True}},
"output spec name and location as machine-readable json.",
),
"debug": base_flags["debug"],
}
def _kernel_spec_manager_default(self) -> KernelSpecManager:
return KernelSpecManager(parent=self, data_dir=self.data_dir)
def start(self) -> dict[str, t.Any] | None: # type:ignore[override]
"""Start the application."""
paths = self.kernel_spec_manager.find_kernel_specs()
specs = self.kernel_spec_manager.get_all_specs()
if not self.json_output:
if not specs:
print("No kernels available")
return None
# pad to width of longest kernel name
name_len = len(sorted(paths, key=lambda name: len(name))[-1])
def path_key(item: t.Any) -> t.Any:
"""sort key function for Jupyter path priority"""
path = item[1]
for idx, prefix in enumerate(self.jupyter_path):
if path.startswith(prefix):
return (idx, path)
# not in jupyter path, artificially added to the front
return (-1, path)
print("Available kernels:")
for kernelname, path in sorted(paths.items(), key=path_key):
print(f" {kernelname.ljust(name_len)} {path}")
else:
print(json.dumps({"kernelspecs": specs}, indent=2))
return specs
class InstallKernelSpec(JupyterApp):
"""An app to install a kernel spec."""
version = __version__
description = """Install a kernel specification directory.
Given a SOURCE DIRECTORY containing a kernel spec,
jupyter will copy that directory into one of the Jupyter kernel directories.
The default is to install kernelspecs for all users.
`--user` can be specified to install a kernel only for the current user.
"""
examples = """
jupyter kernelspec install /path/to/my_kernel --user
"""
usage = "jupyter kernelspec install SOURCE_DIR [--options]"
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self) -> KernelSpecManager:
return KernelSpecManager(data_dir=self.data_dir)
sourcedir = Unicode()
kernel_name = Unicode("", config=True, help="Install the kernel spec with this name")
def _kernel_name_default(self) -> str:
return os.path.basename(self.sourcedir)
user = Bool(
False,
config=True,
help="""
Try to install the kernel spec to the per-user directory instead of
the system or environment directory.
""",
)
prefix = Unicode(
"",
config=True,
help="""Specify a prefix to install to, e.g. an env.
The kernelspec will be installed in PREFIX/share/jupyter/kernels/
""",
)
replace = Bool(False, config=True, help="Replace any existing kernel spec with this name.")
aliases = {
"name": "InstallKernelSpec.kernel_name",
"prefix": "InstallKernelSpec.prefix",
}
aliases.update(base_aliases)
flags = {
"user": (
{"InstallKernelSpec": {"user": True}},
"Install to the per-user kernel registry",
),
"replace": (
{"InstallKernelSpec": {"replace": True}},
"Replace any existing kernel spec with this name.",
),
"sys-prefix": (
{"InstallKernelSpec": {"prefix": sys.prefix}},
"Install to Python's sys.prefix. Useful in conda/virtual environments.",
),
"debug": base_flags["debug"],
}
def parse_command_line(self, argv: None | list[str]) -> None: # type:ignore[override]
"""Parse the command line args."""
super().parse_command_line(argv)
# accept positional arg as profile name
if self.extra_args:
self.sourcedir = self.extra_args[0]
else:
print("No source directory specified.", file=sys.stderr)
self.exit(1)
def start(self) -> None:
"""Start the application."""
if self.user and self.prefix:
self.exit("Can't specify both user and prefix. Please choose one or the other.")
try:
self.kernel_spec_manager.install_kernel_spec(
self.sourcedir,
kernel_name=self.kernel_name,
user=self.user,
prefix=self.prefix,
replace=self.replace,
)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
if not self.user:
print("Perhaps you want to install with `sudo` or `--user`?", file=sys.stderr)
self.exit(1)
elif e.errno == errno.EEXIST:
print(f"A kernel spec is already present at {e.filename}", file=sys.stderr)
self.exit(1)
raise
class RemoveKernelSpec(JupyterApp):
"""An app to remove a kernel spec."""
version = __version__
description = """Remove one or more Jupyter kernelspecs by name."""
examples = """jupyter kernelspec remove python2 [my_kernel ...]"""
force = Bool(False, config=True, help="""Force removal, don't prompt for confirmation.""")
spec_names = List(Unicode())
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self) -> KernelSpecManager:
return KernelSpecManager(data_dir=self.data_dir, parent=self)
flags = {
"f": ({"RemoveKernelSpec": {"force": True}}, force.help),
}
flags.update(JupyterApp.flags)
def parse_command_line(self, argv: list[str] | None) -> None: # type:ignore[override]
"""Parse the command line args."""
super().parse_command_line(argv)
# accept positional arg as profile name
if self.extra_args:
self.spec_names = sorted(set(self.extra_args)) # remove duplicates
else:
self.exit("No kernelspec specified.")
def start(self) -> None:
"""Start the application."""
self.kernel_spec_manager.ensure_native_kernel = False
spec_paths = self.kernel_spec_manager.find_kernel_specs()
missing = set(self.spec_names).difference(set(spec_paths))
if missing:
self.exit("Couldn't find kernel spec(s): %s" % ", ".join(missing))
if not (self.force or self.answer_yes):
print("Kernel specs to remove:")
for name in self.spec_names:
path = spec_paths.get(name, name)
print(f" {name.ljust(20)}\t{path.ljust(20)}")
answer = input("Remove %i kernel specs [y/N]: " % len(self.spec_names))
if not answer.lower().startswith("y"):
return
for kernel_name in self.spec_names:
try:
path = self.kernel_spec_manager.remove_kernel_spec(kernel_name)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
print("Perhaps you want sudo?", file=sys.stderr)
self.exit(1)
else:
raise
print(f"Removed {path}")
class InstallNativeKernelSpec(JupyterApp):
"""An app to install the native kernel spec."""
version = __version__
description = """[DEPRECATED] Install the IPython kernel spec directory for this Python."""
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self) -> KernelSpecManager: # pragma: no cover
return KernelSpecManager(data_dir=self.data_dir)
user = Bool(
False,
config=True,
help="""
Try to install the kernel spec to the per-user directory instead of
the system or environment directory.
""",
)
flags = {
"user": (
{"InstallNativeKernelSpec": {"user": True}},
"Install to the per-user kernel registry",
),
"debug": base_flags["debug"],
}
def start(self) -> None: # pragma: no cover
"""Start the application."""
self.log.warning(
"`jupyter kernelspec install-self` is DEPRECATED as of 4.0."
" You probably want `ipython kernel install` to install the IPython kernelspec."
)
try:
from ipykernel import kernelspec
except ModuleNotFoundError:
print("ipykernel not available, can't install its spec.", file=sys.stderr)
self.exit(1)
try:
kernelspec.install(self.kernel_spec_manager, user=self.user)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
if not self.user:
print(
"Perhaps you want to install with `sudo` or `--user`?",
file=sys.stderr,
)
self.exit(1)
self.exit(e) # type:ignore[arg-type]
class ListProvisioners(JupyterApp):
"""An app to list provisioners."""
version = __version__
description = """List available provisioners for use in kernel specifications."""
def start(self) -> None:
"""Start the application."""
kfp = KernelProvisionerFactory.instance(parent=self)
print("Available kernel provisioners:")
provisioners = kfp.get_provisioner_entries()
# pad to width of longest kernel name
name_len = len(sorted(provisioners, key=lambda name: len(name))[-1])
for name in sorted(provisioners):
print(f" {name.ljust(name_len)} {provisioners[name]}")
class KernelSpecApp(Application):
"""An app to manage kernel specs."""
version = __version__
name = "jupyter kernelspec"
description = """Manage Jupyter kernel specifications."""
subcommands = Dict(
{
"list": (ListKernelSpecs, ListKernelSpecs.description.splitlines()[0]),
"install": (
InstallKernelSpec,
InstallKernelSpec.description.splitlines()[0],
),
"uninstall": (RemoveKernelSpec, "Alias for remove"),
"remove": (RemoveKernelSpec, RemoveKernelSpec.description.splitlines()[0]),
"install-self": (
InstallNativeKernelSpec,
InstallNativeKernelSpec.description.splitlines()[0],
),
"provisioners": (ListProvisioners, ListProvisioners.description.splitlines()[0]),
}
)
aliases = {}
flags = {}
def start(self) -> None:
"""Start the application."""
if self.subapp is None:
print("No subcommand specified. Must specify one of: %s" % list(self.subcommands))
print()
self.print_description()
self.print_subcommands()
self.exit(1)
else:
return self.subapp.start()
if __name__ == "__main__":
KernelSpecApp.launch_instance()

View File

@@ -0,0 +1,186 @@
"""Utilities for launching kernels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import sys
import warnings
from subprocess import PIPE, Popen
from typing import Any, Dict, List, Optional
from traitlets.log import get_logger
def launch_kernel(
cmd: List[str],
stdin: Optional[int] = None,
stdout: Optional[int] = None,
stderr: Optional[int] = None,
env: Optional[Dict[str, str]] = None,
independent: bool = False,
cwd: Optional[str] = None,
**kw: Any,
) -> Popen:
"""Launches a localhost kernel, binding to the specified ports.
Parameters
----------
cmd : Popen list,
A string of Python code that imports and executes a kernel entry point.
stdin, stdout, stderr : optional (default None)
Standards streams, as defined in subprocess.Popen.
env: dict, optional
Environment variables passed to the kernel
independent : bool, optional (default False)
If set, the kernel process is guaranteed to survive if this process
dies. If not set, an effort is made to ensure that the kernel is killed
when this process dies. Note that in this case it is still good practice
to kill kernels manually before exiting.
cwd : path, optional
The working dir of the kernel process (default: cwd of this process).
**kw: optional
Additional arguments for Popen
Returns
-------
Popen instance for the kernel subprocess
"""
# Popen will fail (sometimes with a deadlock) if stdin, stdout, and stderr
# are invalid. Unfortunately, there is in general no way to detect whether
# they are valid. The following two blocks redirect them to (temporary)
# pipes in certain important cases.
# If this process has been backgrounded, our stdin is invalid. Since there
# is no compelling reason for the kernel to inherit our stdin anyway, we'll
# place this one safe and always redirect.
redirect_in = True
_stdin = PIPE if stdin is None else stdin
# If this process in running on pythonw, we know that stdin, stdout, and
# stderr are all invalid.
redirect_out = sys.executable.endswith("pythonw.exe")
if redirect_out:
blackhole = open(os.devnull, "w") # noqa
_stdout = blackhole if stdout is None else stdout
_stderr = blackhole if stderr is None else stderr
else:
_stdout, _stderr = stdout, stderr
env = env if (env is not None) else os.environ.copy()
kwargs = kw.copy()
main_args = {
"stdin": _stdin,
"stdout": _stdout,
"stderr": _stderr,
"cwd": cwd,
"env": env,
}
kwargs.update(main_args)
# Spawn a kernel.
if sys.platform == "win32":
if cwd:
kwargs["cwd"] = cwd
from .win_interrupt import create_interrupt_event
# Create a Win32 event for interrupting the kernel
# and store it in an environment variable.
interrupt_event = create_interrupt_event()
env["JPY_INTERRUPT_EVENT"] = str(interrupt_event)
# deprecated old env name:
env["IPY_INTERRUPT_EVENT"] = env["JPY_INTERRUPT_EVENT"]
try:
from _winapi import (
CREATE_NEW_PROCESS_GROUP,
DUPLICATE_SAME_ACCESS,
DuplicateHandle,
GetCurrentProcess,
)
except: # noqa
from _subprocess import (
CREATE_NEW_PROCESS_GROUP,
DUPLICATE_SAME_ACCESS,
DuplicateHandle,
GetCurrentProcess,
)
# create a handle on the parent to be inherited
if independent:
kwargs["creationflags"] = CREATE_NEW_PROCESS_GROUP
else:
pid = GetCurrentProcess()
handle = DuplicateHandle(
pid,
pid,
pid,
0,
True,
DUPLICATE_SAME_ACCESS, # Inheritable by new processes.
)
env["JPY_PARENT_PID"] = str(int(handle))
# Prevent creating new console window on pythonw
if redirect_out:
kwargs["creationflags"] = (
kwargs.setdefault("creationflags", 0) | 0x08000000
) # CREATE_NO_WINDOW
# Avoid closing the above parent and interrupt handles.
# close_fds is True by default on Python >=3.7
# or when no stream is captured on Python <3.7
# (we always capture stdin, so this is already False by default on <3.7)
kwargs["close_fds"] = False
else:
# Create a new session.
# This makes it easier to interrupt the kernel,
# because we want to interrupt the whole process group.
# We don't use setpgrp, which is known to cause problems for kernels starting
# certain interactive subprocesses, such as bash -i.
kwargs["start_new_session"] = True
if not independent:
env["JPY_PARENT_PID"] = str(os.getpid())
try:
# Allow to use ~/ in the command or its arguments
cmd = [os.path.expanduser(s) for s in cmd]
proc = Popen(cmd, **kwargs) # noqa
except Exception as ex:
try:
msg = "Failed to run command:\n{}\n PATH={!r}\n with kwargs:\n{!r}\n"
# exclude environment variables,
# which may contain access tokens and the like.
without_env = {key: value for key, value in kwargs.items() if key != "env"}
msg = msg.format(cmd, env.get("PATH", os.defpath), without_env)
get_logger().error(msg)
except Exception as ex2: # Don't let a formatting/logger issue lead to the wrong exception
warnings.warn(f"Failed to run command: '{cmd}' due to exception: {ex}", stacklevel=2)
warnings.warn(
f"The following exception occurred handling the previous failure: {ex2}",
stacklevel=2,
)
raise ex
if sys.platform == "win32":
# Attach the interrupt event to the Popen object so it can be used later.
proc.win32_interrupt_event = interrupt_event
# Clean up pipes created to work around Popen bug.
if redirect_in and stdin is None:
assert proc.stdin is not None
proc.stdin.close()
return proc
__all__ = [
"launch_kernel",
]

View File

@@ -0,0 +1,297 @@
"""Utilities for identifying local IP addresses."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import os
import re
import socket
import subprocess
from subprocess import PIPE, Popen
from typing import Any, Callable, Iterable, Sequence
from warnings import warn
LOCAL_IPS: list = []
PUBLIC_IPS: list = []
LOCALHOST: str = ""
def _uniq_stable(elems: Iterable) -> list:
"""uniq_stable(elems) -> list
Return from an iterable, a list of all the unique elements in the input,
maintaining the order in which they first appear.
"""
seen = set()
value = []
for x in elems:
if x not in seen:
value.append(x)
seen.add(x)
return value
def _get_output(cmd: str | Sequence[str]) -> str:
"""Get output of a command, raising IOError if it fails"""
startupinfo = None
if os.name == "nt":
startupinfo = subprocess.STARTUPINFO() # type:ignore[attr-defined]
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type:ignore[attr-defined]
p = Popen(cmd, stdout=PIPE, stderr=PIPE, startupinfo=startupinfo) # noqa
stdout, stderr = p.communicate()
if p.returncode:
msg = "Failed to run {}: {}".format(cmd, stderr.decode("utf8", "replace"))
raise OSError(msg)
return stdout.decode("utf8", "replace")
def _only_once(f: Callable) -> Callable:
"""decorator to only run a function once"""
f.called = False # type:ignore[attr-defined]
def wrapped(**kwargs: Any) -> Any:
if f.called: # type:ignore[attr-defined]
return
ret = f(**kwargs)
f.called = True # type:ignore[attr-defined]
return ret
return wrapped
def _requires_ips(f: Callable) -> Callable:
"""decorator to ensure load_ips has been run before f"""
def ips_loaded(*args: Any, **kwargs: Any) -> Any:
_load_ips()
return f(*args, **kwargs)
return ips_loaded
# subprocess-parsing ip finders
class NoIPAddresses(Exception): # noqa
pass
def _populate_from_list(addrs: Sequence[str] | None) -> None:
"""populate local and public IPs from flat list of all IPs"""
if not addrs:
raise NoIPAddresses
global LOCALHOST
public_ips = []
local_ips = []
for ip in addrs:
local_ips.append(ip)
if not ip.startswith("127."):
public_ips.append(ip)
elif not LOCALHOST:
LOCALHOST = ip
if not LOCALHOST or LOCALHOST == "127.0.0.1":
LOCALHOST = "127.0.0.1"
local_ips.insert(0, LOCALHOST)
local_ips.extend(["0.0.0.0", ""]) # noqa
LOCAL_IPS[:] = _uniq_stable(local_ips)
PUBLIC_IPS[:] = _uniq_stable(public_ips)
_ifconfig_ipv4_pat = re.compile(r"inet\b.*?(\d+\.\d+\.\d+\.\d+)", re.IGNORECASE)
def _load_ips_ifconfig() -> None:
"""load ip addresses from `ifconfig` output (posix)"""
try:
out = _get_output("ifconfig")
except OSError:
# no ifconfig, it's usually in /sbin and /sbin is not on everyone's PATH
out = _get_output("/sbin/ifconfig")
lines = out.splitlines()
addrs = []
for line in lines:
m = _ifconfig_ipv4_pat.match(line.strip())
if m:
addrs.append(m.group(1))
_populate_from_list(addrs)
def _load_ips_ip() -> None:
"""load ip addresses from `ip addr` output (Linux)"""
out = _get_output(["ip", "-f", "inet", "addr"])
lines = out.splitlines()
addrs = []
for line in lines:
blocks = line.lower().split()
if (len(blocks) >= 2) and (blocks[0] == "inet"):
addrs.append(blocks[1].split("/")[0])
_populate_from_list(addrs)
_ipconfig_ipv4_pat = re.compile(r"ipv4.*?(\d+\.\d+\.\d+\.\d+)$", re.IGNORECASE)
def _load_ips_ipconfig() -> None:
"""load ip addresses from `ipconfig` output (Windows)"""
out = _get_output("ipconfig")
lines = out.splitlines()
addrs = []
for line in lines:
m = _ipconfig_ipv4_pat.match(line.strip())
if m:
addrs.append(m.group(1))
_populate_from_list(addrs)
def _load_ips_netifaces() -> None:
"""load ip addresses with netifaces"""
import netifaces # type: ignore[import-not-found]
global LOCALHOST
local_ips = []
public_ips = []
# list of iface names, 'lo0', 'eth0', etc.
for iface in netifaces.interfaces():
# list of ipv4 addrinfo dicts
ipv4s = netifaces.ifaddresses(iface).get(netifaces.AF_INET, [])
for entry in ipv4s:
addr = entry.get("addr")
if not addr:
continue
if not (iface.startswith("lo") or addr.startswith("127.")):
public_ips.append(addr)
elif not LOCALHOST:
LOCALHOST = addr
local_ips.append(addr)
if not LOCALHOST:
# we never found a loopback interface (can this ever happen?), assume common default
LOCALHOST = "127.0.0.1"
local_ips.insert(0, LOCALHOST)
local_ips.extend(["0.0.0.0", ""]) # noqa
LOCAL_IPS[:] = _uniq_stable(local_ips)
PUBLIC_IPS[:] = _uniq_stable(public_ips)
def _load_ips_gethostbyname() -> None:
"""load ip addresses with socket.gethostbyname_ex
This can be slow.
"""
global LOCALHOST
try:
LOCAL_IPS[:] = socket.gethostbyname_ex("localhost")[2]
except OSError:
# assume common default
LOCAL_IPS[:] = ["127.0.0.1"]
try:
hostname = socket.gethostname()
PUBLIC_IPS[:] = socket.gethostbyname_ex(hostname)[2]
# try hostname.local, in case hostname has been short-circuited to loopback
if not hostname.endswith(".local") and all(ip.startswith("127") for ip in PUBLIC_IPS):
PUBLIC_IPS[:] = socket.gethostbyname_ex(socket.gethostname() + ".local")[2]
except OSError:
pass
finally:
PUBLIC_IPS[:] = _uniq_stable(PUBLIC_IPS)
LOCAL_IPS.extend(PUBLIC_IPS)
# include all-interface aliases: 0.0.0.0 and ''
LOCAL_IPS.extend(["0.0.0.0", ""]) # noqa
LOCAL_IPS[:] = _uniq_stable(LOCAL_IPS)
LOCALHOST = LOCAL_IPS[0]
def _load_ips_dumb() -> None:
"""Fallback in case of unexpected failure"""
global LOCALHOST
LOCALHOST = "127.0.0.1"
LOCAL_IPS[:] = [LOCALHOST, "0.0.0.0", ""] # noqa
PUBLIC_IPS[:] = []
@_only_once
def _load_ips(suppress_exceptions: bool = True) -> None:
"""load the IPs that point to this machine
This function will only ever be called once.
It will use netifaces to do it quickly if available.
Then it will fallback on parsing the output of ifconfig / ip addr / ipconfig, as appropriate.
Finally, it will fallback on socket.gethostbyname_ex, which can be slow.
"""
try:
# first priority, use netifaces
try:
return _load_ips_netifaces()
except ImportError:
pass
# second priority, parse subprocess output (how reliable is this?)
if os.name == "nt":
try:
return _load_ips_ipconfig()
except (OSError, NoIPAddresses):
pass
else:
try:
return _load_ips_ip()
except (OSError, NoIPAddresses):
pass
try:
return _load_ips_ifconfig()
except (OSError, NoIPAddresses):
pass
# lowest priority, use gethostbyname
return _load_ips_gethostbyname()
except Exception as e:
if not suppress_exceptions:
raise
# unexpected error shouldn't crash, load dumb default values instead.
warn("Unexpected error discovering local network interfaces: %s" % e, stacklevel=2)
_load_ips_dumb()
@_requires_ips
def local_ips() -> list[str]:
"""return the IP addresses that point to this machine"""
return LOCAL_IPS
@_requires_ips
def public_ips() -> list[str]:
"""return the IP addresses for this machine that are visible to other machines"""
return PUBLIC_IPS
@_requires_ips
def localhost() -> str:
"""return ip for localhost (almost always 127.0.0.1)"""
return LOCALHOST
@_requires_ips
def is_local_ip(ip: str) -> bool:
"""does `ip` point to this machine?"""
return ip in LOCAL_IPS
@_requires_ips
def is_public_ip(ip: str) -> bool:
"""is `ip` a publicly visible address?"""
return ip in PUBLIC_IPS

View File

@@ -0,0 +1,806 @@
"""Base class to manage a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import functools
import os
import re
import signal
import sys
import typing as t
import uuid
import warnings
from asyncio.futures import Future
from concurrent.futures import Future as CFuture
from contextlib import contextmanager
from enum import Enum
import zmq
from jupyter_core.utils import run_sync
from traitlets import (
Any,
Bool,
Dict,
DottedObjectName,
Float,
Instance,
Type,
Unicode,
default,
observe,
observe_compat,
)
from traitlets.utils.importstring import import_item
from . import kernelspec
from .asynchronous import AsyncKernelClient
from .blocking import BlockingKernelClient
from .client import KernelClient
from .connect import ConnectionFileMixin
from .managerabc import KernelManagerABC
from .provisioning import KernelProvisionerBase
from .provisioning import KernelProvisionerFactory as KPF # noqa
class _ShutdownStatus(Enum):
"""
This is so far used only for testing in order to track the internal state of
the shutdown logic, and verifying which path is taken for which
missbehavior.
"""
Unset = None
ShutdownRequest = "ShutdownRequest"
SigtermRequest = "SigtermRequest"
SigkillRequest = "SigkillRequest"
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
def _get_future() -> t.Union[Future, CFuture]:
"""Get an appropriate Future object"""
try:
asyncio.get_running_loop()
return Future()
except RuntimeError:
# No event loop running, use concurrent future
return CFuture()
def in_pending_state(method: F) -> F:
"""Sets the kernel to a pending state by
creating a fresh Future for the KernelManager's `ready`
attribute. Once the method is finished, set the Future's results.
"""
@t.no_type_check
@functools.wraps(method)
async def wrapper(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Create a future for the decorated method."""
if self._attempted_start or not self._ready:
self._ready = _get_future()
try:
# call wrapped method, await, and set the result or exception.
out = await method(self, *args, **kwargs)
# Add a small sleep to ensure tests can capture the state before done
await asyncio.sleep(0.01)
if self.owns_kernel:
self._ready.set_result(None)
return out
except Exception as e:
self._ready.set_exception(e)
self.log.exception(self._ready.exception())
raise e
return t.cast(F, wrapper)
class KernelManager(ConnectionFileMixin):
"""Manages a single kernel in a subprocess on this host.
This version starts kernels with Popen.
"""
_ready: t.Optional[t.Union[Future, CFuture]]
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
"""Initialize a kernel manager."""
if args:
warnings.warn(
"Passing positional only arguments to "
"`KernelManager.__init__` is deprecated since jupyter_client"
" 8.6, and will become an error on future versions. Positional "
" arguments have been ignored since jupyter_client 7.0",
DeprecationWarning,
stacklevel=2,
)
self._owns_kernel = kwargs.pop("owns_kernel", True)
super().__init__(**kwargs)
self._shutdown_status = _ShutdownStatus.Unset
self._attempted_start = False
self._ready = None
_created_context: Bool = Bool(False)
# The PyZMQ Context to use for communication with the kernel.
context: Instance = Instance(zmq.Context)
@default("context")
def _context_default(self) -> zmq.Context:
self._created_context = True
return zmq.Context()
# the class to create with our `client` method
client_class: DottedObjectName = DottedObjectName(
"jupyter_client.blocking.BlockingKernelClient"
)
client_factory: Type = Type(klass=KernelClient)
@default("client_factory")
def _client_factory_default(self) -> Type:
return import_item(self.client_class)
@observe("client_class")
def _client_class_changed(self, change: t.Dict[str, DottedObjectName]) -> None:
self.client_factory = import_item(str(change["new"]))
kernel_id: t.Union[str, Unicode] = Unicode(None, allow_none=True)
# The kernel provisioner with which this KernelManager is communicating.
# This will generally be a LocalProvisioner instance unless the kernelspec
# indicates otherwise.
provisioner: t.Optional[KernelProvisionerBase] = None
kernel_spec_manager: Instance = Instance(kernelspec.KernelSpecManager)
@default("kernel_spec_manager")
def _kernel_spec_manager_default(self) -> kernelspec.KernelSpecManager:
return kernelspec.KernelSpecManager(data_dir=self.data_dir)
@observe("kernel_spec_manager")
@observe_compat
def _kernel_spec_manager_changed(self, change: t.Dict[str, Instance]) -> None:
self._kernel_spec = None
shutdown_wait_time: Float = Float(
5.0,
config=True,
help="Time to wait for a kernel to terminate before killing it, "
"in seconds. When a shutdown request is initiated, the kernel "
"will be immediately sent an interrupt (SIGINT), followed"
"by a shutdown_request message, after 1/2 of `shutdown_wait_time`"
"it will be sent a terminate (SIGTERM) request, and finally at "
"the end of `shutdown_wait_time` will be killed (SIGKILL). terminate "
"and kill may be equivalent on windows. Note that this value can be"
"overridden by the in-use kernel provisioner since shutdown times may"
"vary by provisioned environment.",
)
kernel_name: t.Union[str, Unicode] = Unicode(kernelspec.NATIVE_KERNEL_NAME)
@observe("kernel_name")
def _kernel_name_changed(self, change: t.Dict[str, str]) -> None:
self._kernel_spec = None
if change["new"] == "python":
self.kernel_name = kernelspec.NATIVE_KERNEL_NAME
_kernel_spec: t.Optional[kernelspec.KernelSpec] = None
@property
def kernel_spec(self) -> t.Optional[kernelspec.KernelSpec]:
if self._kernel_spec is None and self.kernel_name != "":
self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(self.kernel_name)
return self._kernel_spec
cache_ports: Bool = Bool(
False,
config=True,
help="True if the MultiKernelManager should cache ports for this KernelManager instance",
)
@default("cache_ports")
def _default_cache_ports(self) -> bool:
return self.transport == "tcp"
@property
def ready(self) -> t.Union[CFuture, Future]:
"""A future that resolves when the kernel process has started for the first time"""
if not self._ready:
self._ready = _get_future()
return self._ready
@property
def ipykernel(self) -> bool:
return self.kernel_name in {"python", "python2", "python3"}
# Protected traits
_launch_args: t.Optional["Dict[str, Any]"] = Dict(allow_none=True)
_control_socket: Any = Any()
_restarter: Any = Any()
autorestart: Bool = Bool(
True, config=True, help="""Should we autorestart the kernel if it dies."""
)
shutting_down: bool = False
def __del__(self) -> None:
self._close_control_socket()
self.cleanup_connection_file()
# --------------------------------------------------------------------------
# Kernel restarter
# --------------------------------------------------------------------------
def start_restarter(self) -> None:
"""Start the kernel restarter."""
pass
def stop_restarter(self) -> None:
"""Stop the kernel restarter."""
pass
def add_restart_callback(self, callback: t.Callable, event: str = "restart") -> None:
"""Register a callback to be called when a kernel is restarted"""
if self._restarter is None:
return
self._restarter.add_callback(callback, event)
def remove_restart_callback(self, callback: t.Callable, event: str = "restart") -> None:
"""Unregister a callback to be called when a kernel is restarted"""
if self._restarter is None:
return
self._restarter.remove_callback(callback, event)
# --------------------------------------------------------------------------
# create a Client connected to our Kernel
# --------------------------------------------------------------------------
def client(self, **kwargs: t.Any) -> BlockingKernelClient:
"""Create a client configured to connect to our kernel"""
kw: dict = {}
kw.update(self.get_connection_info(session=True))
kw.update(
{
"connection_file": self.connection_file,
"parent": self,
}
)
# add kwargs last, for manual overrides
kw.update(kwargs)
return self.client_factory(**kw)
# --------------------------------------------------------------------------
# Kernel management
# --------------------------------------------------------------------------
def update_env(self, *, env: t.Dict[str, str]) -> None:
"""
Allow to update the environment of a kernel manager.
This will take effect only after kernel restart when the new env is
passed to the new kernel.
This is useful as some of the information of the current kernel reflect
the state of the session that started it, and those session information
(like the attach file path, or name), are mutable.
.. version-added: 8.5
"""
# Mypy think this is unreachable as it see _launch_args as Dict, not t.Dict
if (
isinstance(self._launch_args, dict)
and "env" in self._launch_args
and isinstance(self._launch_args["env"], dict) # type: ignore [unreachable]
):
self._launch_args["env"].update(env) # type: ignore [unreachable]
def format_kernel_cmd(self, extra_arguments: t.Optional[t.List[str]] = None) -> t.List[str]:
"""Replace templated args (e.g. {connection_file})"""
extra_arguments = extra_arguments or []
assert self.kernel_spec is not None
cmd = self.kernel_spec.argv + extra_arguments
if cmd and cmd[0] in {
"python",
"python%i" % sys.version_info[0],
"python%i.%i" % sys.version_info[:2],
}:
# executable is 'python' or 'python3', use sys.executable.
# These will typically be the same,
# but if the current process is in an env
# and has been launched by abspath without
# activating the env, python on PATH may not be sys.executable,
# but it should be.
cmd[0] = sys.executable
# Make sure to use the realpath for the connection_file
# On windows, when running with the store python, the connection_file path
# is not usable by non python kernels because the path is being rerouted when
# inside of a store app.
# See this bug here: https://bugs.python.org/issue41196
ns: t.Dict[str, t.Any] = {
"connection_file": os.path.realpath(self.connection_file),
"prefix": sys.prefix,
}
if self.kernel_spec: # type:ignore[truthy-bool]
ns["resource_dir"] = self.kernel_spec.resource_dir
assert isinstance(self._launch_args, dict)
ns.update(self._launch_args)
pat = re.compile(r"\{([A-Za-z0-9_]+)\}")
def from_ns(match: t.Any) -> t.Any:
"""Get the key out of ns if it's there, otherwise no change."""
return ns.get(match.group(1), match.group())
return [pat.sub(from_ns, arg) for arg in cmd]
async def _async_launch_kernel(self, kernel_cmd: t.List[str], **kw: t.Any) -> None:
"""actually launch the kernel
override in a subclass to launch kernel subprocesses differently
Note that provisioners can now be used to customize kernel environments
and
"""
assert self.provisioner is not None
connection_info = await self.provisioner.launch_kernel(kernel_cmd, **kw)
assert self.provisioner.has_process
# Provisioner provides the connection information. Load into kernel manager
# and write the connection file, if not already done.
self._reconcile_connection_info(connection_info)
_launch_kernel = run_sync(_async_launch_kernel)
# Control socket used for polite kernel shutdown
def _connect_control_socket(self) -> None:
if self._control_socket is None:
self._control_socket = self._create_connected_socket("control")
self._control_socket.linger = 100
def _close_control_socket(self) -> None:
if self._control_socket is None:
return
self._control_socket.close()
self._control_socket = None
async def _async_pre_start_kernel(
self, **kw: t.Any
) -> t.Tuple[t.List[str], t.Dict[str, t.Any]]:
"""Prepares a kernel for startup in a separate process.
If random ports (port=0) are being used, this method must be called
before the channels are created.
Parameters
----------
`**kw` : optional
keyword arguments that are passed down to build the kernel_cmd
and launching the kernel (e.g. Popen kwargs).
"""
self.shutting_down = False
self.kernel_id = self.kernel_id or kw.pop("kernel_id", str(uuid.uuid4()))
# save kwargs for use in restart
# assigning Traitlets Dicts to Dict make mypy unhappy but is ok
self._launch_args = kw.copy() # type:ignore [assignment]
if self.provisioner is None: # will not be None on restarts
self.provisioner = KPF.instance(parent=self.parent).create_provisioner_instance(
self.kernel_id,
self.kernel_spec,
parent=self,
)
kw = await self.provisioner.pre_launch(**kw)
kernel_cmd = kw.pop("cmd")
return kernel_cmd, kw
pre_start_kernel = run_sync(_async_pre_start_kernel)
async def _async_post_start_kernel(self, **kw: t.Any) -> None:
"""Performs any post startup tasks relative to the kernel.
Parameters
----------
`**kw` : optional
keyword arguments that were used in the kernel process's launch.
"""
self.start_restarter()
self._connect_control_socket()
assert self.provisioner is not None
await self.provisioner.post_launch(**kw)
post_start_kernel = run_sync(_async_post_start_kernel)
@in_pending_state
async def _async_start_kernel(self, **kw: t.Any) -> None:
"""Starts a kernel on this host in a separate process.
If random ports (port=0) are being used, this method must be called
before the channels are created.
Parameters
----------
`**kw` : optional
keyword arguments that are passed down to build the kernel_cmd
and launching the kernel (e.g. Popen kwargs).
"""
self._attempted_start = True
kernel_cmd, kw = await self._async_pre_start_kernel(**kw)
# launch the kernel subprocess
self.log.debug("Starting kernel: %s", kernel_cmd)
await self._async_launch_kernel(kernel_cmd, **kw)
await self._async_post_start_kernel(**kw)
start_kernel = run_sync(_async_start_kernel)
async def _async_request_shutdown(self, restart: bool = False) -> None:
"""Send a shutdown request via control channel"""
content = {"restart": restart}
msg = self.session.msg("shutdown_request", content=content)
# ensure control socket is connected
self._connect_control_socket()
self.session.send(self._control_socket, msg)
assert self.provisioner is not None
await self.provisioner.shutdown_requested(restart=restart)
self._shutdown_status = _ShutdownStatus.ShutdownRequest
request_shutdown = run_sync(_async_request_shutdown)
async def _async_finish_shutdown(
self,
waittime: t.Optional[float] = None,
pollinterval: float = 0.1,
restart: bool = False,
) -> None:
"""Wait for kernel shutdown, then kill process if it doesn't shutdown.
This does not send shutdown requests - use :meth:`request_shutdown`
first.
"""
if waittime is None:
waittime = max(self.shutdown_wait_time, 0)
if self.provisioner: # Allow provisioner to override
waittime = self.provisioner.get_shutdown_wait_time(recommended=waittime)
try:
await asyncio.wait_for(
self._async_wait(pollinterval=pollinterval), timeout=waittime / 2
)
except asyncio.TimeoutError:
self.log.debug("Kernel is taking too long to finish, terminating")
self._shutdown_status = _ShutdownStatus.SigtermRequest
await self._async_send_kernel_sigterm()
try:
await asyncio.wait_for(
self._async_wait(pollinterval=pollinterval), timeout=waittime / 2
)
except asyncio.TimeoutError:
self.log.debug("Kernel is taking too long to finish, killing")
self._shutdown_status = _ShutdownStatus.SigkillRequest
await self._async_kill_kernel(restart=restart)
else:
# Process is no longer alive, wait and clear
if self.has_kernel:
assert self.provisioner is not None
await self.provisioner.wait()
finish_shutdown = run_sync(_async_finish_shutdown)
async def _async_cleanup_resources(self, restart: bool = False) -> None:
"""Clean up resources when the kernel is shut down"""
if not restart:
self.cleanup_connection_file()
self.cleanup_ipc_files()
self._close_control_socket()
self.session.parent = None
if self._created_context and not restart:
self.context.destroy(linger=100)
if self.provisioner:
await self.provisioner.cleanup(restart=restart)
cleanup_resources = run_sync(_async_cleanup_resources)
@in_pending_state
async def _async_shutdown_kernel(self, now: bool = False, restart: bool = False) -> None:
"""Attempts to stop the kernel process cleanly.
This attempts to shutdown the kernels cleanly by:
1. Sending it a shutdown message over the control channel.
2. If that fails, the kernel is shutdown forcibly by sending it
a signal.
Parameters
----------
now : bool
Should the kernel be forcible killed *now*. This skips the
first, nice shutdown attempt.
restart: bool
Will this kernel be restarted after it is shutdown. When this
is True, connection files will not be cleaned up.
"""
if not self.owns_kernel:
return
self.shutting_down = True # Used by restarter to prevent race condition
# Stop monitoring for restarting while we shutdown.
self.stop_restarter()
if self.has_kernel:
await self._async_interrupt_kernel()
if now:
await self._async_kill_kernel()
else:
await self._async_request_shutdown(restart=restart)
# Don't send any additional kernel kill messages immediately, to give
# the kernel a chance to properly execute shutdown actions. Wait for at
# most 1s, checking every 0.1s.
await self._async_finish_shutdown(restart=restart)
await self._async_cleanup_resources(restart=restart)
shutdown_kernel = run_sync(_async_shutdown_kernel)
async def _async_restart_kernel(
self, now: bool = False, newports: bool = False, **kw: t.Any
) -> None:
"""Restarts a kernel with the arguments that were used to launch it.
Parameters
----------
now : bool, optional
If True, the kernel is forcefully restarted *immediately*, without
having a chance to do any cleanup action. Otherwise the kernel is
given 1s to clean up before a forceful restart is issued.
In all cases the kernel is restarted, the only difference is whether
it is given a chance to perform a clean shutdown or not.
newports : bool, optional
If the old kernel was launched with random ports, this flag decides
whether the same ports and connection file will be used again.
If False, the same ports and connection file are used. This is
the default. If True, new random port numbers are chosen and a
new connection file is written. It is still possible that the newly
chosen random port numbers happen to be the same as the old ones.
`**kw` : optional
Any options specified here will overwrite those used to launch the
kernel.
"""
if self._launch_args is None:
msg = "Cannot restart the kernel. No previous call to 'start_kernel'."
raise RuntimeError(msg)
# Stop currently running kernel.
await self._async_shutdown_kernel(now=now, restart=True)
if newports:
self.cleanup_random_ports()
# Start new kernel.
self._launch_args.update(kw)
await self._async_start_kernel(**self._launch_args)
restart_kernel = run_sync(_async_restart_kernel)
@property
def owns_kernel(self) -> bool:
return self._owns_kernel
@property
def has_kernel(self) -> bool:
"""Has a kernel process been started that we are actively managing."""
return self.provisioner is not None and self.provisioner.has_process
async def _async_send_kernel_sigterm(self, restart: bool = False) -> None:
"""similar to _kill_kernel, but with sigterm (not sigkill), but do not block"""
if self.has_kernel:
assert self.provisioner is not None
await self.provisioner.terminate(restart=restart)
_send_kernel_sigterm = run_sync(_async_send_kernel_sigterm)
async def _async_kill_kernel(self, restart: bool = False) -> None:
"""Kill the running kernel.
This is a private method, callers should use shutdown_kernel(now=True).
"""
if self.has_kernel:
assert self.provisioner is not None
await self.provisioner.kill(restart=restart)
# Wait until the kernel terminates.
try:
await asyncio.wait_for(self._async_wait(), timeout=5.0)
except asyncio.TimeoutError:
# Wait timed out, just log warning but continue - not much more we can do.
self.log.warning("Wait for final termination of kernel timed out - continuing...")
pass
else:
# Process is no longer alive, wait and clear
if self.has_kernel:
await self.provisioner.wait()
_kill_kernel = run_sync(_async_kill_kernel)
async def _async_interrupt_kernel(self) -> None:
"""Interrupts the kernel by sending it a signal.
Unlike ``signal_kernel``, this operation is well supported on all
platforms.
"""
if not self.has_kernel and self._ready is not None:
if isinstance(self._ready, CFuture):
ready = asyncio.ensure_future(t.cast(Future[t.Any], self._ready))
else:
ready = self._ready
# Wait for a shutdown if one is in progress.
if self.shutting_down:
await ready
# Wait for a startup.
await ready
if self.has_kernel:
assert self.kernel_spec is not None
interrupt_mode = self.kernel_spec.interrupt_mode
if interrupt_mode == "signal":
await self._async_signal_kernel(signal.SIGINT)
elif interrupt_mode == "message":
msg = self.session.msg("interrupt_request", content={})
self._connect_control_socket()
self.session.send(self._control_socket, msg)
else:
msg = "Cannot interrupt kernel. No kernel is running!"
raise RuntimeError(msg)
interrupt_kernel = run_sync(_async_interrupt_kernel)
async def _async_signal_kernel(self, signum: int) -> None:
"""Sends a signal to the process group of the kernel (this
usually includes the kernel and any subprocesses spawned by
the kernel).
Note that since only SIGTERM is supported on Windows, this function is
only useful on Unix systems.
"""
if self.has_kernel:
assert self.provisioner is not None
await self.provisioner.send_signal(signum)
else:
msg = "Cannot signal kernel. No kernel is running!"
raise RuntimeError(msg)
signal_kernel = run_sync(_async_signal_kernel)
async def _async_is_alive(self) -> bool:
"""Is the kernel process still running?"""
if not self.owns_kernel:
return True
if self.has_kernel:
assert self.provisioner is not None
ret = await self.provisioner.poll()
if ret is None:
return True
return False
is_alive = run_sync(_async_is_alive)
async def _async_wait(self, pollinterval: float = 0.1) -> None:
# Use busy loop at 100ms intervals, polling until the process is
# not alive. If we find the process is no longer alive, complete
# its cleanup via the blocking wait(). Callers are responsible for
# issuing calls to wait() using a timeout (see _kill_kernel()).
while await self._async_is_alive():
await asyncio.sleep(pollinterval)
class AsyncKernelManager(KernelManager):
"""An async kernel manager."""
# the class to create with our `client` method
client_class: DottedObjectName = DottedObjectName(
"jupyter_client.asynchronous.AsyncKernelClient"
)
client_factory: Type = Type(klass="jupyter_client.asynchronous.AsyncKernelClient")
# The PyZMQ Context to use for communication with the kernel.
context: Instance = Instance(zmq.asyncio.Context)
@default("context")
def _context_default(self) -> zmq.asyncio.Context:
self._created_context = True
return zmq.asyncio.Context()
def client( # type:ignore[override]
self, **kwargs: t.Any
) -> AsyncKernelClient:
"""Get a client for the manager."""
return super().client(**kwargs) # type:ignore[return-value]
_launch_kernel = KernelManager._async_launch_kernel # type:ignore[assignment]
start_kernel: t.Callable[..., t.Awaitable] = KernelManager._async_start_kernel # type:ignore[assignment]
pre_start_kernel: t.Callable[..., t.Awaitable] = KernelManager._async_pre_start_kernel # type:ignore[assignment]
post_start_kernel: t.Callable[..., t.Awaitable] = KernelManager._async_post_start_kernel # type:ignore[assignment]
request_shutdown: t.Callable[..., t.Awaitable] = KernelManager._async_request_shutdown # type:ignore[assignment]
finish_shutdown: t.Callable[..., t.Awaitable] = KernelManager._async_finish_shutdown # type:ignore[assignment]
cleanup_resources: t.Callable[..., t.Awaitable] = KernelManager._async_cleanup_resources # type:ignore[assignment]
shutdown_kernel: t.Callable[..., t.Awaitable] = KernelManager._async_shutdown_kernel # type:ignore[assignment]
restart_kernel: t.Callable[..., t.Awaitable] = KernelManager._async_restart_kernel # type:ignore[assignment]
_send_kernel_sigterm = KernelManager._async_send_kernel_sigterm # type:ignore[assignment]
_kill_kernel = KernelManager._async_kill_kernel # type:ignore[assignment]
interrupt_kernel: t.Callable[..., t.Awaitable] = KernelManager._async_interrupt_kernel # type:ignore[assignment]
signal_kernel: t.Callable[..., t.Awaitable] = KernelManager._async_signal_kernel # type:ignore[assignment]
is_alive: t.Callable[..., t.Awaitable] = KernelManager._async_is_alive # type:ignore[assignment]
KernelManagerABC.register(KernelManager)
def start_new_kernel(
startup_timeout: float = 60, kernel_name: str = "python", **kwargs: t.Any
) -> t.Tuple[KernelManager, BlockingKernelClient]:
"""Start a new kernel, and return its Manager and Client"""
km = KernelManager(kernel_name=kernel_name)
km.start_kernel(**kwargs)
kc = km.client()
kc.start_channels()
try:
kc.wait_for_ready(timeout=startup_timeout)
except RuntimeError:
kc.stop_channels()
km.shutdown_kernel()
raise
return km, kc
async def start_new_async_kernel(
startup_timeout: float = 60, kernel_name: str = "python", **kwargs: t.Any
) -> t.Tuple[AsyncKernelManager, AsyncKernelClient]:
"""Start a new kernel, and return its Manager and Client"""
km = AsyncKernelManager(kernel_name=kernel_name)
await km.start_kernel(**kwargs)
kc = km.client()
kc.start_channels()
try:
await kc.wait_for_ready(timeout=startup_timeout)
except RuntimeError:
kc.stop_channels()
await km.shutdown_kernel()
raise
return (km, kc)
@contextmanager
def run_kernel(**kwargs: t.Any) -> t.Iterator[KernelClient]:
"""Context manager to create a kernel in a subprocess.
The kernel is shut down when the context exits.
Returns
-------
kernel_client: connected KernelClient instance
"""
km, kc = start_new_kernel(**kwargs)
try:
yield kc
finally:
kc.stop_channels()
km.shutdown_kernel(now=True)

View File

@@ -0,0 +1,56 @@
"""Abstract base class for kernel managers."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import abc
from typing import Any
class KernelManagerABC(metaclass=abc.ABCMeta):
"""KernelManager ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.manager.KernelManager`
"""
@abc.abstractproperty
def kernel(self) -> Any:
pass
# --------------------------------------------------------------------------
# Kernel management
# --------------------------------------------------------------------------
@abc.abstractmethod
def start_kernel(self, **kw: Any) -> None:
"""Start the kernel."""
pass
@abc.abstractmethod
def shutdown_kernel(self, now: bool = False, restart: bool = False) -> None:
"""Shut down the kernel."""
pass
@abc.abstractmethod
def restart_kernel(self, now: bool = False, **kw: Any) -> None:
"""Restart the kernel."""
pass
@abc.abstractproperty
def has_kernel(self) -> bool:
pass
@abc.abstractmethod
def interrupt_kernel(self) -> None:
"""Interrupt the kernel."""
pass
@abc.abstractmethod
def signal_kernel(self, signum: int) -> None:
"""Send a signal to the kernel."""
pass
@abc.abstractmethod
def is_alive(self) -> bool:
"""Test whether the kernel is alive."""
pass

View File

@@ -0,0 +1,624 @@
"""A kernel manager for multiple kernels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import asyncio
import json
import os
import socket
import typing as t
import uuid
from functools import wraps
from pathlib import Path
import zmq
from traitlets import Any, Bool, Dict, DottedObjectName, Instance, Unicode, default, observe
from traitlets.config.configurable import LoggingConfigurable
from traitlets.utils.importstring import import_item
from .connect import KernelConnectionInfo
from .kernelspec import NATIVE_KERNEL_NAME, KernelSpecManager
from .manager import KernelManager
from .utils import ensure_async, run_sync, utcnow
class DuplicateKernelError(Exception):
pass
def kernel_method(f: t.Callable) -> t.Callable:
"""decorator for proxying MKM.method(kernel_id) to individual KMs by ID"""
@wraps(f)
def wrapped(
self: t.Any, kernel_id: str, *args: t.Any, **kwargs: t.Any
) -> t.Callable | t.Awaitable:
# get the kernel
km = self.get_kernel(kernel_id)
method = getattr(km, f.__name__)
# call the kernel's method
r = method(*args, **kwargs)
# last thing, call anything defined in the actual class method
# such as logging messages
f(self, kernel_id, *args, **kwargs)
# return the method result
return r
return wrapped
class MultiKernelManager(LoggingConfigurable):
"""A class for managing multiple kernels."""
default_kernel_name = Unicode(
NATIVE_KERNEL_NAME, help="The name of the default kernel to start"
).tag(config=True)
kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)
kernel_manager_class = DottedObjectName(
"jupyter_client.ioloop.IOLoopKernelManager",
help="""The kernel manager class. This is configurable to allow
subclassing of the KernelManager for customized behavior.
""",
).tag(config=True)
@observe("kernel_manager_class")
def _kernel_manager_class_changed(self, change: t.Any) -> None:
self.kernel_manager_factory = self._create_kernel_manager_factory()
kernel_manager_factory = Any(help="this is kernel_manager_class after import")
@default("kernel_manager_factory")
def _kernel_manager_factory_default(self) -> t.Callable:
return self._create_kernel_manager_factory()
def _create_kernel_manager_factory(self) -> t.Callable:
kernel_manager_ctor = import_item(self.kernel_manager_class)
def create_kernel_manager(*args: t.Any, **kwargs: t.Any) -> KernelManager:
if self.shared_context:
if self.context.closed:
# recreate context if closed
self.context = self._context_default()
kwargs.setdefault("context", self.context)
km = kernel_manager_ctor(*args, **kwargs)
return km
return create_kernel_manager
shared_context = Bool(
True,
help="Share a single zmq.Context to talk to all my kernels",
).tag(config=True)
context = Instance("zmq.Context")
_created_context = Bool(False)
_pending_kernels = Dict()
@property
def _starting_kernels(self) -> dict:
"""A shim for backwards compatibility."""
return self._pending_kernels
@default("context")
def _context_default(self) -> zmq.Context:
self._created_context = True
return zmq.Context()
connection_dir = Unicode("")
external_connection_dir = Unicode(None, allow_none=True)
_kernels = Dict()
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.kernel_id_to_connection_file: dict[str, Path] = {}
def __del__(self) -> None:
"""Handle garbage collection. Destroy context if applicable."""
if self._created_context and self.context and not self.context.closed:
if self.log:
self.log.debug("Destroying zmq context for %s", self)
self.context.destroy()
try:
super_del = super().__del__ # type:ignore[misc]
except AttributeError:
pass
else:
super_del()
def list_kernel_ids(self) -> list[str]:
"""Return a list of the kernel ids of the active kernels."""
if self.external_connection_dir is not None:
external_connection_dir = Path(self.external_connection_dir)
if external_connection_dir.is_dir():
connection_files = [p for p in external_connection_dir.iterdir() if p.is_file()]
# remove kernels (whose connection file has disappeared) from our list
k = list(self.kernel_id_to_connection_file.keys())
v = list(self.kernel_id_to_connection_file.values())
for connection_file in list(self.kernel_id_to_connection_file.values()):
if connection_file not in connection_files:
kernel_id = k[v.index(connection_file)]
del self.kernel_id_to_connection_file[kernel_id]
del self._kernels[kernel_id]
# add kernels (whose connection file appeared) to our list
for connection_file in connection_files:
if connection_file in self.kernel_id_to_connection_file.values():
continue
try:
connection_info: KernelConnectionInfo = json.loads(
connection_file.read_text()
)
except Exception: # noqa: S112
continue
self.log.debug("Loading connection file %s", connection_file)
if not ("kernel_name" in connection_info and "key" in connection_info):
continue
# it looks like a connection file
kernel_id = self.new_kernel_id()
self.kernel_id_to_connection_file[kernel_id] = connection_file
km = self.kernel_manager_factory(
parent=self,
log=self.log,
owns_kernel=False,
)
km.load_connection_info(connection_info)
km.last_activity = utcnow()
km.execution_state = "idle"
km.connections = 1
km.kernel_id = kernel_id
km.kernel_name = connection_info["kernel_name"]
km.ready.set_result(None)
self._kernels[kernel_id] = km
# Create a copy so we can iterate over kernels in operations
# that delete keys.
return list(self._kernels.keys())
def __len__(self) -> int:
"""Return the number of running kernels."""
return len(self.list_kernel_ids())
def __contains__(self, kernel_id: str) -> bool:
return kernel_id in self._kernels
def pre_start_kernel(
self, kernel_name: str | None, kwargs: t.Any
) -> tuple[KernelManager, str, str]:
# kwargs should be mutable, passing it as a dict argument.
kernel_id = kwargs.pop("kernel_id", self.new_kernel_id(**kwargs))
if kernel_id in self:
raise DuplicateKernelError("Kernel already exists: %s" % kernel_id)
if kernel_name is None:
kernel_name = self.default_kernel_name
# kernel_manager_factory is the constructor for the KernelManager
# subclass we are using. It can be configured as any Configurable,
# including things like its transport and ip.
constructor_kwargs = {}
if self.kernel_spec_manager:
constructor_kwargs["kernel_spec_manager"] = self.kernel_spec_manager
km = self.kernel_manager_factory(
connection_file=os.path.join(self.connection_dir, "kernel-%s.json" % kernel_id),
parent=self,
log=self.log,
kernel_name=kernel_name,
**constructor_kwargs,
)
return km, kernel_name, kernel_id
def update_env(self, *, kernel_id: str, env: t.Dict[str, str]) -> None:
"""
Allow to update the environment of the given kernel.
Forward the update env request to the corresponding kernel.
.. version-added: 8.5
"""
if kernel_id in self:
self._kernels[kernel_id].update_env(env=env)
async def _add_kernel_when_ready(
self, kernel_id: str, km: KernelManager, kernel_awaitable: t.Awaitable
) -> None:
try:
await kernel_awaitable
self._kernels[kernel_id] = km
self._pending_kernels.pop(kernel_id, None)
except Exception as e:
self.log.exception(e)
async def _remove_kernel_when_ready(
self, kernel_id: str, kernel_awaitable: t.Awaitable
) -> None:
try:
await kernel_awaitable
self.remove_kernel(kernel_id)
self._pending_kernels.pop(kernel_id, None)
except Exception as e:
self.log.exception(e)
def _using_pending_kernels(self) -> bool:
"""Returns a boolean; a clearer method for determining if
this multikernelmanager is using pending kernels or not
"""
return getattr(self, "use_pending_kernels", False)
async def _async_start_kernel(self, *, kernel_name: str | None = None, **kwargs: t.Any) -> str:
"""Start a new kernel.
The caller can pick a kernel_id by passing one in as a keyword arg,
otherwise one will be generated using new_kernel_id().
The kernel ID for the newly started kernel is returned.
"""
km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs)
if not isinstance(km, KernelManager):
self.log.warning( # type:ignore[unreachable]
"Kernel manager class ({km_class}) is not an instance of 'KernelManager'!".format(
km_class=self.kernel_manager_class.__class__
)
)
kwargs["kernel_id"] = kernel_id # Make kernel_id available to manager and provisioner
starter = ensure_async(km.start_kernel(**kwargs))
task = asyncio.create_task(self._add_kernel_when_ready(kernel_id, km, starter))
self._pending_kernels[kernel_id] = task
# Handling a Pending Kernel
if self._using_pending_kernels():
# If using pending kernels, do not block
# on the kernel start.
self._kernels[kernel_id] = km
else:
await task
# raise an exception if one occurred during kernel startup.
if km.ready.exception():
raise km.ready.exception() # type: ignore[misc]
return kernel_id
start_kernel = run_sync(_async_start_kernel)
async def _async_shutdown_kernel(
self,
kernel_id: str,
now: bool | None = False,
restart: bool | None = False,
) -> None:
"""Shutdown a kernel by its kernel uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel to shutdown.
now : bool
Should the kernel be shutdown forcibly using a signal.
restart : bool
Will the kernel be restarted?
"""
self.log.info("Kernel shutdown: %s", kernel_id)
# If the kernel is still starting, wait for it to be ready.
if kernel_id in self._pending_kernels:
task = self._pending_kernels[kernel_id]
try:
await task
km = self.get_kernel(kernel_id)
await t.cast(asyncio.Future, km.ready)
except asyncio.CancelledError:
pass
except Exception:
self.remove_kernel(kernel_id)
return
km = self.get_kernel(kernel_id)
# If a pending kernel raised an exception, remove it.
if not km.ready.cancelled() and km.ready.exception():
self.remove_kernel(kernel_id)
return
stopper = ensure_async(km.shutdown_kernel(now, restart))
fut = asyncio.ensure_future(self._remove_kernel_when_ready(kernel_id, stopper))
self._pending_kernels[kernel_id] = fut
# Await the kernel if not using pending kernels.
if not self._using_pending_kernels():
await fut
# raise an exception if one occurred during kernel shutdown.
if km.ready.exception():
raise km.ready.exception() # type: ignore[misc]
shutdown_kernel = run_sync(_async_shutdown_kernel)
@kernel_method
def request_shutdown(self, kernel_id: str, restart: bool | None = False) -> None:
"""Ask a kernel to shut down by its kernel uuid"""
@kernel_method
def finish_shutdown(
self,
kernel_id: str,
waittime: float | None = None,
pollinterval: float | None = 0.1,
) -> None:
"""Wait for a kernel to finish shutting down, and kill it if it doesn't"""
self.log.info("Kernel shutdown: %s", kernel_id)
@kernel_method
def cleanup_resources(self, kernel_id: str, restart: bool = False) -> None:
"""Clean up a kernel's resources"""
def remove_kernel(self, kernel_id: str) -> KernelManager:
"""remove a kernel from our mapping.
Mainly so that a kernel can be removed if it is already dead,
without having to call shutdown_kernel.
The kernel object is returned, or `None` if not found.
"""
return self._kernels.pop(kernel_id, None)
async def _async_shutdown_all(self, now: bool = False) -> None:
"""Shutdown all kernels."""
kids = self.list_kernel_ids()
kids += list(self._pending_kernels)
kms = list(self._kernels.values())
futs = [self._async_shutdown_kernel(kid, now=now) for kid in set(kids)]
await asyncio.gather(*futs)
# If using pending kernels, the kernels will not have been fully shut down.
if self._using_pending_kernels():
for km in kms:
try:
await km.ready
except asyncio.CancelledError:
self._pending_kernels[km.kernel_id].cancel()
except Exception:
# Will have been logged in _add_kernel_when_ready
pass
shutdown_all = run_sync(_async_shutdown_all)
def interrupt_kernel(self, kernel_id: str) -> None:
"""Interrupt (SIGINT) the kernel by its uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel to interrupt.
"""
kernel = self.get_kernel(kernel_id)
if not kernel.ready.done():
msg = "Kernel is in a pending state. Cannot interrupt."
raise RuntimeError(msg)
out = kernel.interrupt_kernel()
self.log.info("Kernel interrupted: %s", kernel_id)
return out
@kernel_method
def signal_kernel(self, kernel_id: str, signum: int) -> None:
"""Sends a signal to the kernel by its uuid.
Note that since only SIGTERM is supported on Windows, this function
is only useful on Unix systems.
Parameters
==========
kernel_id : uuid
The id of the kernel to signal.
signum : int
Signal number to send kernel.
"""
self.log.info("Signaled Kernel %s with %s", kernel_id, signum)
async def _async_restart_kernel(self, kernel_id: str, now: bool = False) -> None:
"""Restart a kernel by its uuid, keeping the same ports.
Parameters
==========
kernel_id : uuid
The id of the kernel to interrupt.
now : bool, optional
If True, the kernel is forcefully restarted *immediately*, without
having a chance to do any cleanup action. Otherwise the kernel is
given 1s to clean up before a forceful restart is issued.
In all cases the kernel is restarted, the only difference is whether
it is given a chance to perform a clean shutdown or not.
"""
kernel = self.get_kernel(kernel_id)
if self._using_pending_kernels() and not kernel.ready.done():
msg = "Kernel is in a pending state. Cannot restart."
raise RuntimeError(msg)
await ensure_async(kernel.restart_kernel(now=now))
self.log.info("Kernel restarted: %s", kernel_id)
restart_kernel = run_sync(_async_restart_kernel)
@kernel_method
def is_alive(self, kernel_id: str) -> bool: # type:ignore[empty-body]
"""Is the kernel alive.
This calls KernelManager.is_alive() which calls Popen.poll on the
actual kernel subprocess.
Parameters
==========
kernel_id : uuid
The id of the kernel.
"""
def _check_kernel_id(self, kernel_id: str) -> None:
"""check that a kernel id is valid"""
if kernel_id not in self:
raise KeyError("Kernel with id not found: %s" % kernel_id)
def get_kernel(self, kernel_id: str) -> KernelManager:
"""Get the single KernelManager object for a kernel by its uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel.
"""
self._check_kernel_id(kernel_id)
return self._kernels[kernel_id]
@kernel_method
def add_restart_callback(
self, kernel_id: str, callback: t.Callable, event: str = "restart"
) -> None:
"""add a callback for the KernelRestarter"""
@kernel_method
def remove_restart_callback(
self, kernel_id: str, callback: t.Callable, event: str = "restart"
) -> None:
"""remove a callback for the KernelRestarter"""
@kernel_method
def get_connection_info(self, kernel_id: str) -> dict[str, t.Any]: # type:ignore[empty-body]
"""Return a dictionary of connection data for a kernel.
Parameters
==========
kernel_id : uuid
The id of the kernel.
Returns
=======
connection_dict : dict
A dict of the information needed to connect to a kernel.
This includes the ip address and the integer port
numbers of the different channels (stdin_port, iopub_port,
shell_port, hb_port).
"""
@kernel_method
def connect_iopub( # type:ignore[empty-body]
self, kernel_id: str, identity: bytes | None = None
) -> socket.socket:
"""Return a zmq Socket connected to the iopub channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_shell( # type:ignore[empty-body]
self, kernel_id: str, identity: bytes | None = None
) -> socket.socket:
"""Return a zmq Socket connected to the shell channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_control( # type:ignore[empty-body]
self, kernel_id: str, identity: bytes | None = None
) -> socket.socket:
"""Return a zmq Socket connected to the control channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_stdin( # type:ignore[empty-body]
self, kernel_id: str, identity: bytes | None = None
) -> socket.socket:
"""Return a zmq Socket connected to the stdin channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_hb( # type:ignore[empty-body]
self, kernel_id: str, identity: bytes | None = None
) -> socket.socket:
"""Return a zmq Socket connected to the hb channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
def new_kernel_id(self, **kwargs: t.Any) -> str:
"""
Returns the id to associate with the kernel for this request. Subclasses may override
this method to substitute other sources of kernel ids.
:param kwargs:
:return: string-ized version 4 uuid
"""
return str(uuid.uuid4())
class AsyncMultiKernelManager(MultiKernelManager):
kernel_manager_class = DottedObjectName(
"jupyter_client.ioloop.AsyncIOLoopKernelManager",
config=True,
help="""The kernel manager class. This is configurable to allow
subclassing of the AsyncKernelManager for customized behavior.
""",
)
use_pending_kernels = Bool(
False,
help="""Whether to make kernels available before the process has started. The
kernel has a `.ready` future which can be awaited before connecting""",
).tag(config=True)
context = Instance("zmq.asyncio.Context")
@default("context")
def _context_default(self) -> zmq.asyncio.Context:
self._created_context = True
return zmq.asyncio.Context()
start_kernel: t.Callable[..., t.Awaitable] = MultiKernelManager._async_start_kernel # type:ignore[assignment]
restart_kernel: t.Callable[..., t.Awaitable] = MultiKernelManager._async_restart_kernel # type:ignore[assignment]
shutdown_kernel: t.Callable[..., t.Awaitable] = MultiKernelManager._async_shutdown_kernel # type:ignore[assignment]
shutdown_all: t.Callable[..., t.Awaitable] = MultiKernelManager._async_shutdown_all # type:ignore[assignment]

View File

@@ -0,0 +1,3 @@
from .factory import KernelProvisionerFactory # noqa
from .local_provisioner import LocalProvisioner # noqa
from .provisioner_base import KernelProvisionerBase # noqa

View File

@@ -0,0 +1,200 @@
"""Kernel Provisioner Classes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import glob
import sys
from os import getenv, path
from typing import Any, Dict, List
# See compatibility note on `group` keyword in https://docs.python.org/3/library/importlib.metadata.html#entry-points
if sys.version_info < (3, 10): # pragma: no cover
from importlib_metadata import EntryPoint, entry_points # type:ignore[import-not-found]
else: # pragma: no cover
from importlib.metadata import EntryPoint, entry_points
from traitlets.config import SingletonConfigurable, Unicode, default
from .provisioner_base import KernelProvisionerBase
class KernelProvisionerFactory(SingletonConfigurable):
"""
:class:`KernelProvisionerFactory` is responsible for creating provisioner instances.
A singleton instance, `KernelProvisionerFactory` is also used by the :class:`KernelSpecManager`
to validate `kernel_provisioner` references found in kernel specifications to confirm their
availability (in cases where the kernel specification references a kernel provisioner that has
not been installed into the current Python environment).
It's ``default_provisioner_name`` attribute can be used to specify the default provisioner
to use when a kernel_spec is found to not reference a provisioner. It's value defaults to
`"local-provisioner"` which identifies the local provisioner implemented by
:class:`LocalProvisioner`.
"""
GROUP_NAME = "jupyter_client.kernel_provisioners"
provisioners: Dict[str, EntryPoint] = {}
default_provisioner_name_env = "JUPYTER_DEFAULT_PROVISIONER_NAME"
default_provisioner_name = Unicode(
config=True,
help="""Indicates the name of the provisioner to use when no kernel_provisioner
entry is present in the kernelspec.""",
)
@default("default_provisioner_name")
def _default_provisioner_name_default(self) -> str:
"""The default provisioner name."""
return getenv(self.default_provisioner_name_env, "local-provisioner")
def __init__(self, **kwargs: Any) -> None:
"""Initialize a kernel provisioner factory."""
super().__init__(**kwargs)
for ep in KernelProvisionerFactory._get_all_provisioners():
self.provisioners[ep.name] = ep
def is_provisioner_available(self, kernel_spec: Any) -> bool:
"""
Reads the associated ``kernel_spec`` to determine the provisioner and returns whether it
exists as an entry_point (True) or not (False). If the referenced provisioner is not
in the current cache or cannot be loaded via entry_points, a warning message is issued
indicating it is not available.
"""
is_available: bool = True
provisioner_cfg = self._get_provisioner_config(kernel_spec)
provisioner_name = str(provisioner_cfg.get("provisioner_name"))
if not self._check_availability(provisioner_name):
is_available = False
self.log.warning(
f"Kernel '{kernel_spec.display_name}' is referencing a kernel "
f"provisioner ('{provisioner_name}') that is not available. "
f"Ensure the appropriate package has been installed and retry."
)
return is_available
def create_provisioner_instance(
self, kernel_id: str, kernel_spec: Any, parent: Any
) -> KernelProvisionerBase:
"""
Reads the associated ``kernel_spec`` to see if it has a `kernel_provisioner` stanza.
If one exists, it instantiates an instance. If a kernel provisioner is not
specified in the kernel specification, a default provisioner stanza is fabricated
and instantiated corresponding to the current value of ``default_provisioner_name`` trait.
The instantiated instance is returned.
If the provisioner is found to not exist (not registered via entry_points),
`ModuleNotFoundError` is raised.
"""
provisioner_cfg = self._get_provisioner_config(kernel_spec)
provisioner_name = str(provisioner_cfg.get("provisioner_name"))
if not self._check_availability(provisioner_name):
msg = f"Kernel provisioner '{provisioner_name}' has not been registered."
raise ModuleNotFoundError(msg)
self.log.debug(
f"Instantiating kernel '{kernel_spec.display_name}' with "
f"kernel provisioner: {provisioner_name}"
)
provisioner_class = self.provisioners[provisioner_name].load()
provisioner_config = provisioner_cfg.get("config")
provisioner: KernelProvisionerBase = provisioner_class(
kernel_id=kernel_id, kernel_spec=kernel_spec, parent=parent, **provisioner_config
)
return provisioner
def _check_availability(self, provisioner_name: str) -> bool:
"""
Checks that the given provisioner is available.
If the given provisioner is not in the current set of loaded provisioners an attempt
is made to fetch the named entry point and, if successful, loads it into the cache.
:param provisioner_name:
:return:
"""
is_available = True
if provisioner_name not in self.provisioners:
try:
ep = self._get_provisioner(provisioner_name)
self.provisioners[provisioner_name] = ep # Update cache
except Exception:
is_available = False
return is_available
def _get_provisioner_config(self, kernel_spec: Any) -> Dict[str, Any]:
"""
Return the kernel_provisioner stanza from the kernel_spec.
Checks the kernel_spec's metadata dictionary for a kernel_provisioner entry.
If found, it is returned, else one is created relative to the DEFAULT_PROVISIONER
and returned.
Parameters
----------
kernel_spec : Any - this is a KernelSpec type but listed as Any to avoid circular import
The kernel specification object from which the provisioner dictionary is derived.
Returns
-------
dict
The provisioner portion of the kernel_spec. If one does not exist, it will contain
the default information. If no `config` sub-dictionary exists, an empty `config`
dictionary will be added.
"""
env_provisioner = kernel_spec.metadata.get("kernel_provisioner", {})
if "provisioner_name" in env_provisioner: # If no provisioner_name, return default
if (
"config" not in env_provisioner
): # if provisioner_name, but no config stanza, add one
env_provisioner.update({"config": {}})
return env_provisioner # Return what we found (plus config stanza if necessary)
return {"provisioner_name": self.default_provisioner_name, "config": {}}
def get_provisioner_entries(self) -> Dict[str, str]:
"""
Returns a dictionary of provisioner entries.
The key is the provisioner name for its entry point. The value is the colon-separated
string of the entry point's module name and object name.
"""
entries = {}
for name, ep in self.provisioners.items():
entries[name] = ep.value
return entries
@staticmethod
def _get_all_provisioners() -> List[EntryPoint]:
"""Wrapper around entry_points (to fetch the set of provisioners) - primarily to facilitate testing."""
return entry_points(group=KernelProvisionerFactory.GROUP_NAME)
def _get_provisioner(self, name: str) -> EntryPoint:
"""Wrapper around entry_points (to fetch a single provisioner) - primarily to facilitate testing."""
eps = entry_points(group=KernelProvisionerFactory.GROUP_NAME, name=name)
if eps:
return eps[0]
# Check if the entrypoint name is 'local-provisioner'. Although this should never
# happen, we have seen cases where the previous distribution of jupyter_client has
# remained which doesn't include kernel-provisioner entrypoints (so 'local-provisioner'
# is deemed not found even though its definition is in THIS package). In such cases,
# the entrypoints package uses what it first finds - which is the older distribution
# resulting in a violation of a supposed invariant condition. To address this scenario,
# we will log a warning message indicating this situation, then build the entrypoint
# instance ourselves - since we have that information.
if name == "local-provisioner":
distros = glob.glob(f"{path.dirname(path.dirname(__file__))}-*")
self.log.warning(
f"Kernel Provisioning: The 'local-provisioner' is not found. This is likely "
f"due to the presence of multiple jupyter_client distributions and a previous "
f"distribution is being used as the source for entrypoints - which does not "
f"include 'local-provisioner'. That distribution should be removed such that "
f"only the version-appropriate distribution remains (version >= 7). Until "
f"then, a 'local-provisioner' entrypoint will be automatically constructed "
f"and used.\nThe candidate distribution locations are: {distros}"
)
return EntryPoint(
"local-provisioner", "jupyter_client.provisioning", "LocalProvisioner"
)
raise

View File

@@ -0,0 +1,242 @@
"""Kernel Provisioner Classes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import os
import signal
import sys
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ..connect import KernelConnectionInfo, LocalPortCache
from ..launcher import launch_kernel
from ..localinterfaces import is_local_ip, local_ips
from .provisioner_base import KernelProvisionerBase
class LocalProvisioner(KernelProvisionerBase): # type:ignore[misc]
"""
:class:`LocalProvisioner` is a concrete class of ABC :py:class:`KernelProvisionerBase`
and is the out-of-box default implementation used when no kernel provisioner is
specified in the kernel specification (``kernel.json``). It provides functional
parity to existing applications by launching the kernel locally and using
:class:`subprocess.Popen` to manage its lifecycle.
This class is intended to be subclassed for customizing local kernel environments
and serve as a reference implementation for other custom provisioners.
"""
process = None
_exit_future = None
pid = None
pgid = None
ip = None
ports_cached = False
@property
def has_process(self) -> bool:
return self.process is not None
async def poll(self) -> Optional[int]:
"""Poll the provisioner."""
ret = 0
if self.process:
ret = self.process.poll() # type:ignore[unreachable]
return ret
async def wait(self) -> Optional[int]:
"""Wait for the provisioner process."""
ret = 0
if self.process:
# Use busy loop at 100ms intervals, polling until the process is
# not alive. If we find the process is no longer alive, complete
# its cleanup via the blocking wait(). Callers are responsible for
# issuing calls to wait() using a timeout (see kill()).
while await self.poll() is None: # type:ignore[unreachable]
await asyncio.sleep(0.1)
# Process is no longer alive, wait and clear
ret = self.process.wait()
# Make sure all the fds get closed.
for attr in ["stdout", "stderr", "stdin"]:
fid = getattr(self.process, attr)
if fid:
fid.close()
self.process = None # allow has_process to now return False
return ret
async def send_signal(self, signum: int) -> None:
"""Sends a signal to the process group of the kernel (this
usually includes the kernel and any subprocesses spawned by
the kernel).
Note that since only SIGTERM is supported on Windows, we will
check if the desired signal is for interrupt and apply the
applicable code on Windows in that case.
"""
if self.process:
if signum == signal.SIGINT and sys.platform == "win32": # type:ignore[unreachable]
from ..win_interrupt import send_interrupt
send_interrupt(self.process.win32_interrupt_event)
return
# Prefer process-group over process
if self.pgid and hasattr(os, "killpg"):
try:
os.killpg(self.pgid, signum)
return
except OSError:
pass # We'll retry sending the signal to only the process below
# If we're here, send the signal to the process and let caller handle exceptions
self.process.send_signal(signum)
return
async def kill(self, restart: bool = False) -> None:
"""Kill the provisioner and optionally restart."""
if self.process:
if hasattr(signal, "SIGKILL"): # type:ignore[unreachable]
# If available, give preference to signalling the process-group over `kill()`.
try:
await self.send_signal(signal.SIGKILL)
return
except OSError:
pass
try:
self.process.kill()
except OSError as e:
LocalProvisioner._tolerate_no_process(e)
async def terminate(self, restart: bool = False) -> None:
"""Terminate the provisioner and optionally restart."""
if self.process:
if hasattr(signal, "SIGTERM"): # type:ignore[unreachable]
# If available, give preference to signalling the process group over `terminate()`.
try:
await self.send_signal(signal.SIGTERM)
return
except OSError:
pass
try:
self.process.terminate()
except OSError as e:
LocalProvisioner._tolerate_no_process(e)
@staticmethod
def _tolerate_no_process(os_error: OSError) -> None:
# In Windows, we will get an Access Denied error if the process
# has already terminated. Ignore it.
if sys.platform == "win32":
if os_error.winerror != 5:
raise
# On Unix, we may get an ESRCH error (or ProcessLookupError instance) if
# the process has already terminated. Ignore it.
else:
from errno import ESRCH
if not isinstance(os_error, ProcessLookupError) or os_error.errno != ESRCH:
raise
async def cleanup(self, restart: bool = False) -> None:
"""Clean up the resources used by the provisioner and optionally restart."""
if self.ports_cached and not restart:
# provisioner is about to be destroyed, return cached ports
lpc = LocalPortCache.instance()
ports = (
self.connection_info["shell_port"],
self.connection_info["iopub_port"],
self.connection_info["stdin_port"],
self.connection_info["hb_port"],
self.connection_info["control_port"],
)
for port in ports:
if TYPE_CHECKING:
assert isinstance(port, int)
lpc.return_port(port)
async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
"""Perform any steps in preparation for kernel process launch.
This includes applying additional substitutions to the kernel launch command and env.
It also includes preparation of launch parameters.
Returns the updated kwargs.
"""
# This should be considered temporary until a better division of labor can be defined.
km = self.parent
if km:
if km.transport == "tcp" and not is_local_ip(km.ip):
msg = (
"Can only launch a kernel on a local interface. "
f"This one is not: {km.ip}."
"Make sure that the '*_address' attributes are "
"configured properly. "
f"Currently valid addresses are: {local_ips()}"
)
raise RuntimeError(msg)
# build the Popen cmd
extra_arguments = kwargs.pop("extra_arguments", [])
# write connection file / get default ports
# TODO - change when handshake pattern is adopted
if km.cache_ports and not self.ports_cached:
lpc = LocalPortCache.instance()
km.shell_port = lpc.find_available_port(km.ip)
km.iopub_port = lpc.find_available_port(km.ip)
km.stdin_port = lpc.find_available_port(km.ip)
km.hb_port = lpc.find_available_port(km.ip)
km.control_port = lpc.find_available_port(km.ip)
self.ports_cached = True
if "env" in kwargs:
jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "")
km.write_connection_file(jupyter_session=jupyter_session)
else:
km.write_connection_file()
self.connection_info = km.get_connection_info()
kernel_cmd = km.format_kernel_cmd(
extra_arguments=extra_arguments
) # This needs to remain here for b/c
else:
extra_arguments = kwargs.pop("extra_arguments", [])
kernel_cmd = self.kernel_spec.argv + extra_arguments
return await super().pre_launch(cmd=kernel_cmd, **kwargs)
async def launch_kernel(self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
"""Launch a kernel with a command."""
scrubbed_kwargs = LocalProvisioner._scrub_kwargs(kwargs)
self.process = launch_kernel(cmd, **scrubbed_kwargs)
pgid = None
if hasattr(os, "getpgid"):
try:
pgid = os.getpgid(self.process.pid)
except OSError:
pass
self.pid = self.process.pid
self.pgid = pgid
return self.connection_info
@staticmethod
def _scrub_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Remove any keyword arguments that Popen does not tolerate."""
keywords_to_scrub: List[str] = ["extra_arguments", "kernel_id"]
scrubbed_kwargs = kwargs.copy()
for kw in keywords_to_scrub:
scrubbed_kwargs.pop(kw, None)
return scrubbed_kwargs
async def get_provisioner_info(self) -> Dict:
"""Captures the base information necessary for persistence relative to this instance."""
provisioner_info = await super().get_provisioner_info()
provisioner_info.update({"pid": self.pid, "pgid": self.pgid, "ip": self.ip})
return provisioner_info
async def load_provisioner_info(self, provisioner_info: Dict) -> None:
"""Loads the base information necessary for persistence relative to this instance."""
await super().load_provisioner_info(provisioner_info)
self.pid = provisioner_info["pid"]
self.pgid = provisioner_info["pgid"]
self.ip = provisioner_info["ip"]

View File

@@ -0,0 +1,257 @@
"""Kernel Provisioner Classes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Dict, List, Optional, Union
from traitlets.config import Instance, LoggingConfigurable, Unicode
from ..connect import KernelConnectionInfo
class KernelProvisionerMeta(ABCMeta, type(LoggingConfigurable)): # type: ignore[misc]
pass
class KernelProvisionerBase( # type:ignore[misc]
ABC, LoggingConfigurable, metaclass=KernelProvisionerMeta
):
"""
Abstract base class defining methods for KernelProvisioner classes.
A majority of methods are abstract (requiring implementations via a subclass) while
some are optional and others provide implementations common to all instances.
Subclasses should be aware of which methods require a call to the superclass.
Many of these methods model those of :class:`subprocess.Popen` for parity with
previous versions where the kernel process was managed directly.
"""
# The kernel specification associated with this provisioner
kernel_spec: Any = Instance("jupyter_client.kernelspec.KernelSpec", allow_none=True)
kernel_id: Union[str, Unicode] = Unicode(None, allow_none=True)
connection_info: KernelConnectionInfo = {}
@property
@abstractmethod
def has_process(self) -> bool:
"""
Returns true if this provisioner is currently managing a process.
This property is asserted to be True immediately following a call to
the provisioner's :meth:`launch_kernel` method.
"""
pass
@abstractmethod
async def poll(self) -> Optional[int]:
"""
Checks if kernel process is still running.
If running, None is returned, otherwise the process's integer-valued exit code is returned.
This method is called from :meth:`KernelManager.is_alive`.
"""
pass
@abstractmethod
async def wait(self) -> Optional[int]:
"""
Waits for kernel process to terminate.
This method is called from `KernelManager.finish_shutdown()` and
`KernelManager.kill_kernel()` when terminating a kernel gracefully or
immediately, respectively.
"""
pass
@abstractmethod
async def send_signal(self, signum: int) -> None:
"""
Sends signal identified by signum to the kernel process.
This method is called from `KernelManager.signal_kernel()` to send the
kernel process a signal.
"""
pass
@abstractmethod
async def kill(self, restart: bool = False) -> None:
"""
Kill the kernel process.
This is typically accomplished via a SIGKILL signal, which cannot be caught.
This method is called from `KernelManager.kill_kernel()` when terminating
a kernel immediately.
restart is True if this operation will precede a subsequent launch_kernel request.
"""
pass
@abstractmethod
async def terminate(self, restart: bool = False) -> None:
"""
Terminates the kernel process.
This is typically accomplished via a SIGTERM signal, which can be caught, allowing
the kernel provisioner to perform possible cleanup of resources. This method is
called indirectly from `KernelManager.finish_shutdown()` during a kernel's
graceful termination.
restart is True if this operation precedes a start launch_kernel request.
"""
pass
@abstractmethod
async def launch_kernel(self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
"""
Launch the kernel process and return its connection information.
This method is called from `KernelManager.launch_kernel()` during the
kernel manager's start kernel sequence.
"""
pass
@abstractmethod
async def cleanup(self, restart: bool = False) -> None:
"""
Cleanup any resources allocated on behalf of the kernel provisioner.
This method is called from `KernelManager.cleanup_resources()` as part of
its shutdown kernel sequence.
restart is True if this operation precedes a start launch_kernel request.
"""
pass
async def shutdown_requested(self, restart: bool = False) -> None:
"""
Allows the provisioner to determine if the kernel's shutdown has been requested.
This method is called from `KernelManager.request_shutdown()` as part of
its shutdown sequence.
This method is optional and is primarily used in scenarios where the provisioner
may need to perform other operations in preparation for a kernel's shutdown.
"""
pass
async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
"""
Perform any steps in preparation for kernel process launch.
This includes applying additional substitutions to the kernel launch command
and environment. It also includes preparation of launch parameters.
NOTE: Subclass implementations are advised to call this method as it applies
environment variable substitutions from the local environment and calls the
provisioner's :meth:`_finalize_env()` method to allow each provisioner the
ability to cleanup the environment variables that will be used by the kernel.
This method is called from `KernelManager.pre_start_kernel()` as part of its
start kernel sequence.
Returns the (potentially updated) keyword arguments that are passed to
:meth:`launch_kernel()`.
"""
env = kwargs.pop("env", os.environ).copy()
env.update(self.__apply_env_substitutions(env))
self._finalize_env(env)
kwargs["env"] = env
return kwargs
async def post_launch(self, **kwargs: Any) -> None:
"""
Perform any steps following the kernel process launch.
This method is called from `KernelManager.post_start_kernel()` as part of its
start kernel sequence.
"""
pass
async def get_provisioner_info(self) -> Dict[str, Any]:
"""
Captures the base information necessary for persistence relative to this instance.
This enables applications that subclass `KernelManager` to persist a kernel provisioner's
relevant information to accomplish functionality like disaster recovery or high availability
by calling this method via the kernel manager's `provisioner` attribute.
NOTE: The superclass method must always be called first to ensure proper serialization.
"""
provisioner_info: Dict[str, Any] = {}
provisioner_info["kernel_id"] = self.kernel_id
provisioner_info["connection_info"] = self.connection_info
return provisioner_info
async def load_provisioner_info(self, provisioner_info: Dict) -> None:
"""
Loads the base information necessary for persistence relative to this instance.
The inverse of `get_provisioner_info()`, this enables applications that subclass
`KernelManager` to re-establish communication with a provisioner that is managing
a (presumably) remote kernel from an entirely different process that the original
provisioner.
NOTE: The superclass method must always be called first to ensure proper deserialization.
"""
self.kernel_id = provisioner_info["kernel_id"]
self.connection_info = provisioner_info["connection_info"]
def get_shutdown_wait_time(self, recommended: float = 5.0) -> float:
"""
Returns the time allowed for a complete shutdown. This may vary by provisioner.
This method is called from `KernelManager.finish_shutdown()` during the graceful
phase of its kernel shutdown sequence.
The recommended value will typically be what is configured in the kernel manager.
"""
return recommended
def get_stable_start_time(self, recommended: float = 10.0) -> float:
"""
Returns the expected upper bound for a kernel (re-)start to complete.
This may vary by provisioner.
The recommended value will typically be what is configured in the kernel restarter.
"""
return recommended
def _finalize_env(self, env: Dict[str, str]) -> None:
"""
Ensures env is appropriate prior to launch.
This method is called from `KernelProvisionerBase.pre_launch()` during the kernel's
start sequence.
NOTE: Subclasses should be sure to call super()._finalize_env(env)
"""
if self.kernel_spec.language and self.kernel_spec.language.lower().startswith("python"):
# Don't allow PYTHONEXECUTABLE to be passed to kernel process.
# If set, it can bork all the things.
env.pop("PYTHONEXECUTABLE", None)
def __apply_env_substitutions(self, substitution_values: Dict[str, str]) -> Dict[str, str]:
"""
Walks entries in the kernelspec's env stanza and applies substitutions from current env.
This method is called from `KernelProvisionerBase.pre_launch()` during the kernel's
start sequence.
Returns the substituted list of env entries.
NOTE: This method is private and is not intended to be overridden by provisioners.
"""
substituted_env = {}
if self.kernel_spec:
from string import Template
# For each templated env entry, fill any templated references
# matching names of env variables with those values and build
# new dict with substitutions.
templated_env = self.kernel_spec.env
for k, v in templated_env.items():
substituted_env.update({k: Template(v).safe_substitute(substitution_values)})
return substituted_env

View File

@@ -0,0 +1,162 @@
"""A basic kernel monitor with autorestarting.
This watches a kernel's state using KernelManager.is_alive and auto
restarts the kernel if it dies.
It is an incomplete base class, and must be subclassed.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import time
import typing as t
from traitlets import Bool, Dict, Float, Instance, Integer, default
from traitlets.config.configurable import LoggingConfigurable
class KernelRestarter(LoggingConfigurable):
"""Monitor and autorestart a kernel."""
kernel_manager = Instance("jupyter_client.KernelManager")
debug = Bool(
False,
config=True,
help="""Whether to include every poll event in debugging output.
Has to be set explicitly, because there will be *a lot* of output.
""",
)
time_to_dead = Float(3.0, config=True, help="""Kernel heartbeat interval in seconds.""")
stable_start_time = Float(
10.0,
config=True,
help="""The time in seconds to consider the kernel to have completed a stable start up.""",
)
restart_limit = Integer(
5,
config=True,
help="""The number of consecutive autorestarts before the kernel is presumed dead.""",
)
random_ports_until_alive = Bool(
True,
config=True,
help="""Whether to choose new random ports when restarting before the kernel is alive.""",
)
_restarting = Bool(False)
_restart_count = Integer(0)
_initial_startup = Bool(True)
_last_dead = Float()
@default("_last_dead")
def _default_last_dead(self) -> float:
return time.time()
callbacks = Dict()
def _callbacks_default(self) -> dict[str, list]:
return {"restart": [], "dead": []}
def start(self) -> None:
"""Start the polling of the kernel."""
msg = "Must be implemented in a subclass"
raise NotImplementedError(msg)
def stop(self) -> None:
"""Stop the kernel polling."""
msg = "Must be implemented in a subclass"
raise NotImplementedError(msg)
def add_callback(self, f: t.Callable[..., t.Any], event: str = "restart") -> None:
"""register a callback to fire on a particular event
Possible values for event:
'restart' (default): kernel has died, and will be restarted.
'dead': restart has failed, kernel will be left dead.
"""
self.callbacks[event].append(f)
def remove_callback(self, f: t.Callable[..., t.Any], event: str = "restart") -> None:
"""unregister a callback to fire on a particular event
Possible values for event:
'restart' (default): kernel has died, and will be restarted.
'dead': restart has failed, kernel will be left dead.
"""
try:
self.callbacks[event].remove(f)
except ValueError:
pass
def _fire_callbacks(self, event: t.Any) -> None:
"""fire our callbacks for a particular event"""
for callback in self.callbacks[event]:
try:
callback()
except Exception:
self.log.error(
"KernelRestarter: %s callback %r failed",
event,
callback,
exc_info=True,
)
def poll(self) -> None:
if self.debug:
self.log.debug("Polling kernel...")
if self.kernel_manager.shutting_down:
self.log.debug("Kernel shutdown in progress...")
return
now = time.time()
if not self.kernel_manager.is_alive():
self._last_dead = now
if self._restarting:
self._restart_count += 1
else:
self._restart_count = 1
if self._restart_count > self.restart_limit:
self.log.warning("KernelRestarter: restart failed")
self._fire_callbacks("dead")
self._restarting = False
self._restart_count = 0
self.stop()
else:
newports = self.random_ports_until_alive and self._initial_startup
self.log.info(
"KernelRestarter: restarting kernel (%i/%i), %s random ports",
self._restart_count,
self.restart_limit,
"new" if newports else "keep",
)
self._fire_callbacks("restart")
self.kernel_manager.restart_kernel(now=True, newports=newports)
self._restarting = True
else:
# Since `is_alive` only tests that the kernel process is alive, it does not
# indicate that the kernel has successfully completed startup. To solve this
# correctly, we would need to wait for a kernel info reply, but it is not
# necessarily appropriate to start a kernel client + channels in the
# restarter. Therefore, we use "has been alive continuously for X time" as a
# heuristic for a stable start up.
# See https://github.com/jupyter/jupyter_client/pull/717 for details.
stable_start_time = self.stable_start_time
if self.kernel_manager.provisioner:
stable_start_time = self.kernel_manager.provisioner.get_stable_start_time(
recommended=stable_start_time
)
if self._initial_startup and now - self._last_dead >= stable_start_time:
self._initial_startup = False
if self._restarting and now - self._last_dead >= stable_start_time:
self.log.debug("KernelRestarter: restart apparently succeeded")
self._restarting = False

View File

@@ -0,0 +1,128 @@
"""A Jupyter console app to run files."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import queue
import signal
import sys
import time
import typing as t
from jupyter_core.application import JupyterApp, base_aliases, base_flags
from traitlets import Any, Dict, Float
from traitlets.config import catch_config_error
from . import __version__
from .consoleapp import JupyterConsoleApp, app_aliases, app_flags
OUTPUT_TIMEOUT = 10
# copy flags from mixin:
flags = dict(base_flags)
# start with mixin frontend flags:
frontend_flags_dict = dict(app_flags)
# update full dict with frontend flags:
flags.update(frontend_flags_dict)
# copy flags from mixin
aliases = dict(base_aliases)
# start with mixin frontend flags
frontend_aliases_dict = dict(app_aliases)
# load updated frontend flags into full dict
aliases.update(frontend_aliases_dict)
# get flags&aliases into sets, and remove a couple that
# shouldn't be scrubbed from backend flags:
frontend_aliases = set(frontend_aliases_dict.keys())
frontend_flags = set(frontend_flags_dict.keys())
class RunApp(JupyterApp, JupyterConsoleApp): # type:ignore[misc]
"""An Jupyter Console app to run files."""
version = __version__
name = "jupyter run"
description = """Run Jupyter kernel code."""
flags = Dict(flags) # type:ignore[assignment]
aliases = Dict(aliases) # type:ignore[assignment]
frontend_aliases = Any(frontend_aliases)
frontend_flags = Any(frontend_flags)
kernel_timeout = Float(
60,
config=True,
help="""Timeout for giving up on a kernel (in seconds).
On first connect and restart, the console tests whether the
kernel is running and responsive by sending kernel_info_requests.
This sets the timeout in seconds for how long the kernel can take
before being presumed dead.
""",
)
def parse_command_line(self, argv: list[str] | None = None) -> None:
"""Parse the command line arguments."""
super().parse_command_line(argv)
self.build_kernel_argv(self.extra_args)
self.filenames_to_run = self.extra_args[:]
@catch_config_error
def initialize(self, argv: list[str] | None = None) -> None: # type:ignore[override]
"""Initialize the app."""
self.log.debug("jupyter run: initialize...")
super().initialize(argv)
JupyterConsoleApp.initialize(self)
signal.signal(signal.SIGINT, self.handle_sigint)
self.init_kernel_info()
def handle_sigint(self, *args: t.Any) -> None:
"""Handle SIGINT."""
if self.kernel_manager:
self.kernel_manager.interrupt_kernel()
else:
self.log.error("Cannot interrupt kernels we didn't start.\n")
def init_kernel_info(self) -> None:
"""Wait for a kernel to be ready, and store kernel info"""
timeout = self.kernel_timeout
tic = time.time()
self.kernel_client.hb_channel.unpause()
msg_id = self.kernel_client.kernel_info()
while True:
try:
reply = self.kernel_client.get_shell_msg(timeout=1)
except queue.Empty as e:
if (time.time() - tic) > timeout:
msg = "Kernel didn't respond to kernel_info_request"
raise RuntimeError(msg) from e
else:
if reply["parent_header"].get("msg_id") == msg_id:
self.kernel_info = reply["content"]
return
def start(self) -> None:
"""Start the application."""
self.log.debug("jupyter run: starting...")
super().start()
if self.filenames_to_run:
for filename in self.filenames_to_run:
self.log.debug("jupyter run: executing `%s`", filename)
with open(filename) as fp:
code = fp.read()
reply = self.kernel_client.execute_interactive(code, timeout=OUTPUT_TIMEOUT)
return_code = 0 if reply["content"]["status"] == "ok" else 1
if return_code:
raise Exception("jupyter-run error running '%s'" % filename)
else:
code = sys.stdin.read()
reply = self.kernel_client.execute_interactive(code, timeout=OUTPUT_TIMEOUT)
return_code = 0 if reply["content"]["status"] == "ok" else 1
if return_code:
msg = "jupyter-run error running 'stdin'"
raise Exception(msg)
main = launch_new_instance = RunApp.launch_instance
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
from .tunnel import * # noqa

View File

@@ -0,0 +1,102 @@
"""Sample script showing how to do local port forwarding over paramiko.
This script connects to the requested SSH server and sets up local port
forwarding (the openssh -L option) from a local port through a tunneled
connection to a destination reachable from the SSH server machine.
"""
#
# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
# Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
# Edits Copyright (C) 2010 The IPython Team
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA.
import logging
import select
import socketserver
import typing as t
logger = logging.getLogger("ssh")
class ForwardServer(socketserver.ThreadingTCPServer):
"""A server to use for ssh forwarding."""
daemon_threads = True
allow_reuse_address = True
class Handler(socketserver.BaseRequestHandler):
"""A handle for server requests."""
@t.no_type_check
def handle(self):
"""Handle a request."""
try:
chan = self.ssh_transport.open_channel(
"direct-tcpip",
(self.chain_host, self.chain_port),
self.request.getpeername(),
)
except Exception as e:
logger.debug(
"Incoming request to %s:%d failed: %s" % (self.chain_host, self.chain_port, repr(e))
)
return
if chan is None:
logger.debug(
"Incoming request to %s:%d was rejected by the SSH server."
% (self.chain_host, self.chain_port)
)
return
logger.debug(
"Connected! Tunnel open {!r} -> {!r} -> {!r}".format(
self.request.getpeername(),
chan.getpeername(),
(self.chain_host, self.chain_port),
)
)
while True:
r, w, x = select.select([self.request, chan], [], [])
if self.request in r:
data = self.request.recv(1024)
if len(data) == 0:
break
chan.send(data)
if chan in r:
data = chan.recv(1024)
if len(data) == 0:
break
self.request.send(data)
chan.close()
self.request.close()
logger.debug("Tunnel closed ")
def forward_tunnel(local_port: int, remote_host: str, remote_port: int, transport: t.Any) -> None:
"""Forward an ssh tunnel."""
# this is a little convoluted, but lets me configure things for the Handler
# object. (SocketServer doesn't give Handlers any way to access the outer
# server normally.)
class SubHander(Handler):
chain_host = remote_host
chain_port = remote_port
ssh_transport = transport
ForwardServer(("127.0.0.1", local_port), SubHander).serve_forever()
__all__ = ["forward_tunnel"]

View File

@@ -0,0 +1,446 @@
"""Basic ssh tunnel utilities, and convenience functions for tunneling
zeromq connections.
"""
# Copyright (C) 2010-2011 IPython Development Team
# Copyright (C) 2011- PyZMQ Developers
#
# Redistributed from IPython under the terms of the BSD License.
from __future__ import annotations
import atexit
import os
import re
import signal
import socket
import sys
import warnings
from getpass import getpass, getuser
from multiprocessing import Process
from typing import Any, cast
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
import paramiko
SSHException = paramiko.ssh_exception.SSHException
except ImportError:
paramiko = None # type:ignore[assignment]
class SSHException(Exception): # type:ignore[no-redef] # noqa
pass
else:
from .forward import forward_tunnel
try:
import pexpect # type: ignore[import-untyped]
except ImportError:
pexpect = None
def select_random_ports(n: int) -> list[int]:
"""Select and return n random ports that are available."""
ports = []
sockets = []
for _ in range(n):
sock = socket.socket()
sock.bind(("", 0))
ports.append(sock.getsockname()[1])
sockets.append(sock)
for sock in sockets:
sock.close()
return ports
# -----------------------------------------------------------------------------
# Check for passwordless login
# -----------------------------------------------------------------------------
_password_pat = re.compile((rb"pass(word|phrase):"), re.IGNORECASE)
def try_passwordless_ssh(server: str, keyfile: str | None, paramiko: Any = None) -> Any:
"""Attempt to make an ssh connection without a password.
This is mainly used for requiring password input only once
when many tunnels may be connected to the same server.
If paramiko is None, the default for the platform is chosen.
"""
if paramiko is None:
paramiko = sys.platform == "win32"
f = _try_passwordless_paramiko if paramiko else _try_passwordless_openssh
return f(server, keyfile)
def _try_passwordless_openssh(server: str, keyfile: str | None) -> bool:
"""Try passwordless login with shell ssh command."""
if pexpect is None:
msg = "pexpect unavailable, use paramiko"
raise ImportError(msg)
cmd = "ssh -f " + server
if keyfile:
cmd += " -i " + keyfile
cmd += " exit"
# pop SSH_ASKPASS from env
env = os.environ.copy()
env.pop("SSH_ASKPASS", None)
ssh_newkey = "Are you sure you want to continue connecting"
p = pexpect.spawn(cmd, env=env)
while True:
try:
i = p.expect([ssh_newkey, _password_pat], timeout=0.1)
if i == 0:
msg = "The authenticity of the host can't be established."
raise SSHException(msg)
except pexpect.TIMEOUT:
continue
except pexpect.EOF:
return True
else:
return False
def _try_passwordless_paramiko(server: str, keyfile: str | None) -> bool:
"""Try passwordless login with paramiko."""
if paramiko is None:
msg = "Paramiko unavailable, " # type:ignore[unreachable]
if sys.platform == "win32":
msg += "Paramiko is required for ssh tunneled connections on Windows."
else:
msg += "use OpenSSH."
raise ImportError(msg)
username, server, port = _split_server(server)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
try:
client.connect(server, port, username=username, key_filename=keyfile, look_for_keys=True)
except paramiko.AuthenticationException:
return False
else:
client.close()
return True
def tunnel_connection(
socket: socket.socket,
addr: str,
server: str,
keyfile: str | None = None,
password: str | None = None,
paramiko: Any = None,
timeout: int = 60,
) -> int:
"""Connect a socket to an address via an ssh tunnel.
This is a wrapper for socket.connect(addr), when addr is not accessible
from the local machine. It simply creates an ssh tunnel using the remaining args,
and calls socket.connect('tcp://localhost:lport') where lport is the randomly
selected local port of the tunnel.
"""
new_url, tunnel = open_tunnel(
addr,
server,
keyfile=keyfile,
password=password,
paramiko=paramiko,
timeout=timeout,
)
socket.connect(new_url)
return tunnel
def open_tunnel(
addr: str,
server: str,
keyfile: str | None = None,
password: str | None = None,
paramiko: Any = None,
timeout: int = 60,
) -> tuple[str, int]:
"""Open a tunneled connection from a 0MQ url.
For use inside tunnel_connection.
Returns
-------
(url, tunnel) : (str, object)
The 0MQ url that has been forwarded, and the tunnel object
"""
lport = select_random_ports(1)[0]
_, addr = addr.split("://")
ip, rport = addr.split(":")
rport_int = int(rport)
paramiko = sys.platform == "win32" if paramiko is None else paramiko_tunnel
tunnelf = paramiko_tunnel if paramiko else openssh_tunnel
tunnel = tunnelf(
lport,
rport_int,
server,
remoteip=ip,
keyfile=keyfile,
password=password,
timeout=timeout,
)
return "tcp://127.0.0.1:%i" % lport, cast(int, tunnel)
def openssh_tunnel(
lport: int,
rport: int,
server: str,
remoteip: str = "127.0.0.1",
keyfile: str | None = None,
password: str | None | bool = None,
timeout: int = 60,
) -> int:
"""Create an ssh tunnel using command-line ssh that connects port lport
on this machine to localhost:rport on server. The tunnel
will automatically close when not in use, remaining open
for a minimum of timeout seconds for an initial connection.
This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
as seen from `server`.
keyfile and password may be specified, but ssh config is checked for defaults.
Parameters
----------
lport : int
local port for connecting to the tunnel from this machine.
rport : int
port on the remote machine to connect to.
server : str
The ssh server to connect to. The full ssh server string will be parsed.
user@server:port
remoteip : str [Default: 127.0.0.1]
The remote ip, specifying the destination of the tunnel.
Default is localhost, which means that the tunnel would redirect
localhost:lport on this machine to localhost:rport on the *server*.
keyfile : str; path to public key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to the ssh server. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
timeout : int [default: 60]
The time (in seconds) after which no activity will result in the tunnel
closing. This prevents orphaned tunnels from running forever.
"""
if pexpect is None:
msg = "pexpect unavailable, use paramiko_tunnel"
raise ImportError(msg)
ssh = "ssh "
if keyfile:
ssh += "-i " + keyfile
if ":" in server:
server, port = server.split(":")
ssh += " -p %s" % port
cmd = f"{ssh} -O check {server}"
(output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
if not exitstatus:
pid = int(output[output.find(b"(pid=") + 5 : output.find(b")")])
cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % (
ssh,
lport,
remoteip,
rport,
server,
)
(output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
if not exitstatus:
atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1))
return pid
cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % (
ssh,
lport,
remoteip,
rport,
server,
timeout,
)
# pop SSH_ASKPASS from env
env = os.environ.copy()
env.pop("SSH_ASKPASS", None)
ssh_newkey = "Are you sure you want to continue connecting"
tunnel = pexpect.spawn(cmd, env=env)
failed = False
while True:
try:
i = tunnel.expect([ssh_newkey, _password_pat], timeout=0.1)
if i == 0:
msg = "The authenticity of the host can't be established."
raise SSHException(msg)
except pexpect.TIMEOUT:
continue
except pexpect.EOF as e:
tunnel.wait()
if tunnel.exitstatus:
raise RuntimeError("tunnel '%s' failed to start" % (cmd)) from e
else:
return tunnel.pid
else:
if failed:
warnings.warn("Password rejected, try again", stacklevel=2)
password = None
if password is None:
password = getpass("%s's password: " % (server))
tunnel.sendline(password)
failed = True
def _stop_tunnel(cmd: Any) -> None:
pexpect.run(cmd)
def _split_server(server: str) -> tuple[str, str, int]:
if "@" in server:
username, server = server.split("@", 1)
else:
username = getuser()
if ":" in server:
server, port_str = server.split(":")
port = int(port_str)
else:
port = 22
return username, server, port
def paramiko_tunnel(
lport: int,
rport: int,
server: str,
remoteip: str = "127.0.0.1",
keyfile: str | None = None,
password: str | None = None,
timeout: float = 60,
) -> Process:
"""launch a tunner with paramiko in a subprocess. This should only be used
when shell ssh is unavailable (e.g. Windows).
This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
as seen from `server`.
If you are familiar with ssh tunnels, this creates the tunnel:
ssh server -L localhost:lport:remoteip:rport
keyfile and password may be specified, but ssh config is checked for defaults.
Parameters
----------
lport : int
local port for connecting to the tunnel from this machine.
rport : int
port on the remote machine to connect to.
server : str
The ssh server to connect to. The full ssh server string will be parsed.
user@server:port
remoteip : str [Default: 127.0.0.1]
The remote ip, specifying the destination of the tunnel.
Default is localhost, which means that the tunnel would redirect
localhost:lport on this machine to localhost:rport on the *server*.
keyfile : str; path to public key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to the ssh server. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
timeout : int [default: 60]
The time (in seconds) after which no activity will result in the tunnel
closing. This prevents orphaned tunnels from running forever.
"""
if paramiko is None:
msg = "Paramiko not available" # type:ignore[unreachable]
raise ImportError(msg)
if password is None and not _try_passwordless_paramiko(server, keyfile):
password = getpass("%s's password: " % (server))
p = Process(
target=_paramiko_tunnel,
args=(lport, rport, server, remoteip),
kwargs={"keyfile": keyfile, "password": password},
)
p.daemon = True
p.start()
return p
def _paramiko_tunnel(
lport: int,
rport: int,
server: str,
remoteip: str,
keyfile: str | None = None,
password: str | None = None,
) -> None:
"""Function for actually starting a paramiko tunnel, to be passed
to multiprocessing.Process(target=this), and not called directly.
"""
username, server, port = _split_server(server)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
try:
client.connect(
server,
port,
username=username,
key_filename=keyfile,
look_for_keys=True,
password=password,
)
# except paramiko.AuthenticationException:
# if password is None:
# password = getpass("%s@%s's password: "%(username, server))
# client.connect(server, port, username=username, password=password)
# else:
# raise
except Exception as e:
warnings.warn("*** Failed to connect to %s:%d: %r" % (server, port, e), stacklevel=2)
sys.exit(1)
# Don't let SIGINT kill the tunnel subprocess
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
forward_tunnel(lport, remoteip, rport, client.get_transport())
except KeyboardInterrupt:
warnings.warn("SIGINT: Port forwarding stopped cleanly", stacklevel=2)
sys.exit(0)
except Exception as e:
warnings.warn("Port forwarding stopped uncleanly: %s" % e, stacklevel=2)
sys.exit(255)
if sys.platform == "win32":
ssh_tunnel = paramiko_tunnel
else:
ssh_tunnel = openssh_tunnel
__all__ = [
"tunnel_connection",
"ssh_tunnel",
"openssh_tunnel",
"paramiko_tunnel",
"try_passwordless_ssh",
]

View File

@@ -0,0 +1,351 @@
""" Defines a KernelClient that provides thread-safe sockets with async callbacks on message
replies.
"""
import asyncio
import atexit
import time
from concurrent.futures import Future
from functools import partial
from threading import Thread
from typing import Any, Dict, List, Optional
import zmq
from tornado.ioloop import IOLoop
from traitlets import Instance, Type
from traitlets.log import get_logger
from zmq.eventloop import zmqstream
from .channels import HBChannel
from .client import KernelClient
from .session import Session
# Local imports
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
# during garbage collection of threads at exit
class ThreadedZMQSocketChannel:
"""A ZMQ socket invoking a callback in the ioloop"""
session = None
socket = None
ioloop = None
stream = None
_inspect = None
def __init__(
self,
socket: Optional[zmq.Socket],
session: Optional[Session],
loop: Optional[IOLoop],
) -> None:
"""Create a channel.
Parameters
----------
socket : :class:`zmq.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
A tornado ioloop to connect the socket to using a ZMQStream
"""
super().__init__()
self.socket = socket
self.session = session
self.ioloop = loop
f: Future = Future()
def setup_stream() -> None:
try:
assert self.socket is not None
self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
self.stream.on_recv(self._handle_recv)
except Exception as e:
f.set_exception(e)
else:
f.set_result(None)
assert self.ioloop is not None
self.ioloop.add_callback(setup_stream)
# don't wait forever, raise any errors
f.result(timeout=10)
_is_alive = False
def is_alive(self) -> bool:
"""Whether the channel is alive."""
return self._is_alive
def start(self) -> None:
"""Start the channel."""
self._is_alive = True
def stop(self) -> None:
"""Stop the channel."""
self._is_alive = False
def close(self) -> None:
"""Close the channel."""
if self.stream is not None and self.ioloop is not None:
# c.f.Future for threadsafe results
f: Future = Future()
def close_stream() -> None:
try:
if self.stream is not None:
self.stream.close(linger=0)
self.stream = None
except Exception as e:
f.set_exception(e)
else:
f.set_result(None)
self.ioloop.add_callback(close_stream)
# wait for result
try:
f.result(timeout=5)
except Exception as e:
log = get_logger()
msg = f"Error closing stream {self.stream}: {e}"
log.warning(msg, RuntimeWarning, stacklevel=2)
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
def send(self, msg: Dict[str, Any]) -> None:
"""Queue a message to be sent from the IOLoop's thread.
Parameters
----------
msg : message to send
This is threadsafe, as it uses IOLoop.add_callback to give the loop's
thread control of the action.
"""
def thread_send() -> None:
assert self.session is not None
self.session.send(self.stream, msg)
assert self.ioloop is not None
self.ioloop.add_callback(thread_send)
def _handle_recv(self, msg_list: List) -> None:
"""Callback for stream.on_recv.
Unpacks message, and calls handlers with it.
"""
assert self.ioloop is not None
assert self.session is not None
ident, smsg = self.session.feed_identities(msg_list)
msg = self.session.deserialize(smsg)
# let client inspect messages
if self._inspect:
self._inspect(msg) # type:ignore[unreachable]
self.call_handlers(msg)
def call_handlers(self, msg: Dict[str, Any]) -> None:
"""This method is called in the ioloop thread when a message arrives.
Subclasses should override this method to handle incoming messages.
It is important to remember that this method is called in the thread
so that some logic must be done to ensure that the application level
handlers are called in the application thread.
"""
pass
def process_events(self) -> None:
"""Subclasses should override this with a method
processing any pending GUI events.
"""
pass
def flush(self, timeout: float = 1.0) -> None:
"""Immediately processes all pending messages on this channel.
This is only used for the IOPub channel.
Callers should use this method to ensure that :meth:`call_handlers`
has been called for all messages that have been received on the
0MQ SUB socket of this channel.
This method is thread safe.
Parameters
----------
timeout : float, optional
The maximum amount of time to spend flushing, in seconds. The
default is one second.
"""
# We do the IOLoop callback process twice to ensure that the IOLoop
# gets to perform at least one full poll.
stop_time = time.monotonic() + timeout
assert self.ioloop is not None
if self.stream is None or self.stream.closed():
# don't bother scheduling flush on a thread if we're closed
_msg = "Attempt to flush closed stream"
raise OSError(_msg)
def flush(f: Any) -> None:
try:
self._flush()
except Exception as e:
f.set_exception(e)
else:
f.set_result(None)
for _ in range(2):
f: Future = Future()
self.ioloop.add_callback(partial(flush, f))
# wait for async flush, re-raise any errors
timeout = max(stop_time - time.monotonic(), 0)
try:
f.result(max(stop_time - time.monotonic(), 0))
except TimeoutError:
# flush with a timeout means stop waiting, not raise
return
def _flush(self) -> None:
"""Callback for :method:`self.flush`."""
assert self.stream is not None
self.stream.flush()
self._flushed = True
class IOLoopThread(Thread):
"""Run a pyzmq ioloop in a thread to send and receive messages"""
_exiting = False
ioloop = None
def __init__(self) -> None:
"""Initialize an io loop thread."""
super().__init__()
self.daemon = True
@staticmethod
@atexit.register
def _notice_exit() -> None:
# Class definitions can be torn down during interpreter shutdown.
# We only need to set _exiting flag if this hasn't happened.
if IOLoopThread is not None:
IOLoopThread._exiting = True
def start(self) -> None:
"""Start the IOLoop thread
Don't return until self.ioloop is defined,
which is created in the thread
"""
self._start_future: Future = Future()
Thread.start(self)
# wait for start, re-raise any errors
self._start_future.result(timeout=10)
def run(self) -> None:
"""Run my loop, ignoring EINTR events in the poller"""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def assign_ioloop() -> None:
self.ioloop = IOLoop.current()
loop.run_until_complete(assign_ioloop())
except Exception as e:
self._start_future.set_exception(e)
else:
self._start_future.set_result(None)
loop.run_until_complete(self._async_run())
async def _async_run(self) -> None:
"""Run forever (until self._exiting is set)"""
while not self._exiting:
await asyncio.sleep(1)
def stop(self) -> None:
"""Stop the channel's event loop and join its thread.
This calls :meth:`~threading.Thread.join` and returns when the thread
terminates. :class:`RuntimeError` will be raised if
:meth:`~threading.Thread.start` is called again.
"""
self._exiting = True
self.join()
self.close()
self.ioloop = None
def __del__(self) -> None:
self.close()
def close(self) -> None:
"""Close the io loop thread."""
if self.ioloop is not None:
try:
self.ioloop.close(all_fds=True)
except Exception:
pass
class ThreadedKernelClient(KernelClient):
"""A KernelClient that provides thread-safe sockets with async callbacks on message replies."""
@property
def ioloop(self) -> Optional[IOLoop]: # type:ignore[override]
if self.ioloop_thread:
return self.ioloop_thread.ioloop
return None
ioloop_thread = Instance(IOLoopThread, allow_none=True)
def start_channels(
self,
shell: bool = True,
iopub: bool = True,
stdin: bool = True,
hb: bool = True,
control: bool = True,
) -> None:
"""Start the channels on the client."""
self.ioloop_thread = IOLoopThread()
self.ioloop_thread.start()
if shell:
self.shell_channel._inspect = self._check_kernel_info_reply
super().start_channels(shell, iopub, stdin, hb, control)
def _check_kernel_info_reply(self, msg: Dict[str, Any]) -> None:
"""This is run in the ioloop thread when the kernel info reply is received"""
if msg["msg_type"] == "kernel_info_reply":
self._handle_kernel_info_reply(msg)
self.shell_channel._inspect = None
def stop_channels(self) -> None:
"""Stop the channels on the client."""
super().stop_channels()
if self.ioloop_thread and self.ioloop_thread.is_alive():
self.ioloop_thread.stop()
iopub_channel_class = Type(ThreadedZMQSocketChannel) # type:ignore[arg-type]
shell_channel_class = Type(ThreadedZMQSocketChannel) # type:ignore[arg-type]
stdin_channel_class = Type(ThreadedZMQSocketChannel) # type:ignore[arg-type]
hb_channel_class = Type(HBChannel) # type:ignore[arg-type]
control_channel_class = Type(ThreadedZMQSocketChannel) # type:ignore[arg-type]
def is_alive(self) -> bool:
"""Is the kernel process still running?"""
if self._hb_channel is not None:
# We don't have access to the KernelManager,
# so we use the heartbeat.
return self._hb_channel.is_beating()
# no heartbeat and not local, we can't tell if it's running,
# so naively return True
return True

View File

@@ -0,0 +1,90 @@
"""
utils:
- provides utility wrappers to run asynchronous functions in a blocking environment.
- vendor functions from ipython_genutils that should be retired at some point.
"""
from __future__ import annotations
import os
from typing import Sequence
from jupyter_core.utils import ensure_async, run_sync # noqa: F401 # noqa: F401
from .session import utcnow # noqa
def _filefind(filename: str, path_dirs: str | Sequence[str] | None = None) -> str:
"""Find a file by looking through a sequence of paths.
This iterates through a sequence of paths looking for a file and returns
the full, absolute path of the first occurrence of the file. If no set of
path dirs is given, the filename is tested as is, after running through
:func:`expandvars` and :func:`expanduser`. Thus a simple call::
filefind('myfile.txt')
will find the file in the current working dir, but::
filefind('~/myfile.txt')
Will find the file in the users home directory. This function does not
automatically try any paths, such as the cwd or the user's home directory.
Parameters
----------
filename : str
The filename to look for.
path_dirs : str, None or sequence of str
The sequence of paths to look for the file in. If None, the filename
need to be absolute or be in the cwd. If a string, the string is
put into a sequence and the searched. If a sequence, walk through
each element and join with ``filename``, calling :func:`expandvars`
and :func:`expanduser` before testing for existence.
Returns
-------
Raises :exc:`IOError` or returns absolute path to file.
"""
# If paths are quoted, abspath gets confused, strip them...
filename = filename.strip('"').strip("'")
# If the input is an absolute path, just check it exists
if os.path.isabs(filename) and os.path.isfile(filename):
return filename
if path_dirs is None:
path_dirs = ("",)
elif isinstance(path_dirs, str):
path_dirs = (path_dirs,)
for path in path_dirs:
if path == ".":
path = os.getcwd() # noqa
testname = _expand_path(os.path.join(path, filename))
if os.path.isfile(testname):
return os.path.abspath(testname)
msg = f"File {filename!r} does not exist in any of the search paths: {path_dirs!r}"
raise OSError(msg)
def _expand_path(s: str) -> str:
"""Expand $VARS and ~names in a string, like a shell
:Examples:
In [2]: os.environ['FOO']='test'
In [3]: expand_path('variable FOO is $FOO')
Out[3]: 'variable FOO is test'
"""
# This is a pretty subtle hack. When expand user is given a UNC path
# on Windows (\\server\share$\%username%), os.path.expandvars, removes
# the $ to get (\\server\share\%username%). I think it considered $
# alone an empty var. But, we need the $ to remains there (it indicates
# a hidden share).
if os.name == "nt":
s = s.replace("$\\", "IPYTHON_TEMP")
s = os.path.expandvars(os.path.expanduser(s))
if os.name == "nt":
s = s.replace("IPYTHON_TEMP", "$\\")
return s

View File

@@ -0,0 +1,45 @@
"""Use a Windows event to interrupt a child process like SIGINT.
The child needs to explicitly listen for this - see
ipykernel.parentpoller.ParentPollerWindows for a Python implementation.
"""
import ctypes
from typing import Any
def create_interrupt_event() -> Any:
"""Create an interrupt event handle.
The parent process should call this to create the
interrupt event that is passed to the child process. It should store
this handle and use it with ``send_interrupt`` to interrupt the child
process.
"""
# Create a security attributes struct that permits inheritance of the
# handle by new processes.
# FIXME: We can clean up this mess by requiring pywin32 for IPython.
class SECURITY_ATTRIBUTES(ctypes.Structure): # noqa
_fields_ = [
("nLength", ctypes.c_int),
("lpSecurityDescriptor", ctypes.c_void_p),
("bInheritHandle", ctypes.c_int),
]
sa = SECURITY_ATTRIBUTES()
sa_p = ctypes.pointer(sa)
sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES)
sa.lpSecurityDescriptor = 0
sa.bInheritHandle = 1
return ctypes.windll.kernel32.CreateEventA( # type:ignore[attr-defined]
sa_p,
False,
False,
"", # lpEventAttributes # bManualReset # bInitialState
) # lpName
def send_interrupt(interrupt_handle: Any) -> None:
"""Sends an interrupt event using the specified handle."""
ctypes.windll.kernel32.SetEvent(interrupt_handle) # type:ignore[attr-defined]