diff --git a/src/wrf/cache.py b/src/wrf/cache.py index e005e21..12b2226 100644 --- a/src/wrf/cache.py +++ b/src/wrf/cache.py @@ -4,12 +4,39 @@ from __future__ import (absolute_import, division, print_function, from threading import local from collections import OrderedDict +from .py3compat import py3range from .config import get_cache_size _local_storage = local() + +def _shrink_cache(): + """Shrink the cache if applicable. + + This only applies if a user has modified the cache size, otherwise it + just returns. + + Returns: + + None + + """ + global _local_storage + + try: + cache = _local_storage.cache + except AttributeError: + return + + diff = len(cache) - get_cache_size() + + if diff > 0: + for _ in py3range(diff): + cache.popitem(last=False) + + def cache_item(key, product, value): - """Store an item in the cache. + """Store an item in the threadlocal cache. The cache should be viewed as two nested dictionaries. The outer key is usually the id for the sequence where the cached item was generated. The @@ -43,7 +70,9 @@ def cache_item(key, product, value): """ global _local_storage - if key is None: + _shrink_cache() + + if key is None or get_cache_size() == 0: return try: @@ -64,7 +93,7 @@ def cache_item(key, product, value): def get_cached_item(key, product): - """Return an item from the cache. + """Return an item from the threadlocal cache. The cache should be viewed as two nested dictionaries. The outer key is usually the id for the sequence where the cached item was generated. The @@ -94,7 +123,11 @@ def get_cached_item(key, product): :meth:`cache_item` """ - if key is None: + global _local_storage + + _shrink_cache() + + if key is None or get_cache_size == 0: return None cache = getattr(_local_storage, "cache", None) @@ -102,6 +135,9 @@ def get_cached_item(key, product): if cache is None: return None + if len(cache) == 0: + return None + prod_dict = cache.get(key, None) if prod_dict is None: @@ -111,8 +147,9 @@ def get_cached_item(key, product): return result + def _get_cache(): - """Return the cache. + """Return the threadlocal cache. This is primarily used for testing. @@ -121,6 +158,9 @@ def _get_cache(): :class:`threading.local` """ + global _local_storage + + _shrink_cache() return getattr(_local_storage, "cache", None) diff --git a/src/wrf/config.py b/src/wrf/config.py index 0537949..c7f28b2 100644 --- a/src/wrf/config.py +++ b/src/wrf/config.py @@ -127,7 +127,7 @@ def disable_pyngl(): def set_cache_size(size): - """Set the maximum number of items that the threadsafe cache can retain. + """Set the maximum number of items that the threadlocal cache can retain. This cache is primarily used for coordinate variables. @@ -145,7 +145,7 @@ def set_cache_size(size): def get_cache_size(): - """Return the maximum number of items that the threadsafe cache can retain. + """Return the maximum number of items that the threadlocal cache can retain. Returns: