smbinterp/interp/baker/__init__.py

260 lines
6.3 KiB
Python

import sys
import numpy as np
from functools import wraps
import itertools
import logging
log = logging.getLogger('interp')
def get_phis(X, R):
"""
The get_phis function is used to get barycentric coordonites for a point on
a triangle or tetrahedron:
in 2D:
X -- the destination point (2D)
X = [0,0]
r -- the three points that make up the containing triangular simplex (2D)
r = [[-1, -1], [0, 2], [1, -1]]
this will return [0.333, 0.333, 0.333]
in 3D:
X -- the destination point (3D)
X = [0,0,0]
R -- the four points that make up the containing simplex, tetrahedron (3D)
R = [
[0.0, 0.0, 1.0],
[0.94280904333606508, 0.0, -0.3333333283722672],
[-0.47140452166803232, 0.81649658244673617, -0.3333333283722672],
[-0.47140452166803298, -0.81649658244673584, -0.3333333283722672],
]
this will return [0.25, 0.25, 0.25, 0.25]
"""
# baker: eq 7
# TODO: perhaps also test len(R[0]) .. ?
if len(X) == 2:
log.debug("running 2D")
A = np.array([
[ 1, 1, 1],
[R[0][0], R[1][0], R[2][0]],
[R[0][1], R[1][1], R[2][1]],
])
b = np.array([ 1,
X[0],
X[1]
])
elif len(X) == 3:
log.debug("running 3D")
A = np.array([
[ 1, 1, 1, 1 ],
[R[0][0], R[1][0], R[2][0], R[3][0]],
[R[0][1], R[1][1], R[2][1], R[3][1]],
[R[0][2], R[1][2], R[2][2], R[3][2]],
])
b = np.array([ 1,
X[0],
X[1],
X[2]
])
else:
raise Exception("inapropriate demension on X")
try:
phi = np.linalg.solve(A,b)
except np.linalg.LinAlgError as e:
msg = "calculation of phis yielded a linearly dependant system (%s)" % e
log.error(msg)
# raise Exception(msg)
phi = np.dot(np.linalg.pinv(A), b)
log.debug("phi: %s", phi)
return phi
def qlinear(X, R):
"""
this calculates the linear portion of q from R to X
also, this is baker eq 3
X = destination point
R = a inter.grid object; must have R.points and R.q
"""
phis = get_phis(X, R.verts)
qlin = np.sum([q_i * phi_i for q_i, phi_i in zip(R.q, phis)])
log.debug("phis: %s", phis)
log.debug("qlin: %s", qlin)
return phis, qlin
def get_error(phi, R, S, order = 2):
B = [] # baker eq 9
w = [] # baker eq 11
p = pattern(len(phi), order)
log.info("pattern: %s" % p)
for (s,q) in zip(S.verts, S.q):
cur_phi, cur_qlin = qlinear(s, R)
l = []
for i in p:
cur_sum = cur_phi[i[0]]
for j in i[1:]:
cur_sum *= cur_phi[j]
l.append(cur_sum)
B.append(l)
w.append(q - cur_qlin)
log.info("B: %s" % B)
log.info("w: %s" % w)
B = np.array(B)
w = np.array(w)
A = np.dot(B.T, B)
b = np.dot(B.T, w)
# baker solve eq 10
try:
abc = np.linalg.solve(A,b)
except np.linalg.LinAlgError as e:
log.error("linear calculation went bad, resorting to np.linalg.pinv: %s" % e)
abc = np.dot(np.linalg.pinv(A), b)
error_term = 0.0
for (a, i) in zip(abc, p):
cur_sum = a
for j in i:
cur_sum *= phi[j]
error_term += cur_sum
log.debug("error_term: %s" % error_term)
return error_term, abc
def run_baker(X, R, S, order=2):
"""
This is the main function to call to get an interpolation to X from the
input meshes
X -- the destination point (2D)
X = [0,0]
R = Simplex
S = extra points
"""
log.debug("order = %d" % order)
log.debug("extra points = %d" % len(S.verts))
answer = {
'qlin': None,
'error': None,
'final': None,
}
# calculate values only for the simplex triangle
phi, qlin = qlinear(X, R)
if order == 1:
answer['qlin'] = qlin
return answer
elif order in xrange(2,11):
error_term, abc = get_error(phi, R, S, order)
else:
raise Exception('unsupported order "%d" for baker method' % order)
q_final = qlin + error_term
answer['qlin' ] = qlin
answer['error'] = error_term
answer['final'] = q_final
answer['abc' ] = abc
log.debug(answer)
return answer
def memoize(f):
"""
for more information on what I'm doing here,
please read:
http://en.wikipedia.org/wiki/Memoize
"""
cache = {}
@wraps(f)
def memf(*x, **kargs):
if x not in cache:
cache[x] = f(*x, **kargs)
return cache[x]
return memf
def combinations_with_replacement(iterable, r):
"""
What I really need for the pattern function only
exists in python 2.7 and greater. The docs suggest
the implementation in this function as a
replacement.
see: http://docs.python.org/library/itertools.html#itertools.combinations_with_replacement
"""
pool = tuple(iterable)
n = len(pool)
for indices in itertools.product(range(n), repeat=r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
@memoize
def pattern(simplex_size, nu):
"""
my useful docstring
"""
log.debug("pattern: simplex: %d, order: %d" % (simplex_size, nu))
return [i for i in combinations_with_replacement(xrange(simplex_size), nu) if len(set(i)) != 1]
if __name__ == '__main__':
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)