smbinterp/interp/baker/__init__.py

262 lines
6.2 KiB
Python
Raw Normal View History

import sys
2010-10-23 12:49:15 -07:00
import numpy as np
from functools import wraps
import itertools
2011-05-04 20:36:42 -07:00
import interp
import logging
2011-03-22 15:06:08 -07:00
log = logging.getLogger('interp')
def get_phis(X, R):
"""
The get_phis function is used to get barycentric coordonites for a point on
a triangle or tetrahedron:
in 2D:
X -- the destination point (2D)
X = [0,0]
r -- the three points that make up the containing triangular simplex (2D)
r = [[-1, -1], [0, 2], [1, -1]]
this will return [0.333, 0.333, 0.333]
in 3D:
X -- the destination point (3D)
X = [0,0,0]
R -- the four points that make up the containing simplex, tetrahedron (3D)
R = [
[0.0, 0.0, 1.0],
[0.94280904333606508, 0.0, -0.3333333283722672],
[-0.47140452166803232, 0.81649658244673617, -0.3333333283722672],
[-0.47140452166803298, -0.81649658244673584, -0.3333333283722672],
]
this will return [0.25, 0.25, 0.25, 0.25]
"""
# baker: eq 7
# TODO: perhaps also test len(R[0]) .. ?
if len(X) == 2:
log.debug("running 2D")
A = np.array([
[ 1, 1, 1],
[R[0][0], R[1][0], R[2][0]],
[R[0][1], R[1][1], R[2][1]],
])
b = np.array([ 1,
X[0],
X[1]
])
elif len(X) == 3:
log.debug("running 3D")
A = np.array([
[ 1, 1, 1, 1 ],
[R[0][0], R[1][0], R[2][0], R[3][0]],
[R[0][1], R[1][1], R[2][1], R[3][1]],
[R[0][2], R[1][2], R[2][2], R[3][2]],
])
b = np.array([ 1,
X[0],
X[1],
X[2]
])
else:
raise Exception("inapropriate demension on X")
try:
phi = np.linalg.solve(A,b)
except np.linalg.LinAlgError as e:
msg = "calculation of phis yielded a linearly dependant system (%s)" % e
log.error(msg)
# raise Exception(msg)
phi = np.dot(np.linalg.pinv(A), b)
2011-03-22 15:06:08 -07:00
log.debug("phi: %s", phi)
return phi
def qlinear(X, R):
"""
2011-02-02 10:52:53 -08:00
this calculates the linear portion of q from R to X
also, this is baker eq 3
X = destination point
R = a inter.grid object; must have R.points and R.q
"""
2010-10-29 11:40:32 -07:00
phis = get_phis(X, R.verts)
qlin = np.sum([q_i * phi_i for q_i, phi_i in zip(R.q, phis)])
2011-03-22 15:06:08 -07:00
log.debug("phis: %s", phis)
log.debug("qlin: %s", qlin)
return phis, qlin
def get_error(phi, R, S, order = 2):
2011-05-07 15:12:51 -07:00
#TODO: change the equation names in the comments
B = [] # baker eq 9
w = [] # baker eq 11
cur_pattern = pattern(len(phi), order)
log.info("pattern: %s" % cur_pattern)
2010-10-29 11:40:32 -07:00
for (s,q) in zip(S.verts, S.q):
cur_phi, cur_qlin = qlinear(s, R)
l = []
for i in cur_pattern:
cur_sum = cur_phi[i[0]]
for j in i[1:]:
cur_sum *= cur_phi[j]
l.append(cur_sum)
B.append(l)
w.append(q - cur_qlin)
2011-03-22 15:06:08 -07:00
log.info("B: %s" % B)
log.info("w: %s" % w)
2010-10-23 12:49:15 -07:00
B = np.array(B)
w = np.array(w)
A = np.dot(B.T, B)
b = np.dot(B.T, w)
# baker solve eq 10
try:
abc = np.linalg.solve(A,b)
except np.linalg.LinAlgError as e:
2010-10-23 12:49:15 -07:00
log.error("linear calculation went bad, resorting to np.linalg.pinv: %s" % e)
abc = np.dot(np.linalg.pinv(A), b)
error_term = 0.0
for (a, i) in zip(abc, cur_pattern):
cur_sum = a
for j in i:
cur_sum *= phi[j]
error_term += cur_sum
2010-10-23 12:49:15 -07:00
log.debug("error_term: %s" % error_term)
return error_term, abc
def run_baker(X, R, S, order=2):
"""
2011-03-23 10:23:32 -07:00
This is the main function to call to get an interpolation to X from the
input meshes
X -- the destination point (2D)
X = [0,0]
R = Simplex
S = extra points
"""
2010-10-23 12:49:15 -07:00
log.debug("order = %d" % order)
2011-02-02 10:52:53 -08:00
log.debug("extra points = %d" % len(S.verts))
answer = {
'qlin': None,
'error': None,
'final': None,
}
# calculate values only for the simplex triangle
phi, qlin = qlinear(X, R)
if order == 1:
answer['qlin'] = qlin
answer['final'] = qlin
return answer
elif order in xrange(2,11):
error_term, abc = get_error(phi, R, S, order)
else:
raise Exception('unsupported order "%d" for baker method' % order)
q_final = qlin + error_term
answer['qlin' ] = qlin
answer['error'] = error_term
answer['final'] = q_final
answer['abc' ] = abc
log.debug(answer)
return answer
2011-04-03 10:09:03 -07:00
def memoize(f):
2011-04-03 10:16:14 -07:00
"""
for more information on what I'm doing here,
please read:
http://en.wikipedia.org/wiki/Memoize
2011-04-03 10:16:14 -07:00
"""
2011-04-03 10:09:03 -07:00
cache = {}
@wraps(f)
def memf(simplex_size, nu):
x = (simplex_size, nu)
2011-04-03 10:09:03 -07:00
if x not in cache:
2011-05-04 20:25:16 -07:00
log.debug("adding to cache: %s", x)
cache[x] = f(simplex_size, nu)
2011-04-03 10:09:03 -07:00
return cache[x]
return memf
2011-04-03 10:09:03 -07:00
@memoize
def pattern(simplex_size, nu):
"""
This function returns the pattern requisite to compose the error
approximation function, and the matrix B.
"""
2011-05-04 20:49:27 -07:00
log.debug("pattern: simplex: %d, order: %d" % (simplex_size, nu))
2011-05-04 20:36:42 -07:00
r = []
for i in itertools.product(xrange(simplex_size), repeat = nu):
if len(set(i)) !=1:
r.append(tuple(sorted(i)))
unique_r = list(set(r))
return unique_r
2011-05-04 20:36:42 -07:00
if __name__ == '__main__':
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
2011-05-04 20:36:42 -07:00
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
2011-05-04 20:36:42 -07:00
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
2011-05-04 20:36:42 -07:00
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)