smbinterp/interp/baker/__init__.py

263 lines
6.3 KiB
Python
Raw Normal View History

import sys
2010-10-23 12:49:15 -07:00
import numpy as np
from functools import wraps
import itertools
import logging
2011-03-22 15:06:08 -07:00
log = logging.getLogger('interp')
def get_phis(X, R):
"""
The get_phis function is used to get barycentric coordonites for a point on
a triangle or tetrahedron:
in 2D:
X -- the destination point (2D)
X = [0,0]
r -- the three points that make up the containing triangular simplex (2D)
r = [[-1, -1], [0, 2], [1, -1]]
this will return [0.333, 0.333, 0.333]
in 3D:
X -- the destination point (3D)
X = [0,0,0]
R -- the four points that make up the containing simplex, tetrahedron (3D)
R = [
[0.0, 0.0, 1.0],
[0.94280904333606508, 0.0, -0.3333333283722672],
[-0.47140452166803232, 0.81649658244673617, -0.3333333283722672],
[-0.47140452166803298, -0.81649658244673584, -0.3333333283722672],
]
this will return [0.25, 0.25, 0.25, 0.25]
"""
# baker: eq 7
# TODO: perhaps also test len(R[0]) .. ?
if len(X) == 2:
log.debug("running 2D")
A = np.array([
[ 1, 1, 1],
[R[0][0], R[1][0], R[2][0]],
[R[0][1], R[1][1], R[2][1]],
])
b = np.array([ 1,
X[0],
X[1]
])
elif len(X) == 3:
log.debug("running 3D")
A = np.array([
[ 1, 1, 1, 1 ],
[R[0][0], R[1][0], R[2][0], R[3][0]],
[R[0][1], R[1][1], R[2][1], R[3][1]],
[R[0][2], R[1][2], R[2][2], R[3][2]],
])
b = np.array([ 1,
X[0],
X[1],
X[2]
])
else:
raise Exception("inapropriate demension on X")
try:
phi = np.linalg.solve(A,b)
except np.linalg.LinAlgError as e:
msg = "calculation of phis yielded a linearly dependant system (%s)" % e
log.error(msg)
# raise Exception(msg)
phi = np.dot(np.linalg.pinv(A), b)
2011-03-22 15:06:08 -07:00
log.debug("phi: %s", phi)
return phi
def qlinear(X, R):
"""
2011-02-02 10:52:53 -08:00
this calculates the linear portion of q from R to X
also, this is baker eq 3
X = destination point
R = a inter.grid object; must have R.points and R.q
"""
2010-10-29 11:40:32 -07:00
phis = get_phis(X, R.verts)
qlin = np.sum([q_i * phi_i for q_i, phi_i in zip(R.q, phis)])
2011-03-22 15:06:08 -07:00
log.debug("phis: %s", phis)
log.debug("qlin: %s", qlin)
return phis, qlin
def get_error(phi, R, S, order = 2):
B = [] # baker eq 9
w = [] # baker eq 11
p = pattern(len(phi), order)
2011-03-22 15:06:08 -07:00
log.info("pattern: %s" % p)
2010-10-29 11:40:32 -07:00
for (s,q) in zip(S.verts, S.q):
cur_phi, cur_qlin = qlinear(s, R)
l = []
for i in p:
cur_sum = cur_phi[i[0]]
for j in i[1:]:
cur_sum *= cur_phi[j]
l.append(cur_sum)
B.append(l)
w.append(q - cur_qlin)
2011-03-22 15:06:08 -07:00
log.info("B: %s" % B)
log.info("w: %s" % w)
2010-10-23 12:49:15 -07:00
B = np.array(B)
w = np.array(w)
A = np.dot(B.T, B)
b = np.dot(B.T, w)
# baker solve eq 10
try:
abc = np.linalg.solve(A,b)
except np.linalg.LinAlgError as e:
2010-10-23 12:49:15 -07:00
log.error("linear calculation went bad, resorting to np.linalg.pinv: %s" % e)
abc = np.dot(np.linalg.pinv(A), b)
error_term = 0.0
for (a, i) in zip(abc, p):
cur_sum = a
for j in i:
cur_sum *= phi[j]
error_term += cur_sum
2010-10-23 12:49:15 -07:00
log.debug("error_term: %s" % error_term)
return error_term, abc
def run_baker(X, R, S, order=2):
"""
2011-03-23 10:23:32 -07:00
This is the main function to call to get an interpolation to X from the
input meshes
X -- the destination point (2D)
X = [0,0]
R = Simplex
S = extra points
"""
2010-10-23 12:49:15 -07:00
log.debug("order = %d" % order)
2011-02-02 10:52:53 -08:00
log.debug("extra points = %d" % len(S.verts))
answer = {
'qlin': None,
'error': None,
'final': None,
}
# calculate values only for the simplex triangle
phi, qlin = qlinear(X, R)
if order == 1:
answer['qlin'] = qlin
return answer
elif order in xrange(2,11):
error_term, abc = get_error(phi, R, S, order)
else:
raise Exception('unsupported order "%d" for baker method' % order)
q_final = qlin + error_term
answer['qlin' ] = qlin
answer['error'] = error_term
answer['final'] = q_final
answer['abc' ] = abc
log.debug(answer)
return answer
2011-04-03 10:09:03 -07:00
def memoize(f):
2011-04-03 10:16:14 -07:00
"""
for more information on what I'm doing here,
please read:
http://en.wikipedia.org/wiki/Memoize
2011-04-03 10:16:14 -07:00
"""
2011-04-03 10:09:03 -07:00
cache = {}
@wraps(f)
2011-04-03 10:09:03 -07:00
def memf(*x, **kargs):
if x not in cache:
2011-05-04 20:25:16 -07:00
log.debug("adding to cache: %s", x)
2011-04-03 10:09:03 -07:00
cache[x] = f(*x, **kargs)
return cache[x]
return memf
def combinations_with_replacement(iterable, r):
"""
What I really need for the pattern function only
exists in python 2.7 and greater. The docs suggest
the implementation in this function as a
replacement.
2011-05-04 20:25:16 -07:00
source:
http://docs.python.org/library/itertools.html#itertools.combinations_with_replacement
"""
pool = tuple(iterable)
n = len(pool)
for indices in itertools.product(range(n), repeat=r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
2011-04-03 10:09:03 -07:00
@memoize
def pattern(simplex_size, nu):
"""
my useful docstring
"""
log.debug("pattern: simplex: %d, order: %d" % (simplex_size, nu))
return [i for i in combinations_with_replacement(xrange(simplex_size), nu) if len(set(i)) != 1]
if __name__ == '__main__':
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)