fixed the pattern function and the memoization thereof

This commit is contained in:
Stephen Mardson McQuay 2011-05-04 21:09:15 -06:00
parent 08f898d83a
commit 2e7391f573
2 changed files with 78 additions and 43 deletions

View File

@ -2,6 +2,9 @@ import sys
import numpy as np import numpy as np
from functools import wraps
import itertools
import logging import logging
log = logging.getLogger('interp') log = logging.getLogger('interp')
@ -99,7 +102,7 @@ def get_error(phi, R, S, order = 2):
B = [] # baker eq 9 B = [] # baker eq 9
w = [] # baker eq 11 w = [] # baker eq 11
p = pattern(order, len(phi), offset = -1) p = pattern(len(phi), order)
log.info("pattern: %s" % p) log.info("pattern: %s" % p)
for (s,q) in zip(S.verts, S.q): for (s,q) in zip(S.verts, S.q):
@ -183,42 +186,74 @@ def run_baker(X, R, S, order=2):
return answer return answer
def _boxings(n, k):
"""\
source for this function:
http://old.nabble.com/Simple-combinatorics-with-Numpy-td20086915.html
http://old.nabble.com/Re:-Simple-combinatorics-with-Numpy-p20099736.html
"""
seq, i = [n] * k + [0], k
while i:
yield tuple(seq[i] - seq[i+1] for i in xrange(k))
i = seq.index(0) - 1
seq[i:k] = [seq[i] - 1] * (k-i)
def _samples_ur(items, k, offset = 0):
"""Returns k unordered samples (with replacement) from items."""
n = len(items)
for sample in _boxings(k, n):
selections = [[items[i]]*count for i,count in enumerate(sample)]
yield tuple([x + offset for sel in selections for x in sel])
def memoize(f): def memoize(f):
""" """
I only cache on power and phicount; I figure that one should stick to a for more information on what I'm doing here,
particular offset throughout one's codebase. please read:
http://en.wikipedia.org/wiki/Memoize
""" """
cache = {} cache = {}
@wraps(f)
def memf(*x, **kargs): def memf(*x, **kargs):
if x not in cache: if x not in cache:
cache[x] = f(*x, **kargs) cache[x] = f(*x, **kargs)
return cache[x] return cache[x]
return memf return memf
def combinations_with_replacement(iterable, r):
"""
What I really need for the pattern function only
exists in python 2.7 and greater. The docs suggest
the implementation in this function as a
replacement.
see: http://docs.python.org/library/itertools.html#itertools.combinations_with_replacement
"""
pool = tuple(iterable)
n = len(pool)
for indices in itertools.product(range(n), repeat=r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
@memoize @memoize
def pattern(power, phicount, offset = 0): def pattern(simplex_size, nu):
log.debug("(power = %s, phicount = %s)" % (power, phicount)) """
r = [] my useful docstring
for i in _samples_ur(range(1, phicount + 1), power, offset): """
if not len(set(i)) == 1: log.debug("pattern: simplex: %d, order: %d" % (simplex_size, nu))
r.append(i) return [i for i in combinations_with_replacement(xrange(simplex_size), nu) if len(set(i)) != 1]
return r
if __name__ == '__main__':
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)
print len(pattern(3, 2)), pattern(3, 2)
print len(pattern(4, 2)), pattern(4, 2)
print len(pattern(3, 3)), pattern(3, 3)
print len(pattern(4, 3)), pattern(4, 3)
print len(pattern(3, 4)), pattern(3, 4)
print len(pattern(4, 4)), pattern(4, 4)

View File

@ -13,35 +13,35 @@ class Test(unittest.TestCase):
from interp.baker import pattern from interp.baker import pattern
def test_baker_eq_8(self): def test_baker_eq_8(self):
b = sorted([tuple(sorted(i)) for i in ((1,2),(2,3),(3,1))]) b = sorted([tuple(sorted(i)) for i in ((0,1),(1,2),(2,0))])
p = sorted(pattern(power = 2, phicount = 3)) p = sorted(pattern(3,2))
self.assertEqual(b,p) self.assertEqual(b,p)
def test_baker_eq_17(self): def test_baker_eq_17(self):
b = sorted([tuple(sorted(i)) for i in ((1,2,2), (1,3,3), (2,1,1), (2,3,3), (3,1,1), (3,2,2), (1,2,3))]) b = sorted([tuple(sorted(i)) for i in ((0,1,1), (0,2,2), (1,0,0), (1,2,2), (2,0,0), (2,1,1), (0,1,2))])
p = sorted(pattern(power = 3, phicount = 3)) p = sorted(pattern(3,3))
self.assertEqual(b,p) self.assertEqual(b,p)
def test_baker_eq_15(self): def test_baker_eq_15(self):
b = sorted([tuple(sorted(i)) for i in ( b = sorted([tuple(sorted(i)) for i in (
(1,2), (1,3), (1,4), (0,1), (0,2), (0,3),
(2,3), (2,4), (3,4))]) (1,2), (1,3), (2,3))])
p = sorted(pattern(power = 2, phicount = 4)) p = sorted(pattern(4,2))
self.assertEqual(b,p) self.assertEqual(b,p)
def test_smcquay_(self): def test_smcquay_(self):
b = sorted([tuple(sorted(i)) for i in ( b = sorted([tuple(sorted(i)) for i in (
(1,2,3), (2,3,4), (1,2,4), (1,3,4), (0,1,2), (1,2,3), (0,1,3), (0,2,3),
(1,1,2), (1,2,2), (0,0,1), (0,1,1),
(2,3,3), (2,2,3), (1,2,2), (1,1,2),
(0,2,2), (0,0,2),
(1,3,3), (1,1,3), (1,3,3), (1,1,3),
(2,4,4), (2,2,4), (2,2,3), (2,3,3),
(3,3,4), (3,4,4), (0,3,3), (0,0,3))])
(1,4,4), (1,1,4))])
p = sorted(pattern(power = 3, phicount = 4)) p = sorted(pattern(4,3))
self.assertEqual(b,p) self.assertEqual(b,p)