From b13958d5d78b2add8f24ed1df3d8a09346521166 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 3 Dec 2020 22:26:04 +0100
Subject: [PATCH] Extraction of factor 1/2

---
 .../centeredcumulantmethod.py                 |  4 +-
 .../central_moment_transform.py               | 51 +++++++++++++++----
 2 files changed, 44 insertions(+), 11 deletions(-)

diff --git a/lbmpy/methods/centeredcumulant/centeredcumulantmethod.py b/lbmpy/methods/centeredcumulant/centeredcumulantmethod.py
index 83c23f2..dffebe2 100644
--- a/lbmpy/methods/centeredcumulant/centeredcumulantmethod.py
+++ b/lbmpy/methods/centeredcumulant/centeredcumulantmethod.py
@@ -208,7 +208,9 @@ class CenteredCumulantBasedLbMethod(AbstractLbMethod):
     def get_collision_rule(self, pre_simplification=False):
         """Returns an LbmCollisionRule i.e. an equation collection with a reference to the method.
         This collision rule defines the collision operator."""
-        return self._centered_cumulant_collision_rule(self._cumulant_to_relaxation_info_dict, None, pre_simplification, True)
+        ac = self._centered_cumulant_collision_rule(self._cumulant_to_relaxation_info_dict, None, pre_simplification, True)
+        ac = ac.new_without_unused_subexpressions()
+        return LbmCollisionRule(self, ac.main_assignments, ac.subexpressions)
 
     #   ------------------------------- Internals --------------------------------------------
     
diff --git a/lbmpy/methods/centeredcumulant/central_moment_transform.py b/lbmpy/methods/centeredcumulant/central_moment_transform.py
index 70060ca..4874aac 100644
--- a/lbmpy/methods/centeredcumulant/central_moment_transform.py
+++ b/lbmpy/methods/centeredcumulant/central_moment_transform.py
@@ -91,6 +91,23 @@ class PdfsToCentralMomentsByMatrix(AbstractCentralMomentTransform):
 # end class PdfsToCentralMomentsByMatrix
 
 
+class ExtractOneHalf:
+    """
+    Pseudo-Simplification to instruct the FastCentralMomentTransform to extract
+    the factor 1/2 to a subexpression, to hide it from sympy. Otherwise, sympy will
+    distribute it across the sums, producing unnecessary multiplications.
+    """
+    def __init__(self, one_half_proxy=sp.Symbol('half')):
+        self._symbol = one_half_proxy
+
+    @property
+    def symbol(self):
+        return self._symbol
+
+    def __call__(self, ac):
+        return ac
+
+
 class FastCentralMomentTransform(AbstractCentralMomentTransform):
 
     def __init__(self, stencil, moment_exponents, shift_velocity):
@@ -149,13 +166,18 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
                            simplification=True, subexpression_base='sub_k_to_f'):
         if simplification and not isinstance(simplification, SimplificationStrategy):
             simplification = self._default_simplification
+            simplification.add(ExtractOneHalf())
 
         raw_equations = self.mat_transform.backward_transform(
             pdf_symbols, moment_symbol_base=POST_COLLISION_CENTRAL_MOMENT, simplification=False)
         raw_equations = raw_equations.new_without_subexpressions()
 
         symbol_gen = SymbolGen(subexpression_base)
-        ac = self._split_backward_equations(raw_equations, symbol_gen)
+        if simplification:
+            extract_one_half = next(filter(lambda x: isinstance(x, ExtractOneHalf), simplification.rules))
+        else:
+            extract_one_half = None
+        ac = self._split_backward_equations(raw_equations, symbol_gen, extract_one_half=extract_one_half)
 
         if simplification:
             ac = simplification.apply(ac)
@@ -172,7 +194,8 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
         return simplification
 
     def _split_backward_equations_recursive(self, assignment, all_subexpressions,
-                                           stencil_direction, subexp_symgen, known_coeffs_dict, step=0):
+                                           stencil_direction, subexp_symgen, known_coeffs_dict,
+                                           one_half, step=0):
         #   Base Case
         if step == self.dim:
             return assignment
@@ -181,7 +204,8 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
 
         u = self.shift_velocity[-1 - step]
         d = stencil_direction[-1 - step]
-        one = sp.sympify(1)
+        one = sp.Integer(1)
+        two = sp.Integer(2)
 
         #   Factors to group terms by
         grouping_factors = {
@@ -192,7 +216,7 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
         factors = grouping_factors[d]
 
         #   Common Integer factor to extract from all groups
-        common_factor = one if d == 0 else sp.Integer(2)
+        common_factor = one if d == 0 else two
 
         #   Proxy for factor grouping
         v = sp.Symbol('v')
@@ -224,23 +248,30 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
                 #   Recursively split the coefficient term
                 coeff_assignment = self._split_backward_equations_recursive(
                     coeff_assignment, all_subexpressions, stencil_direction, subexp_symgen,
-                    known_coeffs_dict, step=step + 1)
+                    known_coeffs_dict, one_half, step=step + 1)
                 all_subexpressions.append(coeff_assignment)
 
             new_rhs += factors[k] * coeff_symb
 
-        if common_factor != one:
-            new_rhs = sp.Mul(sp.Rational(1, common_factor), new_rhs, evaluate=False)
+        if common_factor == two:
+            # new_rhs = sp.Mul(sp.Rational(1, common_factor), new_rhs, evaluate=False)
+            new_rhs = one_half * new_rhs
 
         return Assignment(assignment.lhs, new_rhs)
 
-    def _split_backward_equations(self, backward_assignments, subexp_symgen):
-        all_subexpressions = []
+    def _split_backward_equations(self, backward_assignments, subexp_symgen, extract_one_half=None):
+        if extract_one_half is not None:
+            one_half_proxy = extract_one_half.symbol
+            all_subexpressions = [Assignment(one_half_proxy, sp.Rational(1,2))]
+        else:
+            one_half_proxy = sp.Rational(1,2)
+            all_subexpressions = []
+
         split_main_assignments = []
         known_coeffs_dict = dict()
         for asm, stencil_dir in zip(backward_assignments, self.stencil):
             split_asm = self._split_backward_equations_recursive(
-                asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict)
+                asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict, one_half_proxy)
             split_main_assignments.append(split_asm)
         ac = AssignmentCollection(split_main_assignments, subexpressions=all_subexpressions,
                                   subexpression_symbol_generator=subexp_symgen)
-- 
GitLab