Skip to content
Snippets Groups Projects
Commit c1ca74ad authored by Jan Hönig's avatar Jan Hönig Committed by Markus Holzer
Browse files

Kernel decorator fix

parent 157c4f11
Branches
Tags
No related merge requests found
...@@ -20,7 +20,9 @@ pystencils/boundaries/createindexlistcython.*.so ...@@ -20,7 +20,9 @@ pystencils/boundaries/createindexlistcython.*.so
pystencils_tests/tmp pystencils_tests/tmp
pystencils_tests/kerncraft_inputs/.2d-5pt.c_kerncraft/ pystencils_tests/kerncraft_inputs/.2d-5pt.c_kerncraft/
pystencils_tests/kerncraft_inputs/.3d-7pt.c_kerncraft/ pystencils_tests/kerncraft_inputs/.3d-7pt.c_kerncraft/
report.xml
coverage_report/
# macOS # macOS
**/.DS_Store **/.DS_Store
\ No newline at end of file
...@@ -7,7 +7,7 @@ from .data_types import TypedSymbol ...@@ -7,7 +7,7 @@ from .data_types import TypedSymbol
from .datahandling import create_data_handling from .datahandling import create_data_handling
from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields from .field import Field, FieldType, fields
from .kernel_decorator import kernel from .kernel_decorator import kernel, kernel_config
from .kernelcreation import ( from .kernelcreation import (
CreateKernelConfig, create_domain_kernel, create_indexed_kernel, create_kernel, create_staggered_kernel) CreateKernelConfig, create_domain_kernel, create_indexed_kernel, create_kernel, create_staggered_kernel)
from .simp import AssignmentCollection from .simp import AssignmentCollection
...@@ -34,7 +34,7 @@ __all__ = ['Field', 'FieldType', 'fields', ...@@ -34,7 +34,7 @@ __all__ = ['Field', 'FieldType', 'fields',
'assignment_from_stencil', 'assignment_from_stencil',
'SymbolCreator', 'SymbolCreator',
'create_data_handling', 'create_data_handling',
'kernel', 'kernel', 'kernel_config',
'x_', 'y_', 'z_', 'x_', 'y_', 'z_',
'x_staggered', 'y_staggered', 'z_staggered', 'x_staggered', 'y_staggered', 'z_staggered',
'x_vector', 'x_staggered_vector', 'x_vector', 'x_staggered_vector',
......
import ast import ast
import inspect import inspect
import textwrap import textwrap
from typing import Callable, Union, List, Dict from typing import Callable, Union, List, Dict, Tuple
import sympy as sp import sympy as sp
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator from pystencils.sympyextensions import SymbolCreator
from pystencils.kernelcreation import CreateKernelConfig
__all__ = ['kernel'] __all__ = ['kernel', 'kernel_config']
def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) -> Union[List[Assignment], Dict]: def _kernel(func: Callable[..., None], **kwargs) -> Tuple[List[Assignment], str]:
"""Decorator to simplify generation of pystencils Assignments. """
Convenient function for kernel decorator to prevent code duplication
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment Args:
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a func: decorated function
sympy Piecewise. **kwargs: kwargs for the function
Returns:
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies assignments, function_name
a SymbolCreator()
func: the decorated function
return_config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
Examples:
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
""" """
source = inspect.getsource(func) source = inspect.getsource(func)
source = textwrap.dedent(source) source = textwrap.dedent(source)
...@@ -55,10 +42,74 @@ def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) -> ...@@ -55,10 +42,74 @@ def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) ->
if 's' in args and 's' not in kwargs: if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator() kwargs['s'] = SymbolCreator()
func(**kwargs) func(**kwargs)
if return_config: return assignments, func.__name__
return {'assignments': assignments, 'function_name': func.__name__}
else:
return assignments def kernel(func: Callable[..., None], **kwargs) -> List[Assignment]:
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
Args:
func: decorated function
**kwargs: kwargs for the function
Examples:
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
"""
assignments, _ = _kernel(func, **kwargs)
return assignments
def kernel_config(config: CreateKernelConfig, **kwargs) -> Callable[..., Dict]:
"""Decorator to simplify generation of pystencils Assignments, which takes a configuration
and updates the function name accordingly.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
Args:
config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
Returns:
decorator with config
Examples:
>>> import pystencils as ps
>>> config = ps.CreateKernelConfig()
>>> @kernel_config(config)
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel['assignments'][0].rhs == f[0,1] + f[1,0]
"""
def decorator(func: Callable[..., None]) -> Union[List[Assignment], Dict]:
"""
Args:
func: decorated function
Returns:
Dict for unpacking into create_kernel
"""
assignments, func_name = _kernel(func, **kwargs)
config.function_name = func_name
return {'assignments': assignments, 'config': config}
return decorator
# noinspection PyMethodMayBeStatic # noinspection PyMethodMayBeStatic
......
import numpy as np
import pystencils as ps import pystencils as ps
...@@ -15,3 +16,14 @@ def test_create_kernel_config(): ...@@ -15,3 +16,14 @@ def test_create_kernel_config():
c = ps.CreateKernelConfig(backend=ps.Backend.CUDA) c = ps.CreateKernelConfig(backend=ps.Backend.CUDA)
assert c.target == ps.Target.CPU assert c.target == ps.Target.CPU
assert c.backend == ps.Backend.CUDA assert c.backend == ps.Backend.CUDA
def test_kernel_decorator_config():
config = ps.CreateKernelConfig()
a, b, c = ps.fields(a=np.ones(100), b=np.ones(100), c=np.ones(100))
@ps.kernel_config(config)
def test():
a[0] @= b[0] + c[0]
ps.create_kernel(**test)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment