diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 2c79f4f461a63480b5fbf3eda692bebd982aab23..31d8ea192269a9a9947457814ff5e58d63f61c14 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -1,4 +1,6 @@
 from __future__ import annotations
+
+from abc import ABC, abstractmethod
 from typing import Iterable, Sequence, cast
 from types import NoneType
 
@@ -9,10 +11,35 @@ from ..memory import PsSymbol
 from .util import failing_cast
 
 
-class PsBlock(PsAstNode):
+class PsStructuralNode(PsAstNode, ABC):
+    """Base class for structural nodes in the pystencils AST.
+
+    This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
+    """
+
+    def clone(self):
+        """Clone this structure node.
+
+        .. note::
+            Subclasses of `PsStructuralNode` should not override this method,
+            but implement `_clone_structural` instead.
+            That implementation shall call `clone` on any of its children.
+        """
+        return self._clone_structural()
+
+    @abstractmethod
+    def _clone_structural(self) -> PsStructuralNode:
+        """Implementation of structural node cloning.
+
+        :meta public:
+        """
+        pass
+
+
+class PsBlock(PsStructuralNode):
     __match_args__ = ("statements",)
 
-    def __init__(self, cs: Iterable[PsAstNode]):
+    def __init__(self, cs: Iterable[PsStructuralNode]):
         self._statements = list(cs)
 
     @property
@@ -21,23 +48,23 @@ class PsBlock(PsAstNode):
 
     @children.setter
     def children(self, cs: Sequence[PsAstNode]):
-        self._statements = list(cs)
+        self._statements = list([failing_cast(PsStructuralNode, c) for c in cs])
 
     def get_children(self) -> tuple[PsAstNode, ...]:
         return tuple(self._statements)
 
     def set_child(self, idx: int, c: PsAstNode):
-        self._statements[idx] = c
+        self._statements[idx] = failing_cast(PsStructuralNode, c)
 
-    def clone(self) -> PsBlock:
-        return PsBlock([stmt.clone() for stmt in self._statements])
+    def _clone_structural(self) -> PsBlock:
+        return PsBlock([stmt._clone_structural() for stmt in self._statements])
 
     @property
-    def statements(self) -> list[PsAstNode]:
+    def statements(self) -> list[PsStructuralNode]:
         return self._statements
 
     @statements.setter
-    def statements(self, stm: Sequence[PsAstNode]):
+    def statements(self, stm: Sequence[PsStructuralNode]):
         self._statements = list(stm)
 
     def __repr__(self) -> str:
@@ -45,7 +72,7 @@ class PsBlock(PsAstNode):
         return f"PsBlock( {contents} )"
 
 
-class PsStatement(PsAstNode):
+class PsStatement(PsStructuralNode):
     __match_args__ = ("expression",)
 
     def __init__(self, expr: PsExpression):
@@ -59,7 +86,7 @@ class PsStatement(PsAstNode):
     def expression(self, expr: PsExpression):
         self._expression = expr
 
-    def clone(self) -> PsStatement:
+    def _clone_structural(self) -> PsStatement:
         return PsStatement(self._expression.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -71,7 +98,7 @@ class PsStatement(PsAstNode):
         self._expression = failing_cast(PsExpression, c)
 
 
-class PsAssignment(PsAstNode):
+class PsAssignment(PsStructuralNode):
     __match_args__ = (
         "lhs",
         "rhs",
@@ -101,7 +128,7 @@ class PsAssignment(PsAstNode):
     def rhs(self, expr: PsExpression):
         self._rhs = expr
 
-    def clone(self) -> PsAssignment:
+    def _clone_structural(self) -> PsAssignment:
         return PsAssignment(self._lhs.clone(), self._rhs.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -141,7 +168,7 @@ class PsDeclaration(PsAssignment):
     def declared_symbol(self) -> PsSymbol:
         return cast(PsSymbolExpr, self._lhs).symbol
 
-    def clone(self) -> PsDeclaration:
+    def _clone_structural(self) -> PsDeclaration:
         return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
 
     def set_child(self, idx: int, c: PsAstNode):
@@ -157,7 +184,7 @@ class PsDeclaration(PsAssignment):
         return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})"
 
 
-class PsLoop(PsAstNode):
+class PsLoop(PsStructuralNode):
     __match_args__ = ("counter", "start", "stop", "step", "body")
 
     def __init__(
@@ -214,13 +241,13 @@ class PsLoop(PsAstNode):
     def body(self, block: PsBlock):
         self._body = block
 
-    def clone(self) -> PsLoop:
+    def _clone_structural(self) -> PsLoop:
         return PsLoop(
             self._ctr.clone(),
             self._start.clone(),
             self._stop.clone(),
             self._step.clone(),
-            self._body.clone(),
+            self._body._clone_structural(),
         )
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -243,7 +270,7 @@ class PsLoop(PsAstNode):
                 assert False, "unreachable code"
 
 
-class PsConditional(PsAstNode):
+class PsConditional(PsStructuralNode):
     """Conditional branch"""
 
     __match_args__ = ("condition", "branch_true", "branch_false")
@@ -282,11 +309,11 @@ class PsConditional(PsAstNode):
     def branch_false(self, block: PsBlock | None):
         self._branch_false = block
 
-    def clone(self) -> PsConditional:
+    def _clone_structural(self) -> PsConditional:
         return PsConditional(
             self._condition.clone(),
-            self._branch_true.clone(),
-            self._branch_false.clone() if self._branch_false is not None else None,
+            self._branch_true._clone_structural(),
+            self._branch_false._clone_structural() if self._branch_false is not None else None,
         )
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -317,7 +344,7 @@ class PsEmptyLeafMixIn:
     pass
 
 
-class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
+class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     """A C/C++ preprocessor pragma.
 
     Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
@@ -335,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
     def text(self) -> str:
         return self._text
 
-    def clone(self) -> PsPragma:
+    def _clone_structural(self) -> PsPragma:
         return PsPragma(self.text)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
@@ -345,7 +372,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
         return self._text == other._text
 
 
-class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
+class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     __match_args__ = ("lines",)
 
     def __init__(self, text: str) -> None:
@@ -360,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
     def lines(self) -> tuple[str, ...]:
         return self._lines
 
-    def clone(self) -> PsComment:
+    def _clone_structural(self) -> PsComment:
         return PsComment(self._text)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 4fd09f879dd8d98903753c8709543e0bcc3fd3e1..b3ff5aefb525ef311d0e3199c79f60c52617a853 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -26,6 +26,7 @@ from ..ast.structural import (
     PsDeclaration,
     PsExpression,
     PsSymbolExpr,
+    PsStructuralNode,
 )
 from ..ast.expressions import (
     PsBufferAcc,
@@ -107,7 +108,7 @@ class FreezeExpressions:
 
     def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode:
         if isinstance(obj, AssignmentCollection):
-            return PsBlock([self.visit(asm) for asm in obj.all_assignments])
+            return PsBlock([cast(PsStructuralNode, self.visit(asm)) for asm in obj.all_assignments])
         elif isinstance(obj, AssignmentBase):
             return cast(PsAssignment, self.visit(obj))
         elif isinstance(obj, _ExprLike):
diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index 0e6d314acf16a5c16b6cb988f61dfaf4ba8e36f8..bd782422f1fa80b96ec7cf69473fda2b1f45c3d6 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -6,7 +6,7 @@ from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsLoop, PsPragma
+from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralNode
 from ..ast.expressions import PsExpression
 
 
@@ -55,13 +55,12 @@ class InsertPragmasAtLoops:
             self._insertions[ins.loop_nesting_depth].append(ins)
 
     def __call__(self, node: PsAstNode) -> PsAstNode:
-        is_loop = isinstance(node, PsLoop)
-        if is_loop:
+        if isinstance(node, PsLoop):
             node = PsBlock([node])
 
         self.visit(node, Nesting(0))
 
-        if is_loop and len(node.children) == 1:
+        if isinstance(node, PsLoop) and len(node.children) == 1:
             node = node.children[0]
 
         return node
@@ -72,7 +71,7 @@ class InsertPragmasAtLoops:
                 return
 
             case PsBlock(children):
-                new_children: list[PsAstNode] = []
+                new_children: list[PsStructuralNode] = []
                 for c in children:
                     if isinstance(c, PsLoop):
                         nest.has_inner_loops = True
@@ -91,8 +90,8 @@ class InsertPragmasAtLoops:
                 node.children = new_children
 
             case other:
-                for c in other.children:
-                    self.visit(c, nest)
+                for child in other.children:
+                    self.visit(child, nest)
 
 
 class AddOpenMP:
diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py
index ab4401f9ca0142d9cfeec258eeb34fb2a7f6e8eb..c793c424d2417cbbdcc0cf3782e696c7c9226bb6 100644
--- a/src/pystencils/backend/transformations/ast_vectorizer.py
+++ b/src/pystencils/backend/transformations/ast_vectorizer.py
@@ -18,6 +18,7 @@ from ..ast.structural import (
     PsAssignment,
     PsLoop,
     PsEmptyLeafMixIn,
+    PsStructuralNode,
 )
 from ..ast.expressions import (
     PsExpression,
@@ -268,6 +269,18 @@ class AstVectorizer:
         """
         return self.visit(node, vc)
 
+    @overload
+    def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode:
+        pass
+
+    @overload
+    def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression:
+        pass
+
+    @overload
+    def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
+        pass
+
     def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
         """Vectorize a subtree."""
 
diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py
index f098d82df1ce6a748097756aa1616a72e57487b5..69dd1dd11d726e597c15ece772846ba8cd84acba 100644
--- a/src/pystencils/backend/transformations/eliminate_branches.py
+++ b/src/pystencils/backend/transformations/eliminate_branches.py
@@ -1,7 +1,9 @@
+from typing import cast
+
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
 from ..ast.analysis import collect_undefined_symbols
-from ..ast.structural import PsLoop, PsBlock, PsConditional
+from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralNode
 from ..ast.expressions import (
     PsAnd,
     PsCast,
@@ -71,9 +73,9 @@ class EliminateBranches:
                 ec.enclosing_loops.pop()
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
-                    statements_new.append(self.visit(stmt, ec))
+                    statements_new.append(cast(PsStructuralNode, self.visit(stmt, ec)))
                 node.statements = statements_new
 
             case PsConditional():
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index ab1cabc557a88b03766e6a9fb2ab44a84a5711da..3a07cb56fcb8f1c60107b5b1883c679191429e7e 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -6,7 +6,7 @@ import numpy as np
 from ..kernelcreation import KernelCreationContext, Typifier
 
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsDeclaration
+from ..ast.structural import PsBlock, PsDeclaration, PsStructuralNode
 from ..ast.expressions import (
     PsExpression,
     PsConstantExpr,
@@ -36,6 +36,7 @@ from ..ast.expressions import (
 )
 from ..ast.vector import PsVecBroadcast
 from ..ast.util import AstEqWrapper
+from ..exceptions import PsInternalCompilerError
 
 from ..constants import PsConstant
 from ..memory import PsSymbol
@@ -138,6 +139,11 @@ class EliminateConstants:
         node = self.visit(node, ecc)
 
         if ecc.extractions:
+            if not isinstance(node, PsStructuralNode):
+                raise PsInternalCompilerError(
+                    f"Cannot extract constant expressions from outermost node {node}"
+                )
+
             prepend_decls = [
                 PsDeclaration(PsExpression.make(symb), expr)
                 for symb, expr in ecc.extractions
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index f0e4cc9f19f1a046125bb3e8aab5302a9df2790c..f7fe81ad736981bee6f38427fbd4face73f0c455 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -2,7 +2,7 @@ from typing import cast
 
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment
+from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralNode
 from ..ast.expressions import (
     PsExpression,
     PsSymbolExpr,
@@ -99,7 +99,7 @@ class HoistLoopInvariantDeclarations:
                     return temp_block
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
                     if isinstance(stmt, PsLoop):
                         loop = stmt
@@ -153,7 +153,7 @@ class HoistLoopInvariantDeclarations:
                 return
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
                     if isinstance(stmt, PsLoop):
                         loop = stmt
@@ -178,7 +178,7 @@ class HoistLoopInvariantDeclarations:
         This method processes only statements of the given block, and any blocks directly nested inside it.
         It does not descend into control structures like conditionals and nested loops.
         """
-        statements_new: list[PsAstNode] = []
+        statements_new: list[PsStructuralNode] = []
 
         for node in block.statements:
             if isinstance(node, PsDeclaration):
diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py
index 59241c295f42eeaf60f4cd03a5138214fdbd6c50..8dff9e45ec283fc6c3712c2e77ff56a9b2aaeae5 100644
--- a/src/pystencils/backend/transformations/rewrite.py
+++ b/src/pystencils/backend/transformations/rewrite.py
@@ -2,7 +2,7 @@ from typing import overload
 
 from ..memory import PsSymbol
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock
+from ..ast.structural import PsStructuralNode, PsBlock
 from ..ast.expressions import PsExpression, PsSymbolExpr
 
 
@@ -18,6 +18,13 @@ def substitute_symbols(
     pass
 
 
+@overload
+def substitute_symbols(
+    node: PsStructuralNode, subs: dict[PsSymbol, PsExpression]
+) -> PsStructuralNode:
+    pass
+
+
 @overload
 def substitute_symbols(
     node: PsAstNode, subs: dict[PsSymbol, PsExpression]