Skip to content
Snippets Groups Projects

Revisions

Merged Markus Holzer requested to merge holzer/pystencils:Revisions into master
Compare and
7 files
+ 120
48
Preferences
Compare changes
Files
7
@@ -443,9 +443,8 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols:
if isinstance(expr.exp, sp.Integer) and (-8 < expr.exp < 8):
raise NotImplementedError("This pow should be simplified already?")
# return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
return super(CustomSympyPrinter, self)._print_Pow(expr)
# TODO don't print ones in sp.Mul
@@ -508,13 +507,13 @@ class CustomSympyPrinter(CCodePrinter):
else:
return f"(({data_type})({self._print(arg)}))"
elif isinstance(expr, fast_division):
return f"({self._print(expr.args[0] / expr.args[1])})"
raise ValueError("fast_division is only supported for Taget.GPU")
elif isinstance(expr, fast_sqrt):
return f"({self._print(sp.sqrt(expr.args[0]))})"
raise ValueError("fast_sqrt is only supported for Taget.GPU")
elif isinstance(expr, fast_inv_sqrt):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt):
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, sp.Abs):
return f"abs({self._print(expr.args[0])})"
elif isinstance(expr, sp.Mod):
@@ -681,21 +680,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend),
**self._kwargs)
return result
elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr)
if not result:
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]),
**self._kwargs)
return result
elif expr.func == fast_sqrt:
return f"({self._print(sp.sqrt(expr.args[0]))})"
elif expr.func == fast_inv_sqrt:
result = self._scalarFallback('_print_Function', expr)
if not result:
if 'rsqrt' in self.instruction_set:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs)
else:
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, fast_division):
raise ValueError("fast_division is only supported for Taget.GPU")
elif isinstance(expr, fast_sqrt):
raise ValueError("fast_sqrt is only supported for Taget.GPU")
elif isinstance(expr, fast_inv_sqrt):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
instr = 'any' if isinstance(expr, vec_any) else 'all'
expr_type = get_type_of_expression(expr.args[0])