218 lines
7.3 KiB
Python
218 lines
7.3 KiB
Python
|
import ast
|
||
|
import inspect
|
||
|
import textwrap
|
||
|
import tokenize
|
||
|
import types
|
||
|
import warnings
|
||
|
from bisect import bisect_right
|
||
|
from typing import Iterable
|
||
|
from typing import Iterator
|
||
|
from typing import List
|
||
|
from typing import Optional
|
||
|
from typing import overload
|
||
|
from typing import Tuple
|
||
|
from typing import Union
|
||
|
|
||
|
|
||
|
class Source:
|
||
|
"""An immutable object holding a source code fragment.
|
||
|
|
||
|
When using Source(...), the source lines are deindented.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, obj: object = None) -> None:
|
||
|
if not obj:
|
||
|
self.lines: List[str] = []
|
||
|
elif isinstance(obj, Source):
|
||
|
self.lines = obj.lines
|
||
|
elif isinstance(obj, (tuple, list)):
|
||
|
self.lines = deindent(x.rstrip("\n") for x in obj)
|
||
|
elif isinstance(obj, str):
|
||
|
self.lines = deindent(obj.split("\n"))
|
||
|
else:
|
||
|
try:
|
||
|
rawcode = getrawcode(obj)
|
||
|
src = inspect.getsource(rawcode)
|
||
|
except TypeError:
|
||
|
src = inspect.getsource(obj) # type: ignore[arg-type]
|
||
|
self.lines = deindent(src.split("\n"))
|
||
|
|
||
|
def __eq__(self, other: object) -> bool:
|
||
|
if not isinstance(other, Source):
|
||
|
return NotImplemented
|
||
|
return self.lines == other.lines
|
||
|
|
||
|
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
||
|
__hash__ = None # type: ignore
|
||
|
|
||
|
@overload
|
||
|
def __getitem__(self, key: int) -> str:
|
||
|
...
|
||
|
|
||
|
@overload
|
||
|
def __getitem__(self, key: slice) -> "Source":
|
||
|
...
|
||
|
|
||
|
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]:
|
||
|
if isinstance(key, int):
|
||
|
return self.lines[key]
|
||
|
else:
|
||
|
if key.step not in (None, 1):
|
||
|
raise IndexError("cannot slice a Source with a step")
|
||
|
newsource = Source()
|
||
|
newsource.lines = self.lines[key.start : key.stop]
|
||
|
return newsource
|
||
|
|
||
|
def __iter__(self) -> Iterator[str]:
|
||
|
return iter(self.lines)
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self.lines)
|
||
|
|
||
|
def strip(self) -> "Source":
|
||
|
"""Return new Source object with trailing and leading blank lines removed."""
|
||
|
start, end = 0, len(self)
|
||
|
while start < end and not self.lines[start].strip():
|
||
|
start += 1
|
||
|
while end > start and not self.lines[end - 1].strip():
|
||
|
end -= 1
|
||
|
source = Source()
|
||
|
source.lines[:] = self.lines[start:end]
|
||
|
return source
|
||
|
|
||
|
def indent(self, indent: str = " " * 4) -> "Source":
|
||
|
"""Return a copy of the source object with all lines indented by the
|
||
|
given indent-string."""
|
||
|
newsource = Source()
|
||
|
newsource.lines = [(indent + line) for line in self.lines]
|
||
|
return newsource
|
||
|
|
||
|
def getstatement(self, lineno: int) -> "Source":
|
||
|
"""Return Source statement which contains the given linenumber
|
||
|
(counted from 0)."""
|
||
|
start, end = self.getstatementrange(lineno)
|
||
|
return self[start:end]
|
||
|
|
||
|
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
|
||
|
"""Return (start, end) tuple which spans the minimal statement region
|
||
|
which containing the given lineno."""
|
||
|
if not (0 <= lineno < len(self)):
|
||
|
raise IndexError("lineno out of range")
|
||
|
ast, start, end = getstatementrange_ast(lineno, self)
|
||
|
return start, end
|
||
|
|
||
|
def deindent(self) -> "Source":
|
||
|
"""Return a new Source object deindented."""
|
||
|
newsource = Source()
|
||
|
newsource.lines[:] = deindent(self.lines)
|
||
|
return newsource
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return "\n".join(self.lines)
|
||
|
|
||
|
|
||
|
#
|
||
|
# helper functions
|
||
|
#
|
||
|
|
||
|
|
||
|
def findsource(obj) -> Tuple[Optional[Source], int]:
|
||
|
try:
|
||
|
sourcelines, lineno = inspect.findsource(obj)
|
||
|
except Exception:
|
||
|
return None, -1
|
||
|
source = Source()
|
||
|
source.lines = [line.rstrip() for line in sourcelines]
|
||
|
return source, lineno
|
||
|
|
||
|
|
||
|
def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
|
||
|
"""Return code object for given function."""
|
||
|
try:
|
||
|
return obj.__code__ # type: ignore[attr-defined,no-any-return]
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
if trycall:
|
||
|
call = getattr(obj, "__call__", None)
|
||
|
if call and not isinstance(obj, type):
|
||
|
return getrawcode(call, trycall=False)
|
||
|
raise TypeError(f"could not get code object for {obj!r}")
|
||
|
|
||
|
|
||
|
def deindent(lines: Iterable[str]) -> List[str]:
|
||
|
return textwrap.dedent("\n".join(lines)).splitlines()
|
||
|
|
||
|
|
||
|
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
|
||
|
# Flatten all statements and except handlers into one lineno-list.
|
||
|
# AST's line numbers start indexing at 1.
|
||
|
values: List[int] = []
|
||
|
for x in ast.walk(node):
|
||
|
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
|
||
|
# Before Python 3.8, the lineno of a decorated class or function pointed at the decorator.
|
||
|
# Since Python 3.8, the lineno points to the class/def, so need to include the decorators.
|
||
|
if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
|
||
|
for d in x.decorator_list:
|
||
|
values.append(d.lineno - 1)
|
||
|
values.append(x.lineno - 1)
|
||
|
for name in ("finalbody", "orelse"):
|
||
|
val: Optional[List[ast.stmt]] = getattr(x, name, None)
|
||
|
if val:
|
||
|
# Treat the finally/orelse part as its own statement.
|
||
|
values.append(val[0].lineno - 1 - 1)
|
||
|
values.sort()
|
||
|
insert_index = bisect_right(values, lineno)
|
||
|
start = values[insert_index - 1]
|
||
|
if insert_index >= len(values):
|
||
|
end = None
|
||
|
else:
|
||
|
end = values[insert_index]
|
||
|
return start, end
|
||
|
|
||
|
|
||
|
def getstatementrange_ast(
|
||
|
lineno: int,
|
||
|
source: Source,
|
||
|
assertion: bool = False,
|
||
|
astnode: Optional[ast.AST] = None,
|
||
|
) -> Tuple[ast.AST, int, int]:
|
||
|
if astnode is None:
|
||
|
content = str(source)
|
||
|
# See #4260:
|
||
|
# Don't produce duplicate warnings when compiling source to find AST.
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.simplefilter("ignore")
|
||
|
astnode = ast.parse(content, "source", "exec")
|
||
|
|
||
|
start, end = get_statement_startend2(lineno, astnode)
|
||
|
# We need to correct the end:
|
||
|
# - ast-parsing strips comments
|
||
|
# - there might be empty lines
|
||
|
# - we might have lesser indented code blocks at the end
|
||
|
if end is None:
|
||
|
end = len(source.lines)
|
||
|
|
||
|
if end > start + 1:
|
||
|
# Make sure we don't span differently indented code blocks
|
||
|
# by using the BlockFinder helper used which inspect.getsource() uses itself.
|
||
|
block_finder = inspect.BlockFinder()
|
||
|
# If we start with an indented line, put blockfinder to "started" mode.
|
||
|
block_finder.started = source.lines[start][0].isspace()
|
||
|
it = ((x + "\n") for x in source.lines[start:end])
|
||
|
try:
|
||
|
for tok in tokenize.generate_tokens(lambda: next(it)):
|
||
|
block_finder.tokeneater(*tok)
|
||
|
except (inspect.EndOfBlock, IndentationError):
|
||
|
end = block_finder.last + start
|
||
|
except Exception:
|
||
|
pass
|
||
|
|
||
|
# The end might still point to a comment or empty line, correct it.
|
||
|
while end:
|
||
|
line = source.lines[end - 1].lstrip()
|
||
|
if line.startswith("#") or not line:
|
||
|
end -= 1
|
||
|
else:
|
||
|
break
|
||
|
return astnode, start, end
|