Browse Source

replaced answer with a namedtuple

Stephen M. McQuay 8 years ago
parent
commit
0577356cd7
8 changed files with 94 additions and 82 deletions
  1. 19
    24
      interp/baker/__init__.py
  2. 3
    3
      interp/grid/__init__.py
  3. 7
    0
      interp/tools.py
  4. 1
    1
      test/all.py
  5. 26
    14
      test/baker2d.py
  6. 5
    5
      test/baker2dorder.py
  7. 17
    18
      test/cubic2d.py
  8. 16
    17
      test/quadratic2d.py

+ 19
- 24
interp/baker/__init__.py View File

@@ -1,8 +1,9 @@
1
-import numpy as np
2
-
1
+from collections import namedtuple
3 2
 from functools import wraps
4 3
 import itertools
5 4
 
5
+import numpy as np
6
+
6 7
 import interp
7 8
 
8 9
 AGGRESSIVE_ERROR_SOLVE = True
@@ -10,6 +11,8 @@ RAISE_PATHOLOGICAL_EXCEPTION = False
10 11
 
11 12
 __version__ = interp.__version__
12 13
 
14
+Answer = namedtuple("Answer", ['qlin', 'final', 'error', 'abc'])
15
+
13 16
 
14 17
 def get_phis(X, R):
15 18
     """
@@ -124,7 +127,7 @@ def get_error(phi, R, R_q, S, S_q, order=2):
124 127
     return error_term, abc
125 128
 
126 129
 
127
-def run_baker(X, R, R_q, S, S_q, order=2):
130
+def interpolate(X, R, R_q, S=None, S_q=None, order=2):
128 131
     """
129 132
         This is the main function to call to get an interpolation to X from the
130 133
         input meshes
@@ -132,23 +135,22 @@ def run_baker(X, R, R_q, S, S_q, order=2):
132 135
         X -- the destination point
133 136
 
134 137
         R = Simplex
138
+        R_q = q values at R
135 139
         S = extra points
140
+        S_q = q values at S
141
+
142
+        order - order of interpolation - 1
136 143
     """
137 144
 
138
-    answer = {
139
-                         'qlin': None,
140
-                         'error': None,
141
-                         'final': None,
142
-                        }
145
+    qlin=None
146
+    error_term=None
147
+    final=None
148
+    abc={}
143 149
 
144 150
     # calculate values only for the simplex triangle
145 151
     phi, qlin = qlinear(X, R, R_q)
146 152
 
147
-    if order == 1:
148
-        answer['qlin'] = qlin
149
-        answer['final'] = qlin
150
-        return answer
151
-    elif order in xrange(2, 11):
153
+    if order in xrange(2, 11) and S:
152 154
         error_term, abc = get_error(phi, R, R_q, S, S_q, order)
153 155
 
154 156
         # if a pathological vertex configuration was encountered and
@@ -157,20 +159,13 @@ def run_baker(X, R, R_q, S, S_q, order=2):
157 159
         if (error_term is None) and (abc is None):
158 160
             if RAISE_PATHOLOGICAL_EXCEPTION:
159 161
                 raise np.linalg.LinAlgError("Pathological Vertex Config")
160
-            answer['qlin'] = qlin
161
-            answer['final'] = qlin
162
-            return answer
163
-    else:
162
+        else:
163
+            final = qlin + error_term
164
+    elif order not in xrange(2,11):
164 165
         raise Exception('unsupported order "%d" for baker method' % order)
165 166
 
166
-    q_final = qlin + error_term
167
-
168
-    answer['qlin'] = qlin
169
-    answer['error'] = error_term
170
-    answer['final'] = q_final
171
-    answer['abc'] = abc
172 167
 
173
-    return answer
168
+    return Answer(qlin=qlin, error=error_term, final=final, abc=abc)
174 169
 
175 170
 
176 171
 def memoize(f):

+ 3
- 3
interp/grid/__init__.py View File

@@ -6,7 +6,7 @@ from xml.dom.minidom import Document
6 6
 import numpy as np
7 7
 from scipy.spatial import KDTree
8 8
 
9
-from interp.baker import run_baker
9
+from interp.baker import interpolate
10 10
 from interp.baker import get_phis
11 11
 import interp
12 12
 
@@ -136,9 +136,9 @@ class grid(object):
136 136
 
137 137
         return (r_mesh, s_mesh)
138 138
 
139
-    def run_baker(self, X, order=2, extra_points=3):
139
+    def interpolate(self, X, order=2, extra_points=3):
140 140
         (R, S) = self.get_simplex_and_nearest_points(X, extra_points)
141
-        answer = run_baker(X, R, S, order)
141
+        answer = interpolate(X, R, S, order)
142 142
         return answer
143 143
 
144 144
     def for_qhull_generator(self):

+ 7
- 0
interp/tools.py View File

@@ -45,6 +45,10 @@ def baker_exact_3D(X):
45 45
                             np.sin(z * np.pi / 2.0)), 2)
46 46
     return answer
47 47
 
48
+def exact_me(X, f):
49
+    a = np.array([f(i) for i in X])
50
+    return a
51
+
48 52
 
49 53
 def friendly_exact_3D(X):
50 54
     x, y, z = X
@@ -67,6 +71,9 @@ def improved_answer(answer, exact):
67 71
     else:
68 72
         return False
69 73
 
74
+def identical_points(a,b):
75
+    return all(set(j[i] for j in a) \
76
+             == set(j[i] for j in b) for i in xrange(len(a[0])))
70 77
 
71 78
 def improved(qlin, err, final, exact):
72 79
     if np.abs(final - exact) <= np.abs(qlin - exact):

+ 1
- 1
test/all.py View File

@@ -21,4 +21,4 @@ if __name__ == '__main__':
21 21
     ]
22 22
 
23 23
     for test in tests:
24
-        unittest.TextTestRunner(verbosity=3).run(test)
24
+        unittest.TextTestRunner(verbosity=1).run(test)

+ 26
- 14
test/baker2d.py View File

@@ -3,6 +3,7 @@
3 3
 import unittest
4 4
 
5 5
 from interp import baker
6
+from interp.baker import Answer
6 7
 
7 8
 
8 9
 class Test(unittest.TestCase):
@@ -68,6 +69,17 @@ class Test(unittest.TestCase):
68 69
 
69 70
         self.assertAlmostEqual(result, right_answer)
70 71
 
72
+    def testRunBaker_linear(self):
73
+        size_of_simplex = 3
74
+
75
+        R, R_q = (self.all_points[:size_of_simplex],
76
+                                 self.q[:size_of_simplex])
77
+
78
+
79
+        answer = baker.interpolate(self.X, R, R_q)
80
+        good_answer = Answer(qlin=0.5, final=None, error=None, abc={})
81
+        self.assertEqual(answer, good_answer)
82
+
71 83
     def testRunBaker_1(self):
72 84
         size_of_simplex = 3
73 85
         extra_points = 3
@@ -79,11 +91,11 @@ class Test(unittest.TestCase):
79 91
                                                  + extra_points],
80 92
                         self.q[size_of_simplex:size_of_simplex + extra_points])
81 93
 
82
-        answer = baker.run_baker(self.X, R, R_q, S, S_q)
94
+        answer = baker.interpolate(self.X, R, R_q, S, S_q)
83 95
 
84
-        a = answer['abc'][0]
85
-        b = answer['abc'][1]
86
-        c = answer['abc'][2]
96
+        a = answer.abc[0]
97
+        b = answer.abc[1]
98
+        c = answer.abc[2]
87 99
 
88 100
         self.assertEqual(sorted((a, b, c)), sorted((0, 0.0, 1 / 3.)))
89 101
 
@@ -97,9 +109,9 @@ class Test(unittest.TestCase):
97 109
                                         + extra_points],
98 110
                         self.q[size_of_simplex:size_of_simplex + extra_points])
99 111
 
100
-        answer = baker.run_baker(self.X, R, R_q, S, S_q)
112
+        answer = baker.interpolate(self.X, R, R_q, S, S_q)
101 113
 
102
-        a, b, c = sorted(answer['abc'])
114
+        a, b, c = sorted(answer.abc)
103 115
         aa, bb, cc = sorted((2 / 3.0, 2 / 3.0, 1 / 3.0))
104 116
 
105 117
         self.assertAlmostEqual(a, aa)
@@ -115,11 +127,11 @@ class Test(unittest.TestCase):
115 127
         S, S_q = (self.all_points[size_of_simplex:size_of_simplex \
116 128
                                     + extra_points],
117 129
                         self.q[size_of_simplex:size_of_simplex + extra_points])
118
-        answer = baker.run_baker(self.X, R, R_q, S, S_q)
130
+        answer = baker.interpolate(self.X, R, R_q, S, S_q)
119 131
 
120
-        a = answer['abc'][0]
121
-        b = answer['abc'][1]
122
-        c = answer['abc'][2]
132
+        a = answer.abc[0]
133
+        b = answer.abc[1]
134
+        c = answer.abc[2]
123 135
 
124 136
         a, b, c = sorted((a, b, c))
125 137
         aa, bb, cc = sorted((13 / 14., 2 / 7., 15 / 14.))
@@ -137,11 +149,11 @@ class Test(unittest.TestCase):
137 149
         S, S_q = (self.all_points[size_of_simplex:size_of_simplex \
138 150
                                             + extra_points],
139 151
                         self.q[size_of_simplex:size_of_simplex + extra_points])
140
-        answer = baker.run_baker(self.X, R, R_q, S, S_q)
152
+        answer = baker.interpolate(self.X, R, R_q, S, S_q)
141 153
 
142
-        a = answer['abc'][0]
143
-        b = answer['abc'][1]
144
-        c = answer['abc'][2]
154
+        a = answer.abc[0]
155
+        b = answer.abc[1]
156
+        c = answer.abc[2]
145 157
 
146 158
         a, b, c = sorted((a, b, c))
147 159
         aa, bb, cc = sorted((48 / 53.0, 15 / 53.0, 54 / 53.0))

+ 5
- 5
test/baker2dorder.py View File

@@ -77,7 +77,7 @@ class Test(unittest.TestCase):
77 77
 
78 78
         self.phis, self.qlin = baker.qlinear(self.X, self.R, self.q)
79 79
         self.exact = exact_func(self.X)
80
-        self.answer = baker.run_baker(self.X, self.R,
80
+        self.answer = baker.interpolate(self.X, self.R,
81 81
                                        self.R_q, self.S, self.S_q)
82 82
 
83 83
     def test_R_contains_X(self):
@@ -86,22 +86,22 @@ class Test(unittest.TestCase):
86 86
     def test_1(self):
87 87
         a, b, c, d, e, f = (0, 1, 1, 2, 2, 0)
88 88
         err = calculate_error_term(self, a, b, c, d, e, f)
89
-        self.assertAlmostEqual(err, self.answer['error'])
89
+        self.assertAlmostEqual(err, self.answer.error)
90 90
 
91 91
     def test_swap_first_elements(self):
92 92
         a, b, c, d, e, f = (1, 0, 1, 2, 2, 0)
93 93
         err = calculate_error_term(self, a, b, c, d, e, f)
94
-        self.assertAlmostEqual(err, self.answer['error'])
94
+        self.assertAlmostEqual(err, self.answer.error)
95 95
 
96 96
     def test_swap_two_pairs(self):
97 97
         a, b, c, d, e, f = (1, 2, 0, 1, 2, 0)
98 98
         err = calculate_error_term(self, a, b, c, d, e, f)
99
-        self.assertAlmostEqual(err, self.answer['error'])
99
+        self.assertAlmostEqual(err, self.answer.error)
100 100
 
101 101
     def test_swap_all_pairs(self):
102 102
         a, b, c, d, e, f = (0, 2, 0, 1, 2, 1)
103 103
         err = calculate_error_term(self, a, b, c, d, e, f)
104
-        self.assertAlmostEqual(err, self.answer['error'])
104
+        self.assertAlmostEqual(err, self.answer.error)
105 105
 
106 106
 
107 107
 if __name__ == '__main__':

+ 17
- 18
test/cubic2d.py View File

@@ -2,9 +2,8 @@
2 2
 
3 3
 import unittest
4 4
 
5
-from interp.baker import run_baker
6
-
7
-from interp.grid    import contains
5
+from interp.baker import interpolate
6
+from interp.grid import contains
8 7
 
9 8
 
10 9
 def exact_func(X):
@@ -36,42 +35,42 @@ class Test(unittest.TestCase):
36 35
     def test_RunBaker_1_extra_point(self, extra=1):
37 36
         S = self.g[3:3 + extra]
38 37
         S_q = self.q[3:3 + extra]
39
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
40
-        lin_err = abs(self.exact - answer['qlin'])
41
-        final_err = abs(self.exact - answer['final'])
38
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q, order=3)
39
+        lin_err = abs(self.exact - answer.qlin)
40
+        final_err = abs(self.exact - answer.final)
42 41
         # expected failure ...
43 42
         self.assertTrue(lin_err >= final_err)
44 43
 
45 44
     def test_RunBaker_2_extra_point(self, extra=2):
46 45
         S = self.g[3: 3 + extra]
47 46
         S_q = self.q[3:3 + extra]
48
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
49
-        lin_err = abs(self.exact - answer['qlin'])
50
-        final_err = abs(self.exact - answer['final'])
47
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q, order=3)
48
+        lin_err = abs(self.exact - answer.qlin)
49
+        final_err = abs(self.exact - answer.final)
51 50
         self.assertTrue(lin_err >= final_err)
52 51
 
53 52
     def test_RunBaker_3_extra_point(self, extra=3):
54 53
         S = self.g[3: 3 + extra]
55 54
         S_q = self.q[3:3 + extra]
56
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
57
-        lin_err = abs(self.exact - answer['qlin'])
58
-        final_err = abs(self.exact - answer['final'])
55
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q, order=3)
56
+        lin_err = abs(self.exact - answer.qlin)
57
+        final_err = abs(self.exact - answer.final)
59 58
         self.assertTrue(lin_err >= final_err)
60 59
 
61 60
     def test_RunBaker_4_extra_point(self, extra=4):
62 61
         S = self.g[3: 3 + extra]
63 62
         S_q = self.q[3:3 + extra]
64
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
65
-        lin_err = abs(self.exact - answer['qlin'])
66
-        final_err = abs(self.exact - answer['final'])
63
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q, order=3)
64
+        lin_err = abs(self.exact - answer.qlin)
65
+        final_err = abs(self.exact - answer.final)
67 66
         self.assertTrue(lin_err >= final_err)
68 67
 
69 68
     def test_RunBaker_5_extra_point(self, extra=5):
70 69
         S = self.g[3: 3 + extra]
71 70
         S_q = self.q[3:3 + extra]
72
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q, order=3)
73
-        lin_err = abs(self.exact - answer['qlin'])
74
-        final_err = abs(self.exact - answer['final'])
71
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q, order=3)
72
+        lin_err = abs(self.exact - answer.qlin)
73
+        final_err = abs(self.exact - answer.final)
75 74
         self.assertTrue(lin_err >= final_err)
76 75
 
77 76
 if __name__ == '__main__':

+ 16
- 17
test/quadratic2d.py View File

@@ -2,8 +2,7 @@
2 2
 
3 3
 import unittest
4 4
 
5
-from interp.baker import run_baker
6
-
5
+from interp.baker import interpolate
7 6
 from interp.grid import grid
8 7
 from interp.grid import contains
9 8
 
@@ -41,9 +40,9 @@ class Test(unittest.TestCase):
41 40
     def test_RunBaker_1_extra_point(self, extra=1):
42 41
         S = self.g[3: 3 + extra]
43 42
         S_q = self.q[3: 3 + extra]
44
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q)
45
-        lin_err = abs(self.exact - answer['qlin'])
46
-        final_err = abs(self.exact - answer['final'])
43
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q)
44
+        lin_err = abs(self.exact - answer.qlin)
45
+        final_err = abs(self.exact - answer.final)
47 46
 
48 47
         #XXX: not sure about this one:
49 48
         self.assertEqual(lin_err,  final_err)
@@ -51,33 +50,33 @@ class Test(unittest.TestCase):
51 50
     def test_RunBaker_2_extra_point(self, extra=2):
52 51
         S = self.g[3: 3 + extra]
53 52
         S_q = self.q[3: 3 + extra]
54
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q)
55
-        lin_err = abs(self.exact - answer['qlin'])
56
-        final_err = abs(self.exact - answer['final'])
53
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q)
54
+        lin_err = abs(self.exact - answer.qlin)
55
+        final_err = abs(self.exact - answer.final)
57 56
         self.assertTrue(lin_err >= final_err)
58 57
 
59 58
     def test_RunBaker_3_extra_point(self, extra=3):
60 59
         S = self.g[3: 3 + extra]
61 60
         S_q = self.q[3: 3 + extra]
62
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q)
63
-        lin_err = abs(self.exact - answer['qlin'])
64
-        final_err = abs(self.exact - answer['final'])
61
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q)
62
+        lin_err = abs(self.exact - answer.qlin)
63
+        final_err = abs(self.exact - answer.final)
65 64
         self.assertTrue(lin_err >= final_err)
66 65
 
67 66
     def test_RunBaker_4_extra_point(self, extra=4):
68 67
         S = self.g[3: 3 + extra]
69 68
         S_q = self.q[3: 3 + extra]
70
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q)
71
-        lin_err = abs(self.exact - answer['qlin'])
72
-        final_err = abs(self.exact - answer['final'])
69
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q)
70
+        lin_err = abs(self.exact - answer.qlin)
71
+        final_err = abs(self.exact - answer.final)
73 72
         self.assertTrue(lin_err >= final_err)
74 73
 
75 74
     def test_RunBaker_5_extra_point(self, extra=5):
76 75
         S = self.g[3: 3 + extra]
77 76
         S_q = self.q[3: 3 + extra]
78
-        answer = run_baker(self.X, self.R, self.R_q, S, S_q)
79
-        lin_err = abs(self.exact - answer['qlin'])
80
-        final_err = abs(self.exact - answer['final'])
77
+        answer = interpolate(self.X, self.R, self.R_q, S, S_q)
78
+        lin_err = abs(self.exact - answer.qlin)
79
+        final_err = abs(self.exact - answer.final)
81 80
         self.assertTrue(lin_err >= final_err)
82 81
 
83 82
 if __name__ == '__main__':

Loading…
Cancel
Save