Skip to content
Snippets Groups Projects
Commit 5257b4ae authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Also update Field.physical_to_index,index_to_physical for coordinate_transform as a function

parent 8aa4fe9f
No related merge requests found
......@@ -329,10 +329,10 @@ class Field(AbstractField):
self._layout = normalize_layout(layout)
self.shape = shape
self.strides = strides
self.latex_name = None # type: Optional[str]
self.coordinate_origin = sp.Matrix(tuple(
self.latex_name: Optional[str] = None
self.coordinate_origin: tuple[float, sp.Symbol] = sp.Matrix(tuple(
0 for _ in range(self.spatial_dimensions)
)) # type: tuple[float,sp.Symbol]
)) # type
self.coordinate_transform = sp.eye(self.spatial_dimensions)
if field_type == FieldType.STAGGERED:
assert self.staggered_stencil
......@@ -433,7 +433,7 @@ class Field(AbstractField):
return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
elif len(index_shape) == 3:
return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])]
for j in range(index_shape[1])] for i in range(index_shape[0])])
for j in range(index_shape[1])] for i in range(index_shape[0])])
else:
raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
......@@ -454,7 +454,7 @@ class Field(AbstractField):
return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])])
elif self.index_dimensions == 2:
return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
for i in range(self.index_shape[0])])
for i in range(self.index_shape[0])])
else:
raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions")
......@@ -529,7 +529,7 @@ class Field(AbstractField):
prefactor = -1
if neighbor not in self.staggered_stencil:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name))
self.staggered_stencil_name))
offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
......@@ -563,7 +563,7 @@ class Field(AbstractField):
return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
elif self.index_dimensions == 3:
return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
for i in range(self.index_shape[1])])
for i in range(self.index_shape[1])])
else:
raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
......@@ -627,10 +627,23 @@ class Field(AbstractField):
def index_to_physical(self, index_coordinates, staggered=False):
if staggered:
index_coordinates = sp.Matrix([i + 0.5 for i in index_coordinates])
return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
if hasattr(self.coordinate_transform, '__call__'):
return self.coordinate_transform(self.coordinate_origin + index_coordinates)
else:
return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
def physical_to_index(self, physical_coordinates, staggered=False):
rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
if hasattr(self.coordinate_transform, '__call__'):
if hasattr(self.coordinate_transform, 'inv'):
return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin
else:
idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True))
rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx)
assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}'
return rtn
else:
rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
if staggered:
rtn = sp.Matrix([i - 0.5 for i in rtn])
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment