From faf330f83b8fdec3974b532bd957d81f271369ff Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 9 Jul 2019 14:01:08 +0200
Subject: [PATCH 1/7] Add CudaBackend, CudaSympyPrinter

---
 pystencils/backends/cbackend.py              |  52 ++--
 pystencils/backends/cuda_backend.py          |  88 ++++++
 pystencils/backends/cuda_known_functions.txt | 293 +++++++++++++++++++
 3 files changed, 406 insertions(+), 27 deletions(-)
 create mode 100644 pystencils/backends/cuda_backend.py
 create mode 100644 pystencils/backends/cuda_known_functions.txt

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 7c4937d1f..cbd75e178 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -32,6 +32,11 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy
 KERNCRAFT_NO_TERNARY_MODE = False
 
 
+class UnsupportedCDialect(Exception):
+    def __init__(self):
+        super(UnsupportedCDialect, self).__init__()
+
+
 def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
     """Prints an abstract syntax tree node as C or CUDA code.
 
@@ -52,9 +57,15 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
             ast_node.global_variables.update(d.symbols_defined)
         else:
             ast_node.global_variables = d.symbols_defined
-    printer = CBackend(signature_only=signature_only,
-                       vector_instruction_set=ast_node.instruction_set,
-                       dialect=dialect)
+
+    if dialect == 'c':
+        printer = CBackend(signature_only=signature_only,
+                           vector_instruction_set=ast_node.instruction_set)
+    elif dialect == 'cuda':
+        from pystencils.backends.cuda_backend import CudaBackend
+        printer = CudaBackend(signature_only=signature_only)
+    else:
+        raise UnsupportedCDialect
     code = printer(ast_node)
     if not signature_only and isinstance(ast_node, KernelFunction):
         code = "\n" + code
@@ -141,9 +152,9 @@ class CBackend:
     def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
         if sympy_printer is None:
             if vector_instruction_set is not None:
-                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
+                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
             else:
-                self.sympy_printer = CustomSympyPrinter(dialect)
+                self.sympy_printer = CustomSympyPrinter()
         else:
             self.sympy_printer = sympy_printer
 
@@ -164,12 +175,12 @@ class CBackend:
             method_name = "_print_" + cls.__name__
             if hasattr(self, method_name):
                 return getattr(self, method_name)(node)
-        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
+        raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node)))
 
     def _print_KernelFunction(self, node):
         function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
         launch_bounds = ""
-        if self._dialect == 'cuda':
+        if self.__class__ == 'cuda':
             max_threads = node.indexing.max_threads_per_block()
             if max_threads:
                 launch_bounds = "__launch_bounds__({}) ".format(max_threads)
@@ -241,10 +252,7 @@ class CBackend:
         return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
 
     def _print_SkipIteration(self, _):
-        if self._dialect == 'cuda':
-            return "return;"
-        else:
-            return "continue;"
+        return "continue;"
 
     def _print_CustomCodeNode(self, node):
         return node.get_code(self._dialect, self._vector_instruction_set)
@@ -292,10 +300,9 @@ class CBackend:
 # noinspection PyPep8Naming
 class CustomSympyPrinter(CCodePrinter):
 
-    def __init__(self, dialect):
+    def __init__(self):
         super(CustomSympyPrinter, self).__init__()
         self._float_type = create_type("float32")
-        self._dialect = dialect
         if 'Min' in self.known_functions:
             del self.known_functions['Min']
         if 'Max' in self.known_functions:
@@ -347,22 +354,13 @@ class CustomSympyPrinter(CCodePrinter):
             else:
                 return "((%s)(%s))" % (data_type, self._print(arg))
         elif isinstance(expr, fast_division):
-            if self._dialect == "cuda":
-                return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
-            else:
-                return "({})".format(self._print(expr.args[0] / expr.args[1]))
+            return "({})".format(self._print(expr.args[0] / expr.args[1]))
         elif isinstance(expr, fast_sqrt):
-            if self._dialect == "cuda":
-                return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
-            else:
-                return "({})".format(self._print(sp.sqrt(expr.args[0])))
+            return "({})".format(self._print(sp.sqrt(expr.args[0])))
         elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
             return self._print(expr.args[0])
         elif isinstance(expr, fast_inv_sqrt):
-            if self._dialect == "cuda":
-                return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
-            else:
-                return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
+            return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
         elif expr.func in infix_functions:
             return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
         elif expr.func == int_power_of_2:
@@ -392,8 +390,8 @@ class CustomSympyPrinter(CCodePrinter):
 class VectorizedCustomSympyPrinter(CustomSympyPrinter):
     SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
 
-    def __init__(self, instruction_set, dialect):
-        super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect)
+    def __init__(self, instruction_set):
+        super(VectorizedCustomSympyPrinter, self).__init__()
         self.instruction_set = instruction_set
 
     def _scalarFallback(self, func_name, expr, *args, **kwargs):
diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py
new file mode 100644
index 000000000..e9a78160d
--- /dev/null
+++ b/pystencils/backends/cuda_backend.py
@@ -0,0 +1,88 @@
+
+from os.path import dirname, join
+
+from pystencils.astnodes import Node
+from pystencils.backends.cbackend import (CBackend, CustomSympyPrinter,
+                                          generate_c)
+from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
+                                           fast_sqrt)
+
+CUDA_KNOWN_FUNCTIONS = None
+with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
+    lines = f.readlines()
+    CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
+
+
+def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
+    """Prints an abstract syntax tree node as CUDA code.
+
+    Args:
+        ast_node:
+        signature_only:
+
+    Returns:
+        C-like code for the ast node and its descendants
+    """
+    return generate_c(astnode, signature_only, dialect='cuda')
+
+
+class CudaBackend(CBackend):
+
+    def __init__(self, sympy_printer=None,
+                 signature_only=False):
+        if not sympy_printer:
+            sympy_printer = CudaSympyPrinter()
+
+        super().__init__(sympy_printer, signature_only, dialect='cuda')
+
+    def _print_SharedMemoryAllocation(self, node):
+        code = "__shared__ {dtype} {name}[{num_elements}];"
+        return code.format(dtype=node.symbol.dtype,
+                           name=self.sympy_printer.doprint(node.symbol.name),
+                           num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
+
+    def _print_ThreadBlockSynchronization(self, node):
+        code = "__synchtreads();"
+        return code
+
+    def _print_TextureDeclaration(self, node):
+        code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
+            str(node.texture.field.dtype),
+            node.texture.field.spatial_dimensions,
+            node.texture
+        )
+        return code
+
+    def _print_SkipIteration(self, _):
+        return "return;"
+
+
+class CudaSympyPrinter(CustomSympyPrinter):
+
+    def __init__(self):
+        super(CudaSympyPrinter, self).__init__()
+        self.known_functions = CUDA_KNOWN_FUNCTIONS
+
+    def _print_TextureAccess(self, node):
+
+        if node.texture.cubic_bspline_interpolation:
+            template = "cubicTex%iDSimple<%s>(%s, %s)"
+        else:
+            template = "tex%iD<%s>(%s, %s)"
+
+        code = template % (
+            node.texture.field.spatial_dimensions,
+            str(node.texture.field.dtype),
+            str(node.texture),
+            ', '.join(self._print(o) for o in node.offsets)
+        )
+        return code
+
+    def _print_Function(self, expr):
+        if isinstance(expr, fast_division):
+            return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
+        elif isinstance(expr, fast_sqrt):
+            return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
+        elif isinstance(expr, fast_inv_sqrt):
+            return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
+        return super()._print_Function(expr)
diff --git a/pystencils/backends/cuda_known_functions.txt b/pystencils/backends/cuda_known_functions.txt
new file mode 100644
index 000000000..42cf554ad
--- /dev/null
+++ b/pystencils/backends/cuda_known_functions.txt
@@ -0,0 +1,293 @@
+__prof_trigger
+printf
+
+__syncthreads
+__syncthreads_count
+__syncthreads_and
+__syncthreads_or
+__syncwarp
+__threadfence
+__threadfence_block
+__threadfence_system
+
+atomicAdd
+atomicSub
+atomicExch
+atomicMin
+atomicMax
+atomicInc
+atomicDec
+atomicAnd
+atomicOr
+atomicXor
+atomicCAS
+
+__all_sync
+__any_sync
+__ballot_sync
+__active_mask
+
+__shfl_sync
+__shfl_up_sync
+__shfl_down_sync
+__shfl_xor_sync
+
+__match_any_sync
+__match_all_sync
+
+__isGlobal
+__isShared
+__isConstant
+__isLocal
+
+tex1Dfetch
+tex1D
+tex2D
+tex3D
+
+rsqrtf
+cbrtf
+rcbrtf
+hypotf
+rhypotf
+norm3df
+rnorm3df
+norm4df
+rnorm4df
+normf
+rnormf
+expf
+exp2f
+exp10f
+expm1f
+logf
+log2f
+log10f
+log1pf
+sinf
+cosf
+tanf
+sincosf
+sinpif
+cospif
+sincospif
+asinf
+acosf
+atanf
+atan2f
+sinhf
+coshf
+tanhf
+asinhf
+acoshf
+atanhf
+powf
+erff
+erfcf
+erfinvf
+erfcinvf
+erfcxf
+normcdff
+normcdfinvf
+lgammaf
+tgammaf
+fmaf
+frexpf
+ldexpf
+scalbnf
+scalblnf
+logbf
+ilogbf
+j0f
+j1f
+jnf
+y0f
+y1f
+ynf
+cyl_bessel_i0f
+cyl_bessel_i1f
+fmodf
+remainderf
+remquof
+modff
+fdimf
+truncf
+roundf
+rintf
+nearbyintf
+ceilf
+floorf
+lrintf
+lroundf
+llrintf
+llroundf
+
+sqrt
+rsqrt
+cbrt
+rcbrt
+hypot
+rhypot
+norm3d
+rnorm3d
+norm4d
+rnorm4d
+norm
+rnorm
+exp
+exp2
+exp10
+expm1
+log
+log2
+log10
+log1p
+sin
+cos
+tan
+sincos
+sinpi
+cospi
+sincospi
+asin
+acos
+atan
+atan2
+sinh
+cosh
+tanh
+asinh
+acosh
+atanh
+pow
+erf
+erfc
+erfinv
+erfcinv
+erfcx
+normcdf
+normcdfinv
+lgamma
+tgamma
+fma
+frexp
+ldexp
+scalbn
+scalbln
+logb
+ilogb
+j0
+j1
+jn
+y0
+y1
+yn
+cyl_bessel_i0
+cyl_bessel_i1
+fmod
+remainder
+remquo
+mod
+fdim
+trunc
+round
+rint
+nearbyint
+ceil
+floor
+lrint
+lround
+llrint
+llround
+
+__fdividef
+__sinf
+__cosf
+__tanf
+__sincosf
+__logf
+__log2f
+__log10f
+__expf
+__exp10f
+__powf
+
+__fadd_rn
+__fsub_rn
+__fmul_rn
+__fmaf_rn
+__frcp_rn
+__fsqrt_rn
+__frsqrt_rn
+__fdiv_rn
+
+__fadd_rz
+__fsub_rz
+__fmul_rz
+__fmaf_rz
+__frcp_rz
+__fsqrt_rz
+__frsqrt_rz
+__fdiv_rz
+
+__fadd_ru
+__fsub_ru
+__fmul_ru
+__fmaf_ru
+__frcp_ru
+__fsqrt_ru
+__frsqrt_ru
+__fdiv_ru
+
+__fadd_rd
+__fsub_rd
+__fmul_rd
+__fmaf_rd
+__frcp_rd
+__fsqrt_rd
+__frsqrt_rd
+__fdiv_rd
+
+__fdividef
+__expf
+__exp10f
+__logf
+__log2f
+__log10f
+__sinf
+__cosf
+__sincosf
+__tanf
+__powf
+
+__dadd_rn
+__dsub_rn
+__dmul_rn
+__fma_rn
+__ddiv_rn
+__drcp_rn
+__dsqrt_rn
+
+__dadd_rz
+__dsub_rz
+__dmul_rz
+__fma_rz
+__ddiv_rz
+__drcp_rz
+__dsqrt_rz
+
+__dadd_ru
+__dsub_ru
+__dmul_ru
+__fma_ru
+__ddiv_ru
+__drcp_ru
+__dsqrt_ru
+
+__dadd_rd
+__dsub_rd
+__dmul_rd
+__fma_rd
+__ddiv_rd
+__drcp_rd
+__dsqrt_rd
-- 
GitLab


From e5700eb414ce4e6f2d1548c183a6a237c987b7e8 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 12 Jul 2019 17:05:35 +0200
Subject: [PATCH 2/7] Add `get_dummy_symbol`

---
 pystencils/astnodes.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 2d3174a1a..c0472f984 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -1,3 +1,4 @@
+import uuid
 from typing import Any, List, Optional, Sequence, Set, Union
 
 import jinja2
@@ -700,3 +701,7 @@ class DestructuringBindingsForFieldClass(Node):
 
     def atoms(self, arg_type) -> Set[Any]:
         return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)}
+
+
+def get_dummy_symbol(dtype='bool'):
+    return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype))
-- 
GitLab


From dbc890ac4e50e7134d387465e73f6d45ab1e77e6 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 12 Jul 2019 17:38:49 +0200
Subject: [PATCH 3/7] Add test_cuda_known_functions.py

---
 pystencils_tests/test_cuda_known_functions.py | 68 +++++++++++++++++++
 1 file changed, 68 insertions(+)
 create mode 100644 pystencils_tests/test_cuda_known_functions.py

diff --git a/pystencils_tests/test_cuda_known_functions.py b/pystencils_tests/test_cuda_known_functions.py
new file mode 100644
index 000000000..ca0d12053
--- /dev/null
+++ b/pystencils_tests/test_cuda_known_functions.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+import sympy
+
+import pystencils
+from pystencils.astnodes import get_dummy_symbol
+from pystencils.backends.cuda_backend import CudaSympyPrinter
+from pystencils.data_types import address_of
+
+
+def test_cuda_known_functions():
+    printer = CudaSympyPrinter()
+    print(printer.known_functions)
+
+    x, y = pystencils.fields('x,y: float32 [2d]')
+
+    assignments = pystencils.AssignmentCollection({
+        get_dummy_symbol(): sympy.Function('atomicAdd')(address_of(y.center()), 2),
+        y.center():  sympy.Function('rsqrtf')(x[0, 0])
+    })
+
+    ast = pystencils.create_kernel(assignments, 'gpu')
+    print(pystencils.show_code(ast))
+    kernel = ast.compile()
+    assert(kernel is not None)
+
+
+def test_cuda_but_not_c():
+    x, y = pystencils.fields('x,y: float32 [2d]')
+
+    assignments = pystencils.AssignmentCollection({
+        get_dummy_symbol(): sympy.Function('atomicAdd')(address_of(y.center()), 2),
+        y.center():  sympy.Function('rsqrtf')(x[0, 0])
+    })
+
+    ast = pystencils.create_kernel(assignments, 'cpu')
+    code = str(pystencils.show_code(ast))
+    assert "Not supported" in code
+
+
+def test_cuda_unknown():
+    x, y = pystencils.fields('x,y: float32 [2d]')
+
+    assignments = pystencils.AssignmentCollection({
+        get_dummy_symbol(): sympy.Function('wtf')(address_of(y.center()), 2),
+    })
+
+    ast = pystencils.create_kernel(assignments, 'gpu')
+    code = str(pystencils.show_code(ast))
+    print(code)
+    assert "Not supported" in code
+
+
+def main():
+    test_cuda_known_functions()
+    test_cuda_but_not_c()
+    test_cuda_unknown()
+
+
+if __name__ == '__main__':
+    main()
-- 
GitLab


From 8a651aa4e3d467f66b1f036dbf83afbcff734728 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 12 Jul 2019 18:01:10 +0200
Subject: [PATCH 4/7] Make CUDA a language (not only a dialect :wink:)

---
 pystencils/backends/cuda_backend.py           | 1 +
 pystencils_tests/test_cuda_known_functions.py | 2 +-
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py
index e9a78160d..15fbde91d 100644
--- a/pystencils/backends/cuda_backend.py
+++ b/pystencils/backends/cuda_backend.py
@@ -58,6 +58,7 @@ class CudaBackend(CBackend):
 
 
 class CudaSympyPrinter(CustomSympyPrinter):
+    language = "CUDA"
 
     def __init__(self):
         super(CudaSympyPrinter, self).__init__()
diff --git a/pystencils_tests/test_cuda_known_functions.py b/pystencils_tests/test_cuda_known_functions.py
index ca0d12053..c249c144b 100644
--- a/pystencils_tests/test_cuda_known_functions.py
+++ b/pystencils_tests/test_cuda_known_functions.py
@@ -55,7 +55,7 @@ def test_cuda_unknown():
     ast = pystencils.create_kernel(assignments, 'gpu')
     code = str(pystencils.show_code(ast))
     print(code)
-    assert "Not supported" in code
+    assert "Not supported in CUDA" in code
 
 
 def main():
-- 
GitLab


From 19f54169af547d8d33b42b5948b174c750476059 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 15 Jul 2019 09:18:36 +0200
Subject: [PATCH 5/7] Allow custom backends for cpujit and gpujit

---
 pystencils/backends/cbackend.py |  7 ++++---
 pystencils/cpu/cpujit.py        | 30 ++++++++++++++++--------------
 pystencils/gpucuda/cudajit.py   |  4 ++--
 3 files changed, 22 insertions(+), 19 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index cbd75e178..9cab5a26d 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -37,7 +37,7 @@ class UnsupportedCDialect(Exception):
         super(UnsupportedCDialect, self).__init__()
 
 
-def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
+def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
     """Prints an abstract syntax tree node as C or CUDA code.
 
     This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
@@ -57,8 +57,9 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
             ast_node.global_variables.update(d.symbols_defined)
         else:
             ast_node.global_variables = d.symbols_defined
-
-    if dialect == 'c':
+    if custom_backend:
+        printer = custom_backend
+    elif dialect == 'c':
         printer = CBackend(signature_only=signature_only,
                            vector_instruction_set=ast_node.instruction_set)
     elif dialect == 'cuda':
diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index 09128360d..76255c1ec 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -43,28 +43,28 @@ Then 'cl.exe' is used to compile.
   For Windows compilers the qualifier should be ``__restrict``
 
 """
-import os
 import hashlib
 import json
+import os
 import platform
 import shutil
+import subprocess
 import textwrap
+from collections import OrderedDict
+from sysconfig import get_paths
 from tempfile import TemporaryDirectory
 
 import numpy as np
-import subprocess
-from appdirs import user_config_dir, user_cache_dir
-from collections import OrderedDict
+from appdirs import user_cache_dir, user_config_dir
 
-from pystencils.utils import recursive_dict_update
-from sysconfig import get_paths
 from pystencils import FieldType
 from pystencils.backends.cbackend import generate_c, get_headers
-from pystencils.utils import file_handle_for_atomic_write, atomic_file_write
 from pystencils.include import get_pystencils_include_path
+from pystencils.utils import (atomic_file_write, file_handle_for_atomic_write,
+                              recursive_dict_update)
 
 
-def make_python_function(kernel_function_node):
+def make_python_function(kernel_function_node, custom_backend=None):
     """
     Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function
 
@@ -75,7 +75,7 @@ def make_python_function(kernel_function_node):
     :param kernel_function_node: the abstract syntax tree
     :return: kernel functor
     """
-    result = compile_and_load(kernel_function_node)
+    result = compile_and_load(kernel_function_node, custom_backend)
     return result
 
 
@@ -424,11 +424,12 @@ def run_compile_step(command):
 
 
 class ExtensionModuleCode:
-    def __init__(self, module_name='generated'):
+    def __init__(self, module_name='generated', custom_backend=None):
         self.module_name = module_name
 
         self._ast_nodes = []
         self._function_names = []
+        self._custom_backend = custom_backend
 
     def add_function(self, ast, name=None):
         self._ast_nodes.append(ast)
@@ -452,7 +453,7 @@ class ExtensionModuleCode:
         for ast, name in zip(self._ast_nodes, self._function_names):
             old_name = ast.function_name
             ast.function_name = "kernel_" + name
-            print(generate_c(ast), file=file)
+            print(generate_c(ast, custom_backend=self._custom_backend), file=file)
             print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
             ast.function_name = old_name
         print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
@@ -515,10 +516,11 @@ def compile_module(code, code_hash, base_dir):
     return lib_file
 
 
-def compile_and_load(ast):
+def compile_and_load(ast, custom_backend=None):
     cache_config = get_cache_config()
-    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c').encode()).hexdigest()
-    code = ExtensionModuleCode(module_name=code_hash_str)
+    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c',
+                                                       custom_backend=custom_backend).encode()).hexdigest()
+    code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend)
     code.add_function(ast, ast.function_name)
 
     if cache_config['object_cache'] is False:
diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py
index 90c23b133..21465433c 100644
--- a/pystencils/gpucuda/cudajit.py
+++ b/pystencils/gpucuda/cudajit.py
@@ -9,7 +9,7 @@ from pystencils.include import get_pystencils_include_path
 USE_FAST_MATH = True
 
 
-def make_python_function(kernel_function_node, argument_dict=None):
+def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None):
     """
     Creates a kernel function from an abstract syntax tree which
     was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel`
@@ -35,7 +35,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
     code = includes + "\n"
     code += "#define FUNC_PREFIX __global__\n"
     code += "#define RESTRICT __restrict__\n\n"
-    code += str(generate_c(kernel_function_node, dialect='cuda'))
+    code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend))
     options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
     if USE_FAST_MATH:
         options.append("-use_fast_math")
-- 
GitLab


From 9bb1e1420f2329a996d3db58ad7c68a28477bc8c Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 15 Jul 2019 10:14:43 +0200
Subject: [PATCH 6/7] Add custom_backend to pystencils.show_code

---
 pystencils/display_utils.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/pystencils/display_utils.py b/pystencils/display_utils.py
index 55a4720c1..8cdaa4820 100644
--- a/pystencils/display_utils.py
+++ b/pystencils/display_utils.py
@@ -1,5 +1,7 @@
-import sympy as sp
 from typing import Any, Dict, Optional
+
+import sympy as sp
+
 from pystencils.astnodes import KernelFunction
 
 
@@ -32,7 +34,7 @@ def highlight_cpp(code: str):
     return HTML(highlight(code, CppLexer(), HtmlFormatter()))
 
 
-def show_code(ast: KernelFunction):
+def show_code(ast: KernelFunction, custom_backend=None):
     """Returns an object to display generated code (C/C++ or CUDA)
 
     Can either  be displayed as HTML in Jupyter notebooks or printed as normal string.
@@ -45,11 +47,11 @@ def show_code(ast: KernelFunction):
             self.ast = ast_input
 
         def _repr_html_(self):
-            return highlight_cpp(generate_c(self.ast, dialect=dialect)).__html__()
+            return highlight_cpp(generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)).__html__()
 
         def __str__(self):
-            return generate_c(self.ast, dialect=dialect)
+            return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)
 
         def __repr__(self):
-            return generate_c(self.ast, dialect=dialect)
+            return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)
     return CodeDisplay(ast)
-- 
GitLab


From 7503c866ef0468ac3699d23f371cd67305987285 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 15 Jul 2019 10:55:51 +0200
Subject: [PATCH 7/7] Add test for custom backends

---
 pystencils_tests/test_custom_backends.py | 60 ++++++++++++++++++++++++
 1 file changed, 60 insertions(+)
 create mode 100644 pystencils_tests/test_custom_backends.py

diff --git a/pystencils_tests/test_custom_backends.py b/pystencils_tests/test_custom_backends.py
new file mode 100644
index 000000000..f68696f13
--- /dev/null
+++ b/pystencils_tests/test_custom_backends.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+from subprocess import CalledProcessError
+
+import pycuda.driver
+import pytest
+import sympy
+
+import pystencils
+import pystencils.cpu.cpujit
+import pystencils.gpucuda.cudajit
+from pystencils.backends.cbackend import CBackend
+from pystencils.backends.cuda_backend import CudaBackend
+
+
+class ScreamingBackend(CBackend):
+
+    def _print(self, node):
+        normal_code = super()._print(node)
+        return normal_code.upper()
+
+
+class ScreamingGpuBackend(CudaBackend):
+
+    def _print(self, node):
+        normal_code = super()._print(node)
+        return normal_code.upper()
+
+
+def test_custom_backends():
+    z, x, y = pystencils.fields("z, y, x: [2d]")
+
+    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
+        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
+
+    ast = pystencils.create_kernel(normal_assignments, target='cpu')
+    print(pystencils.show_code(ast, ScreamingBackend()))
+    with pytest.raises(CalledProcessError):
+        pystencils.cpu.cpujit.make_python_function(ast, custom_backend=ScreamingBackend())
+
+    ast = pystencils.create_kernel(normal_assignments, target='gpu')
+    print(pystencils.show_code(ast, ScreamingGpuBackend()))
+    with pytest.raises(pycuda.driver.CompileError):
+        pystencils.gpucuda.cudajit.make_python_function(ast, custom_backend=ScreamingGpuBackend())
+
+
+def main():
+
+    test_custom_backends()
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab