"""Construct sparse matrix from a local stencil."""
# pylint: disable=redefined-builtin
import numpy as np
from scipy import sparse
[docs]
def stencil_grid(S, grid, dtype=None, format=None):
"""Construct a sparse matrix form a local matrix stencil.
Parameters
----------
S : ndarray
matrix stencil stored in N-d array
grid : tuple
tuple containing the N grid dimensions
dtype :
data type of the result
format : string
sparse matrix format to return, e.g. "csr", "coo", etc.
Returns
-------
A : sparse matrix
Sparse matrix which represents the operator given by applying
stencil S at each vertex of a regular grid with given dimensions.
Notes
-----
The grid vertices are enumerated as arange(prod(grid)).reshape(grid).
This implies that the last grid dimension cycles fastest, while the
first dimension cycles slowest. For example, if grid=(2,3) then the
grid vertices are ordered as (0,0), (0,1), (0,2), (1,0), (1,1), (1,2).
This coincides with the ordering used by the NumPy functions
ndenumerate() and mgrid().
Examples
--------
>>> from pyamg.gallery import stencil_grid
>>> stencil = [-1,2,-1] # 1D Poisson stencil
>>> grid = (5,) # 1D grid with 5 vertices
>>> A = stencil_grid(stencil, grid, dtype=float, format='csr')
>>> A.toarray()
array([[ 2., -1., 0., 0., 0.],
[-1., 2., -1., 0., 0.],
[ 0., -1., 2., -1., 0.],
[ 0., 0., -1., 2., -1.],
[ 0., 0., 0., -1., 2.]])
>>> stencil = [[0,-1,0],[-1,4,-1],[0,-1,0]] # 2D Poisson stencil
>>> grid = (3,3) # 2D grid with shape 3x3
>>> A = stencil_grid(stencil, grid, dtype=float, format='csr')
>>> A.toarray()
array([[ 4., -1., 0., -1., 0., 0., 0., 0., 0.],
[-1., 4., -1., 0., -1., 0., 0., 0., 0.],
[ 0., -1., 4., 0., 0., -1., 0., 0., 0.],
[-1., 0., 0., 4., -1., 0., -1., 0., 0.],
[ 0., -1., 0., -1., 4., -1., 0., -1., 0.],
[ 0., 0., -1., 0., -1., 4., 0., 0., -1.],
[ 0., 0., 0., -1., 0., 0., 4., -1., 0.],
[ 0., 0., 0., 0., -1., 0., -1., 4., -1.],
[ 0., 0., 0., 0., 0., -1., 0., -1., 4.]])
"""
S = np.asarray(S, dtype=dtype)
grid = tuple(grid)
if not (np.asarray(S.shape) % 2 == 1).all():
raise ValueError('all stencil dimensions must be odd')
if len(grid) != np.ndim(S):
raise ValueError('stencil dimension must equal number of grid\
dimensions')
if min(grid) < 1:
raise ValueError('grid dimensions must be positive')
N_v = np.prod(grid) # number of vertices in the mesh
N_s = (S != 0).sum() # number of nonzero stencil entries
# diagonal offsets
diags = np.zeros(N_s, dtype=int)
# compute index offset of each dof within the stencil
strides = np.cumprod([1] + list(reversed(grid)))[:-1]
indices = tuple(i.copy() for i in S.nonzero())
for i, s in zip(indices, S.shape):
i -= s // 2
# i = (i - s) // 2
# i = i // 2
# i = i - (s // 2)
for stride, coords in zip(strides, reversed(indices)):
diags += stride * coords
data = S[S != 0].repeat(N_v).reshape(N_s, N_v)
indices = np.vstack(indices).T
# zero boundary connections
for index, diag in zip(indices, data):
diag = diag.reshape(grid)
for n, i in enumerate(index):
if i > 0:
s = [slice(None)] * len(grid)
s[n] = slice(0, i)
s = tuple(s)
diag[s] = 0
elif i < 0:
s = [slice(None)]*len(grid)
s[n] = slice(i, None)
s = tuple(s)
diag[s] = 0
# remove diagonals that lie outside matrix
mask = abs(diags) < N_v
if not mask.all():
diags = diags[mask]
data = data[mask]
# sum duplicate diagonals
if len(np.unique(diags)) != len(diags):
new_diags = np.unique(diags)
new_data = np.zeros((len(new_diags), data.shape[1]),
dtype=data.dtype)
for dia, dat in zip(diags, data):
n = np.searchsorted(new_diags, dia)
new_data[n, :] += dat
diags = new_diags
data = new_data
return sparse.dia_matrix((data, diags),
shape=(N_v, N_v)).asformat(format)