from collections import namedtuple from functools import wraps import itertools import numpy as np import interp AGGRESSIVE_ERROR_SOLVE = True RAISE_PATHOLOGICAL_EXCEPTION = False __version__ = interp.__version__ Answer = namedtuple("Answer", ['qlin', 'final', 'error', 'abc']) def get_phis(X, R): """ The get_phis function is used to get barycentric coordonites for a point on a triangle or tetrahedron (Equation (*\ref{eq:qlinarea}*)) in 2D: X - the destination point (2D) X = [0,0] R - the three points that make up the 2-D triangular simplex R = [[-1, -1], [0, 2], [1, -1]] this will return [0.333, 0.333, 0.333] in 3D: X - the destination point (3D) X = [0,0,0] R - the four points that make up the 3-D simplex (tetrahedron) R = [ [ 0.0000, 0.0000, 1.0000], [ 0.9428, 0.0000, -0.3333], [-0.4714, 0.8165, -0.3333], [-0.4714, -0.8165, -0.3333], ] this will return [0.25, 0.25, 0.25, 0.25] """ # equations (*\ref{eq:lin3d}*) and (*\ref{eq:lin2d}*) if len(X) == 2: A = np.array([ [1, 1, 1], [R[0][0], R[1][0], R[2][0]], [R[0][1], R[1][1], R[2][1]], ]) b = np.array([1, X[0], X[1]]) elif len(X) == 3: A = np.array([ [1, 1, 1, 1], [R[0][0], R[1][0], R[2][0], R[3][0]], [R[0][1], R[1][1], R[2][1], R[3][1]], [R[0][2], R[1][2], R[2][2], R[3][2]], ]) b = np.array([1, X[0], X[1], X[2]]) else: raise Exception("inapropriate demension on X") phi = np.linalg.solve(A, b) return phi def qlinear(X, R, q): """ this calculates the linear portion of q from R to X This is equation (*\ref{eq:qlinbasis}*) X = destination point R = a inter.grid object; must have R.points and R.q """ phis = get_phis(X, R) qlin = np.sum([q_i * phi_i for q_i, phi_i in zip(q, phis)]) return phis, qlin def get_error(phi, R, R_q, S, S_q, order=2): """ Calculate the error approximation terms, returning the unknowns a,b, and c in equation (*\ref{eq:quadratic2d}*). """ B = [] # equation ((*\ref{eq:B2d}*) w = [] # equation ((*\ref{eq:w}*) cur_pattern = pattern(len(phi), order) for (s, cur_q) in zip(S, S_q): cur_phi, cur_qlin = qlinear(s, R, R_q) l = [] for i in cur_pattern: cur_sum = cur_phi[i[0]] for j in i[1:]: cur_sum *= cur_phi[j] l.append(cur_sum) B.append(l) w.append(cur_q - cur_qlin) B = np.array(B) w = np.array(w) A = np.dot(B.T, B) b = np.dot(B.T, w) try: abc = np.linalg.solve(A, b) except np.linalg.LinAlgError: if not AGGRESSIVE_ERROR_SOLVE: return None, None abc = np.dot(np.linalg.pinv(A), b) error_term = 0.0 for (a, i) in zip(abc, cur_pattern): cur_sum = a for j in i: cur_sum *= phi[j] error_term += cur_sum return error_term, abc def interpolate(X, R, R_q, S=None, S_q=None, order=2): """ This is the main function to call to get an interpolation to X from the input meshes X -- the destination point R = Simplex R_q = q values at R S = extra points S_q = q values at S order - order of interpolation - 1 """ qlin = None error_term = None final = None abc = {} # calculate values only for the simplex triangle phi, qlin = qlinear(X, R, R_q) if order in xrange(2, 11) and S is not None: error_term, abc = get_error(phi, R, R_q, S, S_q, order) # if a pathological vertex configuration was encountered and # AGGRESSIVE_ERROR_SOLVE is False, get_error will return (None, None) # indicating that only linear interpolation should be performed if (error_term is None) and (abc is None): if RAISE_PATHOLOGICAL_EXCEPTION: raise np.linalg.LinAlgError("Pathological Vertex Config") else: final = qlin + error_term elif order not in xrange(2, 11): raise Exception('unsupported order "%d" for baker method' % order) return Answer(qlin=qlin, error=error_term, final=final, abc=abc) def memoize(f): """ for more information on what I'm doing here, please read: http://en.wikipedia.org/wiki/Memoize """ cache = {} @wraps(f) def memf(simplex_size, nu): x = (simplex_size, nu) if x not in cache: cache[x] = f(simplex_size, nu) return cache[x] return memf @memoize def pattern(simplex_size, nu): """ This function returns the pattern requisite to compose the error approximation function, and the matrix B. """ r = [] for i in itertools.product(xrange(simplex_size), repeat=nu): if len(set(i)) != 1: r.append(tuple(sorted(i))) unique_r = list(set(r)) return unique_r