Edit on GitHub

pdoc.doc_ast

This module handles all interpretation of the Abstract Syntax Tree (AST) in pdoc.

Parsing the AST is done to extract docstrings, type annotations, and variable declarations from __init__.

  1"""
  2This module handles all interpretation of the *Abstract Syntax Tree (AST)* in pdoc.
  3
  4Parsing the AST is done to extract docstrings, type annotations, and variable declarations from `__init__`.
  5"""
  6from __future__ import annotations
  7
  8import ast
  9from collections.abc import Iterable
 10from collections.abc import Iterator
 11from dataclasses import dataclass
 12import inspect
 13from itertools import tee
 14from itertools import zip_longest
 15import types
 16from typing import TYPE_CHECKING
 17from typing import Any
 18from typing import TypeVar
 19from typing import overload
 20import warnings
 21
 22import pdoc
 23
 24from ._compat import ast_unparse
 25from ._compat import cache
 26
 27if TYPE_CHECKING:
 28    import pdoc.doc_types
 29
 30
 31def get_source(obj: Any) -> str:
 32    """
 33    Returns the source code of the Python object `obj` as a str.
 34    This tries to first unwrap the method if it is wrapped and then calls `inspect.getsource`.
 35
 36    If this fails, an empty string is returned.
 37    """
 38    # Some objects may not be hashable, so we fall back to the non-cached version if that is the case.
 39    try:
 40        return _get_source(obj)
 41    except TypeError:
 42        return _get_source.__wrapped__(obj)
 43
 44
 45@cache
 46def _get_source(obj: Any) -> str:
 47    try:
 48        return inspect.getsource(obj)
 49    except Exception:
 50        return ""
 51
 52
 53@overload
 54def parse(obj: types.ModuleType) -> ast.Module:
 55    ...
 56
 57
 58@overload
 59def parse(obj: types.FunctionType) -> ast.FunctionDef | ast.AsyncFunctionDef:
 60    ...
 61
 62
 63@overload
 64def parse(obj: type) -> ast.ClassDef:
 65    ...
 66
 67
 68def parse(obj):
 69    """
 70    Parse a module, class or function and return the (unwrapped) AST node.
 71    If an object's source code cannot be found, this function returns an empty ast node stub
 72    which can still be walked.
 73    """
 74    src = get_source(obj)
 75    if isinstance(obj, types.ModuleType):
 76        return _parse_module(src)
 77    elif isinstance(obj, type):
 78        return _parse_class(src)
 79    else:
 80        return _parse_function(src)
 81
 82
 83@cache
 84def unparse(tree: ast.AST):
 85    """`ast.unparse`, but cached."""
 86    return ast_unparse(tree)
 87
 88
 89@dataclass
 90class AstInfo:
 91    """The information extracted from walking the syntax tree."""
 92
 93    var_docstrings: dict[str, str]
 94    """A qualname -> docstring mapping."""
 95    func_docstrings: dict[str, str]
 96    """A qualname -> docstring mapping for functions."""
 97    annotations: dict[str, str | type[pdoc.doc_types.empty]]
 98    """A qualname -> annotation mapping.
 99    
100    Annotations are not evaluated by this module and only returned as strings."""
101
102
103def walk_tree(obj: types.ModuleType | type) -> AstInfo:
104    """
105    Walks the abstract syntax tree for `obj` and returns the extracted information.
106    """
107    return _walk_tree(parse(obj))
108
109
110@cache
111def _walk_tree(
112    tree: ast.Module | ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef,
113) -> AstInfo:
114    var_docstrings = {}
115    func_docstrings = {}
116    annotations = {}
117    for a, b in _pairwise_longest(_nodes(tree)):
118        if isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name) and a.simple:
119            name = a.target.id
120            annotations[name] = unparse(a.annotation)
121        elif (
122            isinstance(a, ast.Assign)
123            and len(a.targets) == 1
124            and isinstance(a.targets[0], ast.Name)
125        ):
126            name = a.targets[0].id
127            # Make sure that all assignments are picked up, even is there is
128            # no annotation or docstring.
129            annotations.setdefault(name, pdoc.doc_types.empty)
130        elif isinstance(a, ast.FunctionDef) and a.body:
131            first = a.body[0]
132            if (
133                isinstance(first, ast.Expr)
134                and isinstance(first.value, ast.Constant)
135                and isinstance(first.value.value, str)
136            ):
137                func_docstrings[a.name] = inspect.cleandoc(first.value.value).strip()
138            continue
139        else:
140            continue
141        if (
142            isinstance(b, ast.Expr)
143            and isinstance(b.value, ast.Constant)
144            and isinstance(b.value.value, str)
145        ):
146            var_docstrings[name] = inspect.cleandoc(b.value.value).strip()
147    return AstInfo(
148        var_docstrings,
149        func_docstrings,
150        annotations,
151    )
152
153
154T = TypeVar("T")
155
156
157def sort_by_source(
158    obj: types.ModuleType | type, sorted: dict[str, T], unsorted: dict[str, T]
159) -> tuple[dict[str, T], dict[str, T]]:
160    """
161    Takes items from `unsorted` and inserts them into `sorted` in order of appearance in the source code of `obj`.
162    The only exception to this rule is `__init__`, which (if present) is always inserted first.
163
164    Some items may not be found, for example because they've been inherited from a superclass. They are returned as-is.
165
166    Returns a `(sorted, not found)` tuple.
167    """
168    tree = parse(obj)
169
170    if "__init__" in unsorted:
171        sorted["__init__"] = unsorted.pop("__init__")
172
173    for a in _nodes(tree):
174        if (
175            isinstance(a, ast.Assign)
176            and len(a.targets) == 1
177            and isinstance(a.targets[0], ast.Name)
178        ):
179            name = a.targets[0].id
180        elif (
181            isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name) and a.simple
182        ):
183            name = a.target.id
184        elif isinstance(a, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
185            name = a.name
186        else:
187            continue
188
189        if name in unsorted:
190            sorted[name] = unsorted.pop(name)
191    return sorted, unsorted
192
193
194def type_checking_sections(mod: types.ModuleType) -> ast.Module:
195    """
196    Walks the abstract syntax tree for `mod` and returns all statements guarded by TYPE_CHECKING blocks.
197    """
198    ret = ast.Module(body=[], type_ignores=[])
199    tree = _parse_module(get_source(mod))
200    for node in tree.body:
201        if (
202            isinstance(node, ast.If)
203            and isinstance(node.test, ast.Name)
204            and node.test.id == "TYPE_CHECKING"
205        ):
206            ret.body.extend(node.body)
207        if (
208            isinstance(node, ast.If)
209            and isinstance(node.test, ast.Attribute)
210            and isinstance(node.test.value, ast.Name)
211            # some folks do "import typing as t", the accuracy with just TYPE_CHECKING is good enough.
212            # and node.test.value.id == "typing"
213            and node.test.attr == "TYPE_CHECKING"
214        ):
215            ret.body.extend(node.body)
216    return ret
217
218
219@cache
220def _parse_module(source: str) -> ast.Module:
221    """
222    Parse the AST for the source code of a module and return the ast.Module.
223
224    Returns an empty ast.Module if source is empty.
225    """
226    tree = _parse(source)
227    assert isinstance(tree, ast.Module)
228    return tree
229
230
231@cache
232def _parse_class(source: str) -> ast.ClassDef:
233    """
234    Parse the AST for the source code of a class and return the ast.ClassDef.
235
236    Returns an empty ast.ClassDef if source is empty.
237    """
238    tree = _parse(source)
239    assert len(tree.body) <= 1
240    if tree.body:
241        t = tree.body[0]
242        assert isinstance(t, ast.ClassDef)
243        return t
244    return ast.ClassDef(body=[], decorator_list=[])
245
246
247@cache
248def _parse_function(source: str) -> ast.FunctionDef | ast.AsyncFunctionDef:
249    """
250    Parse the AST for the source code of a (async) function and return the matching AST node.
251
252    Returns an empty ast.FunctionDef if source is empty.
253    """
254    tree = _parse(source)
255    assert len(tree.body) <= 1
256    if tree.body:
257        t = tree.body[0]
258        if isinstance(t, (ast.FunctionDef, ast.AsyncFunctionDef)):
259            return t
260        else:
261            # we have a lambda function,
262            # to simplify the API return the ast.FunctionDef stub.
263            pass
264    return ast.FunctionDef(body=[], decorator_list=[])
265
266
267def _parse(
268    source: str,
269) -> ast.Module | ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef:
270    try:
271        return ast.parse(_dedent(source))
272    except Exception as e:
273        warnings.warn(f"Error parsing source code: {e}\n" f"===\n" f"{source}\n" f"===")
274        return ast.parse("")
275
276
277@cache
278def _dedent(source: str) -> str:
279    """
280    Dedent the head of a function or class definition so that it can be parsed by `ast.parse`.
281    This is an alternative to `textwrap.dedent`, which does not dedent if there are docstrings
282    without indentation. For example, this is valid Python code but would not be dedented with `textwrap.dedent`:
283
284    class Foo:
285        def bar(self):
286           '''
287    this is a docstring
288           '''
289    """
290    if not source or source[0] not in (" ", "\t"):
291        return source
292    source = source.lstrip()
293    # we may have decorators before our function definition, in which case we need to dedent a few more lines.
294    # the following heuristic should be good enough to detect if we have reached the definition.
295    # it's easy to produce examples where this fails, but this probably is not a problem in practice.
296    if not any(source.startswith(x) for x in ["async ", "def ", "class "]):
297        first_line, rest = source.split("\n", 1)
298        return first_line + "\n" + _dedent(rest)
299    else:
300        return source
301
302
303@cache
304def _nodes(tree: ast.Module | ast.ClassDef) -> list[ast.AST]:
305    """
306    Returns the list of all nodes in tree's body, but also inlines the body of __init__.
307
308    This is useful to detect all declared variables in a class, even if they only appear in the constructor.
309    """
310    return list(_nodes_iter(tree))
311
312
313def _nodes_iter(tree: ast.Module | ast.ClassDef) -> Iterator[ast.AST]:
314    for a in tree.body:
315        yield a
316        if isinstance(a, ast.FunctionDef) and a.name == "__init__":
317            yield from _init_nodes(a)
318
319
320def _init_nodes(tree: ast.FunctionDef) -> Iterator[ast.AST]:
321    """
322    Transform attribute assignments like "self.foo = 42" to name assignments like "foo = 42",
323    keep all constant expressions, and no-op everything else.
324    This essentially allows us to inline __init__ when parsing a class definition.
325    """
326    for a in tree.body:
327        if (
328            isinstance(a, ast.AnnAssign)
329            and isinstance(a.target, ast.Attribute)
330            and isinstance(a.target.value, ast.Name)
331            and a.target.value.id == "self"
332        ):
333            yield ast.AnnAssign(
334                ast.Name(a.target.attr), a.annotation, a.value, simple=1
335            )
336        elif (
337            isinstance(a, ast.Assign)
338            and len(a.targets) == 1
339            and isinstance(a.targets[0], ast.Attribute)
340            and isinstance(a.targets[0].value, ast.Name)
341            and a.targets[0].value.id == "self"
342        ):
343            yield ast.Assign(
344                [ast.Name(a.targets[0].attr)],
345                value=a.value,
346                type_comment=a.type_comment,
347            )
348        elif (
349            isinstance(a, ast.Expr)
350            and isinstance(a.value, ast.Constant)
351            and isinstance(a.value.value, str)
352        ):
353            yield a
354        else:
355            yield ast.Pass()
356
357
358def _pairwise_longest(iterable: Iterable[T]) -> Iterable[tuple[T, T]]:
359    """s -> (s0,s1), (s1,s2), (s2, s3),  ..., (sN, None)"""
360    a, b = tee(iterable)
361    next(b, None)
362    return zip_longest(a, b)
def get_source(obj: Any) -> str:
32def get_source(obj: Any) -> str:
33    """
34    Returns the source code of the Python object `obj` as a str.
35    This tries to first unwrap the method if it is wrapped and then calls `inspect.getsource`.
36
37    If this fails, an empty string is returned.
38    """
39    # Some objects may not be hashable, so we fall back to the non-cached version if that is the case.
40    try:
41        return _get_source(obj)
42    except TypeError:
43        return _get_source.__wrapped__(obj)

Returns the source code of the Python object obj as a str. This tries to first unwrap the method if it is wrapped and then calls inspect.getsource.

If this fails, an empty string is returned.

def parse(obj):
69def parse(obj):
70    """
71    Parse a module, class or function and return the (unwrapped) AST node.
72    If an object's source code cannot be found, this function returns an empty ast node stub
73    which can still be walked.
74    """
75    src = get_source(obj)
76    if isinstance(obj, types.ModuleType):
77        return _parse_module(src)
78    elif isinstance(obj, type):
79        return _parse_class(src)
80    else:
81        return _parse_function(src)

Parse a module, class or function and return the (unwrapped) AST node. If an object's source code cannot be found, this function returns an empty ast node stub which can still be walked.

@cache
def unparse(tree: ast.AST):
84@cache
85def unparse(tree: ast.AST):
86    """`ast.unparse`, but cached."""
87    return ast_unparse(tree)

ast.unparse, but cached.

@dataclass
class AstInfo:
 90@dataclass
 91class AstInfo:
 92    """The information extracted from walking the syntax tree."""
 93
 94    var_docstrings: dict[str, str]
 95    """A qualname -> docstring mapping."""
 96    func_docstrings: dict[str, str]
 97    """A qualname -> docstring mapping for functions."""
 98    annotations: dict[str, str | type[pdoc.doc_types.empty]]
 99    """A qualname -> annotation mapping.
100    
101    Annotations are not evaluated by this module and only returned as strings."""

The information extracted from walking the syntax tree.

AstInfo( var_docstrings: dict[str, str], func_docstrings: dict[str, str], annotations: dict[str, str | type[inspect._empty]])
var_docstrings: dict[str, str]

A qualname -> docstring mapping.

func_docstrings: dict[str, str]

A qualname -> docstring mapping for functions.

annotations: dict[str, str | type[inspect._empty]]

A qualname -> annotation mapping.

Annotations are not evaluated by this module and only returned as strings.

def walk_tree(obj: module | type) -> AstInfo:
104def walk_tree(obj: types.ModuleType | type) -> AstInfo:
105    """
106    Walks the abstract syntax tree for `obj` and returns the extracted information.
107    """
108    return _walk_tree(parse(obj))

Walks the abstract syntax tree for obj and returns the extracted information.

def sort_by_source( obj: module | type, sorted: dict[str, ~T], unsorted: dict[str, ~T]) -> tuple[dict[str, ~T], dict[str, ~T]]:
158def sort_by_source(
159    obj: types.ModuleType | type, sorted: dict[str, T], unsorted: dict[str, T]
160) -> tuple[dict[str, T], dict[str, T]]:
161    """
162    Takes items from `unsorted` and inserts them into `sorted` in order of appearance in the source code of `obj`.
163    The only exception to this rule is `__init__`, which (if present) is always inserted first.
164
165    Some items may not be found, for example because they've been inherited from a superclass. They are returned as-is.
166
167    Returns a `(sorted, not found)` tuple.
168    """
169    tree = parse(obj)
170
171    if "__init__" in unsorted:
172        sorted["__init__"] = unsorted.pop("__init__")
173
174    for a in _nodes(tree):
175        if (
176            isinstance(a, ast.Assign)
177            and len(a.targets) == 1
178            and isinstance(a.targets[0], ast.Name)
179        ):
180            name = a.targets[0].id
181        elif (
182            isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name) and a.simple
183        ):
184            name = a.target.id
185        elif isinstance(a, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
186            name = a.name
187        else:
188            continue
189
190        if name in unsorted:
191            sorted[name] = unsorted.pop(name)
192    return sorted, unsorted

Takes items from unsorted and inserts them into sorted in order of appearance in the source code of obj. The only exception to this rule is __init__, which (if present) is always inserted first.

Some items may not be found, for example because they've been inherited from a superclass. They are returned as-is.

Returns a (sorted, not found) tuple.

def type_checking_sections(mod: module) -> ast.Module:
195def type_checking_sections(mod: types.ModuleType) -> ast.Module:
196    """
197    Walks the abstract syntax tree for `mod` and returns all statements guarded by TYPE_CHECKING blocks.
198    """
199    ret = ast.Module(body=[], type_ignores=[])
200    tree = _parse_module(get_source(mod))
201    for node in tree.body:
202        if (
203            isinstance(node, ast.If)
204            and isinstance(node.test, ast.Name)
205            and node.test.id == "TYPE_CHECKING"
206        ):
207            ret.body.extend(node.body)
208        if (
209            isinstance(node, ast.If)
210            and isinstance(node.test, ast.Attribute)
211            and isinstance(node.test.value, ast.Name)
212            # some folks do "import typing as t", the accuracy with just TYPE_CHECKING is good enough.
213            # and node.test.value.id == "typing"
214            and node.test.attr == "TYPE_CHECKING"
215        ):
216            ret.body.extend(node.body)
217    return ret

Walks the abstract syntax tree for mod and returns all statements guarded by TYPE_CHECKING blocks.