diff --git a/memcache/async_memcache.py b/memcache/async_memcache.py index 46bf10b..9a0f57d 100644 --- a/memcache/async_memcache.py +++ b/memcache/async_memcache.py @@ -101,6 +101,37 @@ async def set( ) await self.execute_meta_command(command) + async def cas( + self, + key: Union[bytes, str], + value: Any, + cas_token: int, + *, + expire: Optional[int] = None, + ) -> None: + """ + Store a value using compare-and-swap operation. + + :param key: The key to store + :param value: The value to store + :param cas_token: The CAS token from a previous gets operation + :param expire: Optional expiration time in seconds + :raises MemcacheError: If the CAS token doesn't match or other error occurs + """ + value, client_flags = self._dump(key, value) + + flags = [b"F%d" % client_flags, b"C%d" % cas_token] + if expire: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", key=key, datalen=len(value), flags=flags, value=value + ) + result = await self.execute_meta_command(command) + + if result.rc != b"HD": + raise MemcacheError("CAS operation failed: token mismatch or other error") + async def get(self, key: Union[bytes, str]) -> Optional[Any]: command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f"]) result = await self.execute_meta_command(command) @@ -112,6 +143,34 @@ async def get(self, key: Union[bytes, str]) -> Optional[Any]: return self._load(key, result.value, client_flags) + async def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: + """ + Get a value and its CAS token from memcached. + + :param key: The key to retrieve + :return: A tuple of (value, cas_token) or None if key doesn't exist + """ + command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f", b"c"]) + result = await self.execute_meta_command(command) + + if result.value is None: + return None + + client_flags = int(result.flags[0][1:]) + value = self._load(key, result.value, client_flags) + + # Find CAS token in flags + cas_token = None + for flag in result.flags[1:]: # Skip the first flag (client_flags) + if flag.startswith(b"c"): + cas_token = int(flag[1:]) + break + + if cas_token is None: + raise MemcacheError("CAS token not found in response") + + return value, cas_token + async def delete(self, key: Union[bytes, str]) -> None: command = MetaCommand(cm=b"md", key=key, flags=[], value=None) await self.execute_meta_command(command) @@ -252,6 +311,36 @@ async def get(self, key: Union[bytes, str]) -> Optional[Any]: async with self._get_connection(key) as connection: return await connection.get(key) + async def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: + """ + Get a value and its CAS token from memcached. + + :param key: The key to retrieve + :return: A tuple of (value, cas_token) or None if key doesn't exist + """ + async with self._get_connection(key) as connection: + return await connection.gets(key) + + async def cas( + self, + key: Union[bytes, str], + value: Any, + cas_token: int, + *, + expire: Optional[int] = None, + ) -> None: + """ + Store a value using compare-and-swap operation. + + :param key: The key to store + :param value: The value to store + :param cas_token: The CAS token from a previous gets operation + :param expire: Optional expiration time in seconds + :raises MemcacheError: If the CAS token doesn't match or other error occurs + """ + async with self._get_connection(key) as connection: + await connection.cas(key, value, cas_token, expire=expire) + async def delete(self, key: Union[bytes, str]) -> None: async with self._get_connection(key) as connection: return await connection.delete(key) diff --git a/memcache/memcache.py b/memcache/memcache.py index 71dee76..c9c3604 100644 --- a/memcache/memcache.py +++ b/memcache/memcache.py @@ -103,6 +103,37 @@ def set( ) self.execute_meta_command(command) + def cas( + self, + key: Union[bytes, str], + value: Any, + cas_token: int, + *, + expire: Optional[int] = None, + ) -> None: + """ + Store a value using compare-and-swap operation. + + :param key: The key to store + :param value: The value to store + :param cas_token: The CAS token from a previous gets operation + :param expire: Optional expiration time in seconds + :raises MemcacheError: If the CAS token doesn't match or other error occurs + """ + value, client_flags = self._dump(key, value) + + flags = [b"F%d" % client_flags, b"C%d" % cas_token] + if expire: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", key=key, datalen=len(value), flags=flags, value=value + ) + result = self.execute_meta_command(command) + + if result.rc != b"HD": + raise MemcacheError("CAS operation failed: token mismatch or other error") + def get(self, key: Union[bytes, str]) -> Optional[Any]: command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f"]) result = self.execute_meta_command(command) @@ -114,6 +145,34 @@ def get(self, key: Union[bytes, str]) -> Optional[Any]: return self._load(key, result.value, client_flags) + def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: + """ + Get a value and its CAS token from memcached. + + :param key: The key to retrieve + :return: A tuple of (value, cas_token) or None if key doesn't exist + """ + command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f", b"c"]) + result = self.execute_meta_command(command) + + if result.value is None: + return None + + client_flags = int(result.flags[0][1:]) + value = self._load(key, result.value, client_flags) + + # Find CAS token in flags + cas_token = None + for flag in result.flags[1:]: # Skip the first flag (client_flags) + if flag.startswith(b"c"): + cas_token = int(flag[1:]) + break + + if cas_token is None: + raise MemcacheError("CAS token not found in response") + + return value, cas_token + def delete(self, key: Union[bytes, str]) -> None: command = MetaCommand(cm=b"md", key=key, flags=[], value=None) self.execute_meta_command(command) @@ -181,6 +240,7 @@ class Memcache: :param username: Memcached ASCII protocol authentication username. :param password: Memcached ASCII protocol authentication password. """ + def __init__( self, addr: Union[Addr, List[Addr], None] = None, @@ -250,6 +310,36 @@ def get(self, key: Union[bytes, str]) -> Optional[Any]: with self._get_connection(key) as connection: return connection.get(key) + def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: + """ + Get a value and its CAS token from memcached. + + :param key: The key to retrieve + :return: A tuple of (value, cas_token) or None if key doesn't exist + """ + with self._get_connection(key) as connection: + return connection.gets(key) + + def cas( + self, + key: Union[bytes, str], + value: Any, + cas_token: int, + *, + expire: Optional[int] = None, + ) -> None: + """ + Store a value using compare-and-swap operation. + + :param key: The key to store + :param value: The value to store + :param cas_token: The CAS token from a previous gets operation + :param expire: Optional expiration time in seconds + :raises MemcacheError: If the CAS token doesn't match or other error occurs + """ + with self._get_connection(key) as connection: + connection.cas(key, value, cas_token, expire=expire) + def delete(self, key: Union[bytes, str]) -> None: with self._get_connection(key) as connection: return connection.delete(key) diff --git a/tests/test_async_client.py b/tests/test_async_client.py index c828d06..50c8f37 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -59,3 +59,55 @@ async def test_pool_timeout(): assert time.time() - start > 1 else: raise ValueError("empty not raised") + + +@pytest.mark.asyncio +async def test_gets(client): + await client.set("test_key", "test_value") + result = await client.gets("test_key") + assert result is not None + value, cas_token = result + assert value == "test_value" + assert isinstance(cas_token, int) + assert cas_token > 0 + + +@pytest.mark.asyncio +async def test_gets_missing_key(client): + await client.delete("nonexistent_key") + assert await client.gets("nonexistent_key") is None + + +@pytest.mark.asyncio +async def test_cas_success(client): + await client.set("cas_key", "initial_value") + _, cas_token = await client.gets("cas_key") + + await client.cas("cas_key", "updated_value", cas_token) + assert await client.get("cas_key") == "updated_value" + + +@pytest.mark.asyncio +async def test_cas_failure(client): + await client.set("cas_key", "initial_value") + _, cas_token = await client.gets("cas_key") + + # Modify the value outside of CAS + await client.set("cas_key", "modified_value") + + # CAS should fail with old token + with pytest.raises(memcache.MemcacheError): + await client.cas("cas_key", "updated_value", cas_token) + assert await client.get("cas_key") == "modified_value" + + +@pytest.mark.asyncio +async def test_cas_with_expire(client): + await client.set("cas_expire_key", "initial_value") + _, cas_token = await client.gets("cas_expire_key") + + await client.cas("cas_expire_key", "updated_value", cas_token, expire=1) + assert await client.get("cas_expire_key") == "updated_value" + + await asyncio.sleep(1.1) + assert await client.get("cas_expire_key") is None diff --git a/tests/test_client.py b/tests/test_client.py index fddcdb5..2866b24 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -55,3 +55,50 @@ def test_pool_timeout(): assert time.time() - start > 1 else: raise ValueError("empty not raised") + + +def test_gets(client): + client.set("test_key", "test_value") + result = client.gets("test_key") + assert result is not None + value, cas_token = result + assert value == "test_value" + assert isinstance(cas_token, int) + assert cas_token > 0 + + +def test_gets_missing_key(client): + client.delete("nonexistent_key") + assert client.gets("nonexistent_key") is None + + +def test_cas_success(client): + client.set("cas_key", "initial_value") + _, cas_token = client.gets("cas_key") + + client.cas("cas_key", "updated_value", cas_token) + assert client.get("cas_key") == "updated_value" + + +def test_cas_failure(client): + client.set("cas_key", "initial_value") + _, cas_token = client.gets("cas_key") + + # Modify the value outside of CAS + client.set("cas_key", "modified_value") + + # CAS should fail with old token + with pytest.raises(memcache.MemcacheError): + client.cas("cas_key", "updated_value", cas_token) + assert client.get("cas_key") == "modified_value" + + +def test_cas_with_expire(client): + client.set("cas_expire_key", "initial_value") + _, cas_token = client.gets("cas_expire_key") + + client.cas("cas_expire_key", "updated_value", cas_token, expire=1) + assert client.get("cas_expire_key") == "updated_value" + + time.sleep(1.1) + assert client.get("cas_expire_key") is None