Source code for drepr.utils.udf

from __future__ import annotations

import ast
import re
from dataclasses import dataclass
from typing import Optional


[docs]@dataclass class SourceTree: node: str children: list[SourceTree]
[docs] def get_simplified_dict(self): children = [ child.get_simplified_dict() if len(child.children) > 0 else child.node for child in self.children ] if self.node == "": return children return { "node": self.node, "children": children, }
[docs]@dataclass class UDFParsedResult: # import statements imports: list[str] # set of variables that we opt to monitor during parsing (e.g., we want to detect a variable named context) monitor_variables: set[str] # the source code of the UDF -- without import statements source_tree: SourceTree
[docs]class UDFParser: def __init__(self, source_code: str): if source_code.strip() == "": raise ValueError(f"Cannot parse an empty code") self.source_code = self.remove_prefix_spaces(source_code) self.source_code_lines = self.source_code.split("\n") self.source_tree = ast.parse(self.source_code)
[docs] def parse(self, monitor_vars: Optional[list[str]] = None) -> UDFParsedResult: imports = [] tree = SourceTree("", []) for stmt in self.source_tree.body: tree.children.extend(self._parse_ast(stmt, imports)) found_vars = set() if monitor_vars is not None and len(monitor_vars) > 0: # we need to find all variables that we want to monitor for node in ast.walk(self.source_tree): if isinstance(node, ast.Name) and node.id in monitor_vars: found_vars.add(node.id) return UDFParsedResult( imports=imports, monitor_variables=found_vars, source_tree=tree, )
def _parse_ast(self, tree: ast.AST, imports: list[str]) -> list[SourceTree]: if isinstance(tree, (ast.Import, ast.ImportFrom)): imports.append(self._get_node_code(tree)) return [] if isinstance( tree, ( ast.Expr, ast.Return, ast.Yield, ast.YieldFrom, ast.Assign, ast.Assert, ast.AugAssign, ), ): return [SourceTree(self._get_node_code(tree), [])] if isinstance(tree, ast.If): content = f"if {self._get_node_code(tree.test)}:" out = [SourceTree(content, [])] for stmt in tree.body: out[0].children.extend(self._parse_ast(stmt, imports)) if len(tree.orelse) > 0: out.append(SourceTree("else:", [])) for stmt in tree.orelse: out[1].children.extend(self._parse_ast(stmt, imports)) return out if isinstance(tree, ast.For): content = f"for {self._get_node_code(tree.target)} in {self._get_node_code(tree.iter)}:" out = [SourceTree(content, [])] for stmt in tree.body: out[0].children.extend(self._parse_ast(stmt, imports)) if len(tree.orelse) > 0: out.append(SourceTree("else:", [])) for stmt in tree.orelse: out[1].children.extend(self._parse_ast(stmt, imports)) return out if isinstance(tree, ast.Continue): return [SourceTree("continue", [])] if isinstance(tree, ast.Break): return [SourceTree("break", [])] if isinstance(tree, ast.Try): out = [SourceTree("try:", [])] for stmt in tree.body: out[0].children.extend(self._parse_ast(stmt, imports)) if len(tree.handlers) > 0: for handler in tree.handlers: except_args = ["except"] if handler.type is not None: except_args.append(self._get_node_code(handler.type)) if handler.name is not None: except_args.append(handler.name) out.append(SourceTree(" ".join(except_args) + ":", [])) for stmt in handler.body: out[-1].children.extend(self._parse_ast(stmt, imports)) if len(tree.orelse) > 0: out.append(SourceTree("else:", [])) for stmt in tree.orelse: out[-1].children.extend(self._parse_ast(stmt, imports)) if len(tree.finalbody) > 0: out.append(SourceTree("finally:", [])) for stmt in tree.finalbody: out[-1].children.extend(self._parse_ast(stmt, imports)) return out if isinstance(tree, ast.FunctionDef): fnargs = ", ".join([self._get_node_code(arg) for arg in tree.args.args]) if tree.returns is not None: returns = f" -> {self._get_node_code(tree.returns)}" else: returns = "" out = [SourceTree(f"def {tree.name}({fnargs}){returns}:", [])] for stmt in tree.body: out[0].children.extend(self._parse_ast(stmt, imports)) return out raise NotImplementedError(type(tree)) def _get_node_code(self, node: ast.AST) -> str: lines = self.source_code_lines[node.lineno - 1 : node.end_lineno] if len(lines) == 1: return lines[0][node.col_offset : node.end_col_offset] lines[0] = lines[0][node.col_offset :] lines[-1] = lines[-1][: node.end_col_offset] return "\n".join(lines)
[docs] def remove_prefix_spaces(self, code: str) -> str: lines = [x.rstrip() for x in code.splitlines()] non_empty_line_no = next(i for i in range(len(lines)) if lines[i] != "") lines = lines[non_empty_line_no:] assert len(lines) > 0 m = re.match(r"^([ \t]*)", lines[0]) assert m is not None indentation = m.group(1) if not all(x.startswith(indentation) or x.strip() == "" for x in lines): raise ValueError( f"The code has inconsistent prefix spaces. The first line has {indentation} spaces, but the following lines do not have the same prefix spaces" ) return "\n".join(x[len(indentation) :] for x in lines)