diff --git a/cache.go b/cache.go index ecddbbb..7fec2b6 100644 --- a/cache.go +++ b/cache.go @@ -41,6 +41,7 @@ type cache struct { defaultExpiration time.Duration items map[string]*Item mu sync.RWMutex + onEvicted func(string, interface{}) janitor *janitor } @@ -810,23 +811,54 @@ func (c *cache) DecrementFloat64(k string, n float64) (float64, error) { // Delete an item from the cache. Does nothing if the key is not in the cache. func (c *cache) Delete(k string) { c.mu.Lock() - c.delete(k) + v, evicted := c.delete(k) c.mu.Unlock() + if evicted { + c.onEvicted(k, v) + } } -func (c *cache) delete(k string) { +func (c *cache) delete(k string) (interface{}, bool) { + if c.onEvicted != nil { + if v, found := c.items[k]; found { + delete(c.items, k) + return v.Object, true + } + } delete(c.items, k) + return nil, false +} + +type keyAndValue struct { + key string + value interface{} } // Delete all expired items from the cache. func (c *cache) DeleteExpired() { + var evictedItems []keyAndValue c.mu.Lock() for k, v := range c.items { if v.Expired() { - c.delete(k) + ov, evicted := c.delete(k) + if evicted { + evictedItems = append(evictedItems, keyAndValue{k, ov}) + } } } c.mu.Unlock() + for _, v := range evictedItems { + c.onEvicted(v.key, v.value) + } +} + +// Sets an (optional) function that is called with the key and value when an +// item is evicted from the cache. (Including when it is deleted manually.) +// Set to nil to disable. +func (c *cache) OnEvicted(f func(string, interface{})) { + c.mu.Lock() + defer c.mu.Unlock() + c.onEvicted = f } // Write the cache's items (using Gob) to an io.Writer. diff --git a/cache_test.go b/cache_test.go index 8b308c8..c357cb0 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1224,6 +1224,24 @@ func TestDecrementUnderflowUint(t *testing.T) { } } +func TestOnEvicted(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("foo", 3, DefaultExpiration) + if tc.onEvicted != nil { + t.Fatal("tc.onEvicted is not nil") + } + works := false + tc.OnEvicted(func(k string, v interface{}) { + if k == "foo" && v.(int) == 3 { + works = true + } + }) + tc.Delete("foo") + if !works { + t.Error("works bool not true") + } +} + func TestCacheSerialization(t *testing.T) { tc := New(DefaultExpiration, 0) testFillAndSerialize(t, tc)