As a fun exercise I tried to code up a thread-safe key-value store with revisions and TTL, meaning that values can time out and past value can be accessed through a full revision log. There are a bunch of open-source key-value store out there to learn and draw inspiration, such as Redis and etcd.
First try (no threading) Link to heading
You’ll need at least Python3.12 and sortedcontainers to run this.
from collections import defaultdict
from sortedcontainers import SortedList
from typing import NamedTuple
import bisect
import operator
type _Timestamp = int
class _LogEntry[V](NamedTuple):
timestamp: _Timestamp
value: V | None
class Store[K: str, V: int]:
"""Key-value store with TTL.
It is required for methods to be called with increasing
`timestamp`.
"""
def __init__(self):
self.log: defaultdict[K, list[_LogEntry[V]]] = defaultdict(list)
self.ttl_map: dict[K, _Timestamp] = {}
# https://peps.python.org/pep-0257/
"""Key `K` will timeout at `_Timestamp`
If `K` is not present, then `K` will have an infinite TTL.
"""
self.ttls = SortedList()
"`SortedList[tuple[_Timestamp, K]]`"
def _honor_ttl(self, timestamp: _Timestamp) -> None:
while self.ttls and self.ttls[0][0] <= timestamp:
ttl, key = self.ttls.pop(0)
if self.ttl_map.get(key, -1) == ttl:
self.ttl_map.pop(key)
self.log[key].append(_LogEntry(ttl, None))
def get(self, timestamp: _Timestamp, key: K) -> V | None:
self._honor_ttl(timestamp)
if (log := self.log.get(key, [])):
return log[-1].value
def set(self, timestamp: _Timestamp, key: K, value: V) -> None:
"""Set key-value pair without TTL.
Meaning that the pair will never time out.
"""
self._honor_ttl(timestamp)
self.log[key].append(_LogEntry(timestamp, value))
try:
ttl = self.ttl_map.pop(key)
self.ttls.remove((ttl, key))
except KeyError:
...
def set_with_ttl(self, timestamp: _Timestamp, key: K, value: V, ttl: int) -> None:
if ttl <= 0:
raise ValueError("TTL should be at least 1.")
self._honor_ttl(timestamp)
self.log[key].append(_LogEntry(timestamp, value))
try:
prev_ttl = self.ttl_map.pop(key)
self.ttls.remove((prev_ttl, key))
except KeyError:
...
self.ttl_map[key] = timestamp + ttl
self.ttls.add((timestamp + ttl, key))
def scan(self, timestamp: _Timestamp) -> list[str]:
"""Get all key-value pairs.
Format for a pair is: `"<key>(<value>)"`
"""
self._honor_ttl(timestamp)
ans = []
for key in self.log:
value = self.get(timestamp, key)
if (value := self.get(timestamp, key)) is not None:
ans.append(f"{key}({value})")
return ans
def get_at(
self,
timestamp: _Timestamp,
key: K,
at_timestamp: _Timestamp,
) -> V | None:
if at_timestamp > timestamp:
raise ValueError("It should hold that: at_timestamp <= timestamp.")
self._honor_ttl(timestamp)
if (log := self.log.get(key)) is None:
return
i = -1 + bisect.bisect_right(
log,
at_timestamp,
key=operator.attrgetter("timestamp")
)
return log[i].value if i >= 0 else None
Some notes about the implementation:
Implementing revisions was as simple as storing a
listof entries together with their timestamp. When updating the value of a key, a simply.append()call does the trick.Since we store a full revision log there is no need to remove timed out values as soon as possible. We don’t free up any memory. Instead every call that accesses the store will first have to update all TTL’d values through
_honor_ttl().Tracking TTL required a
self.ttl_map: dict[K, _Timestamp] = {}to find the timestamp of a key in order to remove the(_Timestamp, K)entry from theSortedListwhenever the TTL value was updated.Now we need to keep two data structures consistent and we are relying on a third-party package (don’t get me wrong,
sortedcontainersis an absolute killer package!). Unlike theheap.goimplementation that etcd uses, Python’sheapqmodule does not have a.heapremove()method that runs inO(log n)to allow performant updates. All heap modifications would have to be tracked so that given a key its index in the heap can be found and a corresponding sift down operation can move the entry to the end of the list for a.pop()call.This can not be implemented on top of
heapqas heap modifications as part of its methods can’t be tracked like they can through theheap.gointerface.Adding new methods or new features to the store will be hard as they all have to make sure to correctly update the existing structures. As you can see, currently each method is updating the
self.ttl_mapandself.ttls, this is bound to introduce issues down the road.
Second try Link to heading
from collections import defaultdict
from collections.abc import Hashable
from typing import NamedTuple
import bisect
import functools
import operator
import random
import threading
import time
type _Timestamp = int
type _Immutable = int | str | bytes
MAX_LOG_SIZE: int = 5
class _LogEntry[V: _Immutable](NamedTuple):
timestamp: _Timestamp
value: V | None
"""Immutable value.
If the value were mutable, then an entry could be changed
someplace else and the entry would no longer be the entry
at time of insertion. So not a true snapshot.
"""
def _behind_mutex(meth):
@functools.wraps(meth)
def wrapper(self, *args, **kwargs):
with self._mutex:
return meth(self, *args, **kwargs)
return wrapper
class _Entry[V: _Immutable]:
def __init__(self, expires_at: _Timestamp | None = None):
self.expires_at: _Timestamp | None = expires_at
"""Timestamp at which the current value expires."""
self._log: list[_LogEntry[V]] = []
"""Revision log of values."""
self._mutex = threading.Lock()
@_behind_mutex
def set_value(
self,
timestamp: _Timestamp,
value: V,
*,
# Has to be passed as keyword argument.
ttl: int | None,
) -> None:
"""Set value for `T >= timestamp` to `value`."""
if self._log and self._log[-1].timestamp > timestamp:
raise RuntimeError("Timestamp values are not increasing.")
if self.expires_at is not None and self.expires_at <= timestamp:
self._log.append(_LogEntry(self.expires_at, None))
if ttl is None:
self.expires_at = None
else:
self.expires_at = timestamp + ttl
self._log.append(_LogEntry(timestamp, value))
@_behind_mutex
def get_value_at(self, timestamp: _Timestamp) -> V | None:
"""Get value at `timestamp`."""
if not self._log:
return None
elif self.expires_at is not None and self.expires_at <= timestamp:
return None
# Quick path for regular get().
elif self._log[-1].timestamp <= timestamp:
return self._log[-1].value
# Revision path.
else:
i = -1 + bisect.bisect_right(
self._log,
timestamp,
key=operator.attrgetter("timestamp"))
return self._log[i].value if i >= 0 else None
@_behind_mutex
def trim_log(self) -> None:
del self._log[:-MAX_LOG_SIZE]
class Store[K: Hashable, V: _Immutable]:
"""Key-value store with TTL.
It is required for methods to be called with increasing
`timestamp`.
"""
def __init__(self, trim: bool = False):
self.map: defaultdict[K, _Entry[V]] = defaultdict(_Entry)
# Duplicate keys so we can pick a number of them at random
# in O(1).
self.keys: list[K] = []
self._mutex = threading.Lock()
"""Mutex to protect adding and removing keys."""
# https://github.com/etcd-io/etcd/blob/main/server/lease/lessor.go#L245
if trim:
self._trimmer_thread = threading.Thread(
target=self._run_trimmer,
daemon=True
)
self._trimmer_thread.start()
def _run_trimmer(self):
"""Periodically trim the revision log of entries.
Entries are selected randomly.
"""
while True:
time.sleep(0.5)
with self._mutex:
to_trim = random.sample(self.keys, k=min(len(self.keys), 3))
for key in to_trim:
with self._mutex:
self.map[key].trim_log()
def get(self, timestamp: _Timestamp, key: K) -> V | None:
# Check __contains__ first to prevent increasing defaultdict
# size.
if key not in self.map:
return None
else:
return self.map[key].get_value_at(timestamp)
def set(self, timestamp: _Timestamp, key: K, value: V) -> None:
"""Set key-value pair without TTL.
Meaning that the value will never time out.
"""
if key in self.map:
# Use lock of underlying _Entry.
self.map[key].set_value(timestamp, value, ttl=None)
else:
with self._mutex:
self.map[key].set_value(timestamp, value, ttl=None)
self.keys.append(key)
def set_with_ttl(self, timestamp: _Timestamp, key: K, value: V, ttl: int) -> None:
if ttl <= 0:
raise ValueError("TTL should be at least 1.")
if key in self.map:
self.map[key].set_value(timestamp, value, ttl=ttl)
else:
with self._mutex:
self.map[key].set_value(timestamp, value, ttl=ttl)
self.keys.append(key)
def scan(self, timestamp: _Timestamp) -> list[str]:
"""Get all key-value pairs.
Format for a pair is: `"<key>(<value>)"`
"""
ans = []
for key, entry in self.map.items():
if (value := entry.get_value_at(timestamp)) is not None:
ans.append(f"{key}({value})")
return ans
def get_at(
self,
timestamp: _Timestamp,
key: K,
at_timestamp: _Timestamp,
) -> V | None:
if at_timestamp > timestamp:
raise ValueError("It should hold that: at_timestamp <= timestamp.")
if key not in self.map:
return None
else:
return self.map[key].get_value_at(at_timestamp)
Some notes about the implementation:
Introduced a long-running thread that trims revision to
MAX_LOG_SIZE. Trimming is done by randomly selecting keys to trim as Redis does the same for removing TTL’d values. Of course, both are very different things but as an exercise I thought it was good practice.Easier to read and maintain version as TTL adherence is offloaded to the
_Entryclass. This allowed for easily adding locking mechanisms to make the store thread-safe.The
trim_log()methods usesdel log[:-N]syntax to trim older entries, which resolves to the exact same function call in CPython aslog[:-N] = []when following the bytecodesDELETE_SUBSCRandSTORE_SLICErespectively.DELETE_SUBSCRcallsPyObject_DelItemandSTORE_SLICEcallsPyObject_SetItem. Which can both be followed further using something like GitHub’s code search.One way to make the trimming faster is by using a FIFO queue (like
collections.deque) so items can be removed from the front in O(1) instead of O(n).
Test cases Link to heading
Besides the below unittests a great way to test the thread-safety of the store would be using the new InterpreterPoolExecutor introduced in Python3.14 that has true multi-core parallelism by spawning an interpreter for each thread (so each has its own GIL).
import unittest
class TestStore(unittest.TestCase):
def test_without_ttl(self):
store = Store[str, int]()
T = 0
store.set(T + 1, "1", 1)
store.set(T + 2, "2", 2)
self.assertEqual(store.scan(T + 3), ["1(1)", "2(2)"])
store.set(T + 4, "1", 3)
self.assertEqual(store.scan(T + 5), ["1(3)", "2(2)"])
self.assertEqual(store.get_at(T + 6, "1", T + 3), 1)
self.assertEqual(store.get_at(T + 7, "1", T + 5), 3)
def test_with_ttl(self):
store = Store[str, int]()
T = 0
store.set(T + 1, "1", 1)
store.set_with_ttl(T + 2, "2", 2, ttl=1)
self.assertEqual(store.scan(T + 3), ["1(1)"])
self.assertEqual(store.get_at(T + 4, "2", T + 1), None)
self.assertEqual(store.get_at(T + 5, "2", T + 2), 2)
self.assertEqual(store.get_at(T + 6, "2", T + 3), None)
def test_mix(self):
store = Store[str, int]()
T = 0
store.set(T + 1, "1", 1)
store.set_with_ttl(T+2, "1", 2, 10)
store.set(T+20, "1", 3)
self.assertEqual(store.get_at(T + 27, "1", T + 3), 2)
self.assertEqual(store.get_at(T + 28, "1", T + 11), 2)
self.assertEqual(store.get_at(T + 28, "1", T + 12), None)
self.assertEqual(store.get_at(T + 29, "1", T + 20), 3)
self.assertEqual(store.get_at(T + 50, "1", T + 40), 3)
def test_reduced_ttl(self):
store = Store[str, int]()
T = 0
store.set_with_ttl(T+1, "1", 1, 10)
store.set_with_ttl(T+2, "1", 2, 5)
self.assertEqual(store.get_at(T+3, "1", T+2), 2)
self.assertEqual(store.get_at(T+7, "1", T+7), None)
self.assertEqual(store.get_at(T+11, "1", T+11), None)
def test_trimming(self):
global MAX_LOG_SIZE
MAX_LOG_SIZE = 5
key = "1"
store = Store[str, int](trim=True)
for i in range(2 * MAX_LOG_SIZE):
store.set(0, key, i)
time.sleep(0.7)
self.assertLessEqual(len(store.map[key]._log), MAX_LOG_SIZE)
if __name__ == "__main__":
unittest.main()