Major: made scripts pass pep8 and pyflakes

This commit is contained in:
Stephen M. McQuay 2011-09-17 15:38:49 -06:00
parent 1bc797a14d
commit 837a72b246
17 changed files with 835 additions and 866 deletions

View File

@ -1,40 +1 @@
import os
import logging
import logging.handlers
import json
LEVELS = {'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'critical': logging.CRITICAL}
default_config = {
'filename': '/tmp/interp.log',
'level': 'debug',
'size' : 102400,
'logbackup': 10,
'pypath': None,
}
try:
with open(os.path.expanduser('~/.config/interp.json')) as config_file:
d = json.load(config_file)
except IOError as e:
d = {}
config = dict(default_config.items() + d.items())
logger = logging.getLogger('interp')
logger.setLevel(LEVELS[config['level']])
my_format = logging.Formatter('%(asctime)s %(levelname)s (%(process)d) %(filename)s %(funcName)s:%(lineno)d %(message)s')
handler = logging.handlers.RotatingFileHandler(
config['filename'], maxBytes = config['size'] * 1024, backupCount = config['logbackup'])
handler.setFormatter(my_format)
logger.addHandler(handler)
__version__ = '0.2' __version__ = '0.2'

View File

@ -1,18 +1,20 @@
import sys
import numpy as np import numpy as np
from functools import wraps from functools import wraps
import itertools import itertools
import interp import interp
import logging
log = logging.getLogger('interp') AGGRESSIVE_ERROR_SOLVE = True
RAISE_PATHOLOGICAL_EXCEPTION = False
__version__ = interp.__version__
def get_phis(X, R): def get_phis(X, R):
""" """
The get_phis function is used to get barycentric coordonites for a The get_phis function is used to get barycentric coordonites for a
point on a triangle or tetrahedron. This is equation (*\ref{eq:qlinarea}*) point on a triangle or tetrahedron (Equation (*\ref{eq:qlinarea}*))
in 2D: in 2D:
@ -41,45 +43,27 @@ def get_phis(X, R):
# equations (*\ref{eq:lin3d}*) and (*\ref{eq:lin2d}*) # equations (*\ref{eq:lin3d}*) and (*\ref{eq:lin2d}*)
if len(X) == 2: if len(X) == 2:
log.debug("running 2D")
A = np.array([ A = np.array([
[ 1, 1, 1], [1, 1, 1],
[R[0][0], R[1][0], R[2][0]], [R[0][0], R[1][0], R[2][0]],
[R[0][1], R[1][1], R[2][1]], [R[0][1], R[1][1], R[2][1]],
]) ])
b = np.array([ 1, b = np.array([1, X[0], X[1]])
X[0],
X[1]
])
elif len(X) == 3: elif len(X) == 3:
log.debug("running 3D")
A = np.array([ A = np.array([
[ 1, 1, 1, 1 ], [1, 1, 1, 1],
[R[0][0], R[1][0], R[2][0], R[3][0]], [R[0][0], R[1][0], R[2][0], R[3][0]],
[R[0][1], R[1][1], R[2][1], R[3][1]], [R[0][1], R[1][1], R[2][1], R[3][1]],
[R[0][2], R[1][2], R[2][2], R[3][2]], [R[0][2], R[1][2], R[2][2], R[3][2]],
]) ])
b = np.array([ 1, b = np.array([1, X[0], X[1], X[2]])
X[0],
X[1],
X[2]
])
else: else:
raise Exception("inapropriate demension on X") raise Exception("inapropriate demension on X")
phi = np.linalg.solve(A, b)
try:
phi = np.linalg.solve(A,b)
except np.linalg.LinAlgError as e:
msg = "calculation of phis yielded a linearly dependant system (%s)" % e
log.error(msg)
# raise Exception(msg)
phi = np.dot(np.linalg.pinv(A), b)
log.debug("phi: %s", phi)
return phi return phi
def qlinear(X, R):
def qlinear(X, R, q):
""" """
this calculates the linear portion of q from R to X this calculates the linear portion of q from R to X
@ -89,15 +73,13 @@ def qlinear(X, R):
R = a inter.grid object; must have R.points and R.q R = a inter.grid object; must have R.points and R.q
""" """
phis = get_phis(X, R.verts) phis = get_phis(X, R)
qlin = np.sum([q_i * phi_i for q_i, phi_i in zip(R.q, phis)]) qlin = np.sum([q_i * phi_i for q_i, phi_i in zip(q, phis)])
log.debug("phis: %s", phis)
log.debug("qlin: %s", qlin)
return phis, qlin return phis, qlin
def get_error(phi, R, S, order = 2):
def get_error(phi, R, R_q, S, S_q, order=2):
""" """
Calculate the error approximation terms, returning the unknowns Calculate the error approximation terms, returning the unknowns
a,b, and c in equation (*\ref{eq:quadratic2d}*). a,b, and c in equation (*\ref{eq:quadratic2d}*).
@ -106,10 +88,9 @@ def get_error(phi, R, S, order = 2):
w = [] # equation ((*\ref{eq:w}*) w = [] # equation ((*\ref{eq:w}*)
cur_pattern = pattern(len(phi), order) cur_pattern = pattern(len(phi), order)
log.info("pattern: %s" % cur_pattern)
for (s,q) in zip(S.verts, S.q): for (s, cur_q) in zip(S, S_q):
cur_phi, cur_qlin = qlinear(s, R) cur_phi, cur_qlin = qlinear(s, R, R_q)
l = [] l = []
for i in cur_pattern: for i in cur_pattern:
cur_sum = cur_phi[i[0]] cur_sum = cur_phi[i[0]]
@ -118,11 +99,7 @@ def get_error(phi, R, S, order = 2):
l.append(cur_sum) l.append(cur_sum)
B.append(l) B.append(l)
w.append(q - cur_qlin) w.append(cur_q - cur_qlin)
log.info("B: %s" % B)
log.info("w: %s" % w)
B = np.array(B) B = np.array(B)
w = np.array(w) w = np.array(w)
@ -131,9 +108,10 @@ def get_error(phi, R, S, order = 2):
b = np.dot(B.T, w) b = np.dot(B.T, w)
try: try:
abc = np.linalg.solve(A,b) abc = np.linalg.solve(A, b)
except np.linalg.LinAlgError as e: except np.linalg.LinAlgError:
log.error("linear calculation went bad, resorting to np.linalg.pinv: %s" % e) if not AGGRESSIVE_ERROR_SOLVE:
return None, None
abc = np.dot(np.linalg.pinv(A), b) abc = np.dot(np.linalg.pinv(A), b)
error_term = 0.0 error_term = 0.0
@ -143,10 +121,10 @@ def get_error(phi, R, S, order = 2):
cur_sum *= phi[j] cur_sum *= phi[j]
error_term += cur_sum error_term += cur_sum
log.debug("error_term: %s" % error_term)
return error_term, abc return error_term, abc
def run_baker(X, R, S, order=2):
def run_baker(X, R, R_q, S, S_q, order=2):
""" """
This is the main function to call to get an interpolation to X from the This is the main function to call to get an interpolation to X from the
input meshes input meshes
@ -156,34 +134,41 @@ def run_baker(X, R, S, order=2):
R = Simplex R = Simplex
S = extra points S = extra points
""" """
log.debug("order = %d" % order)
log.debug("extra points = %d" % len(S.verts))
answer = { answer = {
'qlin': None, 'qlin': None,
'error': None, 'error': None,
'final': None, 'final': None,
} }
# calculate values only for the simplex triangle # calculate values only for the simplex triangle
phi, qlin = qlinear(X, R) phi, qlin = qlinear(X, R, R_q)
if order == 1: if order == 1:
answer['qlin'] = qlin answer['qlin'] = qlin
answer['final'] = qlin answer['final'] = qlin
return answer return answer
elif order in xrange(2,11): elif order in xrange(2, 11):
error_term, abc = get_error(phi, R, S, order) error_term, abc = get_error(phi, R, R_q, S, S_q, order)
# if a pathological vertex configuration was encountered and
# AGGRESSIVE_ERROR_SOLVE is False, get_error will return (None, None)
# indicating that only linear interpolation should be performed
if (error_term is None) and (abc is None):
if RAISE_PATHOLOGICAL_EXCEPTION:
raise np.linalg.LinAlgError("Pathological Vertex Config")
answer['qlin'] = qlin
answer['final'] = qlin
return answer
else: else:
raise Exception('unsupported order "%d" for baker method' % order) raise Exception('unsupported order "%d" for baker method' % order)
q_final = qlin + error_term q_final = qlin + error_term
answer['qlin' ] = qlin answer['qlin'] = qlin
answer['error'] = error_term answer['error'] = error_term
answer['final'] = q_final answer['final'] = q_final
answer['abc' ] = abc answer['abc'] = abc
log.debug(answer)
return answer return answer
@ -194,11 +179,11 @@ def memoize(f):
http://en.wikipedia.org/wiki/Memoize http://en.wikipedia.org/wiki/Memoize
""" """
cache = {} cache = {}
@wraps(f) @wraps(f)
def memf(simplex_size, nu): def memf(simplex_size, nu):
x = (simplex_size, nu) x = (simplex_size, nu)
if x not in cache: if x not in cache:
log.debug("adding to cache: %s", x)
cache[x] = f(simplex_size, nu) cache[x] = f(simplex_size, nu)
return cache[x] return cache[x]
return memf return memf
@ -210,11 +195,10 @@ def pattern(simplex_size, nu):
This function returns the pattern requisite to compose the error This function returns the pattern requisite to compose the error
approximation function, and the matrix B. approximation function, and the matrix B.
""" """
log.debug("pattern: simplex: %d, order: %d" % (simplex_size, nu))
r = [] r = []
for i in itertools.product(xrange(simplex_size), repeat = nu): for i in itertools.product(xrange(simplex_size), repeat=nu):
if len(set(i)) !=1: if len(set(i)) != 1:
r.append(tuple(sorted(i))) r.append(tuple(sorted(i)))
unique_r = list(set(r)) unique_r = list(set(r))
return unique_r return unique_r

View File

@ -5,6 +5,7 @@ import rlcompleter
historyPath = os.path.expanduser("~/.pyhistory") historyPath = os.path.expanduser("~/.pyhistory")
def save_history(historyPath=historyPath): def save_history(historyPath=historyPath):
import readline import readline
readline.write_history_file(historyPath) readline.write_history_file(historyPath)

View File

@ -6,16 +6,18 @@ results_q = Queue.Queue()
minions_q = Queue.Queue() minions_q = Queue.Queue()
master_q = Queue.Queue() master_q = Queue.Queue()
class QueueManager(BaseManager): class QueueManager(BaseManager):
""" """
One QueueManager to rule all network Queues One QueueManager to rule all network Queues
""" """
pass pass
QueueManager.register('get_tasks_q' , callable=lambda:tasks_q ) QueueManager.register('get_tasks_q', callable=lambda: tasks_q)
QueueManager.register('get_results_q', callable=lambda:results_q ) QueueManager.register('get_results_q', callable=lambda: results_q)
QueueManager.register('get_minions_q', callable=lambda:minions_q ) QueueManager.register('get_minions_q', callable=lambda: minions_q)
QueueManager.register('get_master_q' , callable=lambda:master_q ) QueueManager.register('get_master_q', callable=lambda: master_q)
def get_qs(qm): def get_qs(qm):
""" """

19
interp/config.py Normal file
View File

@ -0,0 +1,19 @@
import os
import json
default_config = {
'filename': '/tmp/interp.log',
'level': 'debug',
'size': 102400,
'logbackup': 10,
'pypath': None,
}
try:
with open(os.path.expanduser('~/.config/interp.json')) as config_file:
d = json.load(config_file)
except IOError as e:
d = {}
config = dict(default_config.items() + d.items())

View File

@ -1,10 +1,9 @@
from interp.grid.delaunay import dgrid as basegrid
from interp.tools import baker_exact_2D as exact_func
from itertools import product from itertools import product
import numpy as np import numpy as np
from interp.grid.delaunay import dgrid as basegrid
class rect_grid(basegrid): class rect_grid(basegrid):
def __init__(self, xres = 5, yres = 5): def __init__(self, xres = 5, yres = 5):
xmin = 0.0 xmin = 0.0

View File

@ -1,10 +1,9 @@
from interp.grid.delaunay import dgrid as basegrid
from interp.tools import baker_exact_3D, log
from itertools import product from itertools import product
import numpy as np import numpy as np
from interp.grid.delaunay import dgrid as basegrid
class rect_grid(basegrid): class rect_grid(basegrid):
def __init__(self, xres = 5, yres = 5, zres = 5): def __init__(self, xres = 5, yres = 5, zres = 5):
xmin = 0.0 xmin = 0.0
@ -22,7 +21,6 @@ class rect_grid(basegrid):
zspan = zmaz - zmin zspan = zmaz - zmin
zdel = zspan / float(zres - 1) zdel = zspan / float(zres - 1)
verts = [] verts = []
q = np.zeros(xres * yres * zres) q = np.zeros(xres * yres * zres)
for x in xrange(xres): for x in xrange(xres):
@ -41,8 +39,6 @@ class random_grid(rect_grid):
def __init__(self, num_verts = 100): def __init__(self, num_verts = 100):
verts = [] verts = []
r = np.random
appx_side_res = int(np.power(num_verts, 1/3.0)) appx_side_res = int(np.power(num_verts, 1/3.0))
delta = 1.0 / float(appx_side_res) delta = 1.0 / float(appx_side_res)

View File

@ -1,4 +1,3 @@
import sys
from collections import defaultdict from collections import defaultdict
import pickle import pickle
@ -9,14 +8,19 @@ from scipy.spatial import KDTree
from interp.baker import run_baker from interp.baker import run_baker
from interp.baker import get_phis from interp.baker import get_phis
import interp
import logging import logging
log = logging.getLogger("interp") log = logging.getLogger("interp")
MAX_SEARCH_COUNT = 256 MAX_SEARCH_COUNT = 256
TOL = 1e-8
__version__ = interp.__version__
class grid(object): class grid(object):
def __init__(self, verts = None, q = None): def __init__(self, verts=None, q=None):
""" """
verts = array of arrays (if passed in, will convert to numpy.array) verts = array of arrays (if passed in, will convert to numpy.array)
[ [
@ -60,7 +64,8 @@ class grid(object):
attempts += 1 attempts += 1
if attempts > MAX_SEARCH_COUNT: if attempts > MAX_SEARCH_COUNT:
raise Exception("Is the search becoming exhaustive? (%d attempts)" % attempts) raise Exception("Is the search becoming exhaustive?'\
'(%d attempts)" % attempts)
cur_cell = cells_to_check.pop(0) cur_cell = cells_to_check.pop(0)
checked_cells.append(cur_cell) checked_cells.append(cur_cell)
@ -70,7 +75,8 @@ class grid(object):
continue continue
for neighbor in cur_cell.neighbors: for neighbor in cur_cell.neighbors:
if (neighbor not in checked_cells) and (neighbor not in cells_to_check): if (neighbor not in checked_cells) \
and (neighbor not in cells_to_check):
cells_to_check.append(neighbor) cells_to_check.append(neighbor)
if not simplex: if not simplex:
@ -85,15 +91,15 @@ class grid(object):
def create_mesh(self, indicies): def create_mesh(self, indicies):
""" """
this function takes a list of indicies, and then creates and returns a this function takes a list of indicies, and then creates and
grid object (collection of verts and q). returns a grid object (collection of verts and q).
note: the input is indicies, the grid contains verts note: the input is indicies, the grid contains verts
""" """
return grid(self.verts[indicies], self.q[indicies]) return grid(self.verts[indicies], self.q[indicies])
def get_simplex_and_nearest_points(self, X, extra_points = 3): def get_simplex_and_nearest_points(self, X, extra_points=3):
""" """
this returns two grid objects: R and S. this returns two grid objects: R and S.
@ -130,7 +136,7 @@ class grid(object):
return (r_mesh, s_mesh) return (r_mesh, s_mesh)
def run_baker(self, X, order = 2, extra_points = 3): def run_baker(self, X, order=2, extra_points=3):
(R, S) = self.get_simplex_and_nearest_points(X, extra_points) (R, S) = self.get_simplex_and_nearest_points(X, extra_points)
answer = run_baker(X, R, S, order) answer = run_baker(X, R, S, order)
return answer return answer
@ -140,7 +146,7 @@ class grid(object):
this returns a generator that should be fed into qdelaunay this returns a generator that should be fed into qdelaunay
""" """
yield str(len(self.verts[0])); yield str(len(self.verts[0]))
yield '%d' % len(self.verts) yield '%d' % len(self.verts)
for p in self.verts: for p in self.verts:
@ -159,9 +165,9 @@ class grid(object):
def __str__(self): def __str__(self):
r = '' r = ''
assert( len(self.verts) == len(self.q) ) assert(len(self.verts) == len(self.q))
for c, i in enumerate(zip(self.verts, self.q)): for c, i in enumerate(zip(self.verts, self.q)):
r += "%d vert(%s): q(%0.4f)" % (c,i[0], i[1]) r += "%d vert(%s): q(%0.4f)" % (c, i[0], i[1])
cell_str = ", ".join([str(f.name) for f in self.cells_for_vert[c]]) cell_str = ", ".join([str(f.name) for f in self.cells_for_vert[c]])
r += " cells: [%s]" % cell_str r += " cells: [%s]" % cell_str
r += "\n" r += "\n"
@ -170,18 +176,21 @@ class grid(object):
r += "%s\n" % v r += "%s\n" % v
return r return r
def normalize_q(self, new_max = 0.1): def normalize_q(self, new_max=0.1):
largest_number = np.max(np.abs(self.q)) largest_number = np.max(np.abs(self.q))
self.q *= new_max/largest_number self.q *= new_max / largest_number
def dump_to_blender_files(self,
def dump_to_blender_files(self, pfile = '/tmp/points.p', cfile = '/tmp/cells.p'): pfile='/tmp/points.p', cfile='/tmp/cells.p'):
if len(self.verts[0]) == 2: if len(self.verts[0]) == 2:
pickle.dump([(p[0], p[1], 0.0) for p in self.verts], open(pfile, 'w')) pickle.dump([(p[0], p[1], 0.0) for p in self.verts],
open(pfile, 'w'))
else: else:
pickle.dump([(p[0], p[1], p[2]) for p in self.verts], open(pfile, 'w')) pickle.dump([(p[0], p[1], p[2]) for p in self.verts],
open(pfile, 'w'))
pickle.dump([f.verts for f in self.cells.itervalues()], open(cfile, 'w')) pickle.dump([f.verts for f in self.cells.itervalues()],
open(cfile, 'w'))
def get_xml(self): def get_xml(self):
doc = Document() doc = Document()
@ -193,13 +202,14 @@ class grid(object):
p.setAttribute("x", str(i[0][0])) p.setAttribute("x", str(i[0][0]))
p.setAttribute('y', str(i[0][1])) p.setAttribute('y', str(i[0][1]))
p.setAttribute('z', str(i[0][2])) p.setAttribute('z', str(i[0][2]))
p.setAttribute('q', str(i[1] )) p.setAttribute('q', str(i[1]))
ps.appendChild(p) ps.appendChild(p)
return doc return doc
def toxml(self): def toxml(self):
return self.get_xml().toxml() return self.get_xml().toxml()
def toprettyxml(self): def toprettyxml(self):
return self.get_xml().toprettyxml() return self.get_xml().toprettyxml()
@ -227,9 +237,9 @@ class cell(object):
X = point of interest X = point of interest
G = corrensponding grid object (G.verts) G = corrensponding grid object (G.verts)
because of the way i'm storing things, a cell simply stores indicies, because of the way i'm storing things, a cell simply stores
and so one must pass in a reference to the grid object containing real indicies, and so one must pass in a reference to the grid object
verts. containing real verts.
this simply calls grid.simplex.contains this simply calls grid.simplex.contains
""" """
@ -248,8 +258,6 @@ class cell(object):
__repr__ = __str__ __repr__ = __str__
TOL = 1e-8
def contains(X, R): def contains(X, R):
""" """
tests if X (point) is in R tests if X (point) is in R

View File

@ -1,7 +1,4 @@
import pickle
from itertools import combinations from itertools import combinations
from collections import defaultdict
import numpy as np import numpy as np
from scipy.spatial import KDTree from scipy.spatial import KDTree
@ -36,7 +33,7 @@ class ggrid(grid):
gmsh_file.readline() # $MeshFormat gmsh_file.readline() # $MeshFormat
fmat = gmsh_file.readline() gmsh_file.readline()
gmsh_file.readline() # $EndMeshFormat gmsh_file.readline() # $EndMeshFormat
gmsh_file.readline() # $Nodes gmsh_file.readline() # $Nodes

View File

@ -1,9 +1,5 @@
import os
import numpy as np import numpy as np
import logging
log = logging.getLogger("interp")
def rms(errors): def rms(errors):
""" """
@ -17,64 +13,61 @@ def rms(errors):
# r = np.sqrt(r / len(errors)) # r = np.sqrt(r / len(errors))
# return r # return r
return np.sqrt((errors**2).mean()) return np.sqrt((errors ** 2).mean())
def baker_exact_2D(X): def baker_exact_2D(X):
""" """
the exact function (2D) used from baker's article (for testing, slightly the exact function (2D) used from baker's article (for testing,
modified) slightly modified)
""" """
x ,y = X x, y = X
answer = np.power((np.sin(x * np.pi) * np.cos(y * np.pi)), 2) answer = np.power((np.sin(x * np.pi) * np.cos(y * np.pi)), 2)
log.debug(answer)
return answer return answer
def friendly_exact_2D(X): def friendly_exact_2D(X):
""" """
A friendlier 2D func A friendlier 2D func
""" """
x ,y = X x, y = X
answer = 1.0 + x*x + y*y answer = 1.0 + x * x + y * y
log.debug(answer)
return answer return answer
def baker_exact_3D(X): def baker_exact_3D(X):
""" """
the exact function (3D) used from baker's article (for testing) the exact function (3D) used from baker's article (for testing)
""" """
x = X[0] x, y, z = X
y = X[1] answer = np.power((np.sin(x * np.pi / 2.0) * np.sin(y * np.pi / 2.0) *
z = X[2] np.sin(z * np.pi / 2.0)), 2)
answer = np.power((np.sin(x * np.pi / 2.0) * np.sin(y * np.pi / 2.0) * np.sin(z * np.pi / 2.0)), 2)
log.debug(answer)
return answer return answer
def friendly_exact_3D(X): def friendly_exact_3D(X):
x,y,z = X x, y, z = X
return 1 + x*x + y*y + z*z return 1 + x * x + y * y + z * z
def scipy_exact_2D(X): def scipy_exact_2D(X):
x,y = X x, y = X
return x*(1-x)*np.cos(4*np.pi*x) * np.sin(4*np.pi*y**2)**2 return x * (1 - x) * np.cos(4 * np.pi * x) *\
np.sin(4 * np.pi * y ** 2) ** 2
def improved_answer(answer, exact): def improved_answer(answer, exact):
if not answer['error']: if not answer['error']:
# was probably just a linear interpolation # was probably just a linear interpolation
return False return False
log.debug('qlin: %s' % answer['qlin'])
log.debug('error: %s' % answer['error'])
log.debug('final: %s' % answer['final'])
log.debug('exact: %s' % exact)
if np.abs(answer['final'] - exact) <= np.abs(answer['qlin'] - exact): if np.abs(answer['final'] - exact) <= np.abs(answer['qlin'] - exact):
log.debug(":) improved result")
return True return True
else: else:
log.debug(":( damaged result")
return False return False
def improved(qlin, err, final, exact): def improved(qlin, err, final, exact):
if np.abs(final - exact) <= np.abs(qlin - exact): if np.abs(final - exact) <= np.abs(qlin - exact):
return True return True

View File

@ -3,24 +3,22 @@
import unittest import unittest
from interp import baker from interp import baker
from interp import grid
import numpy as np
import scipy.spatial
class Test(unittest.TestCase): class Test(unittest.TestCase):
def setUp(self): def setUp(self):
self.l = [[-1, 1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1]] self.l = [[-1, 1], [-1, 0], [-1, 1], [0, -1],
[0, 0], [0, 1], [1, -1], [1, 0], [1, 1]]
self.all_points = [ self.all_points = [
[ 0, 0], # 0 [0, 0], # 0
[ 1, 0], # 1 [1, 0], # 1
[ 1, 1], # 2 [1, 1], # 2
[ 0, 1], # 3 [0, 1], # 3
[ 1,-1], # 4 [1, -1], # 4
[ 0,-1], # 5 [0, -1], # 5
[-1, 1], # 6 [-1, 1], # 6
[-1, 0], # 7 [-1, 0], # 7
[-1,-1], # 8 [-1, -1], # 8
] ]
self.q = [1, 0, 0, 0, 0, 0, 0, 0, 0] self.q = [1, 0, 0, 0, 0, 0, 0, 0, 0]
self.X = [0.5, 0.25] self.X = [0.5, 0.25]
@ -29,39 +27,42 @@ class Test(unittest.TestCase):
def testImports(self): def testImports(self):
import numpy import numpy
import scipy import scipy
import interp.grid import interp.grid as gv
import interp.baker import interp.baker as bv
numpy.__version__
scipy.__version__
gv, bv
def testGetPhis(self): def testGetPhis(self):
X = [0, 0]
X = [0,0]
r = [[-1, -1], [0, 2], [1, -1]] r = [[-1, -1], [0, 2], [1, -1]]
result = baker.get_phis(X, r) result = baker.get_phis(X, r)
right_answer = [1/3.0, 1/3.0, 1/3.0] right_answer = [1 / 3.0, 1 / 3.0, 1 / 3.0]
for a,b in zip(result, right_answer): for a, b in zip(result, right_answer):
self.assertAlmostEqual(a,b) self.assertAlmostEqual(a, b)
def testGetPhis2(self): def testGetPhis2(self):
X = [0.5, 0.25]
X = [0.5,0.25]
r = [[0, 0], [1, 0], [1, 1]] r = [[0, 0], [1, 0], [1, 1]]
result = baker.get_phis(X, r) result = baker.get_phis(X, r)
right_answer = [0.5, 0.25, 0.25] right_answer = [0.5, 0.25, 0.25]
for a,b in zip(result, right_answer): for a, b in zip(result, right_answer):
self.assertEqual(a,b) self.assertEqual(a, b)
def testQlinear(self): def testQlinear(self):
X = [0.5, 0.25] X = [0.5, 0.25]
r = [[0, 0], [1, 0], [1, 1]] r = [[0, 0], [1, 0], [1, 1]]
q = [1, 0, 0] q = [1, 0, 0]
phi, result = baker.qlinear(X, grid.grid(r,q)) phi, result = baker.qlinear(X, r, q)
right_answer = 0.5 right_answer = 0.5
@ -71,80 +72,79 @@ class Test(unittest.TestCase):
size_of_simplex = 3 size_of_simplex = 3
extra_points = 3 extra_points = 3
R = grid.grid(self.all_points[:size_of_simplex], R, R_q = (self.all_points[:size_of_simplex],
self.q[:size_of_simplex]) self.q[:size_of_simplex])
S = grid.grid(self.all_points[size_of_simplex:size_of_simplex + extra_points], S, S_q = (self.all_points[size_of_simplex:size_of_simplex \
+ extra_points],
self.q[size_of_simplex:size_of_simplex + extra_points]) self.q[size_of_simplex:size_of_simplex + extra_points])
answer = baker.run_baker(self.X, R, R_q, S, S_q)
answer = baker.run_baker(self.X, R, S)
a = answer['abc'][0] a = answer['abc'][0]
b = answer['abc'][1] b = answer['abc'][1]
c = answer['abc'][2] c = answer['abc'][2]
self.assertEqual(sorted((a,b,c)), sorted((0,0.0,1/3.))) self.assertEqual(sorted((a, b, c)), sorted((0, 0.0, 1 / 3.)))
def testRunBaker_2(self): def testRunBaker_2(self):
size_of_simplex = 3 size_of_simplex = 3
extra_points = 4 extra_points = 4
R = grid.grid(self.all_points[:size_of_simplex], R, R_q = (self.all_points[:size_of_simplex], self.q[:size_of_simplex])
self.q[:size_of_simplex])
S = grid.grid(self.all_points[size_of_simplex:size_of_simplex + extra_points], S, S_q = (self.all_points[size_of_simplex:size_of_simplex \
+ extra_points],
self.q[size_of_simplex:size_of_simplex + extra_points]) self.q[size_of_simplex:size_of_simplex + extra_points])
answer = baker.run_baker(self.X, R, S) answer = baker.run_baker(self.X, R, R_q, S, S_q)
a, b, c = sorted(answer['abc']) a, b, c = sorted(answer['abc'])
aa,bb,cc = sorted((2/3.0, 2/3.0, 1/3.0)) aa, bb, cc = sorted((2 / 3.0, 2 / 3.0, 1 / 3.0))
self.assertAlmostEqual(a,aa) self.assertAlmostEqual(a, aa)
self.assertAlmostEqual(b,bb) self.assertAlmostEqual(b, bb)
self.assertAlmostEqual(c,cc) self.assertAlmostEqual(c, cc)
def testRunBaker_3(self): def testRunBaker_3(self):
size_of_simplex = 3 size_of_simplex = 3
extra_points = 5 extra_points = 5
R = grid.grid(self.all_points[:size_of_simplex], R, R_q = (self.all_points[:size_of_simplex], self.q[:size_of_simplex])
self.q[:size_of_simplex])
S = grid.grid(self.all_points[size_of_simplex:size_of_simplex + extra_points], S, S_q = (self.all_points[size_of_simplex:size_of_simplex \
+ extra_points],
self.q[size_of_simplex:size_of_simplex + extra_points]) self.q[size_of_simplex:size_of_simplex + extra_points])
answer = baker.run_baker(self.X, R, R_q, S, S_q)
answer = baker.run_baker(self.X, R, S)
a = answer['abc'][0] a = answer['abc'][0]
b = answer['abc'][1] b = answer['abc'][1]
c = answer['abc'][2] c = answer['abc'][2]
a,b,c = sorted((a,b,c)) a, b, c = sorted((a, b, c))
aa, bb, cc = sorted((13/14., 2/7., 15/14.)) aa, bb, cc = sorted((13 / 14., 2 / 7., 15 / 14.))
self.assertAlmostEqual(a,aa) self.assertAlmostEqual(a, aa)
self.assertAlmostEqual(b,bb) self.assertAlmostEqual(b, bb)
self.assertAlmostEqual(c,cc) self.assertAlmostEqual(c, cc)
def testRunBaker_4(self): def testRunBaker_4(self):
size_of_simplex = 3 size_of_simplex = 3
extra_points = 6 extra_points = 6
R = grid.grid(self.all_points[:size_of_simplex], R, R_q = (self.all_points[:size_of_simplex],
self.q[:size_of_simplex]) self.q[:size_of_simplex])
S, S_q = (self.all_points[size_of_simplex:size_of_simplex \
S = grid.grid(self.all_points[size_of_simplex:size_of_simplex + extra_points], + extra_points],
self.q[size_of_simplex:size_of_simplex + extra_points]) self.q[size_of_simplex:size_of_simplex + extra_points])
answer = baker.run_baker(self.X, R, R_q, S, S_q)
answer = baker.run_baker(self.X, R, S)
a = answer['abc'][0] a = answer['abc'][0]
b = answer['abc'][1] b = answer['abc'][1]
c = answer['abc'][2] c = answer['abc'][2]
a,b,c = sorted((a,b,c)) a, b, c = sorted((a, b, c))
aa,bb,cc = sorted((48/53.0, 15/53.0, 54/53.0)) aa, bb, cc = sorted((48 / 53.0, 15 / 53.0, 54 / 53.0))
self.assertAlmostEqual(a, aa) self.assertAlmostEqual(a, aa)
self.assertAlmostEqual(b, bb) self.assertAlmostEqual(b, bb)

View File

@ -8,23 +8,26 @@ import numpy as np
from interp.grid import contains from interp.grid import contains
def exact_func(point): def exact_func(point):
x = point[0] x = point[0]
y = point[1] y = point[1]
return 0.5 + x*x + y return 0.5 + x * x + y
def calculate_error_term(self, a,b,c,d,e,f):
def calculate_error_term(self, a, b, c, d, e, f):
B = np.array([ B = np.array([
self.p1[a] * self.p1[b], self.p1[c] * self.p1[d], self.p1[e] * self.p1[f], self.p1[a] * self.p1[b], self.p1[c] * self.p1[d], self.p1[e] * self.p1[f],
self.p2[a] * self.p2[b], self.p2[c] * self.p2[d], self.p2[e] * self.p2[f], self.p2[a] * self.p2[b], self.p2[c] * self.p2[d], self.p2[e] * self.p2[f],
self.p3[a] * self.p3[b], self.p3[c] * self.p3[d], self.p3[e] * self.p3[f], self.p3[a] * self.p3[b], self.p3[c] * self.p3[d], self.p3[e] * self.p3[f],
self.p4[a] * self.p4[b], self.p4[c] * self.p4[d], self.p4[e] * self.p4[f], self.p4[a] * self.p4[b], self.p4[c] * self.p4[d], self.p4[e] * self.p4[f],
]) ])
B.shape = (4,3)
B.shape = (4, 3)
A = np.dot(B.T, B) A = np.dot(B.T, B)
rhs = np.dot(B.T, self.w) rhs = np.dot(B.T, self.w)
abc = np.linalg.solve(A,rhs) abc = np.linalg.solve(A, rhs)
err = \ err = \
abc[0] * self.phis[a] * self.phis[b] + \ abc[0] * self.phis[a] * self.phis[b] + \
@ -32,36 +35,35 @@ def calculate_error_term(self, a,b,c,d,e,f):
abc[2] * self.phis[e] * self.phis[f] abc[2] * self.phis[e] * self.phis[f]
return err return err
class Test(unittest.TestCase): class Test(unittest.TestCase):
def setUp(self): def setUp(self):
self.verts = [ self.verts = [
[ 2, 3], # 0 [2, 3], # 0
[ 7, 4], # 1 [7, 4], # 1
[ 4, 8], # 2 [4, 8], # 2
[ 0, 7], # 3, 1 [0, 7], # 3, 1
[ 5, 0], # 4, 2 [5, 0], # 4, 2
[10, 5], # 5, 3 [0, 5], # 5, 3
[ 8, 9], # 6, 4 [8, 9], # 6, 4
] ]
self.q = [exact_func(v) for v in self.verts] self.q = [exact_func(v) for v in self.verts]
self.g = grid(self.verts, self.q) self.g = grid(self.verts, self.q)
self.R = grid(self.verts[:3], self.q[:3]) self.R, self.R_q = (self.verts[:3], self.q[:3])
self.S = grid(self.verts[3:], self.q[3:]) self.S, self.S_q = (self.verts[3:], self.q[3:])
self.p1, self.ql1 = baker.qlinear(self.verts[3], self.R) self.p1, self.ql1 = baker.qlinear(self.verts[3], self.R, self.q)
self.p2, self.ql2 = baker.qlinear(self.verts[4], self.R) self.p2, self.ql2 = baker.qlinear(self.verts[4], self.R, self.q)
self.p3, self.ql3 = baker.qlinear(self.verts[5], self.R) self.p3, self.ql3 = baker.qlinear(self.verts[5], self.R, self.q)
self.p4, self.ql4 = baker.qlinear(self.verts[6], self.R) self.p4, self.ql4 = baker.qlinear(self.verts[6], self.R, self.q)
self.q1 = exact_func(self.verts[3]) self.q1 = exact_func(self.verts[3])
self.q2 = exact_func(self.verts[4]) self.q2 = exact_func(self.verts[4])
self.q3 = exact_func(self.verts[5]) self.q3 = exact_func(self.verts[5])
self.q4 = exact_func(self.verts[6]) self.q4 = exact_func(self.verts[6])
self.w = np.array([ self.w = np.array([
self.q1 - self.ql1, self.q1 - self.ql1,
self.q2 - self.ql2, self.q2 - self.ql2,
@ -69,33 +71,36 @@ class Test(unittest.TestCase):
self.q4 - self.ql4, self.q4 - self.ql4,
]) ])
self.X = [4,5] self.X = [4, 5]
self.g = grid(self.verts, self.q) self.g = grid(self.verts, self.q)
self.phis, self.qlin = baker.qlinear(self.X, self.R) self.phis, self.qlin = baker.qlinear(self.X, self.R, self.q)
self.exact = exact_func(self.X) self.exact = exact_func(self.X)
self.answer = baker.run_baker(self.X,self.R,self.S) self.answer = baker.run_baker(self.X, self.R,
self.R_q, self.S, self.S_q)
def test_R_contains_X(self): def test_R_contains_X(self):
self.assertTrue(contains(self.X, self.R.verts)) self.assertTrue(contains(self.X, self.R))
def test_1(self): def test_1(self):
a,b,c,d,e,f = (0,1, 1,2, 2,0) a, b, c, d, e, f = (0, 1, 1, 2, 2, 0)
err = calculate_error_term(self, a,b,c,d,e,f) err = calculate_error_term(self, a, b, c, d, e, f)
self.assertAlmostEqual(err, self.answer['error']) self.assertAlmostEqual(err, self.answer['error'])
def test_swap_first_elements(self): def test_swap_first_elements(self):
a,b,c,d,e,f = (1,0, 1,2, 2,0) a, b, c, d, e, f = (1, 0, 1, 2, 2, 0)
err = calculate_error_term(self, a,b,c,d,e,f) err = calculate_error_term(self, a, b, c, d, e, f)
self.assertAlmostEqual(err, self.answer['error']) self.assertAlmostEqual(err, self.answer['error'])
def test_swap_two_pairs(self): def test_swap_two_pairs(self):
a,b,c,d,e,f = (1,2, 0,1, 2,0) a, b, c, d, e, f = (1, 2, 0, 1, 2, 0)
err = calculate_error_term(self, a,b,c,d,e,f) err = calculate_error_term(self, a, b, c, d, e, f)
self.assertAlmostEqual(err, self.answer['error']) self.assertAlmostEqual(err, self.answer['error'])
def test_swap_all_pairs(self): def test_swap_all_pairs(self):
a,b,c,d,e,f = (0,2, 0,1, 2,1) a, b, c, d, e, f = (0, 2, 0, 1, 2, 1)
err = calculate_error_term(self, a,b,c,d,e,f) err = calculate_error_term(self, a, b, c, d, e, f)
self.assertAlmostEqual(err, self.answer['error']) self.assertAlmostEqual(err, self.answer['error'])

View File

@ -2,10 +2,7 @@
import unittest import unittest
from interp.baker import get_phis, qlinear from interp.baker import get_phis, qlinear
from interp.grid import grid
import numpy as np
import scipy.spatial
class Test(unittest.TestCase): class Test(unittest.TestCase):
def setUp(self): def setUp(self):
@ -18,17 +15,15 @@ class Test(unittest.TestCase):
] ]
self.q = [0.0, 0.0, 0.0, 4] self.q = [0.0, 0.0, 0.0, 4]
def testGetPhis(self): def testGetPhis(self):
result = get_phis(self.X, self.r) result = get_phis(self.X, self.r)
right_answer = [0.25, 0.25, 0.25, 0.25] right_answer = [0.25, 0.25, 0.25, 0.25]
for a,b in zip(result, right_answer): for a, b in zip(result, right_answer):
self.assertAlmostEqual(a,b) self.assertAlmostEqual(a, b)
def testQlinear(self): def testQlinear(self):
phi, result = qlinear(self.X, grid(self.r, self.q)) phi, result = qlinear(self.X, self.r, self.q)
result = result result = result
right_answer = 1.0 right_answer = 1.0
self.assertAlmostEqual(result, right_answer) self.assertAlmostEqual(result, right_answer)

View File

@ -4,67 +4,72 @@ import unittest
from interp.baker import run_baker from interp.baker import run_baker
from interp.grid import grid
from interp.grid import contains from interp.grid import contains
def exact_func(X): def exact_func(X):
x = X[0] x = X[0]
y = X[0] y = X[0]
return 1 + x + y return 1 + x + y
class Test(unittest.TestCase): class Test(unittest.TestCase):
def setUp(self): def setUp(self):
self.verts = [ self.g = [[0.25, 0.40], # 0
[ 0.25, 0.40], # 0 [0.60, 0.80], # 1
[ 0.60, 0.80], # 1 [0.65, 0.28], # 2
[ 0.65, 0.28], # 2 [0.28, 0.65], # 3
[ 0.28, 0.65], # 3 [1.00, 0.75], # 4
[ 1.00, 0.75], # 4 [0.30, 0.95], # 5
[ 0.30, 0.95], # 5 [0.80, 0.50], # 6
[ 0.80, 0.50], # 6 [0.35, 0.15], # 7
[ 0.35, 0.15], # 7
] ]
self.q = [exact_func(p) for p in self.verts] self.q = [exact_func(p) for p in self.g]
self.X = [0.55, 0.45] self.X = [0.55, 0.45]
self.R = self.g[0:3]
self.g = grid(self.verts, self.q) self.R_q = self.q[0:3]
# self.g.construct_connectivity()
self.R = self.g.create_mesh(range(3))
self.exact = exact_func(self.X) self.exact = exact_func(self.X)
def test_R_contains_X(self): def test_R_contains_X(self):
self.assertTrue(contains(self.X, self.R.verts)) self.assertTrue(contains(self.X, self.R))
def test_RunBaker_1_extra_point(self, extra=1): def test_RunBaker_1_extra_point(self, extra=1):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3:3 + extra]
answer = run_baker(self.X, self.R, S, order=3) S_q = self.q[3:3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
# expected failure ...
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)
def test_RunBaker_2_extra_point(self, extra=2): def test_RunBaker_2_extra_point(self, extra=2):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S, order=3) S_q = self.q[3:3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)
def test_RunBaker_3_extra_point(self, extra=3): def test_RunBaker_3_extra_point(self, extra=3):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S, order=3) S_q = self.q[3:3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)
def test_RunBaker_4_extra_point(self, extra=4): def test_RunBaker_4_extra_point(self, extra=4):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S, order=3) S_q = self.q[3:3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)
def test_RunBaker_5_extra_point(self, extra=5): def test_RunBaker_5_extra_point(self, extra=5):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S, order=3) S_q = self.q[3:3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)

View File

@ -4,48 +4,46 @@ import unittest
from interp.baker import pattern from interp.baker import pattern
class Test(unittest.TestCase): class Test(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
def testImports(self): def testImports(self):
from interp.baker import pattern from interp.baker import pattern as ppp
ppp
def test_baker_eq_8(self): def test_baker_eq_8(self):
b = sorted([tuple(sorted(i)) for i in ((0,1),(1,2),(2,0))]) b = sorted([tuple(sorted(i)) for i in ((0, 1), (1, 2), (2, 0))])
p = sorted(pattern(3,2)) p = sorted(pattern(3, 2))
self.assertEqual(b,p) self.assertEqual(b, p)
def test_baker_eq_17(self): def test_baker_eq_17(self):
b = sorted([tuple(sorted(i)) for i in ((0,1,1), (0,2,2), (1,0,0), (1,2,2), (2,0,0), (2,1,1), (0,1,2))]) b = sorted([tuple(sorted(i)) for i in ((0, 1, 1), (0, 2, 2), (1, 0, 0),
p = sorted(pattern(3,3)) (1, 2, 2), (2, 0, 0), (2, 1, 1), (0, 1, 2))])
self.assertEqual(b,p) p = sorted(pattern(3, 3))
self.assertEqual(b, p)
def test_baker_eq_15(self): def test_baker_eq_15(self):
b = sorted([tuple(sorted(i)) for i in ( b = sorted([tuple(sorted(i)) for i in (
(0,1), (0,2), (0,3), (0, 1), (0, 2), (0, 3),
(1,2), (1,3), (2,3))]) (1, 2), (1, 3), (2, 3))])
p = sorted(pattern(4,2)) p = sorted(pattern(4, 2))
self.assertEqual(b,p) self.assertEqual(b, p)
def test_smcquay_(self): def test_smcquay_(self):
b = sorted([tuple(sorted(i)) for i in ( b = sorted([tuple(sorted(i)) for i in (
(0,1,2), (1,2,3), (0,1,3), (0,2,3), (0, 1, 2), (1, 2, 3), (0, 1, 3), (0, 2, 3),
(0,0,1), (0,1,1), (0, 0, 1), (0, 1, 1),
(1,2,2), (1,1,2), (1, 2, 2), (1, 1, 2),
(0,2,2), (0,0,2), (0, 2, 2), (0, 0, 2),
(1,3,3), (1,1,3), (1, 3, 3), (1, 1, 3),
(2,2,3), (2,3,3), (2, 2, 3), (2, 3, 3),
(0,3,3), (0,0,3))]) (0, 3, 3), (0, 0, 3))])
p = sorted(pattern(4,3))
self.assertEqual(b,p)
p = sorted(pattern(4, 3))
self.assertEqual(b, p)
if __name__ == '__main__': if __name__ == '__main__':
suite = unittest.TestLoader().loadTestsFromTestCase(Test) suite = unittest.TestLoader().loadTestsFromTestCase(Test)

View File

@ -7,69 +7,75 @@ from interp.baker import run_baker
from interp.grid import grid from interp.grid import grid
from interp.grid import contains from interp.grid import contains
def exact_func(X): def exact_func(X):
x = X[0] x = X[0]
y = X[0] y = X[0]
return 1 - x*x + y*y return 1 - x * x + y * y
class Test(unittest.TestCase): class Test(unittest.TestCase):
def setUp(self): def setUp(self):
self.points = [ self.g = [
[ 0.25, 0.40], # 0 [0.25, 0.40], # 0
[ 0.60, 0.80], # 1 [0.60, 0.80], # 1
[ 0.65, 0.28], # 2 [0.65, 0.28], # 2
[ 0.28, 0.65], # 3 [0.28, 0.65], # 3
[ 1.00, 0.75], # 4 [1.00, 0.75], # 4
[ 0.30, 0.95], # 5 [0.30, 0.95], # 5
[ 0.80, 0.50], # 6 [0.80, 0.50], # 6
[ 0.35, 0.15], # 7 [0.35, 0.15], # 7
] ]
self.q = [exact_func(p) for p in self.points] self.q = [exact_func(p) for p in self.g]
self.X = [0.25, 0.4001] self.X = [0.25, 0.4001]
self.X = [0.55, 0.45] self.X = [0.55, 0.45]
self.g = grid(self.points, self.q) self.R = self.g[0:3]
self.R = self.g.create_mesh(range(3)) self.R_q = self.q[0:3]
self.exact = exact_func(self.X) self.exact = exact_func(self.X)
self.accuracy = 8
def test_R_contains_X(self): def test_R_contains_X(self):
self.assertTrue(contains(self.X, self.R.verts)) self.assertTrue(contains(self.X, self.R))
def test_RunBaker_1_extra_point(self, extra=1): def test_RunBaker_1_extra_point(self, extra=1):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S) S_q = self.q[3: 3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
# I expect this one to be bad: #XXX: not sure about this one:
# self.assertTrue(lin_err >= final_err) self.assertEqual(lin_err, final_err)
def test_RunBaker_2_extra_point(self, extra=2): def test_RunBaker_2_extra_point(self, extra=2):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S) S_q = self.q[3: 3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)
def test_RunBaker_3_extra_point(self, extra=3): def test_RunBaker_3_extra_point(self, extra=3):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S) S_q = self.q[3: 3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)
def test_RunBaker_4_extra_point(self, extra=4): def test_RunBaker_4_extra_point(self, extra=4):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S) S_q = self.q[3: 3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)
def test_RunBaker_5_extra_point(self, extra=5): def test_RunBaker_5_extra_point(self, extra=5):
S = self.g.create_mesh(range(3, 3 + extra)) S = self.g[3: 3 + extra]
answer = run_baker(self.X, self.R, S) S_q = self.q[3: 3 + extra]
answer = run_baker(self.X, self.R, self.R_q, S, S_q)
lin_err = abs(self.exact - answer['qlin']) lin_err = abs(self.exact - answer['qlin'])
final_err = abs(self.exact - answer['final']) final_err = abs(self.exact - answer['final'])
self.assertTrue(lin_err >= final_err) self.assertTrue(lin_err >= final_err)