Skip to content
Snippets Groups Projects
Commit d8e498fa authored by Martin Bauer's avatar Martin Bauer
Browse files

Workaround for sympy bug in placeholder_function

see https://github.com/sympy/sympy/issues/16662
parent 27a131fb
No related merge requests found
......@@ -2,6 +2,7 @@ import sympy as sp
from typing import List
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.sympyextensions import is_constant
from pystencils.transformations import generic_visit
......@@ -37,11 +38,11 @@ def to_placeholder_function(expr, name):
assignments = [Assignment(sp.Symbol(name), expr)]
assignments += [Assignment(symbol, derivative)
for symbol, derivative in zip(derivative_symbols, derivatives)
if not derivative.is_constant()]
if not is_constant(derivative)]
def fdiff(_, index):
result = derivatives[index - 1]
return result if result.is_constant() else derivative_symbols[index - 1]
return result if is_constant(result) else derivative_symbols[index - 1]
func = type(name, (sp.Function, PlaceholderFunction),
{'fdiff': fdiff,
......
......@@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict,
return visit(expression)
def is_constant(expr):
"""Simple version of checking if a sympy expression is constant.
Works also for piecewise defined functions - sympy's is_constant() has a problem there, see:
https://github.com/sympy/sympy/issues/16662
"""
return len(expr.free_symbols) == 0
def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
required_match_replacement: Optional[Union[int, float]] = 0.5,
required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment