Edit on GitHub

pdoc.doc_ast

This module handles all interpretation of the Abstract Syntax Tree (AST) in pdoc.

Parsing the AST is done to extract docstrings, type annotations, and variable declarations from __init__.

  1"""
  2This module handles all interpretation of the *Abstract Syntax Tree (AST)* in pdoc.
  3
  4Parsing the AST is done to extract docstrings, type annotations, and variable declarations from `__init__`.
  5"""
  6
  7from __future__ import annotations
  8
  9import ast
 10from collections.abc import Iterable
 11from collections.abc import Iterator
 12from dataclasses import dataclass
 13import inspect
 14from itertools import tee
 15from itertools import zip_longest
 16import types
 17from typing import TYPE_CHECKING
 18from typing import Any
 19from typing import TypeVar
 20from typing import overload
 21import warnings
 22
 23import pdoc
 24
 25from ._compat import ast_TypeAlias
 26from ._compat import ast_unparse
 27from ._compat import cache
 28
 29if TYPE_CHECKING:
 30    import pdoc.doc_types
 31
 32
 33def get_source(obj: Any) -> str:
 34    """
 35    Returns the source code of the Python object `obj` as a str.
 36
 37    If this fails, an empty string is returned.
 38    """
 39    # Some objects may not be hashable, so we fall back to the non-cached version if that is the case.
 40    try:
 41        return _get_source(obj)
 42    except TypeError:
 43        return _get_source.__wrapped__(obj)
 44
 45
 46@cache
 47def _get_source(obj: Any) -> str:
 48    try:
 49        return inspect.getsource(obj)
 50    except Exception:
 51        return ""
 52
 53
 54@overload
 55def parse(obj: types.ModuleType) -> ast.Module: ...
 56
 57
 58@overload
 59def parse(obj: types.FunctionType) -> ast.FunctionDef | ast.AsyncFunctionDef: ...
 60
 61
 62@overload
 63def parse(obj: type) -> ast.ClassDef: ...
 64
 65
 66def parse(obj):
 67    """
 68    Parse a module, class or function and return the (unwrapped) AST node.
 69    If an object's source code cannot be found, this function returns an empty ast node stub
 70    which can still be walked.
 71    """
 72    src = get_source(obj)
 73    if isinstance(obj, types.ModuleType):
 74        return _parse_module(src)
 75    elif isinstance(obj, type):
 76        return _parse_class(src)
 77    else:
 78        return _parse_function(src)
 79
 80
 81@cache
 82def unparse(tree: ast.AST):
 83    """`ast.unparse`, but cached."""
 84    return ast_unparse(tree)
 85
 86
 87@dataclass
 88class AstInfo:
 89    """The information extracted from walking the syntax tree."""
 90
 91    var_docstrings: dict[str, str]
 92    """A qualname -> docstring mapping."""
 93    func_docstrings: dict[str, str]
 94    """A qualname -> docstring mapping for functions."""
 95    annotations: dict[str, str | type[pdoc.doc_types.empty]]
 96    """A qualname -> annotation mapping.
 97    
 98    Annotations are not evaluated by this module and only returned as strings."""
 99
100
101def walk_tree(obj: types.ModuleType | type) -> AstInfo:
102    """
103    Walks the abstract syntax tree for `obj` and returns the extracted information.
104    """
105    return _walk_tree(parse(obj))
106
107
108@cache
109def _walk_tree(
110    tree: ast.Module | ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef,
111) -> AstInfo:
112    var_docstrings = {}
113    func_docstrings = {}
114    annotations = {}
115    for a, b in _pairwise_longest(_nodes(tree)):
116        if isinstance(a, ast_TypeAlias):
117            name = a.name.id
118        elif (
119            isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name) and a.simple
120        ):
121            name = a.target.id
122            annotations[name] = unparse(a.annotation)
123        elif (
124            isinstance(a, ast.Assign)
125            and len(a.targets) == 1
126            and isinstance(a.targets[0], ast.Name)
127        ):
128            name = a.targets[0].id
129            # Make sure that all assignments are picked up, even is there is
130            # no annotation or docstring.
131            annotations.setdefault(name, pdoc.doc_types.empty)
132        elif isinstance(a, ast.FunctionDef) and a.body:
133            first = a.body[0]
134            if (
135                isinstance(first, ast.Expr)
136                and isinstance(first.value, ast.Constant)
137                and isinstance(first.value.value, str)
138            ):
139                func_docstrings[a.name] = inspect.cleandoc(first.value.value).strip()
140            continue
141        else:
142            continue
143        if (
144            isinstance(b, ast.Expr)
145            and isinstance(b.value, ast.Constant)
146            and isinstance(b.value.value, str)
147        ):
148            var_docstrings[name] = inspect.cleandoc(b.value.value).strip()
149    return AstInfo(
150        var_docstrings,
151        func_docstrings,
152        annotations,
153    )
154
155
156T = TypeVar("T")
157
158
159def sort_by_source(
160    obj: types.ModuleType | type, sorted: dict[str, T], unsorted: dict[str, T]
161) -> tuple[dict[str, T], dict[str, T]]:
162    """
163    Takes items from `unsorted` and inserts them into `sorted` in order of appearance in the source code of `obj`.
164    The only exception to this rule is `__init__`, which (if present) is always inserted first.
165
166    Some items may not be found, for example because they've been inherited from a superclass. They are returned as-is.
167
168    Returns a `(sorted, not found)` tuple.
169    """
170    tree = parse(obj)
171
172    if "__init__" in unsorted:
173        sorted["__init__"] = unsorted.pop("__init__")
174
175    for a in _nodes(tree):
176        if (
177            isinstance(a, ast.Assign)
178            and len(a.targets) == 1
179            and isinstance(a.targets[0], ast.Name)
180        ):
181            name = a.targets[0].id
182        elif (
183            isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name) and a.simple
184        ):
185            name = a.target.id
186        elif isinstance(a, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
187            name = a.name
188        elif isinstance(a, ast_TypeAlias):
189            name = a.name.id
190        else:
191            continue
192
193        if name in unsorted:
194            sorted[name] = unsorted.pop(name)
195    return sorted, unsorted
196
197
198def type_checking_sections(mod: types.ModuleType) -> ast.Module:
199    """
200    Walks the abstract syntax tree for `mod` and returns all statements guarded by TYPE_CHECKING blocks.
201    """
202    ret = ast.Module(body=[], type_ignores=[])
203    tree = _parse_module(get_source(mod))
204    for node in tree.body:
205        if (
206            isinstance(node, ast.If)
207            and isinstance(node.test, ast.Name)
208            and node.test.id == "TYPE_CHECKING"
209        ):
210            ret.body.extend(node.body)
211        if (
212            isinstance(node, ast.If)
213            and isinstance(node.test, ast.Attribute)
214            and isinstance(node.test.value, ast.Name)
215            # some folks do "import typing as t", the accuracy with just TYPE_CHECKING is good enough.
216            # and node.test.value.id == "typing"
217            and node.test.attr == "TYPE_CHECKING"
218        ):
219            ret.body.extend(node.body)
220    return ret
221
222
223@cache
224def _parse_module(source: str) -> ast.Module:
225    """
226    Parse the AST for the source code of a module and return the ast.Module.
227
228    Returns an empty ast.Module if source is empty.
229    """
230    tree = _parse(source)
231    assert isinstance(tree, ast.Module)
232    return tree
233
234
235@cache
236def _parse_class(source: str) -> ast.ClassDef:
237    """
238    Parse the AST for the source code of a class and return the ast.ClassDef.
239
240    Returns an empty ast.ClassDef if source is empty.
241    """
242    tree = _parse(source)
243    assert len(tree.body) <= 1
244    if tree.body:
245        t = tree.body[0]
246        assert isinstance(t, ast.ClassDef)
247        return t
248    return ast.ClassDef(body=[], decorator_list=[])
249
250
251@cache
252def _parse_function(source: str) -> ast.FunctionDef | ast.AsyncFunctionDef:
253    """
254    Parse the AST for the source code of a (async) function and return the matching AST node.
255
256    Returns an empty ast.FunctionDef if source is empty.
257    """
258    tree = _parse(source)
259    assert len(tree.body) <= 1
260    if tree.body:
261        t = tree.body[0]
262        if isinstance(t, (ast.FunctionDef, ast.AsyncFunctionDef)):
263            return t
264        else:
265            # we have a lambda function,
266            # to simplify the API return the ast.FunctionDef stub.
267            pass
268    return ast.FunctionDef(body=[], decorator_list=[])
269
270
271def _parse(
272    source: str,
273) -> ast.Module | ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef:
274    try:
275        return ast.parse(_dedent(source))
276    except Exception as e:
277        warnings.warn(f"Error parsing source code: {e}\n" f"===\n" f"{source}\n" f"===")
278        return ast.parse("")
279
280
281@cache
282def _dedent(source: str) -> str:
283    """
284    Dedent the head of a function or class definition so that it can be parsed by `ast.parse`.
285    This is an alternative to `textwrap.dedent`, which does not dedent if there are docstrings
286    without indentation. For example, this is valid Python code but would not be dedented with `textwrap.dedent`:
287
288    class Foo:
289        def bar(self):
290           '''
291    this is a docstring
292           '''
293    """
294    if not source or source[0] not in (" ", "\t"):
295        return source
296    source = source.lstrip()
297    # we may have decorators before our function definition, in which case we need to dedent a few more lines.
298    # the following heuristic should be good enough to detect if we have reached the definition.
299    # it's easy to produce examples where this fails, but this probably is not a problem in practice.
300    if not any(source.startswith(x) for x in ["async ", "def ", "class "]):
301        first_line, rest = source.split("\n", 1)
302        return first_line + "\n" + _dedent(rest)
303    else:
304        return source
305
306
307@cache
308def _nodes(tree: ast.Module | ast.ClassDef) -> list[ast.AST]:
309    """
310    Returns the list of all nodes in tree's body, but also inlines the body of __init__.
311
312    This is useful to detect all declared variables in a class, even if they only appear in the constructor.
313    """
314    return list(_nodes_iter(tree))
315
316
317def _nodes_iter(tree: ast.Module | ast.ClassDef) -> Iterator[ast.AST]:
318    for a in tree.body:
319        yield a
320        if isinstance(a, ast.FunctionDef) and a.name == "__init__":
321            yield from _init_nodes(a)
322
323
324def _init_nodes(tree: ast.FunctionDef) -> Iterator[ast.AST]:
325    """
326    Transform attribute assignments like "self.foo = 42" to name assignments like "foo = 42",
327    keep all constant expressions, and no-op everything else.
328    This essentially allows us to inline __init__ when parsing a class definition.
329    """
330    for a in tree.body:
331        if (
332            isinstance(a, ast.AnnAssign)
333            and isinstance(a.target, ast.Attribute)
334            and isinstance(a.target.value, ast.Name)
335            and a.target.value.id == "self"
336        ):
337            yield ast.AnnAssign(
338                ast.Name(a.target.attr), a.annotation, a.value, simple=1
339            )
340        elif (
341            isinstance(a, ast.Assign)
342            and len(a.targets) == 1
343            and isinstance(a.targets[0], ast.Attribute)
344            and isinstance(a.targets[0].value, ast.Name)
345            and a.targets[0].value.id == "self"
346        ):
347            yield ast.Assign(
348                [ast.Name(a.targets[0].attr)],
349                value=a.value,
350                type_comment=a.type_comment,
351            )
352        elif (
353            isinstance(a, ast.Expr)
354            and isinstance(a.value, ast.Constant)
355            and isinstance(a.value.value, str)
356        ):
357            yield a
358        else:
359            yield ast.Pass()
360
361
362def _pairwise_longest(iterable: Iterable[T]) -> Iterable[tuple[T, T]]:
363    """s -> (s0,s1), (s1,s2), (s2, s3),  ..., (sN, None)"""
364    a, b = tee(iterable)
365    next(b, None)
366    return zip_longest(a, b)
def get_source(obj: Any) -> str:
34def get_source(obj: Any) -> str:
35    """
36    Returns the source code of the Python object `obj` as a str.
37
38    If this fails, an empty string is returned.
39    """
40    # Some objects may not be hashable, so we fall back to the non-cached version if that is the case.
41    try:
42        return _get_source(obj)
43    except TypeError:
44        return _get_source.__wrapped__(obj)

Returns the source code of the Python object obj as a str.

If this fails, an empty string is returned.

def parse(obj):
67def parse(obj):
68    """
69    Parse a module, class or function and return the (unwrapped) AST node.
70    If an object's source code cannot be found, this function returns an empty ast node stub
71    which can still be walked.
72    """
73    src = get_source(obj)
74    if isinstance(obj, types.ModuleType):
75        return _parse_module(src)
76    elif isinstance(obj, type):
77        return _parse_class(src)
78    else:
79        return _parse_function(src)

Parse a module, class or function and return the (unwrapped) AST node. If an object's source code cannot be found, this function returns an empty ast node stub which can still be walked.

@cache
def unparse(tree: ast.AST):
82@cache
83def unparse(tree: ast.AST):
84    """`ast.unparse`, but cached."""
85    return ast_unparse(tree)

ast.unparse, but cached.

@dataclass
class AstInfo:
88@dataclass
89class AstInfo:
90    """The information extracted from walking the syntax tree."""
91
92    var_docstrings: dict[str, str]
93    """A qualname -> docstring mapping."""
94    func_docstrings: dict[str, str]
95    """A qualname -> docstring mapping for functions."""
96    annotations: dict[str, str | type[pdoc.doc_types.empty]]
97    """A qualname -> annotation mapping.
98    
99    Annotations are not evaluated by this module and only returned as strings."""

The information extracted from walking the syntax tree.

AstInfo( var_docstrings: dict[str, str], func_docstrings: dict[str, str], annotations: dict[str, str | type[inspect._empty]])
var_docstrings: dict[str, str]

A qualname -> docstring mapping.

func_docstrings: dict[str, str]

A qualname -> docstring mapping for functions.

annotations: dict[str, str | type[inspect._empty]]

A qualname -> annotation mapping.

Annotations are not evaluated by this module and only returned as strings.

def walk_tree(obj: module | type) -> AstInfo:
102def walk_tree(obj: types.ModuleType | type) -> AstInfo:
103    """
104    Walks the abstract syntax tree for `obj` and returns the extracted information.
105    """
106    return _walk_tree(parse(obj))

Walks the abstract syntax tree for obj and returns the extracted information.

def sort_by_source( obj: module | type, sorted: dict[str, ~T], unsorted: dict[str, ~T]) -> tuple[dict[str, ~T], dict[str, ~T]]:
160def sort_by_source(
161    obj: types.ModuleType | type, sorted: dict[str, T], unsorted: dict[str, T]
162) -> tuple[dict[str, T], dict[str, T]]:
163    """
164    Takes items from `unsorted` and inserts them into `sorted` in order of appearance in the source code of `obj`.
165    The only exception to this rule is `__init__`, which (if present) is always inserted first.
166
167    Some items may not be found, for example because they've been inherited from a superclass. They are returned as-is.
168
169    Returns a `(sorted, not found)` tuple.
170    """
171    tree = parse(obj)
172
173    if "__init__" in unsorted:
174        sorted["__init__"] = unsorted.pop("__init__")
175
176    for a in _nodes(tree):
177        if (
178            isinstance(a, ast.Assign)
179            and len(a.targets) == 1
180            and isinstance(a.targets[0], ast.Name)
181        ):
182            name = a.targets[0].id
183        elif (
184            isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name) and a.simple
185        ):
186            name = a.target.id
187        elif isinstance(a, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
188            name = a.name
189        elif isinstance(a, ast_TypeAlias):
190            name = a.name.id
191        else:
192            continue
193
194        if name in unsorted:
195            sorted[name] = unsorted.pop(name)
196    return sorted, unsorted

Takes items from unsorted and inserts them into sorted in order of appearance in the source code of obj. The only exception to this rule is __init__, which (if present) is always inserted first.

Some items may not be found, for example because they've been inherited from a superclass. They are returned as-is.

Returns a (sorted, not found) tuple.

def type_checking_sections(mod: module) -> ast.Module:
199def type_checking_sections(mod: types.ModuleType) -> ast.Module:
200    """
201    Walks the abstract syntax tree for `mod` and returns all statements guarded by TYPE_CHECKING blocks.
202    """
203    ret = ast.Module(body=[], type_ignores=[])
204    tree = _parse_module(get_source(mod))
205    for node in tree.body:
206        if (
207            isinstance(node, ast.If)
208            and isinstance(node.test, ast.Name)
209            and node.test.id == "TYPE_CHECKING"
210        ):
211            ret.body.extend(node.body)
212        if (
213            isinstance(node, ast.If)
214            and isinstance(node.test, ast.Attribute)
215            and isinstance(node.test.value, ast.Name)
216            # some folks do "import typing as t", the accuracy with just TYPE_CHECKING is good enough.
217            # and node.test.value.id == "typing"
218            and node.test.attr == "TYPE_CHECKING"
219        ):
220            ret.body.extend(node.body)
221    return ret

Walks the abstract syntax tree for mod and returns all statements guarded by TYPE_CHECKING blocks.