Skip to content
Snippets Groups Projects
Commit d93c549c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix typing of floor/ceil

parent e24cf411
Branches suffa/SYCL
1 merge request!406Fix typing of floor/ceil
Pipeline #67656 passed with stages
in 22 minutes and 27 seconds
......@@ -11,6 +11,7 @@ from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom
......@@ -213,7 +214,7 @@ class TypeAdder:
new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction, sp.log)):
HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
......
......@@ -39,7 +39,7 @@ def test_two_arguments(dtype, func, target):
@pytest.mark.parametrize('dtype', ["float64", "float32"])
@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan])
@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan, sp.floor, sp.ceiling])
@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
def test_single_arguments(dtype, func, target):
if target == ps.Target.GPU:
......@@ -58,7 +58,8 @@ def test_single_arguments(dtype, func, target):
ast = ps.create_kernel(up, config=config)
code = ps.get_code_str(ast)
if dtype == 'float32':
assert func.__name__.lower() in code
func_name = func.__name__.lower() if func is not sp.ceiling else "ceil"
assert func_name in code
kernel = ast.compile()
dh.all_to_gpu()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment