Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions memcache/async_memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,37 @@ async def set(
)
await self.execute_meta_command(command)

async def cas(
self,
key: Union[bytes, str],
value: Any,
cas_token: int,
*,
expire: Optional[int] = None,
) -> None:
"""
Store a value using compare-and-swap operation.

:param key: The key to store
:param value: The value to store
:param cas_token: The CAS token from a previous gets operation
:param expire: Optional expiration time in seconds
:raises MemcacheError: If the CAS token doesn't match or other error occurs
"""
value, client_flags = self._dump(key, value)

flags = [b"F%d" % client_flags, b"C%d" % cas_token]
if expire:
flags.append(b"T%d" % expire)

command = MetaCommand(
cm=b"ms", key=key, datalen=len(value), flags=flags, value=value
)
result = await self.execute_meta_command(command)

if result.rc != b"HD":
raise MemcacheError("CAS operation failed: token mismatch or other error")

async def get(self, key: Union[bytes, str]) -> Optional[Any]:
command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f"])
result = await self.execute_meta_command(command)
Expand All @@ -112,6 +143,34 @@ async def get(self, key: Union[bytes, str]) -> Optional[Any]:

return self._load(key, result.value, client_flags)

async def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]:
"""
Get a value and its CAS token from memcached.

:param key: The key to retrieve
:return: A tuple of (value, cas_token) or None if key doesn't exist
"""
command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f", b"c"])
result = await self.execute_meta_command(command)

if result.value is None:
return None

client_flags = int(result.flags[0][1:])
value = self._load(key, result.value, client_flags)

# Find CAS token in flags
cas_token = None
for flag in result.flags[1:]: # Skip the first flag (client_flags)
if flag.startswith(b"c"):
cas_token = int(flag[1:])
break

if cas_token is None:
raise MemcacheError("CAS token not found in response")

return value, cas_token

async def delete(self, key: Union[bytes, str]) -> None:
command = MetaCommand(cm=b"md", key=key, flags=[], value=None)
await self.execute_meta_command(command)
Expand Down Expand Up @@ -252,6 +311,36 @@ async def get(self, key: Union[bytes, str]) -> Optional[Any]:
async with self._get_connection(key) as connection:
return await connection.get(key)

async def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]:
"""
Get a value and its CAS token from memcached.

:param key: The key to retrieve
:return: A tuple of (value, cas_token) or None if key doesn't exist
"""
async with self._get_connection(key) as connection:
return await connection.gets(key)

async def cas(
self,
key: Union[bytes, str],
value: Any,
cas_token: int,
*,
expire: Optional[int] = None,
) -> None:
"""
Store a value using compare-and-swap operation.

:param key: The key to store
:param value: The value to store
:param cas_token: The CAS token from a previous gets operation
:param expire: Optional expiration time in seconds
:raises MemcacheError: If the CAS token doesn't match or other error occurs
"""
async with self._get_connection(key) as connection:
await connection.cas(key, value, cas_token, expire=expire)

async def delete(self, key: Union[bytes, str]) -> None:
async with self._get_connection(key) as connection:
return await connection.delete(key)
90 changes: 90 additions & 0 deletions memcache/memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,37 @@ def set(
)
self.execute_meta_command(command)

def cas(
self,
key: Union[bytes, str],
value: Any,
cas_token: int,
*,
expire: Optional[int] = None,
) -> None:
"""
Store a value using compare-and-swap operation.

:param key: The key to store
:param value: The value to store
:param cas_token: The CAS token from a previous gets operation
:param expire: Optional expiration time in seconds
:raises MemcacheError: If the CAS token doesn't match or other error occurs
"""
value, client_flags = self._dump(key, value)

flags = [b"F%d" % client_flags, b"C%d" % cas_token]
if expire:
flags.append(b"T%d" % expire)

command = MetaCommand(
cm=b"ms", key=key, datalen=len(value), flags=flags, value=value
)
result = self.execute_meta_command(command)

if result.rc != b"HD":
raise MemcacheError("CAS operation failed: token mismatch or other error")

def get(self, key: Union[bytes, str]) -> Optional[Any]:
command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f"])
result = self.execute_meta_command(command)
Expand All @@ -114,6 +145,34 @@ def get(self, key: Union[bytes, str]) -> Optional[Any]:

return self._load(key, result.value, client_flags)

def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]:
"""
Get a value and its CAS token from memcached.

:param key: The key to retrieve
:return: A tuple of (value, cas_token) or None if key doesn't exist
"""
command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f", b"c"])
result = self.execute_meta_command(command)

if result.value is None:
return None

client_flags = int(result.flags[0][1:])
value = self._load(key, result.value, client_flags)

# Find CAS token in flags
cas_token = None
for flag in result.flags[1:]: # Skip the first flag (client_flags)
if flag.startswith(b"c"):
cas_token = int(flag[1:])
break

if cas_token is None:
raise MemcacheError("CAS token not found in response")

return value, cas_token

def delete(self, key: Union[bytes, str]) -> None:
command = MetaCommand(cm=b"md", key=key, flags=[], value=None)
self.execute_meta_command(command)
Expand Down Expand Up @@ -181,6 +240,7 @@ class Memcache:
:param username: Memcached ASCII protocol authentication username.
:param password: Memcached ASCII protocol authentication password.
"""

def __init__(
self,
addr: Union[Addr, List[Addr], None] = None,
Expand Down Expand Up @@ -250,6 +310,36 @@ def get(self, key: Union[bytes, str]) -> Optional[Any]:
with self._get_connection(key) as connection:
return connection.get(key)

def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]:
"""
Get a value and its CAS token from memcached.

:param key: The key to retrieve
:return: A tuple of (value, cas_token) or None if key doesn't exist
"""
with self._get_connection(key) as connection:
return connection.gets(key)

def cas(
self,
key: Union[bytes, str],
value: Any,
cas_token: int,
*,
expire: Optional[int] = None,
) -> None:
"""
Store a value using compare-and-swap operation.

:param key: The key to store
:param value: The value to store
:param cas_token: The CAS token from a previous gets operation
:param expire: Optional expiration time in seconds
:raises MemcacheError: If the CAS token doesn't match or other error occurs
"""
with self._get_connection(key) as connection:
connection.cas(key, value, cas_token, expire=expire)

def delete(self, key: Union[bytes, str]) -> None:
with self._get_connection(key) as connection:
return connection.delete(key)
52 changes: 52 additions & 0 deletions tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,55 @@ async def test_pool_timeout():
assert time.time() - start > 1
else:
raise ValueError("empty not raised")


@pytest.mark.asyncio
async def test_gets(client):
await client.set("test_key", "test_value")
result = await client.gets("test_key")
assert result is not None
value, cas_token = result
assert value == "test_value"
assert isinstance(cas_token, int)
assert cas_token > 0


@pytest.mark.asyncio
async def test_gets_missing_key(client):
await client.delete("nonexistent_key")
assert await client.gets("nonexistent_key") is None


@pytest.mark.asyncio
async def test_cas_success(client):
await client.set("cas_key", "initial_value")
_, cas_token = await client.gets("cas_key")

await client.cas("cas_key", "updated_value", cas_token)
assert await client.get("cas_key") == "updated_value"


@pytest.mark.asyncio
async def test_cas_failure(client):
await client.set("cas_key", "initial_value")
_, cas_token = await client.gets("cas_key")

# Modify the value outside of CAS
await client.set("cas_key", "modified_value")

# CAS should fail with old token
with pytest.raises(memcache.MemcacheError):
await client.cas("cas_key", "updated_value", cas_token)
assert await client.get("cas_key") == "modified_value"


@pytest.mark.asyncio
async def test_cas_with_expire(client):
await client.set("cas_expire_key", "initial_value")
_, cas_token = await client.gets("cas_expire_key")

await client.cas("cas_expire_key", "updated_value", cas_token, expire=1)
assert await client.get("cas_expire_key") == "updated_value"

await asyncio.sleep(1.1)
assert await client.get("cas_expire_key") is None
47 changes: 47 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,50 @@ def test_pool_timeout():
assert time.time() - start > 1
else:
raise ValueError("empty not raised")


def test_gets(client):
client.set("test_key", "test_value")
result = client.gets("test_key")
assert result is not None
value, cas_token = result
assert value == "test_value"
assert isinstance(cas_token, int)
assert cas_token > 0


def test_gets_missing_key(client):
client.delete("nonexistent_key")
assert client.gets("nonexistent_key") is None


def test_cas_success(client):
client.set("cas_key", "initial_value")
_, cas_token = client.gets("cas_key")

client.cas("cas_key", "updated_value", cas_token)
assert client.get("cas_key") == "updated_value"


def test_cas_failure(client):
client.set("cas_key", "initial_value")
_, cas_token = client.gets("cas_key")

# Modify the value outside of CAS
client.set("cas_key", "modified_value")

# CAS should fail with old token
with pytest.raises(memcache.MemcacheError):
client.cas("cas_key", "updated_value", cas_token)
assert client.get("cas_key") == "modified_value"


def test_cas_with_expire(client):
client.set("cas_expire_key", "initial_value")
_, cas_token = client.gets("cas_expire_key")

client.cas("cas_expire_key", "updated_value", cas_token, expire=1)
assert client.get("cas_expire_key") == "updated_value"

time.sleep(1.1)
assert client.get("cas_expire_key") is None