forked from bton/matekasse
203 lines
7.6 KiB
Python
203 lines
7.6 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import re
|
||
|
import typing as t
|
||
|
from dataclasses import dataclass
|
||
|
from dataclasses import field
|
||
|
|
||
|
from .converters import ValidationError
|
||
|
from .exceptions import NoMatch
|
||
|
from .exceptions import RequestAliasRedirect
|
||
|
from .exceptions import RequestPath
|
||
|
from .rules import Rule
|
||
|
from .rules import RulePart
|
||
|
|
||
|
|
||
|
class SlashRequired(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class State:
|
||
|
"""A representation of a rule state.
|
||
|
|
||
|
This includes the *rules* that correspond to the state and the
|
||
|
possible *static* and *dynamic* transitions to the next state.
|
||
|
"""
|
||
|
|
||
|
dynamic: list[tuple[RulePart, State]] = field(default_factory=list)
|
||
|
rules: list[Rule] = field(default_factory=list)
|
||
|
static: dict[str, State] = field(default_factory=dict)
|
||
|
|
||
|
|
||
|
class StateMachineMatcher:
|
||
|
def __init__(self, merge_slashes: bool) -> None:
|
||
|
self._root = State()
|
||
|
self.merge_slashes = merge_slashes
|
||
|
|
||
|
def add(self, rule: Rule) -> None:
|
||
|
state = self._root
|
||
|
for part in rule._parts:
|
||
|
if part.static:
|
||
|
state.static.setdefault(part.content, State())
|
||
|
state = state.static[part.content]
|
||
|
else:
|
||
|
for test_part, new_state in state.dynamic:
|
||
|
if test_part == part:
|
||
|
state = new_state
|
||
|
break
|
||
|
else:
|
||
|
new_state = State()
|
||
|
state.dynamic.append((part, new_state))
|
||
|
state = new_state
|
||
|
state.rules.append(rule)
|
||
|
|
||
|
def update(self) -> None:
|
||
|
# For every state the dynamic transitions should be sorted by
|
||
|
# the weight of the transition
|
||
|
state = self._root
|
||
|
|
||
|
def _update_state(state: State) -> None:
|
||
|
state.dynamic.sort(key=lambda entry: entry[0].weight)
|
||
|
for new_state in state.static.values():
|
||
|
_update_state(new_state)
|
||
|
for _, new_state in state.dynamic:
|
||
|
_update_state(new_state)
|
||
|
|
||
|
_update_state(state)
|
||
|
|
||
|
def match(
|
||
|
self, domain: str, path: str, method: str, websocket: bool
|
||
|
) -> tuple[Rule, t.MutableMapping[str, t.Any]]:
|
||
|
# To match to a rule we need to start at the root state and
|
||
|
# try to follow the transitions until we find a match, or find
|
||
|
# there is no transition to follow.
|
||
|
|
||
|
have_match_for = set()
|
||
|
websocket_mismatch = False
|
||
|
|
||
|
def _match(
|
||
|
state: State, parts: list[str], values: list[str]
|
||
|
) -> tuple[Rule, list[str]] | None:
|
||
|
# This function is meant to be called recursively, and will attempt
|
||
|
# to match the head part to the state's transitions.
|
||
|
nonlocal have_match_for, websocket_mismatch
|
||
|
|
||
|
# The base case is when all parts have been matched via
|
||
|
# transitions. Hence if there is a rule with methods &
|
||
|
# websocket that work return it and the dynamic values
|
||
|
# extracted.
|
||
|
if parts == []:
|
||
|
for rule in state.rules:
|
||
|
if rule.methods is not None and method not in rule.methods:
|
||
|
have_match_for.update(rule.methods)
|
||
|
elif rule.websocket != websocket:
|
||
|
websocket_mismatch = True
|
||
|
else:
|
||
|
return rule, values
|
||
|
|
||
|
# Test if there is a match with this path with a
|
||
|
# trailing slash, if so raise an exception to report
|
||
|
# that matching is possible with an additional slash
|
||
|
if "" in state.static:
|
||
|
for rule in state.static[""].rules:
|
||
|
if websocket == rule.websocket and (
|
||
|
rule.methods is None or method in rule.methods
|
||
|
):
|
||
|
if rule.strict_slashes:
|
||
|
raise SlashRequired()
|
||
|
else:
|
||
|
return rule, values
|
||
|
return None
|
||
|
|
||
|
part = parts[0]
|
||
|
# To match this part try the static transitions first
|
||
|
if part in state.static:
|
||
|
rv = _match(state.static[part], parts[1:], values)
|
||
|
if rv is not None:
|
||
|
return rv
|
||
|
# No match via the static transitions, so try the dynamic
|
||
|
# ones.
|
||
|
for test_part, new_state in state.dynamic:
|
||
|
target = part
|
||
|
remaining = parts[1:]
|
||
|
# A final part indicates a transition that always
|
||
|
# consumes the remaining parts i.e. transitions to a
|
||
|
# final state.
|
||
|
if test_part.final:
|
||
|
target = "/".join(parts)
|
||
|
remaining = []
|
||
|
match = re.compile(test_part.content).match(target)
|
||
|
if match is not None:
|
||
|
if test_part.suffixed:
|
||
|
# If a part_isolating=False part has a slash suffix, remove the
|
||
|
# suffix from the match and check for the slash redirect next.
|
||
|
suffix = match.groups()[-1]
|
||
|
if suffix == "/":
|
||
|
remaining = [""]
|
||
|
|
||
|
converter_groups = sorted(
|
||
|
match.groupdict().items(), key=lambda entry: entry[0]
|
||
|
)
|
||
|
groups = [
|
||
|
value
|
||
|
for key, value in converter_groups
|
||
|
if key[:11] == "__werkzeug_"
|
||
|
]
|
||
|
rv = _match(new_state, remaining, values + groups)
|
||
|
if rv is not None:
|
||
|
return rv
|
||
|
|
||
|
# If there is no match and the only part left is a
|
||
|
# trailing slash ("") consider rules that aren't
|
||
|
# strict-slashes as these should match if there is a final
|
||
|
# slash part.
|
||
|
if parts == [""]:
|
||
|
for rule in state.rules:
|
||
|
if rule.strict_slashes:
|
||
|
continue
|
||
|
if rule.methods is not None and method not in rule.methods:
|
||
|
have_match_for.update(rule.methods)
|
||
|
elif rule.websocket != websocket:
|
||
|
websocket_mismatch = True
|
||
|
else:
|
||
|
return rule, values
|
||
|
|
||
|
return None
|
||
|
|
||
|
try:
|
||
|
rv = _match(self._root, [domain, *path.split("/")], [])
|
||
|
except SlashRequired:
|
||
|
raise RequestPath(f"{path}/") from None
|
||
|
|
||
|
if self.merge_slashes and rv is None:
|
||
|
# Try to match again, but with slashes merged
|
||
|
path = re.sub("/{2,}?", "/", path)
|
||
|
try:
|
||
|
rv = _match(self._root, [domain, *path.split("/")], [])
|
||
|
except SlashRequired:
|
||
|
raise RequestPath(f"{path}/") from None
|
||
|
if rv is None:
|
||
|
raise NoMatch(have_match_for, websocket_mismatch)
|
||
|
else:
|
||
|
raise RequestPath(f"{path}")
|
||
|
elif rv is not None:
|
||
|
rule, values = rv
|
||
|
|
||
|
result = {}
|
||
|
for name, value in zip(rule._converters.keys(), values):
|
||
|
try:
|
||
|
value = rule._converters[name].to_python(value)
|
||
|
except ValidationError:
|
||
|
raise NoMatch(have_match_for, websocket_mismatch) from None
|
||
|
result[str(name)] = value
|
||
|
if rule.defaults:
|
||
|
result.update(rule.defaults)
|
||
|
|
||
|
if rule.alias and rule.map.redirect_defaults:
|
||
|
raise RequestAliasRedirect(result, rule.endpoint)
|
||
|
|
||
|
return rule, result
|
||
|
|
||
|
raise NoMatch(have_match_for, websocket_mismatch)
|