From 9aba6a1ea5323eadcc97656b479bc03086571efb Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 6 Feb 2025 15:56:26 +0100 Subject: [PATCH 01/18] refactor modelling of code entities and elements --- src/pystencilssfg/composer/basic_composer.py | 22 +- src/pystencilssfg/ir/postprocessing.py | 21 +- src/pystencilssfg/ir/source_components.py | 596 +++++++++++-------- src/pystencilssfg/lang/__init__.py | 2 + src/pystencilssfg/lang/expressions.py | 18 + 5 files changed, 392 insertions(+), 267 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index b96d559..db671b9 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod import numpy as np import sympy as sp from functools import reduce +from warnings import warn from pystencils import Field from pystencils.codegen import Kernel @@ -252,7 +253,13 @@ class SfgBasicComposer(SfgIComposer): func = SfgFunction(name, tree) self._ctx.add_function(func) - def function(self, name: str, return_type: UserTypeSpec = void): + def function( + self, + name: str, + returns: UserTypeSpec = void, + inline: bool = False, + return_type: UserTypeSpec | None = None, + ): """Add a function. The syntax of this function adder uses a chain of two calls to mimic C++ syntax: @@ -265,12 +272,23 @@ class SfgBasicComposer(SfgIComposer): The function body is constructed via sequencing (see `make_sequence`). """ + if return_type is not None: + warn( + "The parameter `return_type` to `function()` is deprecated and will be removed by version 0.1. " + "Setting it will override the value of the `returns` parameter. " + "Use `returns` instead.", + FutureWarning, + ) + returns = return_type + if self._ctx.get_function(name) is not None: raise ValueError(f"Function {name} already exists.") def sequencer(*args: SequencerArg): tree = make_sequence(*args) - func = SfgFunction(name, tree, return_type=create_type(return_type)) + func = SfgFunction( + name, tree, return_type=create_type(returns), inline=inline + ) self._ctx.add_function(func) return sequencer diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index aa3cd27..db26a38 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -29,7 +29,6 @@ from ..lang import ( if TYPE_CHECKING: from ..context import SfgContext - from .source_components import SfgClass class FlattenSequences: @@ -65,19 +64,9 @@ class FlattenSequences: class PostProcessingContext: - def __init__(self, enclosing_class: SfgClass | None = None) -> None: - self.enclosing_class: SfgClass | None = enclosing_class + def __init__(self) -> None: self._live_variables: dict[str, SfgVar] = dict() - def is_method(self) -> bool: - return self.enclosing_class is not None - - def get_enclosing_class(self) -> SfgClass: - if self.enclosing_class is None: - raise SfgException("Cannot get the enclosing class of a free function.") - - return self.enclosing_class - @property def live_variables(self) -> set[SfgVar]: return set(self._live_variables.values()) @@ -144,8 +133,7 @@ class PostProcessingResult: class CallTreePostProcessing: - def __init__(self, enclosing_class: SfgClass | None = None): - self._enclosing_class = enclosing_class + def __init__(self): self._flattener = FlattenSequences() def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: @@ -174,7 +162,7 @@ class CallTreePostProcessing: def get_live_variables(self, node: SfgCallTreeNode) -> set[SfgVar]: match node: case SfgSequence(): - ppc = self._ppc() + ppc = PostProcessingContext() self.handle_sequence(node, ppc) return ppc.live_variables @@ -191,9 +179,6 @@ class CallTreePostProcessing: set(), ) - def _ppc(self) -> PostProcessingContext: - return PostProcessingContext(enclosing_class=self._enclosing_class) - class SfgDeferredNode(SfgCallTreeNode, ABC): """Nodes of this type are inserted as placeholders into the kernel call tree diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index ea43ac8..07ee848 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -2,149 +2,74 @@ from __future__ import annotations from abc import ABC from enum import Enum, auto -from typing import TYPE_CHECKING, Sequence, Generator, TypeVar +from typing import ( + TYPE_CHECKING, + Sequence, + Generator, + Iterable, + TypeVar, + Generic, +) from dataclasses import replace from itertools import chain from pystencils import CreateKernelConfig, create_kernel, Field -from pystencils.codegen import Kernel, Parameter +from pystencils.codegen import Kernel from pystencils.types import PsType, PsCustomType -from ..lang import SfgVar, HeaderFile, void +from ..lang import SfgVar, SfgKernelParamVar, HeaderFile, void from ..exceptions import SfgException if TYPE_CHECKING: from . import SfgCallTreeNode - from ..context import SfgContext -class SfgEmptyLines: - def __init__(self, lines: int): - self._lines = lines +# ========================================================================================================= +# +# SEMANTICAL ENTITIES +# +# These classes model *code entities*, which represent *semantic components* of the generated files. +# +# ========================================================================================================= - @property - def lines(self) -> int: - return self._lines - - -class SfgHeaderInclude: - """Represent ``#include``-directives.""" - - def __init__( - self, header_file: HeaderFile, private: bool = False - ): - self._header_file = header_file - self._private = private - - @property - def file(self) -> str: - return self._header_file.filepath - - @property - def system_header(self): - return self._header_file.system_header - - @property - def private(self): - return self._private - - def __hash__(self) -> int: - return hash((self._header_file, self._private)) - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, SfgHeaderInclude) - and self._header_file == other._header_file - and self._private == other._private - ) +class SfgCodeEntity: + """Base class for code entities. -class SfgKernelNamespace: - """A namespace grouping together a number of kernels.""" + Each code entity has a name and an optional enclosing namespace. + """ - def __init__(self, ctx: SfgContext, name: str): - self._ctx = ctx + def __init__(self, name: str, namespace: SfgNamespace | None) -> None: self._name = name - self._kernel_functions: dict[str, Kernel] = dict() + self._namespace: SfgNamespace | None = namespace @property - def name(self): + def name(self) -> str: + """Name of this entity""" return self._name @property - def kernel_functions(self): - yield from self._kernel_functions.values() - - def get_kernel_function(self, khandle: SfgKernelHandle) -> Kernel: - if khandle.kernel_namespace is not self: - raise ValueError( - f"Kernel handle does not belong to this namespace: {khandle}" - ) - - return self._kernel_functions[khandle.kernel_name] - - def add(self, kernel: Kernel, name: str | None = None): - """Adds an existing pystencils AST to this namespace. - If a name is specified, the AST's function name is changed.""" - if name is not None: - astname = name + def fqname(self) -> str: + """Fully qualified name of this entity""" + if self._namespace is not None: + return self._namespace.fqname + "::" + self._name else: - astname = kernel.name + return self._name - if astname in self._kernel_functions: - raise ValueError( - f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}" - ) - - if name is not None: - kernel.name = name - - self._kernel_functions[astname] = kernel - - for header in kernel.required_headers: - self._ctx.add_include(SfgHeaderInclude(HeaderFile.parse(header), private=True)) - - return SfgKernelHandle(self._ctx, astname, self, kernel.parameters) - - def create( - self, - assignments, - name: str | None = None, - config: CreateKernelConfig | None = None, - ): - """Creates a new pystencils kernel from a list of assignments and a configuration. - This is a wrapper around `pystencils.create_kernel` - with a subsequent call to `add`. - """ - if config is None: - config = CreateKernelConfig() + @property + def namespace(self) -> SfgNamespace | None: + """Parent namespace of this entity""" + return self._namespace - if name is not None: - if name in self._kernel_functions: - raise ValueError( - f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}" - ) - config = replace(config, function_name=name) - # type: ignore - ast = create_kernel(assignments, config=config) - return self.add(ast) +class SfgKernelHandle(SfgCodeEntity): + """Handle to a pystencils kernel.""" + def __init__(self, name: str, namespace: SfgKernelNamespace, kernel: Kernel): + super().__init__(name, namespace) -class SfgKernelHandle: - """A handle that represents a pystencils kernel within a kernel namespace.""" - - def __init__( - self, - ctx: SfgContext, - name: str, - namespace: SfgKernelNamespace, - parameters: Sequence[Parameter], - ): - self._ctx = ctx - self._name = name - self._namespace = namespace - self._parameters = [SfgKernelParamVar(p) for p in parameters] + self._kernel = kernel + self._parameters = [SfgKernelParamVar(p) for p in kernel.parameters] self._scalar_params: set[SfgVar] = set() self._fields: set[Field] = set() @@ -155,82 +80,49 @@ class SfgKernelHandle: else: self._scalar_params.add(param) - @property - def kernel_name(self): - return self._name - - @property - def kernel_namespace(self): - return self._namespace - - @property - def fully_qualified_name(self): - match self._ctx.fully_qualified_namespace: - case None: - return f"{self.kernel_namespace.name}::{self.kernel_name}" - case fqn: - return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}" - @property def parameters(self) -> Sequence[SfgKernelParamVar]: + """Parameters to this kernel""" return self._parameters @property def scalar_parameters(self) -> set[SfgVar]: + """Scalar parameters to this kernel""" return self._scalar_params @property def fields(self): + """Fields accessed by this kernel""" return self._fields - def get_kernel_function(self) -> Kernel: - return self._namespace.get_kernel_function(self) - - -SymbolLike_T = TypeVar("SymbolLike_T", bound=Parameter) - + def get_kernel(self) -> Kernel: + """Underlying pystencils kernel object""" + return self._kernel -class SfgKernelParamVar(SfgVar): - __match_args__ = ("wrapped",) - """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`.""" - - def __init__(self, param: Parameter): - self._param = param - super().__init__(param.name, param.dtype) - - @property - def wrapped(self) -> Parameter: - return self._param - - def _args(self): - return (self._param,) - - -class SfgFunction: +class SfgFunction(SfgCodeEntity): __match_args__ = ("name", "tree", "parameters") def __init__( self, name: str, + namespace: SfgNamespace | None, tree: SfgCallTreeNode, return_type: PsType = void, - _is_method: bool = False, + inline: bool = False, ): - self._name = name + super().__init__(name, namespace) + self._tree = tree self._return_type = return_type + self._inline = inline self._parameters: set[SfgVar] - if not _is_method: - from .postprocessing import CallTreePostProcessing - param_collector = CallTreePostProcessing() - self._parameters = param_collector(self._tree).function_params + from .postprocessing import CallTreePostProcessing - @property - def name(self) -> str: - return self._name + param_collector = CallTreePostProcessing() + self._parameters = param_collector(self._tree).function_params @property def parameters(self) -> set[SfgVar]: @@ -244,6 +136,10 @@ class SfgFunction: def return_type(self) -> PsType: return self._return_type + @property + def inline(self) -> bool: + return self._inline + class SfgVisibility(Enum): DEFAULT = auto() @@ -308,50 +204,6 @@ class SfgClassMember(ABC): self._vis = vis -class SfgVisibilityBlock: - def __init__(self, visibility: SfgVisibility) -> None: - self._vis = visibility - self._members: list[SfgClassMember] = [] - self._cls: SfgClass | None = None - - @property - def visibility(self) -> SfgVisibility: - return self._vis - - def append_member(self, member: SfgClassMember): - if self._cls is not None: - self._cls._add_member(member, self._vis) - self._members.append(member) - - def members(self) -> Generator[SfgClassMember, None, None]: - yield from self._members - - @property - def is_bound(self) -> bool: - return self._cls is not None - - def _bind(self, cls: SfgClass): - if self._cls is not None: - raise SfgException( - f"Binding visibility block to class {cls.class_name} failed: " - f"was already bound to {self._cls.class_name}" - ) - self._cls = cls - - -class SfgInClassDefinition(SfgClassMember): - def __init__(self, text: str): - SfgClassMember.__init__(self) - self._text = text - - @property - def text(self) -> str: - return self._text - - def __str__(self) -> str: - return self._text - - class SfgMemberVariable(SfgVar, SfgClassMember): def __init__(self, name: str, dtype: PsType): SfgVar.__init__(self, name, dtype) @@ -367,29 +219,16 @@ class SfgMethod(SfgFunction, SfgClassMember): inline: bool = False, const: bool = False, ): - SfgFunction.__init__(self, name, tree, return_type=return_type, _is_method=True) + SfgFunction.__init__(self, name, tree, return_type=return_type, inline=inline) SfgClassMember.__init__(self) - self._inline = inline self._const = const self._parameters: set[SfgVar] = set() - @property - def inline(self) -> bool: - return self._inline - @property def const(self) -> bool: return self._const - def _bind(self, cls: SfgClass, vis: SfgVisibility): - super()._bind(cls, vis) - - from .postprocessing import CallTreePostProcessing - - param_collector = CallTreePostProcessing(enclosing_class=cls) - self._parameters = param_collector(self._tree).function_params - class SfgConstructor(SfgClassMember): __match_args__ = ("parameters", "initializers", "body") @@ -418,7 +257,7 @@ class SfgConstructor(SfgClassMember): return self._body -class SfgClass: +class SfgClass(SfgCodeEntity): """Models a C++ class. ### Adding members to classes @@ -430,23 +269,22 @@ class SfgClass: accessible through the `default` property. To add members with custom visibility, create a new SfgVisibilityBlock, add members to the block, and add the block using `append_visibility_block`. - - A more succinct interface for constructing classes is available through the - [SfgClassComposer][pystencilssfg.composer.SfgClassComposer]. """ __match_args__ = ("class_name",) def __init__( self, - class_name: str, + name: str, + namespace: SfgNamespace | None, class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS, bases: Sequence[str] = (), ): if isinstance(bases, str): raise ValueError("Base classes must be given as a sequence.") - self._class_name = class_name + super().__init__(name, namespace) + self._class_keyword = class_keyword self._bases_classes = tuple(bases) @@ -454,18 +292,14 @@ class SfgClass: self._default_block._bind(self) self._blocks = [self._default_block] - self._definitions: list[SfgInClassDefinition] = [] self._constructors: list[SfgConstructor] = [] self._methods: list[SfgMethod] = [] self._member_vars: dict[str, SfgMemberVariable] = dict() - @property - def class_name(self) -> str: - return self._class_name - @property def src_type(self) -> PsType: - return PsCustomType(self._class_name) + # TODO: Use CppTypeFactory instead + return PsCustomType(self._name) @property def base_classes(self) -> tuple[str, ...]: @@ -504,14 +338,6 @@ class SfgClass: for b in filter(lambda b: b.visibility == visibility, self._blocks) ) - def definitions( - self, visibility: SfgVisibility | None = None - ) -> Generator[SfgInClassDefinition, None, None]: - if visibility is not None: - yield from filter(lambda m: m.visibility == visibility, self._definitions) - else: - yield from self._definitions - def member_variables( self, visibility: SfgVisibility | None = None ) -> Generator[SfgMemberVariable, None, None]: @@ -547,16 +373,11 @@ class SfgClass: self._add_member_variable(member) elif isinstance(member, SfgMethod): self._add_method(member) - elif isinstance(member, SfgInClassDefinition): - self._add_definition(member) else: raise SfgException(f"{member} is not a valid class member.") member._bind(self, vis) - def _add_definition(self, definition: SfgInClassDefinition): - self._definitions.append(definition) - def _add_constructor(self, constr: SfgConstructor): self._constructors.append(constr) @@ -566,7 +387,288 @@ class SfgClass: def _add_member_variable(self, variable: SfgMemberVariable): if variable.name in self._member_vars: raise SfgException( - f"Duplicate field name {variable.name} in class {self._class_name}" + f"Duplicate field name {variable.name} in class {self._name}" ) self._member_vars[variable.name] = variable + + +SourceEntity_T = TypeVar( + "SourceEntity_T", bound=SfgFunction | SfgClassMember | SfgClass, covariant=True +) +"""Source entities that may have declarations and definitions.""" + + +# ========================================================================================================= +# +# SYNTACTICAL ELEMENTS +# +# These classes model *code elements*, which represent the actual syntax objects that populate the output +# files, their namespaces and class bodies. +# +# ========================================================================================================= + + +class SfgEntityDecl(Generic[SourceEntity_T]): + """Declaration of a function, class, method, or constructor""" + + __match_args__ = ("entity",) + + def __init__(self, entity: SourceEntity_T) -> None: + self._entity = entity + + @property + def entity(self) -> SourceEntity_T: + return self._entity + + +class SfgEntityDef(Generic[SourceEntity_T]): + """Definition of a function, class, method, or constructor""" + + __match_args__ = ("entity",) + + def __init__(self, entity: SourceEntity_T) -> None: + self._entity = entity + + @property + def entity(self) -> SourceEntity_T: + return self._entity + + +SfgClassBodyElement = ( + str + | SfgEntityDecl[SfgClassMember] + | SfgEntityDef[SfgClassMember] + | SfgMemberVariable +) +"""Elements that may be placed in the visibility blocks of a class body.""" + + +class SfgVisibilityBlock: + """Visibility-qualified block inside a class definition body. + + Visibility blocks host the code elements placed inside a class body: + method and constructor declarations, + in-class method and constructor definitions, + as well as variable declarations and definitions. + + Args: + visibility: The visibility qualifier of this block + """ + + def __init__(self, visibility: SfgVisibility) -> None: + self._vis = visibility + self._elements: list[SfgClassBodyElement] = [] + self._cls: SfgClass | None = None + + @property + def visibility(self) -> SfgVisibility: + return self._vis + + def append_member(self, element: SfgClassBodyElement): + if isinstance(element, (SfgEntityDecl, SfgEntityDef)): + member = element.entity + elif isinstance(element, SfgClassMember): + member = element + else: + member = None + + if self._cls is not None and member is not None: + self._cls._add_member(member, self._vis) + + self._elements.append(element) + + @property + def elements(self) -> tuple[SfgClassBodyElement, ...]: + return tuple(self._elements) + + def members(self) -> Generator[SfgClassMember, None, None]: + for elem in self._elements: + match elem: + case SfgEntityDecl(entity) | SfgEntityDef(entity): + yield entity + case SfgMemberVariable(): + yield elem + + @property + def is_bound(self) -> bool: + return self._cls is not None + + def _bind(self, cls: SfgClass): + if self._cls is not None: + raise SfgException( + f"Binding visibility block to class {cls.class_name} failed: " + f"was already bound to {self._cls.class_name}" + ) + self._cls = cls + + +class SfgNamespace: + """A C++ namespace. + + Each namespace has a `name` and a `parent`; its fully qualified name is given as + ``<parent.name>::<name>``. + + Args: + name: Local name of this namespace + parent: Parent namespace enclosing this namespace + """ + + def __init__(self, name: str, parent: SfgNamespace | None) -> None: + self._name: str = name + self._parent: SfgNamespace | None = parent + self._elements: list[SfgNamespaceElement] = [] + + @property + def name(self) -> str: + """The name of this namespace""" + return self._name + + @property + def fqname(self) -> str: + """The fully qualified name of this namespace""" + if self._parent is not None: + return self._parent.fqname + "::" + self._name + else: + return self._name + + @property + def elements(self) -> list[SfgNamespaceElement]: + """Sequence of source elements that make up the body of this namespace""" + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgNamespaceElement]): + self._elements = list(elems) + + +class SfgKernelNamespace(SfgNamespace): + """A namespace grouping together a number of kernels.""" + + def __init__(self, name: str, parent: SfgNamespace | None): + super().__init__(name, parent) + self._kernels: dict[str, SfgKernelHandle] = [] + + @property + def name(self): + return self._name + + @property + def kernels(self) -> tuple[SfgKernelHandle, ...]: + return tuple(self._kernels.values()) + + def add(self, kernel: Kernel, name: str | None = None): + """Adds an existing pystencils AST to this namespace. + If a name is specified, the AST's function name is changed.""" + if name is None: + kernel_name = kernel.name + else: + kernel_name = name + + if kernel_name in self._kernels: + raise ValueError( + f"Duplicate kernels: A kernel called {kernel_name} already exists in namespace {self.fqname}" + ) + + if name is not None: + kernel.name = kernel_name + + khandle = SfgKernelHandle(kernel_name, self, kernel) + self._kernels[kernel_name] = khandle + + # TODO: collect includes later + # for header in kernel.required_headers: + # self._ctx.add_include( + # SfgHeaderInclude(HeaderFile.parse(header), private=True) + # ) + + return khandle + + def create( + self, + assignments, + name: str | None = None, + config: CreateKernelConfig | None = None, + ): + """Creates a new pystencils kernel from a list of assignments and a configuration. + This is a wrapper around `pystencils.create_kernel` + with a subsequent call to `add`. + """ + if config is None: + config = CreateKernelConfig() + + if name is not None: + if name in self._kernels: + raise ValueError( + f"Duplicate kernels: A kernel with name {name} already exists in namespace {self.fqname}" + ) + config = replace(config, function_name=name) + + # type: ignore + kernel = create_kernel(assignments, config=config) + return self.add(kernel) + + +SfgNamespaceElement = str | SfgNamespace | SfgEntityDecl | SfgEntityDef +"""Elements that may be placed inside a namespace, including the global namespace.""" + + +class SfgSourceFileType(Enum): + HEADER = auto() + TRANSLATION_UNIT = auto() + + +class SfgSourceFile: + """A C++ source file. + + Args: + name: Name of the file (without parent directories), e.g. ``Algorithms.cpp`` + file_type: Type of the source file (header or translation unit) + prelude: Optionally, text of the prelude comment printed at the top of the file + """ + + def __init__( + self, name: str, file_type: SfgSourceFileType, prelude: str | None = None + ) -> None: + self._name: str = name + self._file_type: SfgSourceFileType = file_type + self._prelude: str | None = prelude + self._includes: list[HeaderFile] = [] + self._elements: list[SfgNamespaceElement] = [] + + @property + def name(self) -> str: + """Name of this source file""" + return self._name + + @property + def file_type(self) -> SfgSourceFileType: + """File type of this source file""" + return self._file_type + + @property + def prelude(self) -> str | None: + """Text of the prelude comment""" + return self._prelude + + @prelude.setter + def prelude(self, text: str | None): + self._prelude = text + + @property + def includes(self) -> list[HeaderFile]: + """Sequence of header files to be included at the top of this file""" + return self._includes + + @includes.setter + def includes(self, incl: Iterable[HeaderFile]): + self._includes = list(incl) + + @property + def elements(self) -> list[SfgNamespaceElement]: + """Sequence of source elements comprising the body of this file""" + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgNamespaceElement]): + self._elements = list(elems) diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 9218ec2..a8de86b 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -2,6 +2,7 @@ from .headers import HeaderFile from .expressions import ( SfgVar, + SfgKernelParamVar, AugExpr, VarLike, _VarLike, @@ -21,6 +22,7 @@ from .types import cpptype, void, Ref, strip_ptr_ref __all__ = [ "HeaderFile", "SfgVar", + "SfgKernelParamVar", "AugExpr", "VarLike", "_VarLike", diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index f86140e..4a1f7e9 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod import sympy as sp from pystencils import TypedSymbol +from pystencils.codegen import Parameter from pystencils.types import PsType, UserTypeSpec, create_type from ..exceptions import SfgException @@ -74,6 +75,23 @@ class SfgVar: return self.name_and_type() +class SfgKernelParamVar(SfgVar): + __match_args__ = ("wrapped",) + + """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`.""" + + def __init__(self, param: Parameter): + self._param = param + super().__init__(param.name, param.dtype) + + @property + def wrapped(self) -> Parameter: + return self._param + + def _args(self): + return (self._param,) + + class DependentExpression: """Wrapper around a C++ expression code string, annotated with a set of variables and a set of header files this expression depends on. -- GitLab From 2eb93b6e411b7427aa84a820e02cb3b2008de5e0 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 6 Feb 2025 16:07:17 +0100 Subject: [PATCH 02/18] split namespaces into an entity and an element part --- src/pystencilssfg/ir/source_components.py | 225 +++++++++++++--------- 1 file changed, 132 insertions(+), 93 deletions(-) diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 07ee848..d98e2d8 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -39,9 +39,9 @@ class SfgCodeEntity: Each code entity has a name and an optional enclosing namespace. """ - def __init__(self, name: str, namespace: SfgNamespace | None) -> None: + def __init__(self, name: str, parent_namespace: SfgNamespace | None) -> None: self._name = name - self._namespace: SfgNamespace | None = namespace + self._namespace: SfgNamespace | None = parent_namespace @property def name(self) -> str: @@ -57,11 +57,23 @@ class SfgCodeEntity: return self._name @property - def namespace(self) -> SfgNamespace | None: + def parent_namespace(self) -> SfgNamespace | None: """Parent namespace of this entity""" return self._namespace +class SfgNamespace(SfgCodeEntity): + """A C++ namespace. + + Each namespace has a `name` and a `parent`; its fully qualified name is given as + ``<parent.name>::<name>``. + + Args: + name: Local name of this namespace + parent: Parent namespace enclosing this namespace + """ + + class SfgKernelHandle(SfgCodeEntity): """Handle to a pystencils kernel.""" @@ -100,7 +112,76 @@ class SfgKernelHandle(SfgCodeEntity): return self._kernel +class SfgKernelNamespace(SfgNamespace): + """A namespace grouping together a number of kernels.""" + + def __init__(self, name: str, parent: SfgNamespace | None): + super().__init__(name, parent) + self._kernels: dict[str, SfgKernelHandle] = [] + + @property + def name(self): + return self._name + + @property + def kernels(self) -> tuple[SfgKernelHandle, ...]: + return tuple(self._kernels.values()) + + def add(self, kernel: Kernel, name: str | None = None): + """Adds an existing pystencils AST to this namespace. + If a name is specified, the AST's function name is changed.""" + if name is None: + kernel_name = kernel.name + else: + kernel_name = name + + if kernel_name in self._kernels: + raise ValueError( + f"Duplicate kernels: A kernel called {kernel_name} already exists in namespace {self.fqname}" + ) + + if name is not None: + kernel.name = kernel_name + + khandle = SfgKernelHandle(kernel_name, self, kernel) + self._kernels[kernel_name] = khandle + + # TODO: collect includes later + # for header in kernel.required_headers: + # self._ctx.add_include( + # SfgHeaderInclude(HeaderFile.parse(header), private=True) + # ) + + return khandle + + def create( + self, + assignments, + name: str | None = None, + config: CreateKernelConfig | None = None, + ): + """Creates a new pystencils kernel from a list of assignments and a configuration. + This is a wrapper around `pystencils.create_kernel` + with a subsequent call to `add`. + """ + if config is None: + config = CreateKernelConfig() + + if name is not None: + if name in self._kernels: + raise ValueError( + f"Duplicate kernels: A kernel with name {name} already exists in namespace {self.fqname}" + ) + config = replace(config, function_name=name) + + # type: ignore + kernel = create_kernel(assignments, config=config) + return self.add(kernel) + + class SfgFunction(SfgCodeEntity): + """A free function.""" + __match_args__ = ("name", "tree", "parameters") def __init__( @@ -142,6 +223,8 @@ class SfgFunction(SfgCodeEntity): class SfgVisibility(Enum): + """Visibility qualifiers of C++""" + DEFAULT = auto() PRIVATE = auto() PROTECTED = auto() @@ -160,6 +243,8 @@ class SfgVisibility(Enum): class SfgClassKeyword(Enum): + """Class keywords of C++""" + STRUCT = auto() CLASS = auto() @@ -172,6 +257,8 @@ class SfgClassKeyword(Enum): class SfgClassMember(ABC): + """Base class for class member entities""" + def __init__(self) -> None: self._cls: SfgClass | None = None self._visibility: SfgVisibility | None = None @@ -205,25 +292,56 @@ class SfgClassMember(ABC): class SfgMemberVariable(SfgVar, SfgClassMember): + """Variable that is a field of a class""" + def __init__(self, name: str, dtype: PsType): SfgVar.__init__(self, name, dtype) SfgClassMember.__init__(self) -class SfgMethod(SfgFunction, SfgClassMember): +class SfgMethod(SfgClassMember): + """Instance method of a class""" + + __match_args__ = ("name", "tree", "parameters") + def __init__( self, name: str, tree: SfgCallTreeNode, - return_type: PsType = PsCustomType("void"), + return_type: PsType = void, inline: bool = False, const: bool = False, ): - SfgFunction.__init__(self, name, tree, return_type=return_type, inline=inline) - SfgClassMember.__init__(self) + super().__init__() + self._name = name + self._tree = tree + self._return_type = return_type + self._inline = inline self._const = const - self._parameters: set[SfgVar] = set() + + self._parameters: set[SfgVar] + + from .postprocessing import CallTreePostProcessing + + param_collector = CallTreePostProcessing() + self._parameters = param_collector(self._tree).function_params + + @property + def parameters(self) -> set[SfgVar]: + return self._parameters + + @property + def tree(self) -> SfgCallTreeNode: + return self._tree + + @property + def return_type(self) -> PsType: + return self._return_type + + @property + def inline(self) -> bool: + return self._inline @property def const(self) -> bool: @@ -231,6 +349,8 @@ class SfgMethod(SfgFunction, SfgClassMember): class SfgConstructor(SfgClassMember): + """Constructor of a class""" + __match_args__ = ("parameters", "initializers", "body") def __init__( @@ -503,7 +623,7 @@ class SfgVisibilityBlock: self._cls = cls -class SfgNamespace: +class SfgNamespaceDef: """A C++ namespace. Each namespace has a `name` and a `parent`; its fully qualified name is given as @@ -514,24 +634,10 @@ class SfgNamespace: parent: Parent namespace enclosing this namespace """ - def __init__(self, name: str, parent: SfgNamespace | None) -> None: - self._name: str = name - self._parent: SfgNamespace | None = parent + def __init__(self, namespace: SfgNamespace) -> None: + self._namespace = namespace self._elements: list[SfgNamespaceElement] = [] - @property - def name(self) -> str: - """The name of this namespace""" - return self._name - - @property - def fqname(self) -> str: - """The fully qualified name of this namespace""" - if self._parent is not None: - return self._parent.fqname + "::" + self._name - else: - return self._name - @property def elements(self) -> list[SfgNamespaceElement]: """Sequence of source elements that make up the body of this namespace""" @@ -542,74 +648,7 @@ class SfgNamespace: self._elements = list(elems) -class SfgKernelNamespace(SfgNamespace): - """A namespace grouping together a number of kernels.""" - - def __init__(self, name: str, parent: SfgNamespace | None): - super().__init__(name, parent) - self._kernels: dict[str, SfgKernelHandle] = [] - - @property - def name(self): - return self._name - - @property - def kernels(self) -> tuple[SfgKernelHandle, ...]: - return tuple(self._kernels.values()) - - def add(self, kernel: Kernel, name: str | None = None): - """Adds an existing pystencils AST to this namespace. - If a name is specified, the AST's function name is changed.""" - if name is None: - kernel_name = kernel.name - else: - kernel_name = name - - if kernel_name in self._kernels: - raise ValueError( - f"Duplicate kernels: A kernel called {kernel_name} already exists in namespace {self.fqname}" - ) - - if name is not None: - kernel.name = kernel_name - - khandle = SfgKernelHandle(kernel_name, self, kernel) - self._kernels[kernel_name] = khandle - - # TODO: collect includes later - # for header in kernel.required_headers: - # self._ctx.add_include( - # SfgHeaderInclude(HeaderFile.parse(header), private=True) - # ) - - return khandle - - def create( - self, - assignments, - name: str | None = None, - config: CreateKernelConfig | None = None, - ): - """Creates a new pystencils kernel from a list of assignments and a configuration. - This is a wrapper around `pystencils.create_kernel` - with a subsequent call to `add`. - """ - if config is None: - config = CreateKernelConfig() - - if name is not None: - if name in self._kernels: - raise ValueError( - f"Duplicate kernels: A kernel with name {name} already exists in namespace {self.fqname}" - ) - config = replace(config, function_name=name) - - # type: ignore - kernel = create_kernel(assignments, config=config) - return self.add(kernel) - - -SfgNamespaceElement = str | SfgNamespace | SfgEntityDecl | SfgEntityDef +SfgNamespaceElement = str | SfgNamespaceDef | SfgEntityDecl | SfgEntityDef """Elements that may be placed inside a namespace, including the global namespace.""" -- GitLab From fa347fe5934e9adcf0b7404f4de50cd742306f86 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 6 Feb 2025 17:22:30 +0100 Subject: [PATCH 03/18] updating the composer WIP --- src/pystencilssfg/composer/basic_composer.py | 210 +++++++++------ src/pystencilssfg/context.py | 261 +++++-------------- src/pystencilssfg/ir/source_components.py | 21 +- 3 files changed, 215 insertions(+), 277 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index db671b9..c0b420f 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -6,7 +6,7 @@ import sympy as sp from functools import reduce from warnings import warn -from pystencils import Field +from pystencils import Field, CreateKernelConfig, create_kernel from pystencils.codegen import Kernel from pystencils.types import create_type, UserTypeSpec @@ -31,13 +31,15 @@ from ..ir.postprocessing import ( ) from ..ir.source_components import ( SfgFunction, - SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle, SfgClass, SfgConstructor, SfgMemberVariable, SfgClassKeyword, + SfgEntityDecl, + SfgEntityDef, + SfgNamespaceBlock, ) from ..lang import ( VarLike, @@ -61,6 +63,7 @@ from ..exceptions import SfgException class SfgIComposer(ABC): def __init__(self, ctx: SfgContext): self._ctx = ctx + self._cursor = ctx.cursor @property def context(self): @@ -80,6 +83,66 @@ SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder """Valid arguments to `make_sequence` and any sequencer that uses it.""" +class KernelsAdder: + def __init__(self, ctx: SfgContext, loc: SfgNamespaceBlock): + self._ctx = ctx + self._loc = SfgNamespaceBlock + assert isinstance(loc.namespace, SfgKernelNamespace) + self._kernel_namespace = loc.namespace + + def add(self, kernel: Kernel, name: str | None = None): + """Adds an existing pystencils AST to this namespace. + If a name is specified, the AST's function name is changed.""" + if name is None: + kernel_name = kernel.name + else: + kernel_name = name + + if self._kernel_namespace.find_kernel(kernel_name) is not None: + raise ValueError( + f"Duplicate kernels: A kernel called {kernel_name} already exists " + f"in namespace {self._kernel_namespace.fqname}" + ) + + if name is not None: + kernel.name = kernel_name + + khandle = SfgKernelHandle(kernel_name, self._kernel_namespace, kernel) + self._kernel_namespace.add_kernel(khandle) + + for header in kernel.required_headers: + # TODO: Find current source file by traversing namespace blocks upward? + self._ctx.impl_file.includes.append(HeaderFile.parse(header)) + + return khandle + + def create( + self, + assignments, + name: str | None = None, + config: CreateKernelConfig | None = None, + ): + """Creates a new pystencils kernel from a list of assignments and a configuration. + This is a wrapper around `pystencils.create_kernel` + with a subsequent call to `add`. + """ + if config is None: + config = CreateKernelConfig() + + if name is not None: + if self._kernel_namespace.find_kernel(name) is not None: + raise ValueError( + f"Duplicate kernels: A kernel called {name} already exists " + f"in namespace {self._kernel_namespace.fqname}" + ) + + config.function_name = name + + # type: ignore + kernel = create_kernel(assignments, config=config) + return self.add(kernel) + + class SfgBasicComposer(SfgIComposer): """Composer for basic source components, and base class for all composer mix-ins.""" @@ -87,7 +150,7 @@ class SfgBasicComposer(SfgIComposer): ctx: SfgContext = sfg if isinstance(sfg, SfgContext) else sfg.context super().__init__(ctx) - def prelude(self, content: str): + def prelude(self, content: str, end: str = "\n"): """Append a string to the prelude comment, to be printed at the top of both generated files. The string should not contain C/C++ comment delimiters, since these will be added automatically @@ -105,7 +168,11 @@ class SfgBasicComposer(SfgIComposer): */ """ - self._ctx.append_to_prelude(content) + for f in self._ctx.files: + if f.prelude is None: + f.prelude = content + end + else: + f.prelude += content + end def code(self, *code: str): """Add arbitrary lines of code to the generated header file. @@ -126,7 +193,7 @@ class SfgBasicComposer(SfgIComposer): """ for c in code: - self._ctx.add_definition(c) + self._cursor.write_header(c) def define(self, *definitions: str): from warnings import warn @@ -139,34 +206,9 @@ class SfgBasicComposer(SfgIComposer): self.code(*definitions) - def define_once(self, *definitions: str): - """Add unique definitions to the header file. - - Each code string given to `define_once` will only be added if the exact same string - was not already added before. - """ - for definition in definitions: - if all(d != definition for d in self._ctx.definitions()): - self._ctx.add_definition(definition) - def namespace(self, namespace: str): - """Set the inner code namespace. Throws an exception if a namespace was already set. - - :Example: - - After adding the following to your generator script: - - >>> sfg.namespace("codegen_is_awesome") - - All generated code will be placed within that namespace: - - .. code-block:: C++ - - namespace codegen_is_awesome { - /* all generated code */ - } - """ - self._ctx.set_namespace(namespace) + # TODO: Enter into a new namespace context + raise NotImplementedError() def generate(self, generator: CustomGenerator): """Invoke a custom code generator with the underlying context.""" @@ -183,18 +225,16 @@ class SfgBasicComposer(SfgIComposer): sfg.kernels.add(ast, "kernel_name") sfg.kernels.create(assignments, "kernel_name", config) """ - return self._ctx._default_kernel_namespace + return self.kernel_namespace("kernels") def kernel_namespace(self, name: str) -> SfgKernelNamespace: """Return the kernel namespace of the given name, creating it if it does not exist yet.""" - kns = self._ctx.get_kernel_namespace(name) - if kns is None: - kns = SfgKernelNamespace(self._ctx, name) - self._ctx.add_kernel_namespace(kns) + # TODO: Find the default kernel namespace as a child entity of the current + # namespace, or create it if it does not exist + # Then create a new namespace block, place it at the cursor position, and expose + # it to the user via an adder - return kns - - def include(self, header_file: str, private: bool = False): + def include(self, header_file: str | HeaderFile, private: bool = False): """Include a header file. Args: @@ -214,7 +254,14 @@ class SfgBasicComposer(SfgIComposer): #include <vector> #include "custom.h" """ - self._ctx.add_include(SfgHeaderInclude(HeaderFile.parse(header_file), private)) + header_file = HeaderFile.parse(header_file) + + if private: + if self._ctx.impl_file is None: + raise ValueError("Cannot emit a private include since no implementation file is being generated") + self._ctx.impl_file.includes.append(header_file) + else: + self._ctx.header_file.includes.append(header_file) def numpy_struct( self, name: str, dtype: np.dtype, add_constructor: bool = True @@ -224,11 +271,9 @@ class SfgBasicComposer(SfgIComposer): Returns: The created class object """ - if self._ctx.get_class(name) is not None: - raise SfgException(f"Class with name {name} already exists.") - - cls = _struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) - self._ctx.add_class(cls) + cls = self._struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) + self._ctx.add_entity(cls) + self._cursor.write_header(SfgEntityDecl(cls)) return cls def kernel_function( @@ -281,15 +326,18 @@ class SfgBasicComposer(SfgIComposer): ) returns = return_type - if self._ctx.get_function(name) is not None: - raise ValueError(f"Function {name} already exists.") - def sequencer(*args: SequencerArg): tree = make_sequence(*args) func = SfgFunction( - name, tree, return_type=create_type(returns), inline=inline + name, self._cursor.current_namespace, tree, return_type=create_type(returns), inline=inline ) - self._ctx.add_function(func) + self._ctx.add_entity(func) + + if inline: + self._cursor.write_header(SfgEntityDef(func)) + else: + self._cursor.write_header(SfgEntityDecl(func)) + self._cursor.write_impl(SfgEntityDef(func)) return sequencer @@ -482,6 +530,36 @@ class SfgBasicComposer(SfgIComposer): (asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components ] return SfgDeferredVectorMapping(components, rhs) + + def _struct_from_numpy_dtype( + self, struct_name: str, dtype: np.dtype, add_constructor: bool = True + ): + cls = SfgClass(struct_name, self._cursor.current_namespace, class_keyword=SfgClassKeyword.STRUCT) + + fields = dtype.fields + if fields is None: + raise SfgException(f"Numpy dtype {dtype} is not a structured type.") + + constr_params = [] + constr_inits = [] + + for member_name, type_info in fields.items(): + member_type = create_type(type_info[0]) + + member = SfgMemberVariable(member_name, member_type) + + arg = SfgVar(f"{member_name}_", member_type) + + cls.default.append_member(member) + + constr_params.append(arg) + constr_inits.append(f"{member}({arg})") + + if add_constructor: + cls.default.append_member(SfgEntityDef(SfgConstructor(constr_params, constr_inits))) + + return cls + def make_statements(arg: ExprLike) -> SfgStatements: @@ -628,33 +706,3 @@ class SfgSwitchBuilder(SfgNodeBuilder): def resolve(self) -> SfgCallTreeNode: return SfgSwitch(make_statements(self._switch_arg), self._cases, self._default) - - -def _struct_from_numpy_dtype( - struct_name: str, dtype: np.dtype, add_constructor: bool = True -): - cls = SfgClass(struct_name, class_keyword=SfgClassKeyword.STRUCT) - - fields = dtype.fields - if fields is None: - raise SfgException(f"Numpy dtype {dtype} is not a structured type.") - - constr_params = [] - constr_inits = [] - - for member_name, type_info in fields.items(): - member_type = create_type(type_info[0]) - - member = SfgMemberVariable(member_name, member_type) - - arg = SfgVar(f"{member_name}_", member_type) - - cls.default.append_member(member) - - constr_params.append(arg) - constr_inits.append(f"{member}({arg})") - - if add_constructor: - cls.default.append_member(SfgConstructor(constr_params, constr_inits)) - - return cls diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 17537a2..577dcbd 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -1,83 +1,45 @@ -from typing import Generator, Sequence, Any +from __future__ import annotations +from typing import Sequence, Any, Generator from .config import CodeStyle from .ir.source_components import ( - SfgHeaderInclude, + SfgSourceFile, + SfgNamespace, SfgKernelNamespace, - SfgFunction, + SfgNamespaceBlock, + SfgNamespaceElement, + SfgCodeEntity, SfgClass, ) from .exceptions import SfgException class SfgContext: - """Represents a header/implementation file pair in the code generator. - - **Source File Properties and Components** - - The SfgContext collects all properties and components of a header/implementation - file pair (or just the header file, if header-only generation is used). - These are: - - - The code namespace, which is combined from the `outer_namespace` - and the `pystencilssfg.SfgContext.inner_namespace`. The outer namespace is meant to be set - externally e.g. by the project configuration, while the inner namespace is meant to be set by the generator - script. - - The `prelude comment` is a block of text printed as a comment block - at the top of both generated files. Typically, it contains authorship and licence information. - - The set of included header files (`pystencilssfg.SfgContext.includes`). - - Custom `definitions`, which are just arbitrary code strings. - - Any number of kernel namespaces (`pystencilssfg.SfgContext.kernel_namespaces`), within which *pystencils* - kernels are managed. - - Any number of functions (`pystencilssfg.SfgContext.functions`), which are meant to serve as wrappers - around kernel calls. - - Any number of classes (`pystencilssfg.SfgContext.classes`), which can be used to build more extensive wrappers - around kernels. - - **Order of Definitions** - - To honor C/C++ use-after-declare rules, the context preserves the order in which definitions, functions and classes - are added to it. - The header file printers implemented in *pystencils-sfg* will print the declarations accordingly. - The declarations can retrieved in order of definition via `declarations_ordered`. - """ + """Manages context information during the execution of a generator script.""" def __init__( self, + header_file: SfgSourceFile, + impl_file: SfgSourceFile, outer_namespace: str | None = None, codestyle: CodeStyle | None = None, argv: Sequence[str] | None = None, project_info: Any = None, ): - """ - Args: - outer_namespace: Qualified name of the outer code namespace - codestyle: Code style that should be used by the code emitter - argv: The generator script's command line arguments. - Reserved for internal use by the [SourceFileGenerator][pystencilssfg.SourceFileGenerator]. - project_info: Project-specific information provided by a build system. - Reserved for internal use by the [SourceFileGenerator][pystencilssfg.SourceFileGenerator]. - """ self._argv = argv self._project_info = project_info - self._default_kernel_namespace = SfgKernelNamespace(self, "kernels") self._outer_namespace = outer_namespace self._inner_namespace: str | None = None self._codestyle = codestyle if codestyle is not None else CodeStyle() - # Source Components - self._prelude: str = "" - self._includes: list[SfgHeaderInclude] = [] - self._definitions: list[str] = [] - self._kernel_namespaces = { - self._default_kernel_namespace.name: self._default_kernel_namespace - } - self._functions: dict[str, SfgFunction] = dict() - self._classes: dict[str, SfgClass] = dict() + self._header_file = header_file + self._impl_file = impl_file - self._declarations_ordered: list[str | SfgFunction | SfgClass] = list() + self._entities: dict[str, SfgCodeEntity] = dict() + + self._cursor: SfgCursor @property def argv(self) -> Sequence[str]: @@ -100,163 +62,74 @@ class SfgContext: """Outer code namespace. Set by constructor argument `outer_namespace`.""" return self._outer_namespace - @property - def inner_namespace(self) -> str | None: - """Inner code namespace. Set by `set_namespace`.""" - return self._inner_namespace - - @property - def fully_qualified_namespace(self) -> str | None: - """Combined outer and inner namespaces, as `outer_namespace::inner_namespace`.""" - match (self.outer_namespace, self.inner_namespace): - case None, None: - return None - case outer, None: - return outer - case None, inner: - return inner - case outer, inner: - return f"{outer}::{inner}" - case _: - assert False - @property def codestyle(self) -> CodeStyle: """The code style object for this generation context.""" return self._codestyle - # ---------------------------------------------------------------------------------------------- - # Prelude, Includes, Definitions, Namespace - # ---------------------------------------------------------------------------------------------- - @property - def prelude_comment(self) -> str: - """The prelude is a comment block printed at the top of both generated files.""" - return self._prelude - - def append_to_prelude(self, code_str: str): - """Append a string to the prelude comment. - - The string should not contain - C/C++ comment delimiters, since these will be added automatically during - code generation. - """ - if self._prelude: - self._prelude += "\n" - - self._prelude += code_str - - if not code_str.endswith("\n"): - self._prelude += "\n" - - def includes(self) -> Generator[SfgHeaderInclude, None, None]: - """Includes of headers. Public includes are added to the header file, private includes - are added to the implementation file.""" - yield from self._includes - - def add_include(self, include: SfgHeaderInclude): - self._includes.append(include) - - def definitions(self) -> Generator[str, None, None]: - """Definitions are arbitrary custom lines of code.""" - yield from self._definitions - - def add_definition(self, definition: str): - """Add a custom code string to the header file.""" - self._definitions.append(definition) - self._declarations_ordered.append(definition) - - def set_namespace(self, namespace: str): - """Set the inner code namespace. - - Throws an exception if the namespace was already set. - """ - if self._inner_namespace is not None: - raise SfgException("The code namespace was already set.") - - self._inner_namespace = namespace - - # ---------------------------------------------------------------------------------------------- - # Kernel Namespaces - # ---------------------------------------------------------------------------------------------- + def header_file(self) -> SfgSourceFile: + return self._header_file @property - def default_kernel_namespace(self) -> SfgKernelNamespace: - """The default kernel namespace.""" - return self._default_kernel_namespace - - def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: - """Iterator over all registered kernel namespaces.""" - yield from self._kernel_namespaces.values() - - def get_kernel_namespace(self, str) -> SfgKernelNamespace | None: - """Retrieve a kernel namespace by name, or `None` if it does not exist.""" - return self._kernel_namespaces.get(str) - - def add_kernel_namespace(self, namespace: SfgKernelNamespace): - """Adds a new kernel namespace. - - If a kernel namespace of the same name already exists, throws an exception. - """ - if namespace.name in self._kernel_namespaces: - raise ValueError(f"Duplicate kernel namespace: {namespace.name}") + def impl_file(self) -> SfgSourceFile | None: + return self._impl_file - self._kernel_namespaces[namespace.name] = namespace - - # ---------------------------------------------------------------------------------------------- - # Functions - # ---------------------------------------------------------------------------------------------- - - def functions(self) -> Generator[SfgFunction, None, None]: - """Iterator over all registered functions.""" - yield from self._functions.values() - - def get_function(self, name: str) -> SfgFunction | None: - """Retrieve a function by name. Returns `None` if no function of the given name exists.""" - return self._functions.get(name, None) - - def add_function(self, func: SfgFunction): - """Adds a new function. - - If a function or class with the same name exists already, throws an exception. - """ - if func.name in self._functions or func.name in self._classes: - raise SfgException(f"Duplicate function: {func.name}") + @property + def cursor(self) -> SfgCursor: + return self._cursor - self._functions[func.name] = func - self._declarations_ordered.append(func) + @property + def files(self) -> Generator[SfgSourceFile, None, None]: + yield self._header_file + if self._impl_file is not None: + yield self._impl_file - # ---------------------------------------------------------------------------------------------- - # Classes - # ---------------------------------------------------------------------------------------------- + def get_entity(self, fqname: str) -> SfgCodeEntity | None: + # TODO: Only track top-level entities here, traverse namespaces to find qualified entities + return self._entities.get(fqname, None) - def classes(self) -> Generator[SfgClass, None, None]: - """Iterator over all registered classes.""" - yield from self._classes.values() + def add_entity(self, entity: SfgCodeEntity) -> None: + fqname = entity.fqname + if fqname in self._entities: + raise ValueError(f"Another entity with name {fqname} already exists") + self._entities[fqname] = entity - def get_class(self, name: str) -> SfgClass | None: - """Retrieve a class by name, or `None` if the class does not exist.""" - return self._classes.get(name, None) - def add_class(self, cls: SfgClass): - """Add a class. +class SfgCursor: + """Cursor that tracks the current location in the source file(s) during execution of the generator script.""" - Throws an exception if a class or function of the same name exists already. - """ - if cls.class_name in self._classes or cls.class_name in self._functions: - raise SfgException(f"Duplicate class: {cls.class_name}") + def __init__(self, ctx: SfgContext, namespace: str | None = None) -> None: + self._ctx = ctx - self._classes[cls.class_name] = cls - self._declarations_ordered.append(cls) + self._cur_namespace: SfgNamespace | None + if namespace is not None: + self._cur_namespace = ctx.get_namespace(namespace) + else: + self._cur_namespace = None - # ---------------------------------------------------------------------------------------------- - # Declarations in order of addition - # ---------------------------------------------------------------------------------------------- + self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] + for f in self._ctx.files: + if self._cur_namespace is not None: + block = SfgNamespaceBlock(self._cur_namespace) + f.elements.append(block) + self._loc[f] = block.elements + else: + self._loc[f] = f.elements - def declarations_ordered( - self, - ) -> Generator[str | SfgFunction | SfgClass, None, None]: - """All declared definitions, classes and functions in the order they were added. + # TODO: Enter and exit namespace blocks - Awareness about order is necessary due to the C++ declare-before-use rules.""" - yield from self._declarations_ordered + @property + def current_namespace(self) -> SfgNamespace | None: + return self._cur_namespace + + def write_header(self, elem: SfgNamespaceElement) -> None: + self._loc[self._ctx.header_file].append(elem) + + def write_impl(self, elem: SfgNamespaceElement) -> None: + impl_file = self._ctx.impl_file + if impl_file is None: + raise SfgException( + f"Cannot write element {elem} to implemenation file since no implementation file is being generated." + ) + self._loc[impl_file].append(elem) diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index d98e2d8..d8fb0f5 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -73,6 +73,8 @@ class SfgNamespace(SfgCodeEntity): parent: Parent namespace enclosing this namespace """ + # TODO: Namespaces must keep track of their child entities + class SfgKernelHandle(SfgCodeEntity): """Handle to a pystencils kernel.""" @@ -127,6 +129,17 @@ class SfgKernelNamespace(SfgNamespace): def kernels(self) -> tuple[SfgKernelHandle, ...]: return tuple(self._kernels.values()) + def find_kernel(self, name: str) -> SfgKernelHandle | None: + return self._kernels.get(name, None) + + def add_kernel(self, kernel: SfgKernelHandle): + if kernel.name in self._kernels: + raise ValueError( + f"Duplicate kernels: A kernel called {kernel.name} already exists " + f"in namespace {self.fqname}" + ) + self._kernels[kernel.name] = kernel + def add(self, kernel: Kernel, name: str | None = None): """Adds an existing pystencils AST to this namespace. If a name is specified, the AST's function name is changed.""" @@ -623,7 +636,7 @@ class SfgVisibilityBlock: self._cls = cls -class SfgNamespaceDef: +class SfgNamespaceBlock: """A C++ namespace. Each namespace has a `name` and a `parent`; its fully qualified name is given as @@ -638,6 +651,10 @@ class SfgNamespaceDef: self._namespace = namespace self._elements: list[SfgNamespaceElement] = [] + @property + def namespace(self) -> SfgNamespace: + return self._namespace + @property def elements(self) -> list[SfgNamespaceElement]: """Sequence of source elements that make up the body of this namespace""" @@ -648,7 +665,7 @@ class SfgNamespaceDef: self._elements = list(elems) -SfgNamespaceElement = str | SfgNamespaceDef | SfgEntityDecl | SfgEntityDef +SfgNamespaceElement = str | SfgNamespaceBlock | SfgEntityDecl | SfgEntityDef """Elements that may be placed inside a namespace, including the global namespace.""" -- GitLab From fc465a22a38028c982c170e92401fadaf17d07a8 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 6 Feb 2025 20:20:51 +0100 Subject: [PATCH 04/18] simplify class structure modelling & update class composer --- src/pystencilssfg/composer/basic_composer.py | 74 ++++--- src/pystencilssfg/composer/class_composer.py | 199 +++++++++++-------- src/pystencilssfg/composer/mixin.py | 3 +- src/pystencilssfg/context.py | 40 ++-- src/pystencilssfg/ir/analysis.py | 4 - src/pystencilssfg/ir/call_tree.py | 6 +- src/pystencilssfg/ir/source_components.py | 125 +++++------- 7 files changed, 239 insertions(+), 212 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index c0b420f..936d5c5 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -111,7 +111,7 @@ class KernelsAdder: self._kernel_namespace.add_kernel(khandle) for header in kernel.required_headers: - # TODO: Find current source file by traversing namespace blocks upward? + assert self._ctx.impl_file is not None self._ctx.impl_file.includes.append(HeaderFile.parse(header)) return khandle @@ -217,7 +217,7 @@ class SfgBasicComposer(SfgIComposer): generator.generate(SfgComposer(self)) @property - def kernels(self) -> SfgKernelNamespace: + def kernels(self) -> KernelsAdder: """The default kernel namespace. Add kernels like:: @@ -227,12 +227,20 @@ class SfgBasicComposer(SfgIComposer): """ return self.kernel_namespace("kernels") - def kernel_namespace(self, name: str) -> SfgKernelNamespace: + def kernel_namespace(self, name: str) -> KernelsAdder: """Return the kernel namespace of the given name, creating it if it does not exist yet.""" - # TODO: Find the default kernel namespace as a child entity of the current - # namespace, or create it if it does not exist - # Then create a new namespace block, place it at the cursor position, and expose - # it to the user via an adder + kns = self._cursor.get_entity("kernels") + if kns is None: + kns = SfgKernelNamespace("kernels", self._cursor.current_namespace) + self._cursor.add_entity(kns) + elif not isinstance(kns, SfgKernelNamespace): + raise ValueError( + f"The existing entity {kns.fqname} is not a kernel namespace" + ) + + kns_block = SfgNamespaceBlock(kns) + self._cursor.write_impl(kns_block) + return KernelsAdder(self._ctx, kns_block) def include(self, header_file: str | HeaderFile, private: bool = False): """Include a header file. @@ -258,7 +266,9 @@ class SfgBasicComposer(SfgIComposer): if private: if self._ctx.impl_file is None: - raise ValueError("Cannot emit a private include since no implementation file is being generated") + raise ValueError( + "Cannot emit a private include since no implementation file is being generated" + ) self._ctx.impl_file.includes.append(header_file) else: self._ctx.header_file.includes.append(header_file) @@ -271,32 +281,23 @@ class SfgBasicComposer(SfgIComposer): Returns: The created class object """ - cls = self._struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) - self._ctx.add_entity(cls) + cls = self._struct_from_numpy_dtype( + name, dtype, add_constructor=add_constructor + ) + self._cursor.add_entity(cls) self._cursor.write_header(SfgEntityDecl(cls)) return cls - def kernel_function( - self, name: str, ast_or_kernel_handle: Kernel | SfgKernelHandle - ): + def kernel_function(self, name: str, kernel: Kernel | SfgKernelHandle): """Create a function comprising just a single kernel call. Args: ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST. """ - if self._ctx.get_function(name) is not None: - raise ValueError(f"Function {name} already exists.") - - if isinstance(ast_or_kernel_handle, Kernel): - khandle = self._ctx.default_kernel_namespace.add(ast_or_kernel_handle) - tree = SfgKernelCallNode(khandle) - elif isinstance(ast_or_kernel_handle, SfgKernelHandle): - tree = SfgKernelCallNode(ast_or_kernel_handle) - else: - raise TypeError("Invalid type of argument `ast_or_kernel_handle`!") + if isinstance(kernel, Kernel): + kernel = self.kernels.add(kernel, name) - func = SfgFunction(name, tree) - self._ctx.add_function(func) + self.function(name)(self.call(kernel)) def function( self, @@ -329,10 +330,14 @@ class SfgBasicComposer(SfgIComposer): def sequencer(*args: SequencerArg): tree = make_sequence(*args) func = SfgFunction( - name, self._cursor.current_namespace, tree, return_type=create_type(returns), inline=inline + name, + self._cursor.current_namespace, + tree, + return_type=create_type(returns), + inline=inline, ) - self._ctx.add_entity(func) - + self._cursor.add_entity(func) + if inline: self._cursor.write_header(SfgEntityDef(func)) else: @@ -530,11 +535,15 @@ class SfgBasicComposer(SfgIComposer): (asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components ] return SfgDeferredVectorMapping(components, rhs) - + def _struct_from_numpy_dtype( self, struct_name: str, dtype: np.dtype, add_constructor: bool = True ): - cls = SfgClass(struct_name, self._cursor.current_namespace, class_keyword=SfgClassKeyword.STRUCT) + cls = SfgClass( + struct_name, + self._cursor.current_namespace, + class_keyword=SfgClassKeyword.STRUCT, + ) fields = dtype.fields if fields is None: @@ -556,12 +565,13 @@ class SfgBasicComposer(SfgIComposer): constr_inits.append(f"{member}({arg})") if add_constructor: - cls.default.append_member(SfgEntityDef(SfgConstructor(constr_params, constr_inits))) + cls.default.append_member( + SfgEntityDef(SfgConstructor(constr_params, constr_inits)) + ) return cls - def make_statements(arg: ExprLike) -> SfgStatements: return SfgStatements(str(arg), (), depends(arg), includes(arg)) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 489823b..1e3e6a3 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -1,26 +1,28 @@ from __future__ import annotations from typing import Sequence +from itertools import takewhile, dropwhile from pystencils.types import PsCustomType, UserTypeSpec, create_type +from ..context import SfgContext from ..lang import ( - _VarLike, VarLike, ExprLike, asvar, SfgVar, ) +from ..ir.call_tree import SfgCallTreeNode from ..ir.source_components import ( SfgClass, - SfgClassMember, - SfgInClassDefinition, SfgConstructor, SfgMethod, SfgMemberVariable, SfgClassKeyword, SfgVisibility, SfgVisibilityBlock, + SfgEntityDecl, + SfgEntityDef, ) from ..exceptions import SfgException @@ -40,31 +42,87 @@ class SfgClassComposer(SfgComposerMixIn): Its interface is exposed by :class:`SfgComposer`. """ - class VisibilityContext: + class VisibilityBlockSequencer: """Represent a visibility block in the composer syntax. Returned by `private`, `public`, and `protected`. """ def __init__(self, visibility: SfgVisibility): - self._vis_block = SfgVisibilityBlock(visibility) - - def members(self): - yield from self._vis_block.members() + self._visibility = visibility + self._args: tuple[ + SfgClassComposer.MethodSequencer + | SfgClassComposer.ConstructorBuilder + | VarLike + | str, + ..., + ] def __call__( self, *args: ( - SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str + SfgClassComposer.MethodSequencer + | SfgClassComposer.ConstructorBuilder + | VarLike + | str ), ): - for arg in args: - self._vis_block.append_member(SfgClassComposer._resolve_member(arg)) + self._args = args + return self + def _resolve(self, ctx: SfgContext, cls: SfgClass): + vis_block = SfgVisibilityBlock(self._visibility) + for arg in self._args: + match arg: + case ( + SfgClassComposer.MethodSequencer() + | SfgClassComposer.ConstructorBuilder() + ): + arg._resolve(ctx, cls, vis_block) + case str(): + vis_block.elements.append(arg) + case _: + var = asvar(arg) + member_var = SfgMemberVariable(var.name, var.dtype, cls) + cls.add_member(member_var, vis_block.visibility) + vis_block.elements.append(member_var) + + class MethodSequencer: + def __init__( + self, + name: str, + returns: UserTypeSpec = PsCustomType("void"), + inline: bool = False, + const: bool = False, + ) -> None: + self._name = name + self._returns = create_type(returns) + self._inline = inline + self._const = const + self._tree: SfgCallTreeNode + + def __call__(self, *args: SequencerArg): + self._tree = make_sequence(*args) return self - def resolve(self, cls: SfgClass) -> None: - cls.append_visibility_block(self._vis_block) + def _resolve( + self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock + ): + method = SfgMethod( + self._name, + cls, + self._tree, + return_type=self._returns, + inline=self._inline, + const=self._const, + ) + cls.add_member(method, vis_block.visibility) + + if self._inline: + vis_block.elements.append(SfgEntityDef(method)) + else: + vis_block.elements.append(SfgEntityDecl(method)) + ctx._cursor.write_impl(SfgEntityDef(method)) class ConstructorBuilder: """Composer syntax for constructor building. @@ -107,13 +165,19 @@ class SfgClassComposer(SfgComposerMixIn): self._body = body return self - def resolve(self) -> SfgConstructor: - return SfgConstructor( + def _resolve( + self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock + ): + ctor = SfgConstructor( + cls, parameters=self._params, initializers=self._initializers, body=self._body if self._body is not None else "", ) + cls.add_member(ctor, vis_block.visibility) + vis_block.elements.append(SfgEntityDef(ctor)) + def klass(self, class_name: str, bases: Sequence[str] = ()): """Create a class and add it to the underlying context. @@ -133,19 +197,19 @@ class SfgClassComposer(SfgComposerMixIn): return self._class(class_name, SfgClassKeyword.STRUCT, bases) @property - def public(self) -> SfgClassComposer.VisibilityContext: + def public(self) -> SfgClassComposer.VisibilityBlockSequencer: """Create a `public` visibility block in a class body""" - return SfgClassComposer.VisibilityContext(SfgVisibility.PUBLIC) + return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PUBLIC) @property - def protected(self) -> SfgClassComposer.VisibilityContext: + def protected(self) -> SfgClassComposer.VisibilityBlockSequencer: """Create a `protected` visibility block in a class or struct body""" - return SfgClassComposer.VisibilityContext(SfgVisibility.PROTECTED) + return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PROTECTED) @property - def private(self) -> SfgClassComposer.VisibilityContext: + def private(self) -> SfgClassComposer.VisibilityBlockSequencer: """Create a `private` visibility block in a class or struct body""" - return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE) + return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PRIVATE) def constructor(self, *params: VarLike): """In a class or struct body or visibility block, add a constructor. @@ -172,76 +236,55 @@ class SfgClassComposer(SfgComposerMixIn): const: Whether or not the method is const-qualified. """ - def sequencer(*args: SequencerArg): - tree = make_sequence(*args) - return SfgMethod( - name, - tree, - return_type=create_type(returns), - inline=inline, - const=const, - ) - - return sequencer + return SfgClassComposer.MethodSequencer(name, returns, inline, const) # INTERNALS def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]): - if self._ctx.get_class(class_name) is not None: - raise ValueError(f"Class or struct {class_name} already exists.") + if self._cursor.get_entity(class_name) is not None: + raise ValueError( + f"Another entity with name {class_name} already exists in the current namespace." + ) - cls = SfgClass(class_name, class_keyword=keyword, bases=bases) - self._ctx.add_class(cls) + cls = SfgClass( + class_name, + self._cursor.current_namespace, + class_keyword=keyword, + bases=bases, + ) + self._cursor.add_entity(cls) def sequencer( *args: ( - SfgClassComposer.VisibilityContext - | SfgClassMember + SfgClassComposer.VisibilityBlockSequencer + | SfgClassComposer.MethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str ), ): - default_ended = False - - for arg in args: - if isinstance(arg, SfgClassComposer.VisibilityContext): - default_ended = True - arg.resolve(cls) - elif isinstance( - arg, - ( - SfgClassMember, - SfgClassComposer.ConstructorBuilder, - str, - ) - + _VarLike, - ): - if default_ended: - raise SfgException( - "Composer Syntax Error: " - "Cannot add members with default visibility after a visibility block." - ) - else: - cls.default.append_member(self._resolve_member(arg)) + default_vis_sequencer = SfgClassComposer.VisibilityBlockSequencer( + SfgVisibility.DEFAULT + ) + + def argfilter(arg): + return not isinstance(arg, SfgClassComposer.VisibilityBlockSequencer) + + default_vis_args = takewhile( + argfilter, + args, + ) + default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls) # type: ignore + + for arg in dropwhile(argfilter, args): + if isinstance(arg, SfgClassComposer.VisibilityBlockSequencer): + arg._resolve(self._ctx, cls) else: - raise SfgException(f"{arg} is not a valid class member.") + raise SfgException( + "Composer Syntax Error: " + "Cannot add members with default visibility after a visibility block." + ) - return sequencer + self._cursor.write_header(SfgEntityDef(cls)) - @staticmethod - def _resolve_member( - arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str, - ) -> SfgClassMember: - match arg: - case _ if isinstance(arg, _VarLike): - var = asvar(arg) - return SfgMemberVariable(var.name, var.dtype) - case str(): - return SfgInClassDefinition(arg) - case SfgClassComposer.ConstructorBuilder(): - return arg.resolve() - case SfgClassMember(): - return arg - case _: - raise ValueError(f"Invalid class member: {arg}") + return sequencer diff --git a/src/pystencilssfg/composer/mixin.py b/src/pystencilssfg/composer/mixin.py index 3ee8efa..34b1c58 100644 --- a/src/pystencilssfg/composer/mixin.py +++ b/src/pystencilssfg/composer/mixin.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ..context import SfgContext +from ..context import SfgContext, SfgCursor from .basic_composer import SfgBasicComposer @@ -14,6 +14,7 @@ class SfgComposerMixIn: def __init__(self) -> None: self._ctx: SfgContext + self._cursor: SfgCursor @property def _composer(self) -> SfgBasicComposer: diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 577dcbd..48d0720 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -5,11 +5,10 @@ from .config import CodeStyle from .ir.source_components import ( SfgSourceFile, SfgNamespace, - SfgKernelNamespace, SfgNamespaceBlock, SfgNamespaceElement, SfgCodeEntity, - SfgClass, + SfgGlobalNamespace, ) from .exceptions import SfgException @@ -37,9 +36,14 @@ class SfgContext: self._header_file = header_file self._impl_file = impl_file - self._entities: dict[str, SfgCodeEntity] = dict() + self._global_namespace = SfgGlobalNamespace() - self._cursor: SfgCursor + current_ns: SfgNamespace = self._global_namespace + if outer_namespace is not None: + for token in outer_namespace.split("::"): + current_ns = SfgNamespace(token, current_ns) + + self._cursor = SfgCursor(self, current_ns) @property def argv(self) -> Sequence[str]: @@ -85,28 +89,18 @@ class SfgContext: if self._impl_file is not None: yield self._impl_file - def get_entity(self, fqname: str) -> SfgCodeEntity | None: - # TODO: Only track top-level entities here, traverse namespaces to find qualified entities - return self._entities.get(fqname, None) - - def add_entity(self, entity: SfgCodeEntity) -> None: - fqname = entity.fqname - if fqname in self._entities: - raise ValueError(f"Another entity with name {fqname} already exists") - self._entities[fqname] = entity + @property + def global_namespace(self) -> SfgNamespace: + return self._global_namespace class SfgCursor: """Cursor that tracks the current location in the source file(s) during execution of the generator script.""" - def __init__(self, ctx: SfgContext, namespace: str | None = None) -> None: + def __init__(self, ctx: SfgContext, namespace: SfgNamespace) -> None: self._ctx = ctx - self._cur_namespace: SfgNamespace | None - if namespace is not None: - self._cur_namespace = ctx.get_namespace(namespace) - else: - self._cur_namespace = None + self._cur_namespace: SfgNamespace = namespace self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] for f in self._ctx.files: @@ -120,9 +114,15 @@ class SfgCursor: # TODO: Enter and exit namespace blocks @property - def current_namespace(self) -> SfgNamespace | None: + def current_namespace(self) -> SfgNamespace: return self._cur_namespace + def get_entity(self, name: str) -> SfgCodeEntity | None: + return self._cur_namespace.get_entity(name) + + def add_entity(self, entity: SfgCodeEntity): + self._cur_namespace.add_entity(entity) + def write_header(self, elem: SfgNamespaceElement) -> None: self._loc[self._ctx.header_file].append(elem) diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index 0b42594..c550975 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -15,7 +15,6 @@ def collect_includes(obj: Any) -> set[HeaderFile]: SfgClass, SfgConstructor, SfgMemberVariable, - SfgInClassDefinition, ) match obj: @@ -58,9 +57,6 @@ def collect_includes(obj: Any) -> set[HeaderFile]: case SfgMemberVariable(): return includes(obj) - case SfgInClassDefinition(): - return set() - case _: raise SfgException( f"Can't collect includes from object of type {type(obj)}" diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index a5d2c5a..9a29f2f 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -210,7 +210,7 @@ class SfgKernelCallNode(SfgCallTreeLeaf): def get_code(self, ctx: SfgContext) -> str: ast_params = self._kernel_handle.parameters - fnc_name = self._kernel_handle.fully_qualified_name + fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in ast_params]) return f"{fnc_name}({call_parameters});" @@ -228,7 +228,7 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): from pystencils import Target from pystencils.codegen import GpuKernel - func = kernel_handle.get_kernel_function() + func = kernel_handle.get_kernel() if not (isinstance(func, GpuKernel) and func.target == Target.CUDA): raise ValueError( "An `SfgCudaKernelInvocation` node can only call a CUDA kernel." @@ -247,7 +247,7 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): def get_code(self, ctx: SfgContext) -> str: ast_params = self._kernel_handle.parameters - fnc_name = self._kernel_handle.fully_qualified_name + fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in ast_params]) grid_args = [self._num_blocks, self._threads_per_block] diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index d8fb0f5..15b27fb 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -39,9 +39,9 @@ class SfgCodeEntity: Each code entity has a name and an optional enclosing namespace. """ - def __init__(self, name: str, parent_namespace: SfgNamespace | None) -> None: + def __init__(self, name: str, parent_namespace: SfgNamespace) -> None: self._name = name - self._namespace: SfgNamespace | None = parent_namespace + self._namespace: SfgNamespace = parent_namespace @property def name(self) -> str: @@ -51,7 +51,7 @@ class SfgCodeEntity: @property def fqname(self) -> str: """Fully qualified name of this entity""" - if self._namespace is not None: + if not isinstance(self._namespace, SfgGlobalNamespace): return self._namespace.fqname + "::" + self._name else: return self._name @@ -73,7 +73,31 @@ class SfgNamespace(SfgCodeEntity): parent: Parent namespace enclosing this namespace """ - # TODO: Namespaces must keep track of their child entities + def __init__(self, name: str, parent_namespace: SfgNamespace) -> None: + super().__init__(name, parent_namespace) + + self._entities: dict[str, SfgCodeEntity] = dict() + + def get_entity(self, name: str) -> SfgCodeEntity | None: + return self._entities.get(name, None) + + def add_entity(self, entity: SfgCodeEntity): + if entity.name in self._entities: + raise ValueError( + f"Another entity with the name {entity.fqname} already exists" + ) + self._entities[entity.name] = entity + + +class SfgGlobalNamespace(SfgNamespace): + """The C++ global namespace.""" + + def __init__(self) -> None: + super().__init__("", self) + + @property + def fqname(self) -> str: + return "" class SfgKernelHandle(SfgCodeEntity): @@ -117,9 +141,9 @@ class SfgKernelHandle(SfgCodeEntity): class SfgKernelNamespace(SfgNamespace): """A namespace grouping together a number of kernels.""" - def __init__(self, name: str, parent: SfgNamespace | None): + def __init__(self, name: str, parent: SfgNamespace): super().__init__(name, parent) - self._kernels: dict[str, SfgKernelHandle] = [] + self._kernels: dict[str, SfgKernelHandle] = dict() @property def name(self): @@ -200,7 +224,7 @@ class SfgFunction(SfgCodeEntity): def __init__( self, name: str, - namespace: SfgNamespace | None, + namespace: SfgNamespace, tree: SfgCallTreeNode, return_type: PsType = void, inline: bool = False, @@ -272,8 +296,8 @@ class SfgClassKeyword(Enum): class SfgClassMember(ABC): """Base class for class member entities""" - def __init__(self) -> None: - self._cls: SfgClass | None = None + def __init__(self, cls: SfgClass) -> None: + self._cls: SfgClass = cls self._visibility: SfgVisibility | None = None @property @@ -290,26 +314,13 @@ class SfgClassMember(ABC): ) return self._visibility - @property - def is_bound(self) -> bool: - return self._cls is not None - - def _bind(self, cls: SfgClass, vis: SfgVisibility): - if self.is_bound: - raise SfgException( - f"Binding {self} to class {cls.class_name} failed: " - f"{self} was already bound to {self.owning_class.class_name}" - ) - self._cls = cls - self._vis = vis - class SfgMemberVariable(SfgVar, SfgClassMember): """Variable that is a field of a class""" - def __init__(self, name: str, dtype: PsType): + def __init__(self, name: str, dtype: PsType, cls: SfgClass): SfgVar.__init__(self, name, dtype) - SfgClassMember.__init__(self) + SfgClassMember.__init__(self, cls) class SfgMethod(SfgClassMember): @@ -320,12 +331,13 @@ class SfgMethod(SfgClassMember): def __init__( self, name: str, + cls: SfgClass, tree: SfgCallTreeNode, return_type: PsType = void, inline: bool = False, const: bool = False, ): - super().__init__() + super().__init__(cls) self._name = name self._tree = tree @@ -368,11 +380,12 @@ class SfgConstructor(SfgClassMember): def __init__( self, + cls: SfgClass, parameters: Sequence[SfgVar] = (), initializers: Sequence[str] = (), body: str = "", ): - SfgClassMember.__init__(self) + super().__init__(cls) self._parameters = tuple(parameters) self._initializers = tuple(initializers) self._body = body @@ -409,7 +422,7 @@ class SfgClass(SfgCodeEntity): def __init__( self, name: str, - namespace: SfgNamespace | None, + namespace: SfgNamespace, class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS, bases: Sequence[str] = (), ): @@ -422,7 +435,6 @@ class SfgClass(SfgCodeEntity): self._bases_classes = tuple(bases) self._default_block = SfgVisibilityBlock(SfgVisibility.DEFAULT) - self._default_block._bind(self) self._blocks = [self._default_block] self._constructors: list[SfgConstructor] = [] @@ -451,14 +463,10 @@ class SfgClass(SfgCodeEntity): raise SfgException( "Can't add another block with DEFAULT visibility to a class. Use `.default` instead." ) - - block._bind(self) - for m in block.members(): - self._add_member(m, block.visibility) self._blocks.append(block) - def visibility_blocks(self) -> Generator[SfgVisibilityBlock, None, None]: - yield from self._blocks + def visibility_blocks(self) -> tuple[SfgVisibilityBlock, ...]: + return tuple(self._blocks) def members( self, visibility: SfgVisibility | None = None @@ -497,26 +505,16 @@ class SfgClass(SfgCodeEntity): else: yield from self._methods - # PRIVATE - - def _add_member(self, member: SfgClassMember, vis: SfgVisibility): + def add_member(self, member: SfgClassMember, vis: SfgVisibility): if isinstance(member, SfgConstructor): - self._add_constructor(member) + self._constructors.append(member) elif isinstance(member, SfgMemberVariable): self._add_member_variable(member) elif isinstance(member, SfgMethod): - self._add_method(member) + self._methods.append(member) else: raise SfgException(f"{member} is not a valid class member.") - member._bind(self, vis) - - def _add_constructor(self, constr: SfgConstructor): - self._constructors.append(constr) - - def _add_method(self, method: SfgMethod): - self._methods.append(method) - def _add_member_variable(self, variable: SfgMemberVariable): if variable.name in self._member_vars: raise SfgException( @@ -598,22 +596,13 @@ class SfgVisibilityBlock: def visibility(self) -> SfgVisibility: return self._vis - def append_member(self, element: SfgClassBodyElement): - if isinstance(element, (SfgEntityDecl, SfgEntityDef)): - member = element.entity - elif isinstance(element, SfgClassMember): - member = element - else: - member = None - - if self._cls is not None and member is not None: - self._cls._add_member(member, self._vis) - - self._elements.append(element) - @property - def elements(self) -> tuple[SfgClassBodyElement, ...]: - return tuple(self._elements) + def elements(self) -> list[SfgClassBodyElement]: + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgClassBodyElement]): + self._elements = list(elems) def members(self) -> Generator[SfgClassMember, None, None]: for elem in self._elements: @@ -623,18 +612,6 @@ class SfgVisibilityBlock: case SfgMemberVariable(): yield elem - @property - def is_bound(self) -> bool: - return self._cls is not None - - def _bind(self, cls: SfgClass): - if self._cls is not None: - raise SfgException( - f"Binding visibility block to class {cls.class_name} failed: " - f"was already bound to {self._cls.class_name}" - ) - self._cls = cls - class SfgNamespaceBlock: """A C++ namespace. -- GitLab From b90499a37e6ad106b046ea10b317d6d45132e1e1 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 6 Feb 2025 20:44:01 +0100 Subject: [PATCH 05/18] some cleanup - move numpy_struct to class composer and use its sequencers - fix some type errors --- src/pystencilssfg/composer/basic_composer.py | 69 +++----------------- src/pystencilssfg/composer/class_composer.py | 40 ++++++++++++ src/pystencilssfg/ir/__init__.py | 6 -- src/pystencilssfg/ir/source_components.py | 25 ++----- 4 files changed, 56 insertions(+), 84 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 936d5c5..76e2907 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Sequence, TypeAlias from abc import ABC, abstractmethod -import numpy as np import sympy as sp from functools import reduce from warnings import warn @@ -33,10 +32,6 @@ from ..ir.source_components import ( SfgFunction, SfgKernelNamespace, SfgKernelHandle, - SfgClass, - SfgConstructor, - SfgMemberVariable, - SfgClassKeyword, SfgEntityDecl, SfgEntityDef, SfgNamespaceBlock, @@ -86,7 +81,7 @@ SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder class KernelsAdder: def __init__(self, ctx: SfgContext, loc: SfgNamespaceBlock): self._ctx = ctx - self._loc = SfgNamespaceBlock + self._loc = loc assert isinstance(loc.namespace, SfgKernelNamespace) self._kernel_namespace = loc.namespace @@ -110,6 +105,8 @@ class KernelsAdder: khandle = SfgKernelHandle(kernel_name, self._kernel_namespace, kernel) self._kernel_namespace.add_kernel(khandle) + self._loc.elements.append(SfgEntityDef(khandle)) + for header in kernel.required_headers: assert self._ctx.impl_file is not None self._ctx.impl_file.includes.append(HeaderFile.parse(header)) @@ -242,7 +239,7 @@ class SfgBasicComposer(SfgIComposer): self._cursor.write_impl(kns_block) return KernelsAdder(self._ctx, kns_block) - def include(self, header_file: str | HeaderFile, private: bool = False): + def include(self, header: str | HeaderFile, private: bool = False): """Include a header file. Args: @@ -262,7 +259,7 @@ class SfgBasicComposer(SfgIComposer): #include <vector> #include "custom.h" """ - header_file = HeaderFile.parse(header_file) + header_file = HeaderFile.parse(header) if private: if self._ctx.impl_file is None: @@ -273,21 +270,6 @@ class SfgBasicComposer(SfgIComposer): else: self._ctx.header_file.includes.append(header_file) - def numpy_struct( - self, name: str, dtype: np.dtype, add_constructor: bool = True - ) -> SfgClass: - """Add a numpy structured data type as a C++ struct - - Returns: - The created class object - """ - cls = self._struct_from_numpy_dtype( - name, dtype, add_constructor=add_constructor - ) - self._cursor.add_entity(cls) - self._cursor.write_header(SfgEntityDecl(cls)) - return cls - def kernel_function(self, name: str, kernel: Kernel | SfgKernelHandle): """Create a function comprising just a single kernel call. @@ -295,9 +277,11 @@ class SfgBasicComposer(SfgIComposer): ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST. """ if isinstance(kernel, Kernel): - kernel = self.kernels.add(kernel, name) + khandle = self.kernels.add(kernel, name) + else: + khandle = kernel - self.function(name)(self.call(kernel)) + self.function(name)(self.call(khandle)) def function( self, @@ -536,41 +520,6 @@ class SfgBasicComposer(SfgIComposer): ] return SfgDeferredVectorMapping(components, rhs) - def _struct_from_numpy_dtype( - self, struct_name: str, dtype: np.dtype, add_constructor: bool = True - ): - cls = SfgClass( - struct_name, - self._cursor.current_namespace, - class_keyword=SfgClassKeyword.STRUCT, - ) - - fields = dtype.fields - if fields is None: - raise SfgException(f"Numpy dtype {dtype} is not a structured type.") - - constr_params = [] - constr_inits = [] - - for member_name, type_info in fields.items(): - member_type = create_type(type_info[0]) - - member = SfgMemberVariable(member_name, member_type) - - arg = SfgVar(f"{member_name}_", member_type) - - cls.default.append_member(member) - - constr_params.append(arg) - constr_inits.append(f"{member}({arg})") - - if add_constructor: - cls.default.append_member( - SfgEntityDef(SfgConstructor(constr_params, constr_inits)) - ) - - return cls - def make_statements(arg: ExprLike) -> SfgStatements: return SfgStatements(str(arg), (), depends(arg), includes(arg)) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 1e3e6a3..fa7d6f2 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Sequence from itertools import takewhile, dropwhile +import numpy as np from pystencils.types import PsCustomType, UserTypeSpec, create_type @@ -196,6 +197,16 @@ class SfgClassComposer(SfgComposerMixIn): """ return self._class(class_name, SfgClassKeyword.STRUCT, bases) + def numpy_struct( + self, name: str, dtype: np.dtype, add_constructor: bool = True + ): + """Add a numpy structured data type as a C++ struct + + Returns: + The created class object + """ + return self._struct_from_numpy_dtype(name, dtype, add_constructor) + @property def public(self) -> SfgClassComposer.VisibilityBlockSequencer: """Create a `public` visibility block in a class body""" @@ -241,6 +252,8 @@ class SfgClassComposer(SfgComposerMixIn): # INTERNALS def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]): + # TODO: Return a `CppClass` instance representing the generated class + if self._cursor.get_entity(class_name) is not None: raise ValueError( f"Another entity with name {class_name} already exists in the current namespace." @@ -288,3 +301,30 @@ class SfgClassComposer(SfgComposerMixIn): self._cursor.write_header(SfgEntityDef(cls)) return sequencer + + def _struct_from_numpy_dtype( + self, struct_name: str, dtype: np.dtype, add_constructor: bool = True + ): + fields = dtype.fields + if fields is None: + raise SfgException(f"Numpy dtype {dtype} is not a structured type.") + + members: list[SfgClassComposer.ConstructorBuilder | SfgVar] = [] + if add_constructor: + ctor = self.constructor() + members.append(ctor) + + for member_name, type_info in fields.items(): + member_type = create_type(type_info[0]) + + member = SfgVar(member_name, member_type) + members.append(member) + + if add_constructor: + arg = SfgVar(f"{member_name}_", member_type) + ctor.add_param(arg) + ctor.init(member)(arg) + + return self.struct( + struct_name, + )(*members) diff --git a/src/pystencilssfg/ir/__init__.py b/src/pystencilssfg/ir/__init__.py index 8eee39c..f1760b7 100644 --- a/src/pystencilssfg/ir/__init__.py +++ b/src/pystencilssfg/ir/__init__.py @@ -15,8 +15,6 @@ from .call_tree import ( ) from .source_components import ( - SfgHeaderInclude, - SfgEmptyLines, SfgKernelNamespace, SfgKernelHandle, SfgKernelParamVar, @@ -25,7 +23,6 @@ from .source_components import ( SfgClassKeyword, SfgClassMember, SfgVisibilityBlock, - SfgInClassDefinition, SfgMemberVariable, SfgMethod, SfgConstructor, @@ -47,8 +44,6 @@ __all__ = [ "SfgBranch", "SfgSwitchCase", "SfgSwitch", - "SfgHeaderInclude", - "SfgEmptyLines", "SfgKernelNamespace", "SfgKernelHandle", "SfgKernelParamVar", @@ -57,7 +52,6 @@ __all__ = [ "SfgClassKeyword", "SfgClassMember", "SfgVisibilityBlock", - "SfgInClassDefinition", "SfgMemberVariable", "SfgMethod", "SfgConstructor", diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 15b27fb..b4d8aa7 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -404,18 +404,7 @@ class SfgConstructor(SfgClassMember): class SfgClass(SfgCodeEntity): - """Models a C++ class. - - ### Adding members to classes - - Members are never added directly to a class. Instead, they are added to - an [SfgVisibilityBlock][pystencilssfg.source_components.SfgVisibilityBlock] - which defines their syntactic position and visibility modifier in the code. - At the top of every class, there is a default visibility block - accessible through the `default` property. - To add members with custom visibility, create a new SfgVisibilityBlock, - add members to the block, and add the block using `append_visibility_block`. - """ + """A C++ class.""" __match_args__ = ("class_name",) @@ -524,12 +513,6 @@ class SfgClass(SfgCodeEntity): self._member_vars[variable.name] = variable -SourceEntity_T = TypeVar( - "SourceEntity_T", bound=SfgFunction | SfgClassMember | SfgClass, covariant=True -) -"""Source entities that may have declarations and definitions.""" - - # ========================================================================================================= # # SYNTACTICAL ELEMENTS @@ -540,6 +523,12 @@ SourceEntity_T = TypeVar( # ========================================================================================================= +SourceEntity_T = TypeVar( + "SourceEntity_T", bound=SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, covariant=True +) +"""Source entities that may have declarations and definitions.""" + + class SfgEntityDecl(Generic[SourceEntity_T]): """Declaration of a function, class, method, or constructor""" -- GitLab From a87856af43f8600d2e517c8fd780815ec814e30e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 09:30:42 +0100 Subject: [PATCH 06/18] split class entity and its body; split enitites and syntax into two separate files --- src/pystencilssfg/composer/basic_composer.py | 2 +- src/pystencilssfg/composer/class_composer.py | 4 +- src/pystencilssfg/context.py | 4 +- src/pystencilssfg/extensions/sycl.py | 2 +- src/pystencilssfg/ir/__init__.py | 32 ++- src/pystencilssfg/ir/analysis.py | 2 +- src/pystencilssfg/ir/call_tree.py | 2 +- .../ir/{source_components.py => entities.py} | 214 +--------------- src/pystencilssfg/ir/postprocessing.py | 2 +- src/pystencilssfg/ir/syntax.py | 228 ++++++++++++++++++ 10 files changed, 268 insertions(+), 224 deletions(-) rename src/pystencilssfg/ir/{source_components.py => entities.py} (68%) create mode 100644 src/pystencilssfg/ir/syntax.py diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 76e2907..08422be 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -28,7 +28,7 @@ from ..ir.postprocessing import ( SfgDeferredFieldMapping, SfgDeferredVectorMapping, ) -from ..ir.source_components import ( +from ..ir import ( SfgFunction, SfgKernelNamespace, SfgKernelHandle, diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index fa7d6f2..3ed5c0f 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -13,8 +13,8 @@ from ..lang import ( SfgVar, ) -from ..ir.call_tree import SfgCallTreeNode -from ..ir.source_components import ( +from ..ir import ( + SfgCallTreeNode, SfgClass, SfgConstructor, SfgMethod, diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 48d0720..a129f98 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -2,14 +2,14 @@ from __future__ import annotations from typing import Sequence, Any, Generator from .config import CodeStyle -from .ir.source_components import ( +from .ir import ( SfgSourceFile, SfgNamespace, SfgNamespaceBlock, - SfgNamespaceElement, SfgCodeEntity, SfgGlobalNamespace, ) +from .ir.syntax import SfgNamespaceElement from .exceptions import SfgException diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 88dbc9b..a628f00 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -17,8 +17,8 @@ from ..composer import ( SfgComposerMixIn, make_sequence, ) -from ..ir.source_components import SfgKernelHandle from ..ir import ( + SfgKernelHandle, SfgCallTreeNode, SfgCallTreeLeaf, SfgKernelCallNode, diff --git a/src/pystencilssfg/ir/__init__.py b/src/pystencilssfg/ir/__init__.py index f1760b7..8f03fed 100644 --- a/src/pystencilssfg/ir/__init__.py +++ b/src/pystencilssfg/ir/__init__.py @@ -14,20 +14,32 @@ from .call_tree import ( SfgSwitch, ) -from .source_components import ( +from .entities import ( + SfgCodeEntity, + SfgNamespace, + SfgGlobalNamespace, SfgKernelNamespace, SfgKernelHandle, - SfgKernelParamVar, SfgFunction, SfgVisibility, SfgClassKeyword, SfgClassMember, - SfgVisibilityBlock, SfgMemberVariable, SfgMethod, SfgConstructor, SfgClass, ) + +from .syntax import ( + SfgEntityDecl, + SfgEntityDef, + SfgVisibilityBlock, + SfgNamespaceBlock, + SfgClassBody, + SfgSourceFileType, + SfgSourceFile, +) + from .analysis import collect_includes __all__ = [ @@ -44,17 +56,25 @@ __all__ = [ "SfgBranch", "SfgSwitchCase", "SfgSwitch", + "SfgCodeEntity", + "SfgNamespace", + "SfgGlobalNamespace", "SfgKernelNamespace", "SfgKernelHandle", - "SfgKernelParamVar", "SfgFunction", "SfgVisibility", "SfgClassKeyword", "SfgClassMember", - "SfgVisibilityBlock", "SfgMemberVariable", "SfgMethod", "SfgConstructor", "SfgClass", - "collect_includes" + "SfgEntityDecl", + "SfgEntityDef", + "SfgVisibilityBlock", + "SfgNamespaceBlock", + "SfgClassBody", + "SfgSourceFileType", + "SfgSourceFile", + "collect_includes", ] diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index c550975..c2c3e34 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -10,7 +10,7 @@ from ..lang import HeaderFile, includes def collect_includes(obj: Any) -> set[HeaderFile]: from ..context import SfgContext from .call_tree import SfgCallTreeNode - from .source_components import ( + from .entities import ( SfgFunction, SfgClass, SfgConstructor, diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index 9a29f2f..2e057dd 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Sequence, Iterable, NewType from abc import ABC, abstractmethod -from .source_components import SfgKernelHandle +from .entities import SfgKernelHandle from ..lang import SfgVar, HeaderFile if TYPE_CHECKING: diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/entities.py similarity index 68% rename from src/pystencilssfg/ir/source_components.py rename to src/pystencilssfg/ir/entities.py index b4d8aa7..9d76b39 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/entities.py @@ -6,9 +6,6 @@ from typing import ( TYPE_CHECKING, Sequence, Generator, - Iterable, - TypeVar, - Generic, ) from dataclasses import replace from itertools import chain @@ -17,7 +14,7 @@ from pystencils import CreateKernelConfig, create_kernel, Field from pystencils.codegen import Kernel from pystencils.types import PsType, PsCustomType -from ..lang import SfgVar, SfgKernelParamVar, HeaderFile, void +from ..lang import SfgVar, SfgKernelParamVar, void from ..exceptions import SfgException if TYPE_CHECKING: @@ -423,9 +420,6 @@ class SfgClass(SfgCodeEntity): self._class_keyword = class_keyword self._bases_classes = tuple(bases) - self._default_block = SfgVisibilityBlock(SfgVisibility.DEFAULT) - self._blocks = [self._default_block] - self._constructors: list[SfgConstructor] = [] self._methods: list[SfgMethod] = [] self._member_vars: dict[str, SfgMemberVariable] = dict() @@ -443,30 +437,15 @@ class SfgClass(SfgCodeEntity): def class_keyword(self) -> SfgClassKeyword: return self._class_keyword - @property - def default(self) -> SfgVisibilityBlock: - return self._default_block - - def append_visibility_block(self, block: SfgVisibilityBlock): - if block.visibility == SfgVisibility.DEFAULT: - raise SfgException( - "Can't add another block with DEFAULT visibility to a class. Use `.default` instead." - ) - self._blocks.append(block) - - def visibility_blocks(self) -> tuple[SfgVisibilityBlock, ...]: - return tuple(self._blocks) - def members( self, visibility: SfgVisibility | None = None ) -> Generator[SfgClassMember, None, None]: if visibility is None: - yield from chain.from_iterable(b.members() for b in self._blocks) - else: - yield from chain.from_iterable( - b.members() - for b in filter(lambda b: b.visibility == visibility, self._blocks) + yield from chain( + self._constructors, self._methods, self._member_vars.values() ) + else: + yield from filter(lambda m: m.visibility == visibility, self.members()) def member_variables( self, visibility: SfgVisibility | None = None @@ -511,186 +490,3 @@ class SfgClass(SfgCodeEntity): ) self._member_vars[variable.name] = variable - - -# ========================================================================================================= -# -# SYNTACTICAL ELEMENTS -# -# These classes model *code elements*, which represent the actual syntax objects that populate the output -# files, their namespaces and class bodies. -# -# ========================================================================================================= - - -SourceEntity_T = TypeVar( - "SourceEntity_T", bound=SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, covariant=True -) -"""Source entities that may have declarations and definitions.""" - - -class SfgEntityDecl(Generic[SourceEntity_T]): - """Declaration of a function, class, method, or constructor""" - - __match_args__ = ("entity",) - - def __init__(self, entity: SourceEntity_T) -> None: - self._entity = entity - - @property - def entity(self) -> SourceEntity_T: - return self._entity - - -class SfgEntityDef(Generic[SourceEntity_T]): - """Definition of a function, class, method, or constructor""" - - __match_args__ = ("entity",) - - def __init__(self, entity: SourceEntity_T) -> None: - self._entity = entity - - @property - def entity(self) -> SourceEntity_T: - return self._entity - - -SfgClassBodyElement = ( - str - | SfgEntityDecl[SfgClassMember] - | SfgEntityDef[SfgClassMember] - | SfgMemberVariable -) -"""Elements that may be placed in the visibility blocks of a class body.""" - - -class SfgVisibilityBlock: - """Visibility-qualified block inside a class definition body. - - Visibility blocks host the code elements placed inside a class body: - method and constructor declarations, - in-class method and constructor definitions, - as well as variable declarations and definitions. - - Args: - visibility: The visibility qualifier of this block - """ - - def __init__(self, visibility: SfgVisibility) -> None: - self._vis = visibility - self._elements: list[SfgClassBodyElement] = [] - self._cls: SfgClass | None = None - - @property - def visibility(self) -> SfgVisibility: - return self._vis - - @property - def elements(self) -> list[SfgClassBodyElement]: - return self._elements - - @elements.setter - def elements(self, elems: Iterable[SfgClassBodyElement]): - self._elements = list(elems) - - def members(self) -> Generator[SfgClassMember, None, None]: - for elem in self._elements: - match elem: - case SfgEntityDecl(entity) | SfgEntityDef(entity): - yield entity - case SfgMemberVariable(): - yield elem - - -class SfgNamespaceBlock: - """A C++ namespace. - - Each namespace has a `name` and a `parent`; its fully qualified name is given as - ``<parent.name>::<name>``. - - Args: - name: Local name of this namespace - parent: Parent namespace enclosing this namespace - """ - - def __init__(self, namespace: SfgNamespace) -> None: - self._namespace = namespace - self._elements: list[SfgNamespaceElement] = [] - - @property - def namespace(self) -> SfgNamespace: - return self._namespace - - @property - def elements(self) -> list[SfgNamespaceElement]: - """Sequence of source elements that make up the body of this namespace""" - return self._elements - - @elements.setter - def elements(self, elems: Iterable[SfgNamespaceElement]): - self._elements = list(elems) - - -SfgNamespaceElement = str | SfgNamespaceBlock | SfgEntityDecl | SfgEntityDef -"""Elements that may be placed inside a namespace, including the global namespace.""" - - -class SfgSourceFileType(Enum): - HEADER = auto() - TRANSLATION_UNIT = auto() - - -class SfgSourceFile: - """A C++ source file. - - Args: - name: Name of the file (without parent directories), e.g. ``Algorithms.cpp`` - file_type: Type of the source file (header or translation unit) - prelude: Optionally, text of the prelude comment printed at the top of the file - """ - - def __init__( - self, name: str, file_type: SfgSourceFileType, prelude: str | None = None - ) -> None: - self._name: str = name - self._file_type: SfgSourceFileType = file_type - self._prelude: str | None = prelude - self._includes: list[HeaderFile] = [] - self._elements: list[SfgNamespaceElement] = [] - - @property - def name(self) -> str: - """Name of this source file""" - return self._name - - @property - def file_type(self) -> SfgSourceFileType: - """File type of this source file""" - return self._file_type - - @property - def prelude(self) -> str | None: - """Text of the prelude comment""" - return self._prelude - - @prelude.setter - def prelude(self, text: str | None): - self._prelude = text - - @property - def includes(self) -> list[HeaderFile]: - """Sequence of header files to be included at the top of this file""" - return self._includes - - @includes.setter - def includes(self, incl: Iterable[HeaderFile]): - self._includes = list(incl) - - @property - def elements(self) -> list[SfgNamespaceElement]: - """Sequence of source elements comprising the body of this file""" - return self._elements - - @elements.setter - def elements(self, elems: Iterable[SfgNamespaceElement]): - self._elements = list(elems) diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index db26a38..ca6d9f2 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -15,7 +15,7 @@ from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements -from ..ir.source_components import SfgKernelParamVar +from ..lang.expressions import SfgKernelParamVar from ..lang import ( SfgVar, IFieldExtraction, diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py new file mode 100644 index 0000000..2e924a0 --- /dev/null +++ b/src/pystencilssfg/ir/syntax.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from enum import Enum, auto +from typing import ( + Generator, + Iterable, + TypeVar, + Generic, +) + +from ..lang import HeaderFile + +from .entities import ( + SfgNamespace, + SfgKernelHandle, + SfgFunction, + SfgClassMember, + SfgMemberVariable, + SfgVisibility, + SfgClass, +) + +# ========================================================================================================= +# +# SYNTACTICAL ELEMENTS +# +# These classes model *code elements*, which represent the actual syntax objects that populate the output +# files, their namespaces and class bodies. +# +# ========================================================================================================= + + +SourceEntity_T = TypeVar( + "SourceEntity_T", + bound=SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, + covariant=True, +) +"""Source entities that may have declarations and definitions.""" + + +class SfgEntityDecl(Generic[SourceEntity_T]): + """Declaration of a function, class, method, or constructor""" + + __match_args__ = ("entity",) + + def __init__(self, entity: SourceEntity_T) -> None: + self._entity = entity + + @property + def entity(self) -> SourceEntity_T: + return self._entity + + +class SfgEntityDef(Generic[SourceEntity_T]): + """Definition of a function, class, method, or constructor""" + + __match_args__ = ("entity",) + + def __init__(self, entity: SourceEntity_T) -> None: + self._entity = entity + + @property + def entity(self) -> SourceEntity_T: + return self._entity + + +SfgClassBodyElement = ( + str + | SfgEntityDecl[SfgClassMember] + | SfgEntityDef[SfgClassMember] + | SfgMemberVariable +) +"""Elements that may be placed in the visibility blocks of a class body.""" + + +class SfgVisibilityBlock: + """Visibility-qualified block inside a class definition body. + + Visibility blocks host the code elements placed inside a class body: + method and constructor declarations, + in-class method and constructor definitions, + as well as variable declarations and definitions. + + Args: + visibility: The visibility qualifier of this block + """ + + def __init__(self, visibility: SfgVisibility) -> None: + self._vis = visibility + self._elements: list[SfgClassBodyElement] = [] + self._cls: SfgClass | None = None + + @property + def visibility(self) -> SfgVisibility: + return self._vis + + @property + def elements(self) -> list[SfgClassBodyElement]: + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgClassBodyElement]): + self._elements = list(elems) + + def members(self) -> Generator[SfgClassMember, None, None]: + for elem in self._elements: + match elem: + case SfgEntityDecl(entity) | SfgEntityDef(entity): + yield entity + case SfgMemberVariable(): + yield elem + + +class SfgNamespaceBlock: + """A C++ namespace. + + Each namespace has a `name` and a `parent`; its fully qualified name is given as + ``<parent.name>::<name>``. + + Args: + name: Local name of this namespace + parent: Parent namespace enclosing this namespace + """ + + def __init__(self, namespace: SfgNamespace) -> None: + self._namespace = namespace + self._elements: list[SfgNamespaceElement] = [] + + @property + def namespace(self) -> SfgNamespace: + return self._namespace + + @property + def elements(self) -> list[SfgNamespaceElement]: + """Sequence of source elements that make up the body of this namespace""" + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgNamespaceElement]): + self._elements = list(elems) + + +class SfgClassBody: + """Body of a class definition.""" + + def __init__(self, cls: SfgClass) -> None: + self._cls = cls + self._default_block = SfgVisibilityBlock(SfgVisibility.DEFAULT) + self._blocks = [self._default_block] + + @property + def default(self) -> SfgVisibilityBlock: + return self._default_block + + def append_visibility_block(self, block: SfgVisibilityBlock): + if block.visibility == SfgVisibility.DEFAULT: + raise ValueError( + "Can't add another block with DEFAULT visibility to this class body." + ) + self._blocks.append(block) + + def visibility_blocks(self) -> tuple[SfgVisibilityBlock, ...]: + return tuple(self._blocks) + + +SfgNamespaceElement = str | SfgNamespaceBlock | SfgEntityDecl | SfgEntityDef +"""Elements that may be placed inside a namespace, including the global namespace.""" + + +class SfgSourceFileType(Enum): + HEADER = auto() + TRANSLATION_UNIT = auto() + + +class SfgSourceFile: + """A C++ source file. + + Args: + name: Name of the file (without parent directories), e.g. ``Algorithms.cpp`` + file_type: Type of the source file (header or translation unit) + prelude: Optionally, text of the prelude comment printed at the top of the file + """ + + def __init__( + self, name: str, file_type: SfgSourceFileType, prelude: str | None = None + ) -> None: + self._name: str = name + self._file_type: SfgSourceFileType = file_type + self._prelude: str | None = prelude + self._includes: list[HeaderFile] = [] + self._elements: list[SfgNamespaceElement] = [] + + @property + def name(self) -> str: + """Name of this source file""" + return self._name + + @property + def file_type(self) -> SfgSourceFileType: + """File type of this source file""" + return self._file_type + + @property + def prelude(self) -> str | None: + """Text of the prelude comment""" + return self._prelude + + @prelude.setter + def prelude(self, text: str | None): + self._prelude = text + + @property + def includes(self) -> list[HeaderFile]: + """Sequence of header files to be included at the top of this file""" + return self._includes + + @includes.setter + def includes(self, incl: Iterable[HeaderFile]): + self._includes = list(incl) + + @property + def elements(self) -> list[SfgNamespaceElement]: + """Sequence of source elements comprising the body of this file""" + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgNamespaceElement]): + self._elements = list(elems) -- GitLab From 87d35024dc5bc875f76f297da3a1a11c85656051 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 09:44:59 +0100 Subject: [PATCH 07/18] start writing printer --- src/pystencilssfg/emission/file_printer.py | 57 ++++++++++++++++++++++ src/pystencilssfg/ir/syntax.py | 2 + 2 files changed, 59 insertions(+) create mode 100644 src/pystencilssfg/emission/file_printer.py diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py new file mode 100644 index 0000000..2dc1e7b --- /dev/null +++ b/src/pystencilssfg/emission/file_printer.py @@ -0,0 +1,57 @@ +from __future__ import annotations +from textwrap import indent + +from ..ir import ( + SfgSourceFile, + SfgSourceFileType, + SfgNamespaceBlock, + SfgEntityDecl, + SfgEntityDef, +) +from ..ir.syntax import SfgNamespaceElement +from ..config import CodeStyle + + +class SfgFilePrinter: + def __init__(self, code_style: CodeStyle) -> None: + self._code_style = code_style + + def __call__(self, file: SfgSourceFile) -> str: + code = "" + + if file.file_type == SfgSourceFileType.HEADER: + code += "#pragma once\n" + + if file.prelude: + comment = "/**\n" + comment += indent(file.prelude, " * ") + comment += "\n */\n\n" + + code += comment + + for header in file.includes: + incl = str(header) if header.system_header else f'"{str(header)}"' + code += f"#include {incl}\n" + + if file.includes: + code += "\n" + + # Here begins the actual code + code += "\n\n".join(self.visit(elem) for elem in file.elements) + code += "\n" + return code + + def visit(self, elem: SfgNamespaceElement) -> str: + match elem: + case str(): + return elem + case SfgNamespaceBlock(name, elements): + code = f"namespace {name} {{\n" + code += self._code_style.indent( + "\n\n".join(self.visit(e) for e in elements) + ) + code += f"\n}} // namespace {name}" + case SfgEntityDecl(entity): + code += self.visit_decl(entity) + case SfgEntityDef(entity): + code += self.visit_defin(entity) diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py index 2e924a0..42b705b 100644 --- a/src/pystencilssfg/ir/syntax.py +++ b/src/pystencilssfg/ir/syntax.py @@ -122,6 +122,8 @@ class SfgNamespaceBlock: parent: Parent namespace enclosing this namespace """ + __match_args__ = ("name", "elements",) + def __init__(self, namespace: SfgNamespace) -> None: self._namespace = namespace self._elements: list[SfgNamespaceElement] = [] -- GitLab From c9683d173b1f53a2af0892ccbb53c5ae8da35a8f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 12:08:15 +0100 Subject: [PATCH 08/18] Implement new code printer; remove old emitters and printers. Update generation driver. --- src/pystencilssfg/cli.py | 29 +- src/pystencilssfg/composer/class_composer.py | 19 +- src/pystencilssfg/config.py | 10 +- src/pystencilssfg/context.py | 2 +- src/pystencilssfg/emission/__init__.py | 7 +- src/pystencilssfg/emission/emitter.py | 84 ++---- src/pystencilssfg/emission/file_printer.py | 144 +++++++++- .../emission/header_impl_pair.py | 58 ---- src/pystencilssfg/emission/header_only.py | 38 --- src/pystencilssfg/emission/printers.py | 265 ------------------ src/pystencilssfg/extensions/sycl.py | 27 +- src/pystencilssfg/generator.py | 90 +++--- src/pystencilssfg/ir/analysis.py | 13 +- src/pystencilssfg/ir/call_tree.py | 42 +-- src/pystencilssfg/ir/entities.py | 32 ++- src/pystencilssfg/ir/postprocessing.py | 8 +- src/pystencilssfg/ir/syntax.py | 35 ++- src/pystencilssfg/visitors/__init__.py | 5 - src/pystencilssfg/visitors/dispatcher.py | 69 ----- 19 files changed, 335 insertions(+), 642 deletions(-) delete mode 100644 src/pystencilssfg/emission/header_impl_pair.py delete mode 100644 src/pystencilssfg/emission/header_only.py delete mode 100644 src/pystencilssfg/emission/printers.py delete mode 100644 src/pystencilssfg/visitors/__init__.py delete mode 100644 src/pystencilssfg/visitors/dispatcher.py diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py index 3b321c2..8cd9f4f 100644 --- a/src/pystencilssfg/cli.py +++ b/src/pystencilssfg/cli.py @@ -1,11 +1,12 @@ import sys import os from os import path +from pathlib import Path +from typing import NoReturn from argparse import ArgumentParser, BooleanOptionalAction from .config import CommandLineParameters, SfgConfigException, OutputMode -from .emission import OutputSpec def add_newline_arg(parser): @@ -17,7 +18,7 @@ def add_newline_arg(parser): ) -def cli_main(program="sfg-cli"): +def cli_main(program="sfg-cli") -> NoReturn: parser = ArgumentParser( program, description="pystencilssfg command-line utility for build system integration", @@ -65,7 +66,7 @@ def cli_main(program="sfg-cli"): exit(-1) # should never happen -def version(args): +def version(args) -> NoReturn: from . import __version__ print(__version__, end=os.linesep if args.newline else "") @@ -73,37 +74,43 @@ def version(args): exit(0) -def list_files(args): +def list_files(args) -> NoReturn: cli_params = CommandLineParameters(args) config = cli_params.get_config() _, scriptname = path.split(args.codegen_script) basename = path.splitext(scriptname)[0] - output_spec = OutputSpec.create(config, basename) - output_files = [output_spec.get_header_filepath()] + output_dir: Path = config.get_option("output_directory") + + header_ext = config.extensions.get_option("header") + output_files = [output_dir / f"{basename}.{header_ext}"] if config.output_mode != OutputMode.HEADER_ONLY: - output_files.append(output_spec.get_impl_filepath()) + impl_ext = config.extensions.get_option("impl") + output_files.append(output_dir / f"{basename}.{impl_ext}") - print(args.sep.join(output_files), end=os.linesep if args.newline else "") + print( + args.sep.join(str(of) for of in output_files), + end=os.linesep if args.newline else "", + ) exit(0) -def print_cmake_modulepath(args): +def print_cmake_modulepath(args) -> NoReturn: from .cmake import get_sfg_cmake_modulepath print(get_sfg_cmake_modulepath(), end=os.linesep if args.newline else "") exit(0) -def make_cmake_find_module(args): +def make_cmake_find_module(args) -> NoReturn: from .cmake import make_find_module make_find_module() exit(0) -def abort_with_config_exception(exception: SfgConfigException, source: str): +def abort_with_config_exception(exception: SfgConfigException, source: str) -> NoReturn: print(f"Invalid {source} configuration: {exception.args[0]}.", file=sys.stderr) exit(1) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 3ed5c0f..0a72e80 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -24,6 +24,7 @@ from ..ir import ( SfgVisibilityBlock, SfgEntityDecl, SfgEntityDef, + SfgClassBody, ) from ..exceptions import SfgException @@ -71,7 +72,7 @@ class SfgClassComposer(SfgComposerMixIn): self._args = args return self - def _resolve(self, ctx: SfgContext, cls: SfgClass): + def _resolve(self, ctx: SfgContext, cls: SfgClass) -> SfgVisibilityBlock: vis_block = SfgVisibilityBlock(self._visibility) for arg in self._args: match arg: @@ -86,7 +87,8 @@ class SfgClassComposer(SfgComposerMixIn): var = asvar(arg) member_var = SfgMemberVariable(var.name, var.dtype, cls) cls.add_member(member_var, vis_block.visibility) - vis_block.elements.append(member_var) + vis_block.elements.append(SfgEntityDef(member_var)) + return vis_block class MethodSequencer: def __init__( @@ -133,7 +135,7 @@ class SfgClassComposer(SfgComposerMixIn): def __init__(self, *params: VarLike): self._params = list(asvar(p) for p in params) - self._initializers: list[str] = [] + self._initializers: list[tuple[SfgVar | str, tuple[ExprLike, ...]]] = [] self._body: str | None = None def add_param(self, param: VarLike, at: int | None = None): @@ -152,9 +154,7 @@ class SfgClassComposer(SfgComposerMixIn): member = var if isinstance(var, str) else asvar(var) def init_sequencer(*args: ExprLike): - expr = ", ".join(str(arg) for arg in args) - initializer = f"{member}{{ {expr} }}" - self._initializers.append(initializer) + self._initializers.append((member, args)) return self return init_sequencer @@ -287,18 +287,19 @@ class SfgClassComposer(SfgComposerMixIn): argfilter, args, ) - default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls) # type: ignore + default_block = default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls) # type: ignore + vis_blocks: list[SfgVisibilityBlock] = [] for arg in dropwhile(argfilter, args): if isinstance(arg, SfgClassComposer.VisibilityBlockSequencer): - arg._resolve(self._ctx, cls) + vis_blocks.append(arg._resolve(self._ctx, cls)) else: raise SfgException( "Composer Syntax Error: " "Cannot add members with default visibility after a visibility block." ) - self._cursor.write_header(SfgEntityDef(cls)) + self._cursor.write_header(SfgClassBody(cls, default_block, vis_blocks)) return sequencer diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index aae9dab..7bbcfc6 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -8,9 +8,9 @@ from dataclasses import dataclass from enum import Enum, auto from os import path from importlib import util as iutil +from pathlib import Path - -from pystencils.codegen.config import ConfigBase, BasicOption, Category +from pystencils.codegen.config import ConfigBase, Option, BasicOption, Category class SfgConfigException(Exception): ... # noqa: E701 @@ -166,9 +166,13 @@ class SfgConfig(ConfigBase): ClangFormatOptions.binary """ - output_directory: BasicOption[str] = BasicOption(".") + output_directory: Option[Path, str | Path] = Option(Path(".")) """Directory to which the generated files should be written.""" + @output_directory.validate + def _validate_output_directory(self, pth: str | Path) -> Path: + return Path(pth) + class CommandLineParameters: @staticmethod diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index a129f98..f1a3251 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -19,7 +19,7 @@ class SfgContext: def __init__( self, header_file: SfgSourceFile, - impl_file: SfgSourceFile, + impl_file: SfgSourceFile | None, outer_namespace: str | None = None, codestyle: CodeStyle | None = None, argv: Sequence[str] | None = None, diff --git a/src/pystencilssfg/emission/__init__.py b/src/pystencilssfg/emission/__init__.py index fd66628..1a22aa2 100644 --- a/src/pystencilssfg/emission/__init__.py +++ b/src/pystencilssfg/emission/__init__.py @@ -1,5 +1,4 @@ -from .emitter import AbstractEmitter, OutputSpec -from .header_impl_pair import HeaderImplPairEmitter -from .header_only import HeaderOnlyEmitter +from .emitter import SfgCodeEmitter +from .file_printer import SfgFilePrinter -__all__ = ["AbstractEmitter", "OutputSpec", "HeaderImplPairEmitter", "HeaderOnlyEmitter"] +__all__ = ["SfgCodeEmitter", "SfgFilePrinter"] diff --git a/src/pystencilssfg/emission/emitter.py b/src/pystencilssfg/emission/emitter.py index c32b18a..440344b 100644 --- a/src/pystencilssfg/emission/emitter.py +++ b/src/pystencilssfg/emission/emitter.py @@ -1,71 +1,29 @@ from __future__ import annotations -from typing import Sequence -from abc import ABC, abstractmethod -from dataclasses import dataclass -from os import path +from pathlib import Path -from ..context import SfgContext -from ..config import SfgConfig, OutputMode +from ..config import CodeStyle, ClangFormatOptions +from ..ir import SfgSourceFile +from .file_printer import SfgFilePrinter +from .clang_format import invoke_clang_format -@dataclass -class OutputSpec: - """Name and path specification for files output by the code generator. - Filenames are constructed as `<output_directory>/<basename>.<extension>`.""" +class SfgCodeEmitter: + def __init__( + self, + output_directory: Path, + code_style: CodeStyle, + clang_format: ClangFormatOptions, + ): + self._output_dir = output_directory + self._clang_format_opts = clang_format + self._printer = SfgFilePrinter(code_style) - output_directory: str - """Directory to which the generated files should be written.""" + def emit(self, file: SfgSourceFile): + code = self._printer(file) + code = invoke_clang_format(code, self._clang_format_opts) - basename: str - """Base name for output files.""" - - header_extension: str - """File extension for generated header file.""" - - impl_extension: str - """File extension for generated implementation file.""" - - def get_header_filename(self): - return f"{self.basename}.{self.header_extension}" - - def get_impl_filename(self): - return f"{self.basename}.{self.impl_extension}" - - def get_header_filepath(self): - return path.join(self.output_directory, self.get_header_filename()) - - def get_impl_filepath(self): - return path.join(self.output_directory, self.get_impl_filename()) - - @staticmethod - def create(config: SfgConfig, basename: str) -> OutputSpec: - output_mode = config.get_option("output_mode") - header_extension = config.extensions.get_option("header") - impl_extension = config.extensions.get_option("impl") - - if impl_extension is None: - match output_mode: - case OutputMode.INLINE: - impl_extension = "ipp" - case OutputMode.STANDALONE: - impl_extension = "cpp" - - return OutputSpec( - config.get_option("output_directory"), - basename, - header_extension, - impl_extension, - ) - - -class AbstractEmitter(ABC): - @property - @abstractmethod - def output_files(self) -> Sequence[str]: - pass - - @abstractmethod - def write_files(self, ctx: SfgContext): - pass + self._output_dir.mkdir(parents=True, exist_ok=True) + fpath = self._output_dir / file.name + fpath.write_text(code) diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index 2dc1e7b..ec43468 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -1,20 +1,35 @@ from __future__ import annotations from textwrap import indent +from pystencils.backend.emission import CAstPrinter + from ..ir import ( SfgSourceFile, SfgSourceFileType, SfgNamespaceBlock, SfgEntityDecl, SfgEntityDef, + SfgKernelHandle, + SfgFunction, + SfgClassMember, + SfgMethod, + SfgMemberVariable, + SfgConstructor, + SfgClass, + SfgClassBody, + SfgVisibilityBlock, + SfgVisibility, ) -from ..ir.syntax import SfgNamespaceElement +from ..ir.syntax import SfgNamespaceElement, SfgClassBodyElement from ..config import CodeStyle class SfgFilePrinter: def __init__(self, code_style: CodeStyle) -> None: self._code_style = code_style + self._kernel_printer = CAstPrinter( + indent_width=code_style.get_option("indent_width") + ) def __call__(self, file: SfgSourceFile) -> str: code = "" @@ -41,17 +56,132 @@ class SfgFilePrinter: code += "\n" return code - def visit(self, elem: SfgNamespaceElement) -> str: + def visit( + self, elem: SfgNamespaceElement | SfgClassBodyElement, inclass: bool = False + ) -> str: match elem: case str(): return elem - case SfgNamespaceBlock(name, elements): - code = f"namespace {name} {{\n" + case SfgNamespaceBlock(namespace, elements): + code = f"namespace {namespace.name} {{\n" code += self._code_style.indent( "\n\n".join(self.visit(e) for e in elements) ) - code += f"\n}} // namespace {name}" + code += f"\n}} // namespace {namespace.name}" + return code case SfgEntityDecl(entity): - code += self.visit_decl(entity) + return self.visit_decl(entity, inclass) case SfgEntityDef(entity): - code += self.visit_defin(entity) + return self.visit_defin(entity, inclass) + case _: + assert False, "illegal code element" + + def visit_decl( + self, + declared_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, + inclass: bool = False, + ) -> str: + match declared_entity: + case SfgKernelHandle(kernel): + return self._kernel_printer.print_signature(kernel) + ";" + + case SfgFunction(name, _, params) | SfgMethod(name, _, params): + return self._func_signature(declared_entity, inclass) + ";" + + case SfgConstructor(cls, params): + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in params + ) + return f"{cls.name}({params_str});" + + case SfgMemberVariable(name, dtype): + return f"{dtype.c_string()} {name};" + + case SfgClass(kwd, name): + return f"{str(kwd)} {name};" + + case _: + assert False, f"unsupported declared entity: {declared_entity}" + + def visit_defin( + self, + defined_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClassBody, + inclass: bool = False, + ) -> str: + match defined_entity: + case SfgKernelHandle(kernel): + return self._kernel_printer(kernel) + + case SfgFunction(name, tree, params) | SfgMethod(name, tree, params): + sig = self._func_signature(defined_entity, inclass) + body = tree.get_code(self._code_style) + body = "\n{\n" + self._code_style.indent(body) + "\n}" + return sig + body + + case SfgConstructor(cls, params): + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in params + ) + + code = "" + if not inclass: + code += f"{cls.name}::" + code += f"{cls.name} ({params_str})" + + inits: list[str] = [] + for var, args in defined_entity.initializers: + args_str = ", ".join(str(arg) for arg in args) + inits.append(f"{str(var)}({args_str})") + + if inits: + code += "\n:" + ",\n".join(inits) + + code += "\n{\n" + self._code_style.indent(defined_entity.body) + "\n}" + return code + + case SfgMemberVariable(name, dtype): + code = dtype.c_string() + if not inclass: + code += f" {defined_entity.owning_class.name}::" + code += f" {name}" + if defined_entity.default_init is not None: + args_str = ", ".join(str(expr) for expr in defined_entity.default_init) + code += "{" + args_str + "}" + code += ";" + return code + + case SfgClassBody(cls, vblocks): + code = f"{cls.class_keyword} {cls.name} {{\n" + vblocks_str = [self._visibility_block(b) for b in vblocks] + code += "\n\n".join(vblocks_str) + code += "\n}\n" + return code + + case _: + assert False, f"unsupported defined entity: {defined_entity}" + + def _visibility_block(self, vblock: SfgVisibilityBlock): + prefix = ( + f"{vblock.visibility}:\n" + if vblock.visibility != SfgVisibility.DEFAULT + else "" + ) + elements = [self.visit(elem, inclass=True) for elem in vblock.elements] + return prefix + self._code_style.indent("\n".join(elements)) + + def _func_signature(self, func: SfgFunction | SfgMethod, inclass: bool): + code = "" + if func.inline: + code += "inline " + code += func.return_type.c_string() + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in func.parameters + ) + if isinstance(func, SfgMethod) and not inclass: + code += f"{func.owning_class.name}::" + code += f"{func.name}({params_str})" + + if isinstance(func, SfgMethod) and func.const: + code += " const" + + return code diff --git a/src/pystencilssfg/emission/header_impl_pair.py b/src/pystencilssfg/emission/header_impl_pair.py deleted file mode 100644 index 87ff5f5..0000000 --- a/src/pystencilssfg/emission/header_impl_pair.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Sequence -from os import path, makedirs - -from ..context import SfgContext -from .printers import SfgHeaderPrinter, SfgImplPrinter -from .clang_format import invoke_clang_format -from ..config import ClangFormatOptions - -from .emitter import AbstractEmitter, OutputSpec - - -class HeaderImplPairEmitter(AbstractEmitter): - """Emits a header-implementation file pair.""" - - def __init__( - self, - output_spec: OutputSpec, - inline_impl: bool = False, - clang_format: ClangFormatOptions | None = None, - ): - """Create a `HeaderImplPairEmitter` from an [SfgOutputSpec][pystencilssfg.configuration.SfgOutputSpec].""" - self._basename = output_spec.basename - self._output_directory = output_spec.output_directory - self._header_filename = output_spec.get_header_filename() - self._impl_filename = output_spec.get_impl_filename() - self._inline_impl = inline_impl - - self._ospec = output_spec - self._clang_format = clang_format - - @property - def output_files(self) -> Sequence[str]: - """The files that will be written by `write_files`.""" - return ( - path.join(self._output_directory, self._header_filename), - path.join(self._output_directory, self._impl_filename), - ) - - def write_files(self, ctx: SfgContext): - """Write the code represented by the given [SfgContext][pystencilssfg.SfgContext] to the files - specified by the output specification.""" - header_printer = SfgHeaderPrinter(ctx, self._ospec, self._inline_impl) - impl_printer = SfgImplPrinter(ctx, self._ospec, self._inline_impl) - - header = header_printer.get_code() - impl = impl_printer.get_code() - - if self._clang_format is not None: - header = invoke_clang_format(header, self._clang_format) - impl = invoke_clang_format(impl, self._clang_format) - - makedirs(self._output_directory, exist_ok=True) - - with open(self._ospec.get_header_filepath(), "w") as headerfile: - headerfile.write(header) - - with open(self._ospec.get_impl_filepath(), "w") as cppfile: - cppfile.write(impl) diff --git a/src/pystencilssfg/emission/header_only.py b/src/pystencilssfg/emission/header_only.py deleted file mode 100644 index 7d026da..0000000 --- a/src/pystencilssfg/emission/header_only.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Sequence -from os import path, makedirs - -from ..context import SfgContext -from .printers import SfgHeaderPrinter -from ..config import ClangFormatOptions -from .clang_format import invoke_clang_format - -from .emitter import AbstractEmitter, OutputSpec - - -class HeaderOnlyEmitter(AbstractEmitter): - def __init__( - self, output_spec: OutputSpec, clang_format: ClangFormatOptions | None = None - ): - """Create a `HeaderImplPairEmitter` from an [SfgOutputSpec][pystencilssfg.configuration.SfgOutputSpec].""" - self._basename = output_spec.basename - self._output_directory = output_spec.output_directory - self._header_filename = output_spec.get_header_filename() - - self._ospec = output_spec - self._clang_format = clang_format - - @property - def output_files(self) -> Sequence[str]: - """The files that will be written by `write_files`.""" - return (path.join(self._output_directory, self._header_filename),) - - def write_files(self, ctx: SfgContext): - header_printer = SfgHeaderPrinter(ctx, self._ospec) - header = header_printer.get_code() - if self._clang_format is not None: - header = invoke_clang_format(header, self._clang_format) - - makedirs(self._output_directory, exist_ok=True) - - with open(self._ospec.get_header_filepath(), "w") as headerfile: - headerfile.write(header) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py deleted file mode 100644 index 9d7c97e..0000000 --- a/src/pystencilssfg/emission/printers.py +++ /dev/null @@ -1,265 +0,0 @@ -from __future__ import annotations - -from textwrap import indent -from itertools import chain, repeat, cycle - -from pystencils.codegen import Kernel -from pystencils.backend.emission import emit_code - -from ..context import SfgContext -from ..visitors import visitor -from ..exceptions import SfgException - -from ..ir.source_components import ( - SfgEmptyLines, - SfgHeaderInclude, - SfgKernelNamespace, - SfgFunction, - SfgClass, - SfgInClassDefinition, - SfgConstructor, - SfgMemberVariable, - SfgMethod, - SfgVisibility, - SfgVisibilityBlock, -) - -from .emitter import OutputSpec - - -def interleave(*iters): - try: - for iter in cycle(iters): - yield next(iter) - except StopIteration: - pass - - -class SfgGeneralPrinter: - @visitor - def visit(self, obj: object) -> str: - raise SfgException(f"Can't print object of type {type(obj)}") - - @visit.case(SfgEmptyLines) - def emptylines(self, el: SfgEmptyLines) -> str: - return "\n" * el.lines - - @visit.case(str) - def string(self, s: str) -> str: - return s - - @visit.case(SfgHeaderInclude) - def include(self, incl: SfgHeaderInclude) -> str: - if incl.system_header: - return f"#include <{incl.file}>" - else: - return f'#include "{incl.file}"' - - def prelude(self, ctx: SfgContext) -> str: - if ctx.prelude_comment: - return ( - "/*\n" - + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) - + "*/\n" - ) - else: - return "" - - def param_list(self, func: SfgFunction) -> str: - params = sorted(list(func.parameters), key=lambda p: p.name) - return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) - - -class SfgHeaderPrinter(SfgGeneralPrinter): - def __init__( - self, ctx: SfgContext, output_spec: OutputSpec, inline_impl: bool = False - ): - self._output_spec = output_spec - self._ctx = ctx - self._inline_impl = inline_impl - - def get_code(self) -> str: - return self.visit(self._ctx) - - @visitor - def visit(self, obj: object) -> str: - return super().visit(obj) - - @visit.case(SfgContext) - def frame(self, ctx: SfgContext) -> str: - code = super().prelude(ctx) - - code += "\n#pragma once\n\n" - - includes = filter(lambda incl: not incl.private, ctx.includes()) - code += "\n".join(self.visit(incl) for incl in includes) - code += "\n\n" - - fq_namespace = ctx.fully_qualified_namespace - if fq_namespace is not None: - code += f"namespace {fq_namespace} {{\n\n" - - parts = interleave(ctx.declarations_ordered(), repeat(SfgEmptyLines(1))) - - code += "\n".join(self.visit(p) for p in parts) - - if fq_namespace is not None: - code += f"}} // namespace {fq_namespace}\n" - - if self._inline_impl: - code += f'#include "{self._output_spec.get_impl_filename()}"\n' - - return code - - @visit.case(SfgFunction) - def function(self, func: SfgFunction): - params = sorted(list(func.parameters), key=lambda p: p.name) - param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) - return f"{func.return_type.c_string()} {func.name} ( {param_list} );" - - @visit.case(SfgClass) - def sfg_class(self, cls: SfgClass): - code = f"{cls.class_keyword} {cls.class_name} \n" - - if cls.base_classes: - code += f" : {','.join(cls.base_classes)}\n" - - code += "{\n" - - for block in cls.visibility_blocks(): - code += self.visit(block) + "\n" - - code += "};\n" - - return code - - @visit.case(SfgVisibilityBlock) - def vis_block(self, block: SfgVisibilityBlock) -> str: - code = "" - if block.visibility != SfgVisibility.DEFAULT: - code += f"{block.visibility}:\n" - code += self._ctx.codestyle.indent( - "\n".join(self.visit(m) for m in block.members()) - ) - return code - - @visit.case(SfgInClassDefinition) - def sfg_inclassdef(self, definition: SfgInClassDefinition): - return definition.text - - @visit.case(SfgConstructor) - def sfg_constructor(self, constr: SfgConstructor): - code = f"{constr.owning_class.class_name} (" - code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters) - code += ")\n" - if constr.initializers: - code += " : " + ", ".join(constr.initializers) + "\n" - if constr.body: - code += "{\n" + self._ctx.codestyle.indent(constr.body) + "\n}\n" - else: - code += "{ }\n" - return code - - @visit.case(SfgMemberVariable) - def sfg_member_var(self, var: SfgMemberVariable): - return f"{var.dtype.c_string()} {var.name};" - - @visit.case(SfgMethod) - def sfg_method(self, method: SfgMethod): - code = f"{method.return_type.c_string()} {method.name} ({self.param_list(method)})" - code += "const" if method.const else "" - if method.inline: - code += ( - " {\n" - + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) - + "}\n" - ) - else: - code += ";" - return code - - -class SfgImplPrinter(SfgGeneralPrinter): - def __init__( - self, ctx: SfgContext, output_spec: OutputSpec, inline_impl: bool = False - ): - self._output_spec = output_spec - self._ctx = ctx - self._inline_impl = inline_impl - - def get_code(self) -> str: - return self.visit(self._ctx) - - @visitor - def visit(self, obj: object) -> str: - return super().visit(obj) - - @visit.case(SfgContext) - def frame(self, ctx: SfgContext) -> str: - code = super().prelude(ctx) - - if not self._inline_impl: - code += f'\n#include "{self._output_spec.get_header_filename()}"\n\n' - - includes = filter(lambda incl: incl.private, ctx.includes()) - code += "\n".join(self.visit(incl) for incl in includes) - - code += "\n\n#define FUNC_PREFIX inline\n\n" - - fq_namespace = ctx.fully_qualified_namespace - if fq_namespace is not None: - code += f"namespace {fq_namespace} {{\n\n" - - parts = interleave( - chain( - ctx.kernel_namespaces(), - ctx.functions(), - ctx.classes(), - ), - repeat(SfgEmptyLines(1)), - ) - - code += "\n".join(self.visit(p) for p in parts) - - if fq_namespace is not None: - code += f"}} // namespace {fq_namespace}\n" - - return code - - @visit.case(SfgKernelNamespace) - def kernel_namespace(self, kns: SfgKernelNamespace) -> str: - code = f"namespace {kns.name} {{\n\n" - code += "\n\n".join(self.visit(ast) for ast in kns.kernel_functions) - code += f"\n}} // namespace {kns.name}\n" - return code - - @visit.case(Kernel) - def kernel(self, kfunc: Kernel) -> str: - return emit_code(kfunc) - - @visit.case(SfgFunction) - def function(self, func: SfgFunction) -> str: - inline_prefix = "inline " if self._inline_impl else "" - code = ( - f"{inline_prefix} {func.return_type.c_string()} {func.name} ({self.param_list(func)})" - ) - code += ( - "{\n" + self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + "}\n" - ) - return code - - @visit.case(SfgClass) - def sfg_class(self, cls: SfgClass) -> str: - methods = filter(lambda m: not m.inline, cls.methods()) - return "\n".join(self.visit(m) for m in methods) - - @visit.case(SfgMethod) - def sfg_method(self, method: SfgMethod) -> str: - inline_prefix = "inline " if self._inline_impl else "" - const_qual = "const" if method.const else "" - code = f"{inline_prefix}{method.return_type} {method.owning_class.class_name}::{method.name}" - code += f"({self.param_list(method)}) {const_qual}" - code += ( - " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" - ) - return code diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index a628f00..48f9c08 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -8,6 +8,7 @@ from pystencils import Target from pystencilssfg.composer.basic_composer import SequencerArg +from ..config import CodeStyle from ..exceptions import SfgException from ..context import SfgContext from ..composer import ( @@ -93,11 +94,11 @@ class SyclHandler(AugExpr): if isinstance(range, _VarLike): range = asvar(range) - def check_kernel(kernel: SfgKernelHandle): - kfunc = kernel.get_kernel_function() + def check_kernel(khandle: SfgKernelHandle): + kfunc = khandle.kernel if kfunc.target != Target.SYCL: raise SfgException( - f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}" + f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}" ) id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") @@ -144,7 +145,7 @@ class SyclGroup(AugExpr): self._ctx = ctx def parallel_for_work_item( - self, range: VarLike | Sequence[int], kernel: SfgKernelHandle + self, range: VarLike | Sequence[int], khandle: SfgKernelHandle ): """Generate a ``parallel_for_work_item` kernel invocation on this group.` @@ -155,10 +156,10 @@ class SyclGroup(AugExpr): if isinstance(range, _VarLike): range = asvar(range) - kfunc = kernel.get_kernel_function() + kfunc = khandle.kernel if kfunc.target != Target.SYCL: raise SfgException( - f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}" + f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}" ) id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") @@ -169,13 +170,13 @@ class SyclGroup(AugExpr): and id_regex.search(param.dtype.c_string()) is not None ) - id_param = list(filter(filter_id, kernel.scalar_parameters))[0] + id_param = list(filter(filter_id, khandle.scalar_parameters))[0] h_item = SfgVar("item", PsCustomType("sycl::h_item< 3 >")) comp = SfgComposer(self._ctx) tree = comp.seq( comp.set_param(id_param, AugExpr.format("{}.get_local_id()", h_item)), - SfgKernelCallNode(kernel), + SfgKernelCallNode(khandle), ) kernel_lambda = SfgLambda(("=",), (h_item,), tree, None) @@ -229,11 +230,11 @@ class SfgLambda: def required_parameters(self) -> set[SfgVar]: return self._required_params - def get_code(self, ctx: SfgContext): + def get_code(self, cstyle: CodeStyle): captures = ", ".join(self._captures) params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params) - body = self._tree.get_code(ctx) - body = ctx.codestyle.indent(body) + body = self._tree.get_code(cstyle) + body = cstyle.indent(body) rtype = ( f"-> {self._return_type.c_string()} " if self._return_type is not None @@ -300,13 +301,13 @@ class SyclKernelInvoke(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return self._required_params - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: if isinstance(self._range, SfgVar): range_code = self._range.name else: range_code = "{ " + ", ".join(str(r) for r in self._range) + " }" - kernel_code = self._lambda.get_code(ctx) + kernel_code = self._lambda.get_code(cstyle) invoker = str(self._invoker) method = self._invoke_type.method diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index eed06b0..b055ad5 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -1,10 +1,9 @@ -import os -from os import path +from pathlib import Path from .config import SfgConfig, CommandLineParameters, OutputMode, GLOBAL_NAMESPACE from .context import SfgContext from .composer import SfgComposer -from .emission import AbstractEmitter, OutputSpec +from .emission import SfgCodeEmitter from .exceptions import SfgException @@ -41,9 +40,9 @@ class SourceFileGenerator: "without a valid entry point, such as a REPL or a multiprocessing fork." ) - scriptpath = __main__.__file__ - scriptname = path.split(scriptpath)[1] - basename = path.splitext(scriptname)[0] + scriptpath = Path(__main__.__file__) + scriptname = scriptpath.name + basename = scriptname.rsplit(".")[0] from argparse import ArgumentParser @@ -67,47 +66,52 @@ class SourceFileGenerator: cli_params.find_conflicts(sfg_config) config.override(sfg_config) - self._context = SfgContext( - None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace, # type: ignore - config.codestyle, - argv=script_args, - project_info=cli_params.get_project_info(), - ) + self._output_mode: OutputMode = config.get_option("output_mode") + self._output_dir: Path = config.get_option("output_directory") + self._header_ext: str = config.extensions.get_option("header") + self._impl_ext: str = config.extensions.get_option("impl") - from .lang import HeaderFile - from .ir import SfgHeaderInclude + from .ir import SfgSourceFile, SfgSourceFileType - self._context.add_include(SfgHeaderInclude(HeaderFile("cstdint", system_header=True))) - self._context.add_definition("#define RESTRICT __restrict__") - - output_mode = config.get_option("output_mode") - output_spec = OutputSpec.create(config, basename) + self._header_file = SfgSourceFile( + f"{basename}.{self._header_ext}", SfgSourceFileType.HEADER + ) + self._impl_file: SfgSourceFile | None - self._emitter: AbstractEmitter - match output_mode: + match self._output_mode: case OutputMode.HEADER_ONLY: - from .emission import HeaderOnlyEmitter - - self._emitter = HeaderOnlyEmitter( - output_spec, clang_format=config.clang_format + self._impl_file = None + case OutputMode.STANDALONE: + self._impl_file = SfgSourceFile( + f"{basename}.{self._impl_ext}", SfgSourceFileType.TRANSLATION_UNIT ) case OutputMode.INLINE: - from .emission import HeaderImplPairEmitter - - self._emitter = HeaderImplPairEmitter( - output_spec, inline_impl=True, clang_format=config.clang_format + self._impl_file = SfgSourceFile( + f"{basename}.{self._impl_ext}", SfgSourceFileType.HEADER ) - case OutputMode.STANDALONE: - from .emission import HeaderImplPairEmitter - self._emitter = HeaderImplPairEmitter( - output_spec, clang_format=config.clang_format - ) + self._context = SfgContext( + self._header_file, + self._impl_file, + None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace, # type: ignore + config.codestyle, + argv=script_args, + project_info=cli_params.get_project_info(), + ) + + self._emitter = SfgCodeEmitter( + self._output_dir, config.codestyle, config.clang_format + ) def clean_files(self): - for file in self._emitter.output_files: - if path.exists(file): - os.remove(file) + header_path = self._output_dir / self._header_file.name + if header_path.exists(): + header_path.unlink() + + if self._impl_file is not None: + impl_path = self._output_dir / self._impl_file.name + if impl_path.exists(): + impl_path.unlink() def __enter__(self) -> SfgComposer: self.clean_files() @@ -116,8 +120,12 @@ class SourceFileGenerator: def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: # Collect header files for inclusion - from .ir import SfgHeaderInclude, collect_includes - for header in collect_includes(self._context): - self._context.add_include(SfgHeaderInclude(header)) + # from .ir import collect_includes + + # TODO: Collect headers + # for header in collect_includes(self._context): + # self._context.add_include(SfgHeaderInclude(header)) - self._emitter.write_files(self._context) + self._emitter.emit(self._header_file) + if self._impl_file is not None: + self._emitter.emit(self._impl_file) diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index c2c3e34..b88c4f3 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -8,7 +8,6 @@ from ..lang import HeaderFile, includes def collect_includes(obj: Any) -> set[HeaderFile]: - from ..context import SfgContext from .call_tree import SfgCallTreeNode from .entities import ( SfgFunction, @@ -18,15 +17,7 @@ def collect_includes(obj: Any) -> set[HeaderFile]: ) match obj: - case SfgContext(): - headers = set() - for func in obj.functions(): - headers |= collect_includes(func) - - for cls in obj.classes(): - headers |= collect_includes(cls) - - return headers + # TODO case SfgCallTreeNode(): return reduce( @@ -48,7 +39,7 @@ def collect_includes(obj: Any) -> set[HeaderFile]: set(), ) - case SfgConstructor(parameters): + case SfgConstructor(_, parameters): param_headers = reduce( set.union, (includes(p) for p in parameters), set() ) diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index 2e057dd..4cee2f5 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -7,7 +7,7 @@ from .entities import SfgKernelHandle from ..lang import SfgVar, HeaderFile if TYPE_CHECKING: - from ..context import SfgContext + from ..config import CodeStyle class SfgCallTreeNode(ABC): @@ -35,7 +35,7 @@ class SfgCallTreeNode(ABC): """This node's children""" @abstractmethod - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: """Returns the code of this node. By convention, the code block emitted by this function should not contain a trailing newline. @@ -75,7 +75,7 @@ class SfgEmptyNode(SfgCallTreeLeaf): def __init__(self): super().__init__() - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: return "" @@ -122,7 +122,7 @@ class SfgStatements(SfgCallTreeLeaf): def code_string(self) -> str: return self._code_string - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: return self._code_string @@ -167,8 +167,8 @@ class SfgSequence(SfgCallTreeNode): def __setitem__(self, idx: int, c: SfgCallTreeNode): self._children[idx] = c - def get_code(self, ctx: SfgContext) -> str: - return "\n".join(c.get_code(ctx) for c in self._children) + def get_code(self, cstyle: CodeStyle) -> str: + return "\n".join(c.get_code(cstyle) for c in self._children) class SfgBlock(SfgCallTreeNode): @@ -184,8 +184,8 @@ class SfgBlock(SfgCallTreeNode): def children(self) -> Sequence[SfgCallTreeNode]: return (self._seq,) - def get_code(self, ctx: SfgContext) -> str: - seq_code = ctx.codestyle.indent(self._seq.get_code(ctx)) + def get_code(self, cstyle: CodeStyle) -> str: + seq_code = cstyle.indent(self._seq.get_code(cstyle)) return "{\n" + seq_code + "\n}" @@ -208,7 +208,7 @@ class SfgKernelCallNode(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return set(self._kernel_handle.parameters) - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: ast_params = self._kernel_handle.parameters fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in ast_params]) @@ -228,8 +228,8 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): from pystencils import Target from pystencils.codegen import GpuKernel - func = kernel_handle.get_kernel() - if not (isinstance(func, GpuKernel) and func.target == Target.CUDA): + kernel = kernel_handle.kernel + if not (isinstance(kernel, GpuKernel) and kernel.target == Target.CUDA): raise ValueError( "An `SfgCudaKernelInvocation` node can only call a CUDA kernel." ) @@ -245,7 +245,7 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return set(self._kernel_handle.parameters) | self._depends - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: ast_params = self._kernel_handle.parameters fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in ast_params]) @@ -289,14 +289,14 @@ class SfgBranch(SfgCallTreeNode): self._branch_true, ) + ((self.branch_false,) if self.branch_false is not None else ()) - def get_code(self, ctx: SfgContext) -> str: - code = f"if({self.condition.get_code(ctx)}) {{\n" - code += ctx.codestyle.indent(self.branch_true.get_code(ctx)) + def get_code(self, cstyle: CodeStyle) -> str: + code = f"if({self.condition.get_code(cstyle)}) {{\n" + code += cstyle.indent(self.branch_true.get_code(cstyle)) code += "\n}" if self.branch_false is not None: code += "else {\n" - code += ctx.codestyle.indent(self.branch_false.get_code(ctx)) + code += cstyle.indent(self.branch_false.get_code(cstyle)) code += "\n}" return code @@ -327,13 +327,13 @@ class SfgSwitchCase(SfgCallTreeNode): def is_default(self) -> bool: return self._label == SfgSwitchCase.Default - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: code = "" if self._label == SfgSwitchCase.Default: code += "default: {\n" else: code += f"case {self._label}: {{\n" - code += ctx.codestyle.indent(self.body.get_code(ctx)) + code += cstyle.indent(self.body.get_code(cstyle)) code += "\n}" return code @@ -403,8 +403,8 @@ class SfgSwitch(SfgCallTreeNode): else: self._children[idx] = c - def get_code(self, ctx: SfgContext) -> str: - code = f"switch({self._switch_arg.get_code(ctx)}) {{\n" - code += "\n".join(c.get_code(ctx) for c in self._cases) + def get_code(self, cstyle: CodeStyle) -> str: + code = f"switch({self._switch_arg.get_code(cstyle)}) {{\n" + code += "\n".join(c.get_code(cstyle) for c in self._cases) code += "}" return code diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index 9d76b39..2b0e3e5 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -14,7 +14,7 @@ from pystencils import CreateKernelConfig, create_kernel, Field from pystencils.codegen import Kernel from pystencils.types import PsType, PsCustomType -from ..lang import SfgVar, SfgKernelParamVar, void +from ..lang import SfgVar, SfgKernelParamVar, void, ExprLike from ..exceptions import SfgException if TYPE_CHECKING: @@ -100,6 +100,8 @@ class SfgGlobalNamespace(SfgNamespace): class SfgKernelHandle(SfgCodeEntity): """Handle to a pystencils kernel.""" + __match_args__ = ("kernel",) + def __init__(self, name: str, namespace: SfgKernelNamespace, kernel: Kernel): super().__init__(name, namespace) @@ -130,7 +132,8 @@ class SfgKernelHandle(SfgCodeEntity): """Fields accessed by this kernel""" return self._fields - def get_kernel(self) -> Kernel: + @property + def kernel(self) -> Kernel: """Underlying pystencils kernel object""" return self._kernel @@ -315,9 +318,20 @@ class SfgClassMember(ABC): class SfgMemberVariable(SfgVar, SfgClassMember): """Variable that is a field of a class""" - def __init__(self, name: str, dtype: PsType, cls: SfgClass): + def __init__( + self, + name: str, + dtype: PsType, + cls: SfgClass, + default_init: tuple[ExprLike, ...] | None = None, + ): SfgVar.__init__(self, name, dtype) SfgClassMember.__init__(self, cls) + self._default_init = default_init + + @property + def default_init(self) -> tuple[ExprLike, ...] | None: + return self._default_init class SfgMethod(SfgClassMember): @@ -349,6 +363,10 @@ class SfgMethod(SfgClassMember): param_collector = CallTreePostProcessing() self._parameters = param_collector(self._tree).function_params + @property + def name(self) -> str: + return self._name + @property def parameters(self) -> set[SfgVar]: return self._parameters @@ -373,13 +391,13 @@ class SfgMethod(SfgClassMember): class SfgConstructor(SfgClassMember): """Constructor of a class""" - __match_args__ = ("parameters", "initializers", "body") + __match_args__ = ("owning_class", "parameters", "initializers", "body") def __init__( self, cls: SfgClass, parameters: Sequence[SfgVar] = (), - initializers: Sequence[str] = (), + initializers: Sequence[tuple[SfgVar | str, tuple[ExprLike, ...]]] = (), body: str = "", ): super().__init__(cls) @@ -392,7 +410,7 @@ class SfgConstructor(SfgClassMember): return self._parameters @property - def initializers(self) -> tuple[str, ...]: + def initializers(self) -> tuple[tuple[SfgVar | str, tuple[ExprLike, ...]], ...]: return self._initializers @property @@ -403,7 +421,7 @@ class SfgConstructor(SfgClassMember): class SfgClass(SfgCodeEntity): """A C++ class.""" - __match_args__ = ("class_name",) + __match_args__ = ("class_keyword", "name") def __init__( self, diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index ca6d9f2..5563783 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, Iterable +from typing import Sequence, Iterable import warnings from functools import reduce from dataclasses import dataclass @@ -13,6 +13,7 @@ from pystencils.types import deconstify, PsType from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException +from ..config import CodeStyle from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements from ..lang.expressions import SfgKernelParamVar @@ -27,9 +28,6 @@ from ..lang import ( includes, ) -if TYPE_CHECKING: - from ..context import SfgContext - class FlattenSequences: """Flattens any nested sequences occuring in a kernel call tree.""" @@ -198,7 +196,7 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: pass - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: raise SfgException( "Invalid access into deferred node; deferred nodes must be expanded first." ) diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py index 42b705b..574c029 100644 --- a/src/pystencilssfg/ir/syntax.py +++ b/src/pystencilssfg/ir/syntax.py @@ -64,12 +64,7 @@ class SfgEntityDef(Generic[SourceEntity_T]): return self._entity -SfgClassBodyElement = ( - str - | SfgEntityDecl[SfgClassMember] - | SfgEntityDef[SfgClassMember] - | SfgMemberVariable -) +SfgClassBodyElement = str | SfgEntityDecl[SfgClassMember] | SfgEntityDef[SfgClassMember] """Elements that may be placed in the visibility blocks of a class body.""" @@ -122,7 +117,10 @@ class SfgNamespaceBlock: parent: Parent namespace enclosing this namespace """ - __match_args__ = ("name", "elements",) + __match_args__ = ( + "namespace", + "elements", + ) def __init__(self, namespace: SfgNamespace) -> None: self._namespace = namespace @@ -145,10 +143,22 @@ class SfgNamespaceBlock: class SfgClassBody: """Body of a class definition.""" - def __init__(self, cls: SfgClass) -> None: + __match_args__ = ("associated_class", "visibility_blocks") + + def __init__( + self, + cls: SfgClass, + default_block: SfgVisibilityBlock, + vis_blocks: Iterable[SfgVisibilityBlock], + ) -> None: self._cls = cls - self._default_block = SfgVisibilityBlock(SfgVisibility.DEFAULT) - self._blocks = [self._default_block] + assert default_block.visibility == SfgVisibility.DEFAULT + self._default_block = default_block + self._blocks = [self._default_block] + list(vis_blocks) + + @property + def associated_class(self) -> SfgClass: + return self._cls @property def default(self) -> SfgVisibilityBlock: @@ -161,11 +171,14 @@ class SfgClassBody: ) self._blocks.append(block) + @property def visibility_blocks(self) -> tuple[SfgVisibilityBlock, ...]: return tuple(self._blocks) -SfgNamespaceElement = str | SfgNamespaceBlock | SfgEntityDecl | SfgEntityDef +SfgNamespaceElement = ( + str | SfgNamespaceBlock | SfgClassBody | SfgEntityDecl | SfgEntityDef +) """Elements that may be placed inside a namespace, including the global namespace.""" diff --git a/src/pystencilssfg/visitors/__init__.py b/src/pystencilssfg/visitors/__init__.py deleted file mode 100644 index fc7af1b..0000000 --- a/src/pystencilssfg/visitors/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .dispatcher import visitor - -__all__ = [ - "visitor", -] diff --git a/src/pystencilssfg/visitors/dispatcher.py b/src/pystencilssfg/visitors/dispatcher.py deleted file mode 100644 index 85a0f08..0000000 --- a/src/pystencilssfg/visitors/dispatcher.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations -from typing import Callable, TypeVar, Generic -from types import MethodType - -from functools import wraps - -V = TypeVar("V") -R = TypeVar("R") - - -class VisitorDispatcher(Generic[V, R]): - def __init__(self, wrapped_method: Callable[..., R]): - self._dispatch_dict: dict[type, Callable[..., R]] = {} - self._wrapped_method: Callable[..., R] = wrapped_method - - def case(self, node_type: type): - """Decorator for visitor's case handlers.""" - - def decorate(handler: Callable[..., R]): - if node_type in self._dispatch_dict: - raise ValueError(f"Duplicate visitor case {node_type}") - self._dispatch_dict[node_type] = handler - return handler - - return decorate - - def __call__(self, instance: V, node: object, *args, **kwargs) -> R: - for cls in node.__class__.mro(): - if cls in self._dispatch_dict: - return self._dispatch_dict[cls](instance, node, *args, **kwargs) - - return self._wrapped_method(instance, node, *args, **kwargs) - - def __get__(self, obj: V, objtype=None) -> Callable[..., R]: - if obj is None: - return self - return MethodType(self, obj) - - -def visitor(method): - """Decorator to create a visitor using type-based dispatch. - - Use this decorator to convert a method into a visitor, like shown below. - After declaring a method (e.g. `my_method`) a visitor, - its case handlers can be declared using the `my_method.case` decorator, like this: - - ```Python - class DemoVisitor: - @visitor - def visit(self, obj: object): - # fallback case - ... - - @visit.case(str) - def visit_str(self, obj: str): - # code for handling a str - ``` - - When `visit` is later called with some object `x`, the case handler to be executed is - determined according to the method resolution order of `x` (i.e. along its type's inheritance hierarchy). - If no case matches, the fallback code in the original visitor method is executed. - In this example, if `visit` is called with an object of type `str`, the call is dispatched to `visit_str`. - - This visitor dispatch method is primarily designed for traversing abstract syntax tree structures. - The primary visitor method (`visit` in above example) should define the common parent type of all object - types the visitor can handle, with cases declared for all required subtypes. - However, this type relationship is not enforced at runtime. - """ - return wraps(method)(VisitorDispatcher(method)) -- GitLab From 271404a6da71fa30cd238ab1c59d5668cf9d94e8 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 13:30:43 +0100 Subject: [PATCH 09/18] update and fix large parts of the test suite --- conftest.py | 33 +++++++++------ src/pystencilssfg/cli.py | 11 +---- src/pystencilssfg/config.py | 20 +++++++++ src/pystencilssfg/context.py | 2 +- src/pystencilssfg/emission/file_printer.py | 9 ++-- src/pystencilssfg/generator.py | 41 +++++++++++++++---- tests/extensions/test_sycl.py | 12 +++--- tests/generator/test_config.py | 5 ++- .../source/BasicDefinitions.py | 4 +- .../generator_scripts/source/Conditionals.py | 4 +- .../generator_scripts/source/JacobiMdspan.py | 4 +- .../source/MdSpanFixedShapeLayouts.py | 3 +- .../source/MdSpanLbStreaming.py | 3 +- tests/generator_scripts/source/ScaleKernel.py | 3 +- .../source/StlContainers1D.py | 4 +- tests/generator_scripts/source/SyclBuffers.py | 3 +- tests/ir/test_postprocessing.py | 19 ++------- 17 files changed, 104 insertions(+), 76 deletions(-) diff --git a/conftest.py b/conftest.py index 661e722..287ed04 100644 --- a/conftest.py +++ b/conftest.py @@ -2,21 +2,30 @@ import pytest from os import path -@pytest.fixture(autouse=True) -def prepare_doctest_namespace(doctest_namespace): - from pystencilssfg import SfgContext, SfgComposer - from pystencilssfg import lang - - # Place a composer object in the environment for doctests - - sfg = SfgComposer(SfgContext()) - doctest_namespace["sfg"] = sfg - doctest_namespace["lang"] = lang - - DATA_DIR = path.join(path.split(__file__)[0], "tests/data") @pytest.fixture def sample_config_module(): return path.join(DATA_DIR, "project_config.py") + + +@pytest.fixture +def sfg(): + from pystencilssfg import SfgContext, SfgComposer + from pystencilssfg.ir import SfgSourceFile, SfgSourceFileType + + return SfgComposer( + SfgContext( + header_file=SfgSourceFile("", SfgSourceFileType.HEADER), + impl_file=SfgSourceFile("", SfgSourceFileType.TRANSLATION_UNIT), + ) + ) + + +@pytest.fixture(autouse=True) +def prepare_doctest_namespace(doctest_namespace, sfg): + from pystencilssfg import lang + + doctest_namespace["sfg"] = sfg + doctest_namespace["lang"] = lang diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py index 8cd9f4f..d612fbc 100644 --- a/src/pystencilssfg/cli.py +++ b/src/pystencilssfg/cli.py @@ -1,12 +1,11 @@ import sys import os from os import path -from pathlib import Path from typing import NoReturn from argparse import ArgumentParser, BooleanOptionalAction -from .config import CommandLineParameters, SfgConfigException, OutputMode +from .config import CommandLineParameters, SfgConfigException def add_newline_arg(parser): @@ -81,13 +80,7 @@ def list_files(args) -> NoReturn: _, scriptname = path.split(args.codegen_script) basename = path.splitext(scriptname)[0] - output_dir: Path = config.get_option("output_directory") - - header_ext = config.extensions.get_option("header") - output_files = [output_dir / f"{basename}.{header_ext}"] - if config.output_mode != OutputMode.HEADER_ONLY: - impl_ext = config.extensions.get_option("impl") - output_files.append(output_dir / f"{basename}.{impl_ext}") + output_files = config._get_output_files(basename) print( args.sep.join(str(of) for of in output_files), diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index 7bbcfc6..18aaa51 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -173,6 +173,26 @@ class SfgConfig(ConfigBase): def _validate_output_directory(self, pth: str | Path) -> Path: return Path(pth) + def _get_output_files(self, basename: str): + output_dir: Path = self.get_option("output_directory") + + header_ext = self.extensions.get_option("header") + impl_ext = self.extensions.get_option("impl") + output_files = [output_dir / f"{basename}.{header_ext}"] + output_mode = self.get_option("output_mode") + + if impl_ext is None: + match output_mode: + case OutputMode.INLINE: + impl_ext = "ipp" + case OutputMode.STANDALONE: + impl_ext = "cpp" + + if impl_ext is not None: + output_files.append(output_dir / f"{basename}.{impl_ext}") + + return tuple(output_files) + class CommandLineParameters: @staticmethod diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index f1a3251..032c1e4 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -102,7 +102,7 @@ class SfgCursor: self._cur_namespace: SfgNamespace = namespace - self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] + self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] = dict() for f in self._ctx.files: if self._cur_namespace is not None: block = SfgNamespaceBlock(self._cur_namespace) diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index ec43468..f92fba4 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -35,12 +35,12 @@ class SfgFilePrinter: code = "" if file.file_type == SfgSourceFileType.HEADER: - code += "#pragma once\n" + code += "#pragma once\n\n" if file.prelude: comment = "/**\n" comment += indent(file.prelude, " * ") - comment += "\n */\n\n" + comment += " */\n\n" code += comment @@ -54,6 +54,7 @@ class SfgFilePrinter: # Here begins the actual code code += "\n\n".join(self.visit(elem) for elem in file.elements) code += "\n" + return code def visit( @@ -73,6 +74,8 @@ class SfgFilePrinter: return self.visit_decl(entity, inclass) case SfgEntityDef(entity): return self.visit_defin(entity, inclass) + case SfgClassBody(): + return self.visit_defin(elem, inclass) case _: assert False, "illegal code element" @@ -173,7 +176,7 @@ class SfgFilePrinter: code = "" if func.inline: code += "inline " - code += func.return_type.c_string() + code += func.return_type.c_string() + " " params_str = ", ".join( f"{param.dtype.c_string()} {param.name}" for param in func.parameters ) diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index b055ad5..1191ebe 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -1,10 +1,16 @@ from pathlib import Path -from .config import SfgConfig, CommandLineParameters, OutputMode, GLOBAL_NAMESPACE +from .config import ( + SfgConfig, + CommandLineParameters, + OutputMode, + _GlobalNamespace, +) from .context import SfgContext from .composer import SfgComposer from .emission import SfgCodeEmitter from .exceptions import SfgException +from .lang import HeaderFile class SourceFileGenerator: @@ -26,7 +32,10 @@ class SourceFileGenerator: """ def __init__( - self, sfg_config: SfgConfig | None = None, keep_unknown_argv: bool = False + self, + sfg_config: SfgConfig | None = None, + namespace: str | None = None, + keep_unknown_argv: bool = False, ): if sfg_config and not isinstance(sfg_config, SfgConfig): raise TypeError("sfg_config is not an SfgConfiguration.") @@ -68,13 +77,13 @@ class SourceFileGenerator: self._output_mode: OutputMode = config.get_option("output_mode") self._output_dir: Path = config.get_option("output_directory") - self._header_ext: str = config.extensions.get_option("header") - self._impl_ext: str = config.extensions.get_option("impl") + + output_files = config._get_output_files(basename) from .ir import SfgSourceFile, SfgSourceFileType self._header_file = SfgSourceFile( - f"{basename}.{self._header_ext}", SfgSourceFileType.HEADER + output_files[0].name, SfgSourceFileType.HEADER ) self._impl_file: SfgSourceFile | None @@ -83,17 +92,29 @@ class SourceFileGenerator: self._impl_file = None case OutputMode.STANDALONE: self._impl_file = SfgSourceFile( - f"{basename}.{self._impl_ext}", SfgSourceFileType.TRANSLATION_UNIT + output_files[1].name, SfgSourceFileType.TRANSLATION_UNIT + ) + self._impl_file.includes.append( + HeaderFile.parse(self._header_file.name) ) case OutputMode.INLINE: self._impl_file = SfgSourceFile( - f"{basename}.{self._impl_ext}", SfgSourceFileType.HEADER + output_files[1].name, SfgSourceFileType.HEADER ) + outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace") + match (outer_namespace, namespace): + case [_GlobalNamespace(), None]: + namespace = None + case [_GlobalNamespace(), nspace] | [nspace, None]: + namespace = nspace + case [outer, inner]: + namespace = f"{outer}::{inner}" + self._context = SfgContext( self._header_file, self._impl_file, - None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace, # type: ignore + namespace, config.codestyle, argv=script_args, project_info=cli_params.get_project_info(), @@ -119,6 +140,10 @@ class SourceFileGenerator: def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: + if self._output_mode == OutputMode.INLINE: + assert self._impl_file is not None + self._header_file.elements.append(f'#include "{self._impl_file.name}"') + # Collect header files for inclusion # from .ir import collect_includes diff --git a/tests/extensions/test_sycl.py b/tests/extensions/test_sycl.py index db99278..0e067c8 100644 --- a/tests/extensions/test_sycl.py +++ b/tests/extensions/test_sycl.py @@ -5,8 +5,8 @@ import pystencils as ps from pystencilssfg import SfgContext -def test_parallel_for_1_kernels(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_1_kernels(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]") @@ -24,8 +24,8 @@ def test_parallel_for_1_kernels(): ) -def test_parallel_for_2_kernels(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_2_kernels(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]") @@ -43,8 +43,8 @@ def test_parallel_for_2_kernels(): ) -def test_parallel_for_2_kernels_fail(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_2_kernels_fail(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g = ps.fields(f"f,g:{data_type}[{dim}D]") diff --git a/tests/generator/test_config.py b/tests/generator/test_config.py index 4485dc2..250c158 100644 --- a/tests/generator/test_config.py +++ b/tests/generator/test_config.py @@ -1,4 +1,5 @@ import pytest +from pathlib import Path from pystencilssfg.config import ( SfgConfig, @@ -86,7 +87,7 @@ def test_from_commandline(sample_config_module): cli_args = CommandLineParameters(args) cfg = cli_args.get_config() - assert cfg.output_directory == ".out" + assert cfg.output_directory == Path(".out") assert cfg.extensions.header == "h++" assert cfg.extensions.impl == "c++" @@ -100,7 +101,7 @@ def test_from_commandline(sample_config_module): assert cfg.clang_format.code_style == "llvm" assert cfg.clang_format.skip is True assert ( - cfg.output_directory == "gen_sources" + cfg.output_directory == Path("gen_sources") ) # value from config module overridden by commandline assert cfg.outer_namespace == "myproject" assert cfg.extensions.header == "hpp" diff --git a/tests/generator_scripts/source/BasicDefinitions.py b/tests/generator_scripts/source/BasicDefinitions.py index 7cfe352..51ad4d5 100644 --- a/tests/generator_scripts/source/BasicDefinitions.py +++ b/tests/generator_scripts/source/BasicDefinitions.py @@ -4,12 +4,10 @@ from pystencilssfg import SourceFileGenerator, SfgConfig cfg = SfgConfig() cfg.clang_format.skip = True -with SourceFileGenerator(cfg) as sfg: +with SourceFileGenerator(cfg, namespace="awesome") as sfg: sfg.prelude("Expect the unexpected, and you shall never be surprised.") sfg.include("<iostream>") sfg.include("config.h") - sfg.namespace("awesome") - sfg.code("#define PI 3.1415") sfg.code("using namespace std;") diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py index 9016b73..216f95f 100644 --- a/tests/generator_scripts/source/Conditionals.py +++ b/tests/generator_scripts/source/Conditionals.py @@ -1,9 +1,7 @@ from pystencilssfg import SourceFileGenerator from pystencils.types import PsCustomType -with SourceFileGenerator() as sfg: - sfg.namespace("gen") - +with SourceFileGenerator(namespace="gen") as sfg: sfg.include("<iostream>") sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };") diff --git a/tests/generator_scripts/source/JacobiMdspan.py b/tests/generator_scripts/source/JacobiMdspan.py index bbe95ac..b8f1744 100644 --- a/tests/generator_scripts/source/JacobiMdspan.py +++ b/tests/generator_scripts/source/JacobiMdspan.py @@ -7,9 +7,7 @@ from pystencilssfg.lang.cpp.std import mdspan mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") -with SourceFileGenerator() as sfg: - sfg.namespace("gen") - +with SourceFileGenerator(namespace="gen") as sfg: u_src, u_dst, f = fields("u_src, u_dst, f(1) : double[2D]", layout="fzyx") h = sp.Symbol("h") diff --git a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py index c89fe24..9a66b40 100644 --- a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py +++ b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py @@ -5,8 +5,7 @@ from pystencilssfg.lang import strip_ptr_ref std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") -with SourceFileGenerator() as sfg: - sfg.namespace("gen") +with SourceFileGenerator(namespace="gen") as sfg: sfg.include("<cassert>") def check_layout(field: ps.Field, mdspan: std.mdspan): diff --git a/tests/generator_scripts/source/MdSpanLbStreaming.py b/tests/generator_scripts/source/MdSpanLbStreaming.py index 60049a8..ad8a758 100644 --- a/tests/generator_scripts/source/MdSpanLbStreaming.py +++ b/tests/generator_scripts/source/MdSpanLbStreaming.py @@ -43,8 +43,7 @@ def lbm_stream(sfg: SfgComposer, field_layout: str, layout_policy: str): ) -with SourceFileGenerator() as sfg: - sfg.namespace("gen") +with SourceFileGenerator(namespace="gen") as sfg: sfg.include("<cassert>") sfg.include("<array>") diff --git a/tests/generator_scripts/source/ScaleKernel.py b/tests/generator_scripts/source/ScaleKernel.py index 8bcc75f..1d76dc7 100644 --- a/tests/generator_scripts/source/ScaleKernel.py +++ b/tests/generator_scripts/source/ScaleKernel.py @@ -2,7 +2,7 @@ from pystencils import TypedSymbol, fields, kernel from pystencilssfg import SourceFileGenerator -with SourceFileGenerator() as sfg: +with SourceFileGenerator(namespace="gen") as sfg: N = 10 α = TypedSymbol("alpha", "float32") src, dst = fields(f"src, dst: float32[{N}]") @@ -13,7 +13,6 @@ with SourceFileGenerator() as sfg: khandle = sfg.kernels.create(scale) - sfg.namespace("gen") sfg.code(f"constexpr int N = {N};") sfg.klass("Scale")( diff --git a/tests/generator_scripts/source/StlContainers1D.py b/tests/generator_scripts/source/StlContainers1D.py index 3f6ec2c..260a650 100644 --- a/tests/generator_scripts/source/StlContainers1D.py +++ b/tests/generator_scripts/source/StlContainers1D.py @@ -5,9 +5,7 @@ from pystencilssfg import SourceFileGenerator from pystencilssfg.lang.cpp import std -with SourceFileGenerator() as sfg: - sfg.namespace("StlContainers1D::gen") - +with SourceFileGenerator(namespace="StlContainers1D::gen") as sfg: src, dst = ps.fields("src, dst: double[1D]") asms = [ diff --git a/tests/generator_scripts/source/SyclBuffers.py b/tests/generator_scripts/source/SyclBuffers.py index 36234a8..4668b3c 100644 --- a/tests/generator_scripts/source/SyclBuffers.py +++ b/tests/generator_scripts/source/SyclBuffers.py @@ -4,9 +4,8 @@ from pystencilssfg import SourceFileGenerator import pystencilssfg.extensions.sycl as sycl -with SourceFileGenerator() as sfg: +with SourceFileGenerator(namespace="gen") as sfg: sfg = sycl.SyclComposer(sfg) - sfg.namespace("gen") u_src, u_dst, f = ps.fields("u_src, u_dst, f : double[2D]", layout="fzyx") h = sp.Symbol("h") diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 070743a..9d51c8f 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -2,7 +2,6 @@ import sympy as sp from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type from pystencils.types import PsCustomType -from pystencilssfg import SfgContext, SfgComposer from pystencilssfg.composer import make_sequence from pystencilssfg.lang import IFieldExtraction, AugExpr @@ -11,10 +10,7 @@ from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir.postprocessing import CallTreePostProcessing -def test_live_vars(): - ctx = SfgContext() - sfg = SfgComposer(ctx) - +def test_live_vars(sfg): f, g = fields("f, g(2): double[2D]") x, y = [TypedSymbol(n, "double") for n in "xy"] z = sp.Symbol("z") @@ -42,10 +38,7 @@ def test_live_vars(): assert free_vars == expected -def test_find_sympy_symbols(): - ctx = SfgContext() - sfg = SfgComposer(ctx) - +def test_find_sympy_symbols(sfg): f, g = fields("f, g(2): double[2D]") x, y, z = sp.symbols("x, y, z") @@ -94,7 +87,7 @@ class DemoFieldExtraction(IFieldExtraction): return AugExpr.format("{}.stride({})", self.obj, coordinate) -def test_field_extraction(): +def test_field_extraction(sfg): sx, sy, tx, ty = [ TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty") ] @@ -104,8 +97,6 @@ def test_field_extraction(): def set_constant(): f.center @= 13.2 - sfg = SfgComposer(SfgContext()) - khandle = sfg.kernels.create(set_constant) extraction = DemoFieldExtraction("f") @@ -129,7 +120,7 @@ def test_field_extraction(): assert stmt.code_string == line -def test_duplicate_field_shapes(): +def test_duplicate_field_shapes(sfg): N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")] f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) @@ -138,8 +129,6 @@ def test_duplicate_field_shapes(): def set_constant(): f.center @= g.center(0) - sfg = SfgComposer(SfgContext()) - khandle = sfg.kernels.create(set_constant) call_tree = make_sequence( -- GitLab From 01113bf46501ba370d429ff6844a8d241537bc5c Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 14:54:50 +0100 Subject: [PATCH 10/18] implement user-defined include sorting. Improve header collection for pystencils-native types. Fix more tests. --- src/pystencilssfg/__init__.py | 4 +- src/pystencilssfg/config.py | 10 +- src/pystencilssfg/context.py | 14 ++- src/pystencilssfg/emission/clang_format.py | 11 +- src/pystencilssfg/emission/emitter.py | 11 +- src/pystencilssfg/emission/file_printer.py | 8 +- src/pystencilssfg/generator.py | 38 +++++- src/pystencilssfg/ir/analysis.py | 129 ++++++++++++++------- src/pystencilssfg/ir/entities.py | 6 +- src/pystencilssfg/ir/syntax.py | 21 ++-- src/pystencilssfg/lang/expressions.py | 34 ++++-- 11 files changed, 202 insertions(+), 84 deletions(-) diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index b2def3b..fea6f8a 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -1,5 +1,5 @@ -from .config import SfgConfig -from .generator import SourceFileGenerator, GLOBAL_NAMESPACE, OutputMode +from .config import SfgConfig, GLOBAL_NAMESPACE, OutputMode +from .generator import SourceFileGenerator from .composer import SfgComposer from .context import SfgContext from .lang import SfgVar, AugExpr diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index 18aaa51..a94d9ad 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -3,7 +3,7 @@ from __future__ import annotations from argparse import ArgumentParser from types import ModuleType -from typing import Any, Sequence +from typing import Any, Sequence, Callable from dataclasses import dataclass from enum import Enum, auto from os import path @@ -12,6 +12,8 @@ from pathlib import Path from pystencils.codegen.config import ConfigBase, Option, BasicOption, Category +from .lang import HeaderFile + class SfgConfigException(Exception): ... # noqa: E701 @@ -61,6 +63,12 @@ class CodeStyle(ConfigBase): indent_width: BasicOption[int] = BasicOption(2) """The number of spaces successively nested blocks should be indented with""" + includes_sorting_key: BasicOption[Callable[[HeaderFile], Any]] = BasicOption() + """Key function that will be used to sort `#include` statements in generated files. + + Pystencils-sfg will instruct clang-tidy to forego include sorting if this option is set. + """ + # TODO possible future options: # - newline before opening { # - trailing return types diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 032c1e4..24c38e1 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -20,7 +20,7 @@ class SfgContext: self, header_file: SfgSourceFile, impl_file: SfgSourceFile | None, - outer_namespace: str | None = None, + namespace: str | None = None, codestyle: CodeStyle | None = None, argv: Sequence[str] | None = None, project_info: Any = None, @@ -28,7 +28,7 @@ class SfgContext: self._argv = argv self._project_info = project_info - self._outer_namespace = outer_namespace + self._outer_namespace = namespace self._inner_namespace: str | None = None self._codestyle = codestyle if codestyle is not None else CodeStyle() @@ -39,8 +39,8 @@ class SfgContext: self._global_namespace = SfgGlobalNamespace() current_ns: SfgNamespace = self._global_namespace - if outer_namespace is not None: - for token in outer_namespace.split("::"): + if namespace is not None: + for token in namespace.split("::"): current_ns = SfgNamespace(token, current_ns) self._cursor = SfgCursor(self, current_ns) @@ -104,8 +104,10 @@ class SfgCursor: self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] = dict() for f in self._ctx.files: - if self._cur_namespace is not None: - block = SfgNamespaceBlock(self._cur_namespace) + if not isinstance(namespace, SfgGlobalNamespace): + block = SfgNamespaceBlock( + self._cur_namespace, self._cur_namespace.fqname + ) f.elements.append(block) self._loc[f] = block.elements else: diff --git a/src/pystencilssfg/emission/clang_format.py b/src/pystencilssfg/emission/clang_format.py index 1b15e8c..b73d9da 100644 --- a/src/pystencilssfg/emission/clang_format.py +++ b/src/pystencilssfg/emission/clang_format.py @@ -5,14 +5,16 @@ from ..config import ClangFormatOptions from ..exceptions import SfgException -def invoke_clang_format(code: str, options: ClangFormatOptions) -> str: +def invoke_clang_format( + code: str, options: ClangFormatOptions, sort_includes: str | None = None +) -> str: """Call the `clang-format` command-line tool to format the given code string according to the given style arguments. Args: code: Code string to format - codestyle: [SfgCodeStyle][pystencilssfg.configuration.SfgCodeStyle] object - defining the `clang-format` binary and the desired code style. + options: Options controlling the clang-format invocation + sort_includes: Option to be passed on to clang-format's ``--sort-includes`` argument Returns: The formatted code, if `clang-format` was run sucessfully. @@ -31,6 +33,9 @@ def invoke_clang_format(code: str, options: ClangFormatOptions) -> str: force = options.get_option("force") style = options.get_option("code_style") args = [binary, f"--style={style}"] + + if sort_includes is not None: + args += ["--sort-includes", sort_includes] if not shutil.which(binary): if force: diff --git a/src/pystencilssfg/emission/emitter.py b/src/pystencilssfg/emission/emitter.py index 440344b..c1b6e9c 100644 --- a/src/pystencilssfg/emission/emitter.py +++ b/src/pystencilssfg/emission/emitter.py @@ -17,12 +17,21 @@ class SfgCodeEmitter: clang_format: ClangFormatOptions, ): self._output_dir = output_directory + self._code_style = code_style self._clang_format_opts = clang_format self._printer = SfgFilePrinter(code_style) def emit(self, file: SfgSourceFile): code = self._printer(file) - code = invoke_clang_format(code, self._clang_format_opts) + + if self._code_style.get_option("includes_sorting_key") is not None: + sort_includes = "Never" + else: + sort_includes = None + + code = invoke_clang_format( + code, self._clang_format_opts, sort_includes=sort_includes + ) self._output_dir.mkdir(parents=True, exist_ok=True) fpath = self._output_dir / file.name diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index f92fba4..8216a7b 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -63,12 +63,12 @@ class SfgFilePrinter: match elem: case str(): return elem - case SfgNamespaceBlock(namespace, elements): - code = f"namespace {namespace.name} {{\n" + case SfgNamespaceBlock(_, elements, label): + code = f"namespace {label} {{\n" code += self._code_style.indent( "\n\n".join(self.visit(e) for e in elements) ) - code += f"\n}} // namespace {namespace.name}" + code += f"\n}} // namespace {label}" return code case SfgEntityDecl(entity): return self.visit_decl(entity, inclass) @@ -157,7 +157,7 @@ class SfgFilePrinter: code = f"{cls.class_keyword} {cls.name} {{\n" vblocks_str = [self._visibility_block(b) for b in vblocks] code += "\n\n".join(vblocks_str) - code += "\n}\n" + code += "\n};\n" return code case _: diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index 1191ebe..dd9a78c 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -1,5 +1,6 @@ from pathlib import Path +from typing import Callable, Any from .config import ( SfgConfig, CommandLineParameters, @@ -102,11 +103,16 @@ class SourceFileGenerator: output_files[1].name, SfgSourceFileType.HEADER ) + # TODO: Find a way to not hard-code the restrict qualifier in pystencils + self._header_file.elements.append("#define RESTRICT __restrict__") + outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace") match (outer_namespace, namespace): case [_GlobalNamespace(), None]: namespace = None - case [_GlobalNamespace(), nspace] | [nspace, None]: + case [_GlobalNamespace(), nspace] if nspace is not None: + namespace = nspace + case [nspace, None]: namespace = nspace case [outer, inner]: namespace = f"{outer}::{inner}" @@ -124,6 +130,16 @@ class SourceFileGenerator: self._output_dir, config.codestyle, config.clang_format ) + sort_key = config.codestyle.get_option("includes_sorting_key") + if sort_key is None: + + def default_key(h: HeaderFile): + return str(h) + + sort_key = default_key + + self._include_sort_key: Callable[[HeaderFile], Any] = sort_key + def clean_files(self): header_path = self._output_dir / self._header_file.name if header_path.exists(): @@ -144,12 +160,22 @@ class SourceFileGenerator: assert self._impl_file is not None self._header_file.elements.append(f'#include "{self._impl_file.name}"') - # Collect header files for inclusion - # from .ir import collect_includes + from .ir import collect_includes + + header_includes = collect_includes(self._header_file) + self._header_file.includes = list( + set(self._header_file.includes) | header_includes + ) + self._header_file.includes.sort(key=self._include_sort_key) - # TODO: Collect headers - # for header in collect_includes(self._context): - # self._context.add_include(SfgHeaderInclude(header)) + if self._impl_file is not None: + impl_includes = collect_includes(self._impl_file) + # If some header is already included by the generated header file, do not duplicate that inclusion + impl_includes -= header_includes + self._impl_file.includes = list( + set(self._impl_file.includes) | impl_includes + ) + self._impl_file.includes.sort(key=self._include_sort_key) self._emitter.emit(self._header_file) if self._impl_file is not None: diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index b88c4f3..a2bce07 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -1,54 +1,103 @@ from __future__ import annotations -from typing import Any from functools import reduce -from ..exceptions import SfgException from ..lang import HeaderFile, includes +from .syntax import ( + SfgSourceFile, + SfgNamespaceElement, + SfgClassBodyElement, + SfgNamespaceBlock, + SfgEntityDecl, + SfgEntityDef, + SfgClassBody, + SfgVisibilityBlock, +) -def collect_includes(obj: Any) -> set[HeaderFile]: +def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: from .call_tree import SfgCallTreeNode from .entities import ( + SfgCodeEntity, + SfgKernelHandle, SfgFunction, - SfgClass, + SfgMethod, + SfgClassMember, SfgConstructor, SfgMemberVariable, ) - match obj: - # TODO - - case SfgCallTreeNode(): - return reduce( - lambda accu, child: accu | collect_includes(child), - obj.children, - obj.required_includes, - ) - - case SfgFunction(_, tree, parameters): - param_headers: set[HeaderFile] = reduce( - set.union, (includes(p) for p in parameters), set() - ) - return param_headers | collect_includes(tree) - - case SfgClass(): - return reduce( - lambda accu, member: accu | (collect_includes(member)), - obj.members(), - set(), - ) - - case SfgConstructor(_, parameters): - param_headers = reduce( - set.union, (includes(p) for p in parameters), set() - ) - return param_headers - - case SfgMemberVariable(): - return includes(obj) - - case _: - raise SfgException( - f"Can't collect includes from object of type {type(obj)}" - ) + def visit_decl(entity: SfgCodeEntity | SfgClassMember) -> set[HeaderFile]: + match entity: + case ( + SfgKernelHandle(_, parameters) + | SfgFunction(_, _, parameters) + | SfgMethod(_, _, parameters) + | SfgConstructor(_, parameters, _, _) + ): + incls = reduce(set.union, (includes(p) for p in parameters), set()) + if isinstance(entity, (SfgFunction, SfgMethod)): + incls |= includes(entity.return_type) + return incls + + case SfgMemberVariable(): + return includes(entity) + + case _: + assert False, "unexpected entity" + + def walk_syntax( + obj: ( + SfgNamespaceElement + | SfgClassBodyElement + | SfgVisibilityBlock + | SfgCallTreeNode + ), + ) -> set[HeaderFile]: + match obj: + case str(): + return set() + + case SfgCallTreeNode(): + return reduce( + lambda accu, child: accu | walk_syntax(child), + obj.children, + obj.required_includes, + ) + + case SfgEntityDecl(entity): + return visit_decl(entity) + + case SfgEntityDef(entity): + match entity: + case SfgKernelHandle(kernel, _): + return set( + HeaderFile.parse(h) for h in kernel.required_headers + ) | visit_decl(entity) + + case SfgFunction(_, tree, _) | SfgMethod(_, tree, _): + return walk_syntax(tree) | visit_decl(entity) + + case SfgConstructor(): + return visit_decl(entity) + + case SfgMemberVariable(): + return includes(entity) + + case _: + assert False, "unexpected entity" + + case SfgNamespaceBlock(_, elements) | SfgVisibilityBlock(_, elements): + return reduce( + lambda accu, elem: accu | walk_syntax(elem), elements, set() + ) + + case SfgClassBody(_, vblocks): + return reduce( + lambda accu, vblock: accu | walk_syntax(vblock), vblocks, set() + ) + + case _: + assert False, "unexpected syntax element" + + return reduce(lambda accu, elem: accu | walk_syntax(elem), file.elements, set()) diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index 2b0e3e5..90205fe 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -100,7 +100,7 @@ class SfgGlobalNamespace(SfgNamespace): class SfgKernelHandle(SfgCodeEntity): """Handle to a pystencils kernel.""" - __match_args__ = ("kernel",) + __match_args__ = ("kernel", "parameters") def __init__(self, name: str, namespace: SfgKernelNamespace, kernel: Kernel): super().__init__(name, namespace) @@ -219,7 +219,7 @@ class SfgKernelNamespace(SfgNamespace): class SfgFunction(SfgCodeEntity): """A free function.""" - __match_args__ = ("name", "tree", "parameters") + __match_args__ = ("name", "tree", "parameters", "return_type") def __init__( self, @@ -337,7 +337,7 @@ class SfgMemberVariable(SfgVar, SfgClassMember): class SfgMethod(SfgClassMember): """Instance method of a class""" - __match_args__ = ("name", "tree", "parameters") + __match_args__ = ("name", "tree", "parameters", "return_type") def __init__( self, diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py index 574c029..699e7b5 100644 --- a/src/pystencilssfg/ir/syntax.py +++ b/src/pystencilssfg/ir/syntax.py @@ -80,6 +80,8 @@ class SfgVisibilityBlock: visibility: The visibility qualifier of this block """ + __match_args__ = ("visibility", "elements") + def __init__(self, visibility: SfgVisibility) -> None: self._vis = visibility self._elements: list[SfgClassBodyElement] = [] @@ -107,29 +109,34 @@ class SfgVisibilityBlock: class SfgNamespaceBlock: - """A C++ namespace. - - Each namespace has a `name` and a `parent`; its fully qualified name is given as - ``<parent.name>::<name>``. + """A C++ namespace block. Args: - name: Local name of this namespace - parent: Parent namespace enclosing this namespace + namespace: Namespace associated with this block + label: Label printed at the opening brace of this block. + This may be the namespace name, or a compressed qualified + name containing one or more of its parent namespaces. """ __match_args__ = ( "namespace", "elements", + "label", ) - def __init__(self, namespace: SfgNamespace) -> None: + def __init__(self, namespace: SfgNamespace, label: str | None = None) -> None: self._namespace = namespace + self._label = label if label is not None else namespace.name self._elements: list[SfgNamespaceElement] = [] @property def namespace(self) -> SfgNamespace: return self._namespace + @property + def label(self) -> str: + return self._label + @property def elements(self) -> list[SfgNamespaceElement]: """Sequence of source elements that make up the body of this namespace""" diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 4a1f7e9..b3ed18e 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -7,7 +7,7 @@ import sympy as sp from pystencils import TypedSymbol from pystencils.codegen import Parameter -from pystencils.types import PsType, UserTypeSpec, create_type +from pystencils.types import PsType, PsScalarType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile @@ -452,7 +452,7 @@ def depends(expr: ExprLike) -> set[SfgVar]: raise ValueError(f"Invalid expression: {expr}") -def includes(expr: ExprLike) -> set[HeaderFile]: +def includes(obj: ExprLike | PsType) -> set[HeaderFile]: """Determine the set of header files an expression depends on. Args: @@ -465,21 +465,33 @@ def includes(expr: ExprLike) -> set[HeaderFile]: ValueError: If the argument was not a valid variable or expression """ - match expr: + if isinstance(obj, PsType): + obj = strip_ptr_ref(obj) + + match obj: + case CppType(): + return set(obj.includes) + + case PsType(): + headers = set(HeaderFile.parse(h) for h in obj.required_headers) + if isinstance(obj, PsScalarType): + headers.add(HeaderFile.parse("<cstdint>")) + return headers + case SfgVar(_, dtype): - match dtype: - case CppType(): - return set(dtype.includes) - case _: - return set(HeaderFile.parse(h) for h in dtype.required_headers) + return includes(dtype) + case TypedSymbol(): - return includes(asvar(expr)) + return includes(asvar(obj)) + case str(): return set() + case AugExpr(): - return set(expr.includes) + return set(obj.includes) + case _: - raise ValueError(f"Invalid expression: {expr}") + raise ValueError(f"Invalid expression: {obj}") class IFieldExtraction(ABC): -- GitLab From 35764348ad9fd6d5f778d69692b0212fd9111630 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 15:24:22 +0100 Subject: [PATCH 11/18] fix cstdint collection --- src/pystencilssfg/lang/expressions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index b3ed18e..72287ea 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -7,7 +7,7 @@ import sympy as sp from pystencils import TypedSymbol from pystencils.codegen import Parameter -from pystencils.types import PsType, PsScalarType, UserTypeSpec, create_type +from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile @@ -474,7 +474,7 @@ def includes(obj: ExprLike | PsType) -> set[HeaderFile]: case PsType(): headers = set(HeaderFile.parse(h) for h in obj.required_headers) - if isinstance(obj, PsScalarType): + if isinstance(obj, PsIntegerType): headers.add(HeaderFile.parse("<cstdint>")) return headers -- GitLab From 73ead40ad10f47f32bafd3330486552ddffff7ac Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 15:46:46 +0100 Subject: [PATCH 12/18] fix remaining bugs and tests - fix source file listing on cli - properly sort function and method args again --- src/pystencilssfg/config.py | 5 +++-- src/pystencilssfg/emission/clang_format.py | 4 ++-- src/pystencilssfg/generator.py | 3 ++- src/pystencilssfg/ir/analysis.py | 6 ++++-- src/pystencilssfg/ir/entities.py | 16 ++++++++++------ tests/generator_scripts/index.yaml | 2 +- tests/integration/cmake_project/GenTest.py | 4 +--- 7 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index a94d9ad..3b63f4f 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -65,7 +65,7 @@ class CodeStyle(ConfigBase): includes_sorting_key: BasicOption[Callable[[HeaderFile], Any]] = BasicOption() """Key function that will be used to sort `#include` statements in generated files. - + Pystencils-sfg will instruct clang-tidy to forego include sorting if this option is set. """ @@ -196,7 +196,8 @@ class SfgConfig(ConfigBase): case OutputMode.STANDALONE: impl_ext = "cpp" - if impl_ext is not None: + if output_mode != OutputMode.HEADER_ONLY: + assert impl_ext is not None output_files.append(output_dir / f"{basename}.{impl_ext}") return tuple(output_files) diff --git a/src/pystencilssfg/emission/clang_format.py b/src/pystencilssfg/emission/clang_format.py index b73d9da..50c51f1 100644 --- a/src/pystencilssfg/emission/clang_format.py +++ b/src/pystencilssfg/emission/clang_format.py @@ -14,7 +14,7 @@ def invoke_clang_format( Args: code: Code string to format options: Options controlling the clang-format invocation - sort_includes: Option to be passed on to clang-format's ``--sort-includes`` argument + sort_includes: Option to be passed on to clang-format's ``--sort-includes`` argument Returns: The formatted code, if `clang-format` was run sucessfully. @@ -33,7 +33,7 @@ def invoke_clang_format( force = options.get_option("force") style = options.get_option("code_style") args = [binary, f"--style={style}"] - + if sort_includes is not None: args += ["--sort-includes", sort_includes] diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index dd9a78c..471e60b 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -107,12 +107,13 @@ class SourceFileGenerator: self._header_file.elements.append("#define RESTRICT __restrict__") outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace") + match (outer_namespace, namespace): case [_GlobalNamespace(), None]: namespace = None case [_GlobalNamespace(), nspace] if nspace is not None: namespace = nspace - case [nspace, None]: + case [nspace, None] if not isinstance(nspace, _GlobalNamespace): namespace = nspace case [outer, inner]: namespace = f"{outer}::{inner}" diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index a2bce07..ff8331f 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -35,7 +35,9 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: | SfgMethod(_, _, parameters) | SfgConstructor(_, parameters, _, _) ): - incls = reduce(set.union, (includes(p) for p in parameters), set()) + incls: set[HeaderFile] = reduce( + lambda accu, p: accu | includes(p), parameters, set() + ) if isinstance(entity, (SfgFunction, SfgMethod)): incls |= includes(entity.return_type) return incls @@ -90,7 +92,7 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: case SfgNamespaceBlock(_, elements) | SfgVisibilityBlock(_, elements): return reduce( lambda accu, elem: accu | walk_syntax(elem), elements, set() - ) + ) # type: ignore case SfgClassBody(_, vblocks): return reduce( diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index 90205fe..a855155 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -235,15 +235,17 @@ class SfgFunction(SfgCodeEntity): self._return_type = return_type self._inline = inline - self._parameters: set[SfgVar] + self._parameters: tuple[SfgVar, ...] from .postprocessing import CallTreePostProcessing param_collector = CallTreePostProcessing() - self._parameters = param_collector(self._tree).function_params + self._parameters = tuple( + sorted(param_collector(self._tree).function_params, key=lambda p: p.name) + ) @property - def parameters(self) -> set[SfgVar]: + def parameters(self) -> tuple[SfgVar, ...]: return self._parameters @property @@ -356,19 +358,21 @@ class SfgMethod(SfgClassMember): self._inline = inline self._const = const - self._parameters: set[SfgVar] + self._parameters: tuple[SfgVar, ...] from .postprocessing import CallTreePostProcessing param_collector = CallTreePostProcessing() - self._parameters = param_collector(self._tree).function_params + self._parameters = tuple( + sorted(param_collector(self._tree).function_params, key=lambda p: p.name) + ) @property def name(self) -> str: return self._name @property - def parameters(self) -> set[SfgVar]: + def parameters(self) -> tuple[SfgVar, ...]: return self._parameters @property diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml index c78b335..eae4c39 100644 --- a/tests/generator_scripts/index.yaml +++ b/tests/generator_scripts/index.yaml @@ -25,7 +25,7 @@ BasicDefinitions: expect-code: hpp: - regex: >- - #include\s\"config\.h\"\s* + #include\s\"config\.h\"(\s|.)* namespace\s+awesome\s+{\s+.+\s+ #define\sPI\s3\.1415\s+ using\snamespace\sstd\;\s+ diff --git a/tests/integration/cmake_project/GenTest.py b/tests/integration/cmake_project/GenTest.py index 8399e70..093374c 100644 --- a/tests/integration/cmake_project/GenTest.py +++ b/tests/integration/cmake_project/GenTest.py @@ -1,8 +1,6 @@ from pystencilssfg import SourceFileGenerator -with SourceFileGenerator() as sfg: - sfg.namespace("gen") - +with SourceFileGenerator(namespace="gen") as sfg: retval = 42 if sfg.context.project_info is None else sfg.context.project_info sfg.function("getValue", return_type="int")( -- GitLab From b87be486a5030d543eb56e2d69dec54b7661a9ea Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 15:54:15 +0100 Subject: [PATCH 13/18] always include cstdint for kernels --- src/pystencilssfg/ir/analysis.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index ff8331f..5b6d2d6 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -73,9 +73,11 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: case SfgEntityDef(entity): match entity: case SfgKernelHandle(kernel, _): - return set( - HeaderFile.parse(h) for h in kernel.required_headers - ) | visit_decl(entity) + return ( + set(HeaderFile.parse(h) for h in kernel.required_headers) + | {HeaderFile.parse("<cstdint>")} + | visit_decl(entity) + ) case SfgFunction(_, tree, _) | SfgMethod(_, tree, _): return walk_syntax(tree) | visit_decl(entity) -- GitLab From 2dca4f92ebdfdbb8fd72a31d85fcfe011c371c89 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 7 Feb 2025 17:13:49 +0100 Subject: [PATCH 14/18] remove superfluous interfaces; remove debug lines from coverage --- pyproject.toml | 2 ++ src/pystencilssfg/ir/entities.py | 54 +------------------------------- src/pystencilssfg/ir/syntax.py | 10 ------ 3 files changed, 3 insertions(+), 63 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6ac0327..93987d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ parentdir_prefix = "pystencilssfg-" [tool.coverage.run] omit = [ "setup.py", + "noxfile.py", "src/pystencilssfg/_version.py", "integration/*" ] @@ -68,4 +69,5 @@ exclude_also = [ "\\.\\.\\.\n", "if TYPE_CHECKING:", "@(abc\\.)?abstractmethod", + "assert False" ] diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index a855155..6e6597d 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -7,10 +7,9 @@ from typing import ( Sequence, Generator, ) -from dataclasses import replace from itertools import chain -from pystencils import CreateKernelConfig, create_kernel, Field +from pystencils import Field from pystencils.codegen import Kernel from pystencils.types import PsType, PsCustomType @@ -164,57 +163,6 @@ class SfgKernelNamespace(SfgNamespace): ) self._kernels[kernel.name] = kernel - def add(self, kernel: Kernel, name: str | None = None): - """Adds an existing pystencils AST to this namespace. - If a name is specified, the AST's function name is changed.""" - if name is None: - kernel_name = kernel.name - else: - kernel_name = name - - if kernel_name in self._kernels: - raise ValueError( - f"Duplicate kernels: A kernel called {kernel_name} already exists in namespace {self.fqname}" - ) - - if name is not None: - kernel.name = kernel_name - - khandle = SfgKernelHandle(kernel_name, self, kernel) - self._kernels[kernel_name] = khandle - - # TODO: collect includes later - # for header in kernel.required_headers: - # self._ctx.add_include( - # SfgHeaderInclude(HeaderFile.parse(header), private=True) - # ) - - return khandle - - def create( - self, - assignments, - name: str | None = None, - config: CreateKernelConfig | None = None, - ): - """Creates a new pystencils kernel from a list of assignments and a configuration. - This is a wrapper around `pystencils.create_kernel` - with a subsequent call to `add`. - """ - if config is None: - config = CreateKernelConfig() - - if name is not None: - if name in self._kernels: - raise ValueError( - f"Duplicate kernels: A kernel with name {name} already exists in namespace {self.fqname}" - ) - config = replace(config, function_name=name) - - # type: ignore - kernel = create_kernel(assignments, config=config) - return self.add(kernel) - class SfgFunction(SfgCodeEntity): """A free function.""" diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py index 699e7b5..cdbd4c2 100644 --- a/src/pystencilssfg/ir/syntax.py +++ b/src/pystencilssfg/ir/syntax.py @@ -2,7 +2,6 @@ from __future__ import annotations from enum import Enum, auto from typing import ( - Generator, Iterable, TypeVar, Generic, @@ -15,7 +14,6 @@ from .entities import ( SfgKernelHandle, SfgFunction, SfgClassMember, - SfgMemberVariable, SfgVisibility, SfgClass, ) @@ -99,14 +97,6 @@ class SfgVisibilityBlock: def elements(self, elems: Iterable[SfgClassBodyElement]): self._elements = list(elems) - def members(self) -> Generator[SfgClassMember, None, None]: - for elem in self._elements: - match elem: - case SfgEntityDecl(entity) | SfgEntityDef(entity): - yield entity - case SfgMemberVariable(): - yield elem - class SfgNamespaceBlock: """A C++ namespace block. -- GitLab From 45febbc79aecc809f1e2ead641e4add55f0f95f6 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 10 Feb 2025 09:51:29 +0100 Subject: [PATCH 15/18] introduce namespace context manager --- src/pystencilssfg/composer/basic_composer.py | 3 +- src/pystencilssfg/context.py | 32 +++++++++++--- src/pystencilssfg/generator.py | 15 +++---- src/pystencilssfg/ir/entities.py | 43 ++++++++++++++++++- .../source/BasicDefinitions.py | 4 +- .../generator_scripts/source/Conditionals.py | 4 +- .../generator_scripts/source/JacobiMdspan.py | 10 +++-- .../source/MdSpanFixedShapeLayouts.py | 3 +- .../source/MdSpanLbStreaming.py | 3 +- tests/generator_scripts/source/ScaleKernel.py | 4 +- .../source/StlContainers1D.py | 13 +++--- tests/generator_scripts/source/SyclBuffers.py | 3 +- tests/integration/cmake_project/GenTest.py | 3 +- 13 files changed, 103 insertions(+), 37 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 08422be..7c08b67 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -204,8 +204,7 @@ class SfgBasicComposer(SfgIComposer): self.code(*definitions) def namespace(self, namespace: str): - # TODO: Enter into a new namespace context - raise NotImplementedError() + return self._cursor.enter_namespace(namespace) def generate(self, generator: CustomGenerator): """Invoke a custom code generator with the underlying context.""" diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 24c38e1..199c678 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Sequence, Any, Generator +from contextlib import contextmanager from .config import CodeStyle from .ir import ( @@ -38,12 +39,13 @@ class SfgContext: self._global_namespace = SfgGlobalNamespace() - current_ns: SfgNamespace = self._global_namespace + current_namespace: SfgNamespace if namespace is not None: - for token in namespace.split("::"): - current_ns = SfgNamespace(token, current_ns) + current_namespace = self._global_namespace.get_child_namespace(namespace) + else: + current_namespace = self._global_namespace - self._cursor = SfgCursor(self, current_ns) + self._cursor = SfgCursor(self, current_namespace) @property def argv(self) -> Sequence[str]: @@ -113,8 +115,6 @@ class SfgCursor: else: self._loc[f] = f.elements - # TODO: Enter and exit namespace blocks - @property def current_namespace(self) -> SfgNamespace: return self._cur_namespace @@ -135,3 +135,23 @@ class SfgCursor: f"Cannot write element {elem} to implemenation file since no implementation file is being generated." ) self._loc[impl_file].append(elem) + + def enter_namespace(self, qual_name: str): + namespace = self._cur_namespace.get_child_namespace(qual_name) + + outer_locs = self._loc.copy() + + for f in self._ctx.files: + block = SfgNamespaceBlock(namespace, qual_name) + self._loc[f].append(block) + self._loc[f] = block.elements + + @contextmanager + def ctxmgr(): + try: + yield None + finally: + # Have the cursor step back out of the nested namespace blocks + self._loc = outer_locs + + return ctxmgr() diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index 471e60b..f3f67a0 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -35,7 +35,6 @@ class SourceFileGenerator: def __init__( self, sfg_config: SfgConfig | None = None, - namespace: str | None = None, keep_unknown_argv: bool = False, ): if sfg_config and not isinstance(sfg_config, SfgConfig): @@ -108,15 +107,11 @@ class SourceFileGenerator: outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace") - match (outer_namespace, namespace): - case [_GlobalNamespace(), None]: - namespace = None - case [_GlobalNamespace(), nspace] if nspace is not None: - namespace = nspace - case [nspace, None] if not isinstance(nspace, _GlobalNamespace): - namespace = nspace - case [outer, inner]: - namespace = f"{outer}::{inner}" + namespace: str | None + if isinstance(outer_namespace, _GlobalNamespace): + namespace = None + else: + namespace = outer_namespace self._context = SfgContext( self._header_file, diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index 6e6597d..62ae1eb 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -74,8 +74,29 @@ class SfgNamespace(SfgCodeEntity): self._entities: dict[str, SfgCodeEntity] = dict() - def get_entity(self, name: str) -> SfgCodeEntity | None: - return self._entities.get(name, None) + def get_entity(self, qual_name: str) -> SfgCodeEntity | None: + """Find an entity with the given qualified name within this namespace. + + If `qual_name` contains any qualifying delimiters ``::``, + each component but the last is interpreted as a namespace. + """ + tokens = qual_name.split("::", 1) + match tokens: + case [entity_name]: + return self._entities.get(entity_name, None) + case [nspace, remaining_qualname]: + sub_nspace = self._entities.get(nspace, None) + if sub_nspace is not None: + if not isinstance(sub_nspace, SfgNamespace): + raise KeyError( + f"Unable to find entity {qual_name} in namespace {self._name}: " + f"Entity {nspace} is not a namespace." + ) + return sub_nspace.get_entity(remaining_qualname) + else: + return None + case _: + assert False, "unreachable code" def add_entity(self, entity: SfgCodeEntity): if entity.name in self._entities: @@ -84,6 +105,24 @@ class SfgNamespace(SfgCodeEntity): ) self._entities[entity.name] = entity + def get_child_namespace(self, qual_name: str): + if not qual_name: + raise ValueError("Anonymous namespaces are not supported") + + # Find the namespace by qualified lookup ... + namespace = self.get_entity(qual_name) + if namespace is not None: + if not type(namespace) is SfgNamespace: + raise ValueError(f"Entity {qual_name} exists, but is not a namespace") + else: + # ... or create it + tokens = qual_name.split("::") + namespace = self + for tok in tokens: + namespace = SfgNamespace(tok, namespace) + + return namespace + class SfgGlobalNamespace(SfgNamespace): """The C++ global namespace.""" diff --git a/tests/generator_scripts/source/BasicDefinitions.py b/tests/generator_scripts/source/BasicDefinitions.py index 51ad4d5..4453066 100644 --- a/tests/generator_scripts/source/BasicDefinitions.py +++ b/tests/generator_scripts/source/BasicDefinitions.py @@ -4,7 +4,9 @@ from pystencilssfg import SourceFileGenerator, SfgConfig cfg = SfgConfig() cfg.clang_format.skip = True -with SourceFileGenerator(cfg, namespace="awesome") as sfg: +with SourceFileGenerator(cfg) as sfg: + sfg.namespace("awesome") + sfg.prelude("Expect the unexpected, and you shall never be surprised.") sfg.include("<iostream>") sfg.include("config.h") diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py index 216f95f..9016b73 100644 --- a/tests/generator_scripts/source/Conditionals.py +++ b/tests/generator_scripts/source/Conditionals.py @@ -1,7 +1,9 @@ from pystencilssfg import SourceFileGenerator from pystencils.types import PsCustomType -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") + sfg.include("<iostream>") sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };") diff --git a/tests/generator_scripts/source/JacobiMdspan.py b/tests/generator_scripts/source/JacobiMdspan.py index b8f1744..2e0741a 100644 --- a/tests/generator_scripts/source/JacobiMdspan.py +++ b/tests/generator_scripts/source/JacobiMdspan.py @@ -7,13 +7,17 @@ from pystencilssfg.lang.cpp.std import mdspan mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") + u_src, u_dst, f = fields("u_src, u_dst, f(1) : double[2D]", layout="fzyx") h = sp.Symbol("h") @kernel def poisson_jacobi(): - u_dst[0,0] @= (h**2 * f[0, 0] + u_src[1, 0] + u_src[-1, 0] + u_src[0, 1] + u_src[0, -1]) / 4 + u_dst[0, 0] @= ( + h**2 * f[0, 0] + u_src[1, 0] + u_src[-1, 0] + u_src[0, 1] + u_src[0, -1] + ) / 4 poisson_kernel = sfg.kernels.create(poisson_jacobi) @@ -21,5 +25,5 @@ with SourceFileGenerator(namespace="gen") as sfg: sfg.map_field(u_src, mdspan.from_field(u_src, layout_policy="layout_left")), sfg.map_field(u_dst, mdspan.from_field(u_dst, layout_policy="layout_left")), sfg.map_field(f, mdspan.from_field(f, layout_policy="layout_left")), - sfg.call(poisson_kernel) + sfg.call(poisson_kernel), ) diff --git a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py index 9a66b40..c89fe24 100644 --- a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py +++ b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py @@ -5,7 +5,8 @@ from pystencilssfg.lang import strip_ptr_ref std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") sfg.include("<cassert>") def check_layout(field: ps.Field, mdspan: std.mdspan): diff --git a/tests/generator_scripts/source/MdSpanLbStreaming.py b/tests/generator_scripts/source/MdSpanLbStreaming.py index ad8a758..60049a8 100644 --- a/tests/generator_scripts/source/MdSpanLbStreaming.py +++ b/tests/generator_scripts/source/MdSpanLbStreaming.py @@ -43,7 +43,8 @@ def lbm_stream(sfg: SfgComposer, field_layout: str, layout_policy: str): ) -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") sfg.include("<cassert>") sfg.include("<array>") diff --git a/tests/generator_scripts/source/ScaleKernel.py b/tests/generator_scripts/source/ScaleKernel.py index 1d76dc7..2242a3b 100644 --- a/tests/generator_scripts/source/ScaleKernel.py +++ b/tests/generator_scripts/source/ScaleKernel.py @@ -2,7 +2,9 @@ from pystencils import TypedSymbol, fields, kernel from pystencilssfg import SourceFileGenerator -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") + N = 10 α = TypedSymbol("alpha", "float32") src, dst = fields(f"src, dst: float32[{N}]") diff --git a/tests/generator_scripts/source/StlContainers1D.py b/tests/generator_scripts/source/StlContainers1D.py index 260a650..91b2911 100644 --- a/tests/generator_scripts/source/StlContainers1D.py +++ b/tests/generator_scripts/source/StlContainers1D.py @@ -5,24 +5,23 @@ from pystencilssfg import SourceFileGenerator from pystencilssfg.lang.cpp import std -with SourceFileGenerator(namespace="StlContainers1D::gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("StlContainers1D::gen") + src, dst = ps.fields("src, dst: double[1D]") - asms = [ - ps.Assignment(dst[0], sp.Rational(1, 3) * (src[-1] + src[0] + src[1])) - ] + asms = [ps.Assignment(dst[0], sp.Rational(1, 3) * (src[-1] + src[0] + src[1]))] kernel = sfg.kernels.create(asms, "average") sfg.function("averageVector")( sfg.map_field(src, std.vector.from_field(src)), sfg.map_field(dst, std.vector.from_field(dst)), - sfg.call(kernel) + sfg.call(kernel), ) sfg.function("averageSpan")( sfg.map_field(src, std.span.from_field(src)), sfg.map_field(dst, std.span.from_field(dst)), - sfg.call(kernel) + sfg.call(kernel), ) - diff --git a/tests/generator_scripts/source/SyclBuffers.py b/tests/generator_scripts/source/SyclBuffers.py index 4668b3c..36234a8 100644 --- a/tests/generator_scripts/source/SyclBuffers.py +++ b/tests/generator_scripts/source/SyclBuffers.py @@ -4,8 +4,9 @@ from pystencilssfg import SourceFileGenerator import pystencilssfg.extensions.sycl as sycl -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: sfg = sycl.SyclComposer(sfg) + sfg.namespace("gen") u_src, u_dst, f = ps.fields("u_src, u_dst, f : double[2D]", layout="fzyx") h = sp.Symbol("h") diff --git a/tests/integration/cmake_project/GenTest.py b/tests/integration/cmake_project/GenTest.py index 093374c..81aec18 100644 --- a/tests/integration/cmake_project/GenTest.py +++ b/tests/integration/cmake_project/GenTest.py @@ -1,6 +1,7 @@ from pystencilssfg import SourceFileGenerator -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") retval = 42 if sfg.context.project_info is None else sfg.context.project_info sfg.function("getValue", return_type="int")( -- GitLab From 798ba31bfdbebea9c4d5e3edaa26491fb8dd88e1 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 10 Feb 2025 13:29:57 +0100 Subject: [PATCH 16/18] add test for nested namespaces --- tests/generator_scripts/index.yaml | 4 ++++ .../source/NestedNamespaces.harness.cpp | 12 ++++++++++++ .../source/NestedNamespaces.py | 19 +++++++++++++++++++ 3 files changed, 35 insertions(+) create mode 100644 tests/generator_scripts/source/NestedNamespaces.harness.cpp create mode 100644 tests/generator_scripts/source/NestedNamespaces.py diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml index eae4c39..b723799 100644 --- a/tests/generator_scripts/index.yaml +++ b/tests/generator_scripts/index.yaml @@ -48,6 +48,10 @@ Conditionals: - regex: if\s*\(\s*noodle\s==\sNoodles::RIGATONI\s\|\|\snoodle\s==\sNoodles::SPAGHETTI\s*\) count: 1 +NestedNamespaces: + sfg-args: + output-mode: header-only + # Kernel Generation ScaleKernel: diff --git a/tests/generator_scripts/source/NestedNamespaces.harness.cpp b/tests/generator_scripts/source/NestedNamespaces.harness.cpp new file mode 100644 index 0000000..ea7c465 --- /dev/null +++ b/tests/generator_scripts/source/NestedNamespaces.harness.cpp @@ -0,0 +1,12 @@ +#include "NestedNamespaces.hpp" + +static_assert( outer::X == 13 ); +static_assert( outer::inner::Y == 52 ); +static_assert( outer::Z == 41 ); +static_assert( outer::second_inner::W == 91 ); +static_assert( outer::inner::innermost::V == 29 ); +static_assert( GLOBAL == 42 ); + +int main() { + return 0; +} diff --git a/tests/generator_scripts/source/NestedNamespaces.py b/tests/generator_scripts/source/NestedNamespaces.py new file mode 100644 index 0000000..4af7bc7 --- /dev/null +++ b/tests/generator_scripts/source/NestedNamespaces.py @@ -0,0 +1,19 @@ +from pystencilssfg import SourceFileGenerator + +with SourceFileGenerator() as sfg: + + with sfg.namespace("outer"): + sfg.code("constexpr int X = 13;") + + with sfg.namespace("inner"): + sfg.code("constexpr int Y = 52;") + + sfg.code("constexpr int Z = 41;") + + with sfg.namespace("outer::second_inner"): + sfg.code("constexpr int W = 91;") + + with sfg.namespace("outer::inner::innermost"): + sfg.code("constexpr int V = 29;") + + sfg.code("constexpr int GLOBAL = 42;") -- GitLab From f9519e0f127ff90326b075f45d5a24624e9f3700 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 10 Feb 2025 13:59:45 +0100 Subject: [PATCH 17/18] add doc comment to sfg.namespace --- docs/source/conf.py | 2 +- docs/source/usage/generator_scripts.md | 1 - src/pystencilssfg/composer/basic_composer.py | 33 ++++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index da6f4d7..d6aab17 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -59,7 +59,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3.8", None), "numpy": ("https://numpy.org/doc/stable/", None), "sympy": ("https://docs.sympy.org/latest/", None), - "pystencils": ("https://da15siwa.pages.i10git.cs.fau.de/dev-docs/pystencils-nbackend/", None), + "pystencils": ("https://pycodegen.pages.i10git.cs.fau.de/docs/pystencils/2.0dev/", None), } # References diff --git a/docs/source/usage/generator_scripts.md b/docs/source/usage/generator_scripts.md index 4a1f6aa..3141fee 100644 --- a/docs/source/usage/generator_scripts.md +++ b/docs/source/usage/generator_scripts.md @@ -64,7 +64,6 @@ Structure and Verbatim Code: SfgBasicComposer.include SfgBasicComposer.namespace SfgBasicComposer.code - SfgBasicComposer.define_once ``` Kernels and Kernel Namespaces: diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 7c08b67..e75f0e2 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -204,6 +204,39 @@ class SfgBasicComposer(SfgIComposer): self.code(*definitions) def namespace(self, namespace: str): + """Enter a new namespace block. + + Calling `namespace` as a regular function will open a new namespace as a child of the + currently active namespace; this new namespace will then become active instead. + Using `namespace` as a context manager will instead activate the given namespace + only for the length of the ``with`` block. + + Args: + namespace: Qualified name of the namespace + + :Example: + + The following calls will set the current namespace to ``outer::inner`` + for the remaining code generation run: + + .. code-block:: + + sfg.namespace("outer") + sfg.namespace("inner") + + Subsequent calls to `namespace` can only create further nested namespaces. + + To step back out of a namespace, `namespace` can also be used as a context manager: + + .. code-block:: + + with sfg.namespace("detail"): + ... + + This way, code generated inside the ``with`` region is placed in the ``detail`` namespace, + and code after this block will again live in the enclosing namespace. + + """ return self._cursor.enter_namespace(namespace) def generate(self, generator: CustomGenerator): -- GitLab From 3be468c23c822fccb6d58a7f29af55cd25408bba Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 12 Feb 2025 13:15:27 +0100 Subject: [PATCH 18/18] use set().union instead of reduce. remove non-required imports in test_sycl. Add test for include sorting. --- src/pystencilssfg/ir/analysis.py | 22 +++++------------- tests/extensions/test_sycl.py | 2 -- tests/generator_scripts/index.yaml | 11 +++++++++ .../source/TestIncludeSorting.py | 23 +++++++++++++++++++ 4 files changed, 40 insertions(+), 18 deletions(-) create mode 100644 tests/generator_scripts/source/TestIncludeSorting.py diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index 5b6d2d6..4e43eb9 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -1,7 +1,5 @@ from __future__ import annotations -from functools import reduce - from ..lang import HeaderFile, includes from .syntax import ( SfgSourceFile, @@ -35,9 +33,7 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: | SfgMethod(_, _, parameters) | SfgConstructor(_, parameters, _, _) ): - incls: set[HeaderFile] = reduce( - lambda accu, p: accu | includes(p), parameters, set() - ) + incls: set[HeaderFile] = set().union(*(includes(p) for p in parameters)) if isinstance(entity, (SfgFunction, SfgMethod)): incls |= includes(entity.return_type) return incls @@ -61,10 +57,8 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: return set() case SfgCallTreeNode(): - return reduce( - lambda accu, child: accu | walk_syntax(child), - obj.children, - obj.required_includes, + return obj.required_includes.union( + *(walk_syntax(child) for child in obj.children), ) case SfgEntityDecl(entity): @@ -92,16 +86,12 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: assert False, "unexpected entity" case SfgNamespaceBlock(_, elements) | SfgVisibilityBlock(_, elements): - return reduce( - lambda accu, elem: accu | walk_syntax(elem), elements, set() - ) # type: ignore + return set().union(*(walk_syntax(elem) for elem in elements)) # type: ignore case SfgClassBody(_, vblocks): - return reduce( - lambda accu, vblock: accu | walk_syntax(vblock), vblocks, set() - ) + return set().union(*(walk_syntax(vb) for vb in vblocks)) case _: assert False, "unexpected syntax element" - return reduce(lambda accu, elem: accu | walk_syntax(elem), file.elements, set()) + return set().union(*(walk_syntax(elem) for elem in file.elements)) diff --git a/tests/extensions/test_sycl.py b/tests/extensions/test_sycl.py index 0e067c8..71effb6 100644 --- a/tests/extensions/test_sycl.py +++ b/tests/extensions/test_sycl.py @@ -1,8 +1,6 @@ import pytest -from pystencilssfg import SourceFileGenerator import pystencilssfg.extensions.sycl as sycl import pystencils as ps -from pystencilssfg import SfgContext def test_parallel_for_1_kernels(sfg): diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml index b723799..0e08e22 100644 --- a/tests/generator_scripts/index.yaml +++ b/tests/generator_scripts/index.yaml @@ -17,6 +17,17 @@ TestIllegalArgs: extra-args: [--sfg-file-extensionss, ".c++,.h++"] expect-failure: true +TestIncludeSorting: + sfg-args: + output-mode: header-only + expect-code: + hpp: + - regex: >- + #include\s\<memory>\s* + #include\s<vector>\s* + #include\s<array> + strip-whitespace: true + # Basic Composer Functionality BasicDefinitions: diff --git a/tests/generator_scripts/source/TestIncludeSorting.py b/tests/generator_scripts/source/TestIncludeSorting.py new file mode 100644 index 0000000..8a584f6 --- /dev/null +++ b/tests/generator_scripts/source/TestIncludeSorting.py @@ -0,0 +1,23 @@ +from pystencilssfg import SourceFileGenerator, SfgConfig +from pystencilssfg.lang import HeaderFile + + +def sortkey(h: HeaderFile): + try: + return [ + "memory", + "vector", + "array" + ].index(h.filepath) + except ValueError: + return 100 + + +cfg = SfgConfig() +cfg.codestyle.includes_sorting_key = sortkey + + +with SourceFileGenerator(cfg) as sfg: + sfg.include("<array>") + sfg.include("<memory>") + sfg.include("<vector>") -- GitLab