moved the quadratic error calculation out of the baker method, and added a cubic version too
This commit is contained in:
parent
700ccc8c25
commit
3efe335563
@ -102,6 +102,88 @@ def qlinear_3D(X, R):
|
||||
qlin = sum([q_i * phi_i for q_i, phi_i in zip(R.q, phis)])
|
||||
return phis, qlin
|
||||
|
||||
def get_error_quadratic(phi, R, S):
|
||||
B = [] # baker eq 9
|
||||
w = [] # baker eq 11
|
||||
|
||||
for (s, q) in zip(S.points, S.q):
|
||||
cur_phi, cur_qlin = qlinear(s, R)
|
||||
(phi1, phi2, phi3) = cur_phi
|
||||
|
||||
B.append(
|
||||
[
|
||||
phi1 * phi2,
|
||||
phi2 * phi3,
|
||||
phi3 * phi1,
|
||||
]
|
||||
)
|
||||
w.append(q - cur_qlin)
|
||||
|
||||
B = np.array(B)
|
||||
w = np.array(w)
|
||||
|
||||
A = np.dot(B.T, B)
|
||||
b = np.dot(B.T, w)
|
||||
|
||||
# baker solve eq 10
|
||||
try:
|
||||
(a, b, c) = np.linalg.solve(A,b)
|
||||
except np.linalg.LinAlgError as e:
|
||||
print >> sys.stderr, "warning: run_baker: linear calculation went bad, resorting to np.linalg.pinv", e
|
||||
(a, b, c) = np.dot(np.linalg.pinv(A), b)
|
||||
|
||||
error_term = a * phi[0] * phi[1]\
|
||||
+ b * phi[1] * phi[2]\
|
||||
+ c * phi[2] * phi[0]
|
||||
|
||||
return error_term, a, b, c
|
||||
|
||||
def get_error_cubic(phi, R, S):
|
||||
B = [] # baker eq 9
|
||||
w = [] # baker eq 11
|
||||
|
||||
for (s, q) in zip(S.points, S.q):
|
||||
cur_phi, cur_qlin = qlinear(s, R)
|
||||
(phi1, phi2, phi3) = cur_phi
|
||||
|
||||
# basing this on eq 17
|
||||
B.append(
|
||||
[
|
||||
phi1 * phi2, # a
|
||||
phi1 * phi3, # b
|
||||
phi2 * phi1, # c
|
||||
phi2 * phi3, # d
|
||||
phi3 * phi1, # e
|
||||
phi3 * phi2, # f
|
||||
phi1 * phi2 * phi3, # g
|
||||
]
|
||||
)
|
||||
w.append(q - cur_qlin)
|
||||
|
||||
B = np.array(B)
|
||||
w = np.array(w)
|
||||
|
||||
A = np.dot(B.T, B)
|
||||
b = np.dot(B.T, w)
|
||||
|
||||
# baker solve eq 10
|
||||
try:
|
||||
(a, b, c, d, e, f, g) = np.linalg.solve(A,b)
|
||||
except np.linalg.LinAlgError as e:
|
||||
print >> sys.stderr, "warning: run_baker: linear calculation went bad, resorting to np.linalg.pinv", e
|
||||
(a, b, c, d, e, f, g) = np.dot(np.linalg.pinv(A), b)
|
||||
|
||||
error_term = a * phi[0] * phi[1]\
|
||||
+ b * phi[0] * phi[2]\
|
||||
+ c * phi[1] * phi[0]\
|
||||
+ d * phi[1] * phi[2]\
|
||||
+ e * phi[2] * phi[0]\
|
||||
+ f * phi[2] * phi[1]\
|
||||
+ g * phi[0] * phi[1] * phi[2]\
|
||||
|
||||
return error_term, a, b, c
|
||||
|
||||
|
||||
def run_baker(X, R, S, order=2):
|
||||
"""
|
||||
This is the main function to call to get an interpolation to X from the input meshes
|
||||
@ -135,38 +217,12 @@ def run_baker(X, R, S, order=2):
|
||||
}
|
||||
return answer
|
||||
|
||||
B = [] # baker eq 9
|
||||
w = [] # baker eq 11
|
||||
|
||||
for (s, q) in zip(S.points, S.q):
|
||||
cur_phi, cur_qlin = qlinear(s, R)
|
||||
(phi1, phi2, phi3) = cur_phi
|
||||
|
||||
B.append(
|
||||
[
|
||||
phi1 * phi2,
|
||||
phi2 * phi3,
|
||||
phi3 * phi1,
|
||||
]
|
||||
)
|
||||
w.append(q - cur_qlin)
|
||||
|
||||
B = np.array(B)
|
||||
w = np.array(w)
|
||||
|
||||
A = np.dot(B.T, B)
|
||||
b = np.dot(B.T, w)
|
||||
|
||||
# baker solve eq 10
|
||||
try:
|
||||
(a, b, c) = np.linalg.solve(A,b)
|
||||
except np.linalg.LinAlgError as e:
|
||||
print >> sys.stderr, "warning: run_baker: linear calculation went bad, resorting to np.linalg.pinv", e
|
||||
(a, b, c) = np.dot(np.linalg.pinv(A), b)
|
||||
|
||||
error_term = a * phi[0] * phi[1]\
|
||||
+ b * phi[1] * phi[2]\
|
||||
+ c * phi[2] * phi[0]
|
||||
if order == 2:
|
||||
error_term, a, b, c = get_error_quadratic(phi, R, S)
|
||||
elif order == 3:
|
||||
error_term, a, b, c = get_error_cubic(phi, R, S)
|
||||
else:
|
||||
raise smberror('unacceptable order for baker method')
|
||||
|
||||
q_final = qlin + error_term
|
||||
|
||||
|
79
test/cubic.test.py
Executable file
79
test/cubic.test.py
Executable file
@ -0,0 +1,79 @@
|
||||
#!/usr/bin/python
|
||||
|
||||
import unittest
|
||||
|
||||
import math
|
||||
|
||||
from baker import run_baker
|
||||
|
||||
from grid.DD import grid
|
||||
from grid.simplex import contains
|
||||
|
||||
def exact_func(X):
|
||||
x = X[0]
|
||||
y = X[0]
|
||||
return 1 - math.sin((x-0.5)**2 + (y-0.5)**2)
|
||||
|
||||
class TestSequenceFunctions(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.points = [
|
||||
[ 0.25, 0.40], # 0
|
||||
[ 0.60, 0.80], # 1
|
||||
[ 0.65, 0.28], # 2
|
||||
[ 0.28, 0.65], # 3
|
||||
[ 1.00, 0.75], # 4
|
||||
[ 0.30, 0.95], # 5
|
||||
[ 0.80, 0.50], # 6
|
||||
[ 0.35, 0.15], # 7
|
||||
]
|
||||
self.q = [exact_func(p) for p in self.points]
|
||||
|
||||
self.X = [0.25, 0.4001]
|
||||
self.X = [0.55, 0.45]
|
||||
|
||||
self.g = grid(self.points, self.q)
|
||||
self.g.construct_connectivity()
|
||||
self.R = self.g.create_mesh(range(3))
|
||||
|
||||
self.exact = exact_func(self.X)
|
||||
|
||||
|
||||
self.accuracy = 8
|
||||
|
||||
def test_R_contains_X(self):
|
||||
self.assertTrue(contains(self.X, self.R.points))
|
||||
|
||||
def test_RunBaker_1_extra_point(self, extra=1):
|
||||
S = self.g.create_mesh(range(3, 3 + extra))
|
||||
answer = run_baker(self.X, self.R, S, order=3)
|
||||
lin_err = abs(self.exact - answer['qlin'])
|
||||
final_err = abs(self.exact - answer['final'])
|
||||
self.assertTrue(lin_err >= final_err)
|
||||
def test_RunBaker_2_extra_point(self, extra=2):
|
||||
S = self.g.create_mesh(range(3, 3 + extra))
|
||||
answer = run_baker(self.X, self.R, S, order=3)
|
||||
lin_err = abs(self.exact - answer['qlin'])
|
||||
final_err = abs(self.exact - answer['final'])
|
||||
self.assertTrue(lin_err >= final_err)
|
||||
def test_RunBaker_3_extra_point(self, extra=3):
|
||||
S = self.g.create_mesh(range(3, 3 + extra))
|
||||
answer = run_baker(self.X, self.R, S, order=3)
|
||||
lin_err = abs(self.exact - answer['qlin'])
|
||||
final_err = abs(self.exact - answer['final'])
|
||||
self.assertTrue(lin_err >= final_err)
|
||||
def test_RunBaker_4_extra_point(self, extra=4):
|
||||
S = self.g.create_mesh(range(3, 3 + extra))
|
||||
answer = run_baker(self.X, self.R, S, order=3)
|
||||
lin_err = abs(self.exact - answer['qlin'])
|
||||
final_err = abs(self.exact - answer['final'])
|
||||
self.assertTrue(lin_err >= final_err)
|
||||
def test_RunBaker_5_extra_point(self, extra=5):
|
||||
S = self.g.create_mesh(range(3, 3 + extra))
|
||||
answer = run_baker(self.X, self.R, S, order=3)
|
||||
lin_err = abs(self.exact - answer['qlin'])
|
||||
final_err = abs(self.exact - answer['final'])
|
||||
self.assertTrue(lin_err >= final_err)
|
||||
|
||||
if __name__ == '__main__':
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestSequenceFunctions)
|
||||
unittest.TextTestRunner(verbosity=3).run(suite)
|
Loading…
Reference in New Issue
Block a user