diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 2c79f4f461a63480b5fbf3eda692bebd982aab23..31d8ea192269a9a9947457814ff5e58d63f61c14 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -1,4 +1,6 @@ from __future__ import annotations + +from abc import ABC, abstractmethod from typing import Iterable, Sequence, cast from types import NoneType @@ -9,10 +11,35 @@ from ..memory import PsSymbol from .util import failing_cast -class PsBlock(PsAstNode): +class PsStructuralNode(PsAstNode, ABC): + """Base class for structural nodes in the pystencils AST. + + This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. + """ + + def clone(self): + """Clone this structure node. + + .. note:: + Subclasses of `PsStructuralNode` should not override this method, + but implement `_clone_structural` instead. + That implementation shall call `clone` on any of its children. + """ + return self._clone_structural() + + @abstractmethod + def _clone_structural(self) -> PsStructuralNode: + """Implementation of structural node cloning. + + :meta public: + """ + pass + + +class PsBlock(PsStructuralNode): __match_args__ = ("statements",) - def __init__(self, cs: Iterable[PsAstNode]): + def __init__(self, cs: Iterable[PsStructuralNode]): self._statements = list(cs) @property @@ -21,23 +48,23 @@ class PsBlock(PsAstNode): @children.setter def children(self, cs: Sequence[PsAstNode]): - self._statements = list(cs) + self._statements = list([failing_cast(PsStructuralNode, c) for c in cs]) def get_children(self) -> tuple[PsAstNode, ...]: return tuple(self._statements) def set_child(self, idx: int, c: PsAstNode): - self._statements[idx] = c + self._statements[idx] = failing_cast(PsStructuralNode, c) - def clone(self) -> PsBlock: - return PsBlock([stmt.clone() for stmt in self._statements]) + def _clone_structural(self) -> PsBlock: + return PsBlock([stmt._clone_structural() for stmt in self._statements]) @property - def statements(self) -> list[PsAstNode]: + def statements(self) -> list[PsStructuralNode]: return self._statements @statements.setter - def statements(self, stm: Sequence[PsAstNode]): + def statements(self, stm: Sequence[PsStructuralNode]): self._statements = list(stm) def __repr__(self) -> str: @@ -45,7 +72,7 @@ class PsBlock(PsAstNode): return f"PsBlock( {contents} )" -class PsStatement(PsAstNode): +class PsStatement(PsStructuralNode): __match_args__ = ("expression",) def __init__(self, expr: PsExpression): @@ -59,7 +86,7 @@ class PsStatement(PsAstNode): def expression(self, expr: PsExpression): self._expression = expr - def clone(self) -> PsStatement: + def _clone_structural(self) -> PsStatement: return PsStatement(self._expression.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -71,7 +98,7 @@ class PsStatement(PsAstNode): self._expression = failing_cast(PsExpression, c) -class PsAssignment(PsAstNode): +class PsAssignment(PsStructuralNode): __match_args__ = ( "lhs", "rhs", @@ -101,7 +128,7 @@ class PsAssignment(PsAstNode): def rhs(self, expr: PsExpression): self._rhs = expr - def clone(self) -> PsAssignment: + def _clone_structural(self) -> PsAssignment: return PsAssignment(self._lhs.clone(), self._rhs.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -141,7 +168,7 @@ class PsDeclaration(PsAssignment): def declared_symbol(self) -> PsSymbol: return cast(PsSymbolExpr, self._lhs).symbol - def clone(self) -> PsDeclaration: + def _clone_structural(self) -> PsDeclaration: return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone()) def set_child(self, idx: int, c: PsAstNode): @@ -157,7 +184,7 @@ class PsDeclaration(PsAssignment): return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})" -class PsLoop(PsAstNode): +class PsLoop(PsStructuralNode): __match_args__ = ("counter", "start", "stop", "step", "body") def __init__( @@ -214,13 +241,13 @@ class PsLoop(PsAstNode): def body(self, block: PsBlock): self._body = block - def clone(self) -> PsLoop: + def _clone_structural(self) -> PsLoop: return PsLoop( self._ctr.clone(), self._start.clone(), self._stop.clone(), self._step.clone(), - self._body.clone(), + self._body._clone_structural(), ) def get_children(self) -> tuple[PsAstNode, ...]: @@ -243,7 +270,7 @@ class PsLoop(PsAstNode): assert False, "unreachable code" -class PsConditional(PsAstNode): +class PsConditional(PsStructuralNode): """Conditional branch""" __match_args__ = ("condition", "branch_true", "branch_false") @@ -282,11 +309,11 @@ class PsConditional(PsAstNode): def branch_false(self, block: PsBlock | None): self._branch_false = block - def clone(self) -> PsConditional: + def _clone_structural(self) -> PsConditional: return PsConditional( self._condition.clone(), - self._branch_true.clone(), - self._branch_false.clone() if self._branch_false is not None else None, + self._branch_true._clone_structural(), + self._branch_false._clone_structural() if self._branch_false is not None else None, ) def get_children(self) -> tuple[PsAstNode, ...]: @@ -317,7 +344,7 @@ class PsEmptyLeafMixIn: pass -class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): +class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): """A C/C++ preprocessor pragma. Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. @@ -335,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): def text(self) -> str: return self._text - def clone(self) -> PsPragma: + def _clone_structural(self) -> PsPragma: return PsPragma(self.text) def structurally_equal(self, other: PsAstNode) -> bool: @@ -345,7 +372,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): return self._text == other._text -class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): +class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): __match_args__ = ("lines",) def __init__(self, text: str) -> None: @@ -360,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): def lines(self) -> tuple[str, ...]: return self._lines - def clone(self) -> PsComment: + def _clone_structural(self) -> PsComment: return PsComment(self._text) def structurally_equal(self, other: PsAstNode) -> bool: diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 4fd09f879dd8d98903753c8709543e0bcc3fd3e1..b3ff5aefb525ef311d0e3199c79f60c52617a853 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -26,6 +26,7 @@ from ..ast.structural import ( PsDeclaration, PsExpression, PsSymbolExpr, + PsStructuralNode, ) from ..ast.expressions import ( PsBufferAcc, @@ -107,7 +108,7 @@ class FreezeExpressions: def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode: if isinstance(obj, AssignmentCollection): - return PsBlock([self.visit(asm) for asm in obj.all_assignments]) + return PsBlock([cast(PsStructuralNode, self.visit(asm)) for asm in obj.all_assignments]) elif isinstance(obj, AssignmentBase): return cast(PsAssignment, self.visit(obj)) elif isinstance(obj, _ExprLike): diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 0e6d314acf16a5c16b6cb988f61dfaf4ba8e36f8..bd782422f1fa80b96ec7cf69473fda2b1f45c3d6 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -6,7 +6,7 @@ from collections import defaultdict from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsLoop, PsPragma +from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralNode from ..ast.expressions import PsExpression @@ -55,13 +55,12 @@ class InsertPragmasAtLoops: self._insertions[ins.loop_nesting_depth].append(ins) def __call__(self, node: PsAstNode) -> PsAstNode: - is_loop = isinstance(node, PsLoop) - if is_loop: + if isinstance(node, PsLoop): node = PsBlock([node]) self.visit(node, Nesting(0)) - if is_loop and len(node.children) == 1: + if isinstance(node, PsLoop) and len(node.children) == 1: node = node.children[0] return node @@ -72,7 +71,7 @@ class InsertPragmasAtLoops: return case PsBlock(children): - new_children: list[PsAstNode] = [] + new_children: list[PsStructuralNode] = [] for c in children: if isinstance(c, PsLoop): nest.has_inner_loops = True @@ -91,8 +90,8 @@ class InsertPragmasAtLoops: node.children = new_children case other: - for c in other.children: - self.visit(c, nest) + for child in other.children: + self.visit(child, nest) class AddOpenMP: diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py index ab4401f9ca0142d9cfeec258eeb34fb2a7f6e8eb..c793c424d2417cbbdcc0cf3782e696c7c9226bb6 100644 --- a/src/pystencils/backend/transformations/ast_vectorizer.py +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -18,6 +18,7 @@ from ..ast.structural import ( PsAssignment, PsLoop, PsEmptyLeafMixIn, + PsStructuralNode, ) from ..ast.expressions import ( PsExpression, @@ -268,6 +269,18 @@ class AstVectorizer: """ return self.visit(node, vc) + @overload + def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode: + pass + + @overload + def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression: + pass + + @overload + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: + pass + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: """Vectorize a subtree.""" diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index f098d82df1ce6a748097756aa1616a72e57487b5..69dd1dd11d726e597c15ece772846ba8cd84acba 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -1,7 +1,9 @@ +from typing import cast + from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode from ..ast.analysis import collect_undefined_symbols -from ..ast.structural import PsLoop, PsBlock, PsConditional +from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralNode from ..ast.expressions import ( PsAnd, PsCast, @@ -71,9 +73,9 @@ class EliminateBranches: ec.enclosing_loops.pop() case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: - statements_new.append(self.visit(stmt, ec)) + statements_new.append(cast(PsStructuralNode, self.visit(stmt, ec))) node.statements = statements_new case PsConditional(): diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index ab1cabc557a88b03766e6a9fb2ab44a84a5711da..3a07cb56fcb8f1c60107b5b1883c679191429e7e 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -6,7 +6,7 @@ import numpy as np from ..kernelcreation import KernelCreationContext, Typifier from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsDeclaration +from ..ast.structural import PsBlock, PsDeclaration, PsStructuralNode from ..ast.expressions import ( PsExpression, PsConstantExpr, @@ -36,6 +36,7 @@ from ..ast.expressions import ( ) from ..ast.vector import PsVecBroadcast from ..ast.util import AstEqWrapper +from ..exceptions import PsInternalCompilerError from ..constants import PsConstant from ..memory import PsSymbol @@ -138,6 +139,11 @@ class EliminateConstants: node = self.visit(node, ecc) if ecc.extractions: + if not isinstance(node, PsStructuralNode): + raise PsInternalCompilerError( + f"Cannot extract constant expressions from outermost node {node}" + ) + prepend_decls = [ PsDeclaration(PsExpression.make(symb), expr) for symb, expr in ecc.extractions diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index f0e4cc9f19f1a046125bb3e8aab5302a9df2790c..f7fe81ad736981bee6f38427fbd4face73f0c455 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -2,7 +2,7 @@ from typing import cast from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment +from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralNode from ..ast.expressions import ( PsExpression, PsSymbolExpr, @@ -99,7 +99,7 @@ class HoistLoopInvariantDeclarations: return temp_block case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: if isinstance(stmt, PsLoop): loop = stmt @@ -153,7 +153,7 @@ class HoistLoopInvariantDeclarations: return case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: if isinstance(stmt, PsLoop): loop = stmt @@ -178,7 +178,7 @@ class HoistLoopInvariantDeclarations: This method processes only statements of the given block, and any blocks directly nested inside it. It does not descend into control structures like conditionals and nested loops. """ - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralNode] = [] for node in block.statements: if isinstance(node, PsDeclaration): diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py index 59241c295f42eeaf60f4cd03a5138214fdbd6c50..8dff9e45ec283fc6c3712c2e77ff56a9b2aaeae5 100644 --- a/src/pystencils/backend/transformations/rewrite.py +++ b/src/pystencils/backend/transformations/rewrite.py @@ -2,7 +2,7 @@ from typing import overload from ..memory import PsSymbol from ..ast import PsAstNode -from ..ast.structural import PsBlock +from ..ast.structural import PsStructuralNode, PsBlock from ..ast.expressions import PsExpression, PsSymbolExpr @@ -18,6 +18,13 @@ def substitute_symbols( pass +@overload +def substitute_symbols( + node: PsStructuralNode, subs: dict[PsSymbol, PsExpression] +) -> PsStructuralNode: + pass + + @overload def substitute_symbols( node: PsAstNode, subs: dict[PsSymbol, PsExpression]