From b13958d5d78b2add8f24ed1df3d8a09346521166 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 3 Dec 2020 22:26:04 +0100 Subject: [PATCH] Extraction of factor 1/2 --- .../centeredcumulantmethod.py | 4 +- .../central_moment_transform.py | 51 +++++++++++++++---- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/lbmpy/methods/centeredcumulant/centeredcumulantmethod.py b/lbmpy/methods/centeredcumulant/centeredcumulantmethod.py index 83c23f2..dffebe2 100644 --- a/lbmpy/methods/centeredcumulant/centeredcumulantmethod.py +++ b/lbmpy/methods/centeredcumulant/centeredcumulantmethod.py @@ -208,7 +208,9 @@ class CenteredCumulantBasedLbMethod(AbstractLbMethod): def get_collision_rule(self, pre_simplification=False): """Returns an LbmCollisionRule i.e. an equation collection with a reference to the method. This collision rule defines the collision operator.""" - return self._centered_cumulant_collision_rule(self._cumulant_to_relaxation_info_dict, None, pre_simplification, True) + ac = self._centered_cumulant_collision_rule(self._cumulant_to_relaxation_info_dict, None, pre_simplification, True) + ac = ac.new_without_unused_subexpressions() + return LbmCollisionRule(self, ac.main_assignments, ac.subexpressions) # ------------------------------- Internals -------------------------------------------- diff --git a/lbmpy/methods/centeredcumulant/central_moment_transform.py b/lbmpy/methods/centeredcumulant/central_moment_transform.py index 70060ca..4874aac 100644 --- a/lbmpy/methods/centeredcumulant/central_moment_transform.py +++ b/lbmpy/methods/centeredcumulant/central_moment_transform.py @@ -91,6 +91,23 @@ class PdfsToCentralMomentsByMatrix(AbstractCentralMomentTransform): # end class PdfsToCentralMomentsByMatrix +class ExtractOneHalf: + """ + Pseudo-Simplification to instruct the FastCentralMomentTransform to extract + the factor 1/2 to a subexpression, to hide it from sympy. Otherwise, sympy will + distribute it across the sums, producing unnecessary multiplications. + """ + def __init__(self, one_half_proxy=sp.Symbol('half')): + self._symbol = one_half_proxy + + @property + def symbol(self): + return self._symbol + + def __call__(self, ac): + return ac + + class FastCentralMomentTransform(AbstractCentralMomentTransform): def __init__(self, stencil, moment_exponents, shift_velocity): @@ -149,13 +166,18 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): simplification=True, subexpression_base='sub_k_to_f'): if simplification and not isinstance(simplification, SimplificationStrategy): simplification = self._default_simplification + simplification.add(ExtractOneHalf()) raw_equations = self.mat_transform.backward_transform( pdf_symbols, moment_symbol_base=POST_COLLISION_CENTRAL_MOMENT, simplification=False) raw_equations = raw_equations.new_without_subexpressions() symbol_gen = SymbolGen(subexpression_base) - ac = self._split_backward_equations(raw_equations, symbol_gen) + if simplification: + extract_one_half = next(filter(lambda x: isinstance(x, ExtractOneHalf), simplification.rules)) + else: + extract_one_half = None + ac = self._split_backward_equations(raw_equations, symbol_gen, extract_one_half=extract_one_half) if simplification: ac = simplification.apply(ac) @@ -172,7 +194,8 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): return simplification def _split_backward_equations_recursive(self, assignment, all_subexpressions, - stencil_direction, subexp_symgen, known_coeffs_dict, step=0): + stencil_direction, subexp_symgen, known_coeffs_dict, + one_half, step=0): # Base Case if step == self.dim: return assignment @@ -181,7 +204,8 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): u = self.shift_velocity[-1 - step] d = stencil_direction[-1 - step] - one = sp.sympify(1) + one = sp.Integer(1) + two = sp.Integer(2) # Factors to group terms by grouping_factors = { @@ -192,7 +216,7 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): factors = grouping_factors[d] # Common Integer factor to extract from all groups - common_factor = one if d == 0 else sp.Integer(2) + common_factor = one if d == 0 else two # Proxy for factor grouping v = sp.Symbol('v') @@ -224,23 +248,30 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): # Recursively split the coefficient term coeff_assignment = self._split_backward_equations_recursive( coeff_assignment, all_subexpressions, stencil_direction, subexp_symgen, - known_coeffs_dict, step=step + 1) + known_coeffs_dict, one_half, step=step + 1) all_subexpressions.append(coeff_assignment) new_rhs += factors[k] * coeff_symb - if common_factor != one: - new_rhs = sp.Mul(sp.Rational(1, common_factor), new_rhs, evaluate=False) + if common_factor == two: + # new_rhs = sp.Mul(sp.Rational(1, common_factor), new_rhs, evaluate=False) + new_rhs = one_half * new_rhs return Assignment(assignment.lhs, new_rhs) - def _split_backward_equations(self, backward_assignments, subexp_symgen): - all_subexpressions = [] + def _split_backward_equations(self, backward_assignments, subexp_symgen, extract_one_half=None): + if extract_one_half is not None: + one_half_proxy = extract_one_half.symbol + all_subexpressions = [Assignment(one_half_proxy, sp.Rational(1,2))] + else: + one_half_proxy = sp.Rational(1,2) + all_subexpressions = [] + split_main_assignments = [] known_coeffs_dict = dict() for asm, stencil_dir in zip(backward_assignments, self.stencil): split_asm = self._split_backward_equations_recursive( - asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict) + asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict, one_half_proxy) split_main_assignments.append(split_asm) ac = AssignmentCollection(split_main_assignments, subexpressions=all_subexpressions, subexpression_symbol_generator=subexp_symgen) -- GitLab