Skip to content

Commit 020784a

Browse files
authored
fix(proxy): resolve k8s-era TC regressions (#290)
* fix(shutdown): reset global lifecycle state between app lifespans Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) * fix(proxy): bound balancer retries and bridge reader teardown Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) * test(helm): build chart dependencies before rendering Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) * fix(tests): stabilize bridge errors and postgres rate limits Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) * test(bridge): relax previous-response continuity assertions Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) * fix(proxy): clear stale sticky selections at retry cap Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) * fix(bridge): abort reconnect when reader shutdown stalls Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)
1 parent a269b37 commit 020784a

8 files changed

Lines changed: 379 additions & 78 deletions

File tree

app/core/rate_limiter/db_rate_limiter.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from datetime import UTC, datetime, timedelta
44

5-
from sqlalchemy import delete, func, select, text
5+
from sqlalchemy import delete, func, insert, literal, select, text
66
from sqlalchemy.ext.asyncio import AsyncSession
77

88
from app.core.exceptions import DashboardRateLimitError
@@ -71,22 +71,29 @@ async def check_and_increment(self, key: str, session: AsyncSession) -> None:
7171

7272
now = datetime.now(UTC)
7373
window_start = now - timedelta(seconds=self.window_seconds)
74+
key_column = RateLimitAttempt.__table__.c.key
75+
type_column = RateLimitAttempt.__table__.c.type
76+
attempted_at_column = RateLimitAttempt.__table__.c.attempted_at
7477

7578
raw_result = await session.execute(
76-
text(
77-
"INSERT INTO rate_limit_attempts (key, type, attempted_at) "
78-
"SELECT :key, :type, :now "
79-
"WHERE (SELECT COUNT(*) FROM rate_limit_attempts "
80-
" WHERE key = :key AND type = :type AND attempted_at >= :window_start"
81-
" ) < :max_attempts"
82-
),
83-
{
84-
"key": key,
85-
"type": self.type,
86-
"now": now,
87-
"window_start": window_start,
88-
"max_attempts": self.max_attempts,
89-
},
79+
insert(RateLimitAttempt).from_select(
80+
[key_column.name, type_column.name, attempted_at_column.name],
81+
select(
82+
literal(key, type_=key_column.type),
83+
literal(self.type, type_=type_column.type),
84+
literal(now, type_=attempted_at_column.type),
85+
).where(
86+
select(func.count())
87+
.select_from(RateLimitAttempt)
88+
.where(
89+
RateLimitAttempt.key == key,
90+
RateLimitAttempt.type == self.type,
91+
RateLimitAttempt.attempted_at >= window_start,
92+
)
93+
.scalar_subquery()
94+
< self.max_attempts
95+
),
96+
)
9097
)
9198
inserted = raw_result.rowcount > 0 # type: ignore[union-attr]
9299
await session.commit()

app/core/shutdown.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
_in_flight: int = 0
99

1010

11+
def reset() -> None:
12+
global _draining, _bridge_drain_active, _in_flight
13+
_draining = False
14+
_bridge_drain_active = False
15+
_in_flight = 0
16+
17+
1118
def set_draining(val: bool = True) -> None:
1219
global _draining
1320
_draining = val

app/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ async def lifespan(app: FastAPI):
9797
instance_id = None
9898

9999
startup_module._startup_complete = False
100-
shutdown_state.set_draining(False)
101-
shutdown_state.set_bridge_drain_active(False)
100+
shutdown_state.reset()
102101
await get_settings_cache().invalidate()
103102
await get_rate_limit_headers_cache().invalidate()
104103
reload_additional_quota_registry()
@@ -267,6 +266,7 @@ async def _heartbeat_only(svc: RingMembershipService, iid: str) -> None:
267266
except Exception:
268267
logger.exception("Metrics server stopped with an error")
269268
finally:
269+
shutdown_state.reset()
270270
mark_process_dead()
271271
await close_db()
272272

app/modules/proxy/load_balancer.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
logger = logging.getLogger(__name__)
4141

42+
_MAX_SELECTION_ATTEMPTS = 4
43+
4244
_STICKY_GRACE_PERIOD_SECONDS = 10.0
4345
_RECOVERABLE_STATUSES = frozenset(
4446
{
@@ -140,12 +142,13 @@ async def load_selection_inputs() -> _SelectionInputs:
140142
)
141143

142144
selected_snapshot: Account | None = None
143-
selected_version_at_pick: int = 0
144145
error_message: str | None = None
145146
selected_states: list[AccountState] = []
146147
selected_account_map: dict[str, Account] = {}
147148
if sticky_key is None:
149+
attempt = 0
148150
while True:
151+
attempt += 1
149152
self._prune_runtime(selection_inputs.accounts)
150153
states, account_map = _build_states(
151154
accounts=selection_inputs.accounts,
@@ -189,11 +192,20 @@ async def load_selection_inputs() -> _SelectionInputs:
189192
selected_snapshot.status = result.account.status
190193
selected_snapshot.deactivation_reason = result.account.deactivation_reason
191194
selected_snapshot.reset_at = selected_reset_at
192-
selected_version_at_pick = self._runtime.get(selected.id, RuntimeState()).version
193195
else:
194196
error_message = result.error_message
195197

196-
pre_persist_versions = {aid: runtime.version for aid, runtime in self._runtime.items()}
198+
pre_persist_runtime_state = {
199+
aid: (
200+
runtime.reset_at,
201+
runtime.cooldown_until,
202+
runtime.error_count,
203+
runtime.last_error_at,
204+
)
205+
for aid, runtime in self._runtime.items()
206+
}
207+
pre_persist_cache_generation = self._selection_inputs_cache.generation
208+
197209
async with self._repo_factory() as repos:
198210
stale_account_ids = await self._persist_selection_state(
199211
repos.accounts,
@@ -202,6 +214,10 @@ async def load_selection_inputs() -> _SelectionInputs:
202214
)
203215
stale_account_ids = stale_account_ids or set()
204216
if selected_snapshot is not None and selected_snapshot.id in stale_account_ids:
217+
if attempt >= _MAX_SELECTION_ATTEMPTS:
218+
selected_snapshot = None
219+
error_message = None
220+
break
205221
selection_inputs = await load_selection_inputs()
206222
if selection_inputs.error_code is not None and not selection_inputs.accounts:
207223
return AccountSelection(
@@ -215,37 +231,53 @@ async def load_selection_inputs() -> _SelectionInputs:
215231
selected_account_map = {}
216232
continue
217233

218-
if selected_snapshot is not None:
219-
_sel_runtime = self._runtime.get(selected_snapshot.id)
220-
_sel_pre_ver = selected_version_at_pick
221-
if _sel_runtime is not None and _sel_runtime.version != _sel_pre_ver:
234+
if (
235+
selected_snapshot is not None
236+
and self._selection_inputs_cache.generation != pre_persist_cache_generation
237+
and attempt < _MAX_SELECTION_ATTEMPTS
238+
):
239+
selection_inputs = await load_selection_inputs()
240+
if selection_inputs.error_code is not None and not selection_inputs.accounts:
241+
return AccountSelection(
242+
account=None,
243+
error_message=selection_inputs.error_message,
244+
error_code=selection_inputs.error_code,
245+
)
246+
selected_snapshot = None
247+
error_message = None
248+
selected_states = []
249+
selected_account_map = {}
250+
await asyncio.sleep(0)
251+
continue
252+
253+
if selected_snapshot is None and error_message == "No available accounts":
254+
runtime_recovered = any(
255+
self._runtime.get(account_id, RuntimeState()).reset_at != before[0]
256+
or self._runtime.get(account_id, RuntimeState()).cooldown_until != before[1]
257+
or self._runtime.get(account_id, RuntimeState()).error_count != before[2]
258+
or self._runtime.get(account_id, RuntimeState()).last_error_at != before[3]
259+
for account_id, before in pre_persist_runtime_state.items()
260+
)
261+
if runtime_recovered and attempt < _MAX_SELECTION_ATTEMPTS:
222262
selection_inputs = await load_selection_inputs()
223263
if selection_inputs.error_code is not None and not selection_inputs.accounts:
224264
return AccountSelection(
225265
account=None,
226266
error_message=selection_inputs.error_message,
227267
error_code=selection_inputs.error_code,
228268
)
229-
selected_snapshot = None
230269
error_message = None
231270
selected_states = []
232271
selected_account_map = {}
272+
await asyncio.sleep(0)
233273
continue
234274

235-
if selected_snapshot is None and error_message == "No available accounts":
236-
runtime_recovered = any(
237-
self._runtime.get(aid, RuntimeState()).version != pre_persist_versions.get(aid, 0)
238-
for aid in account_map
239-
)
240-
if runtime_recovered:
241-
error_message = None
242-
selected_states = []
243-
selected_account_map = {}
244-
pre_persist_versions = {aid: runtime.version for aid, runtime in self._runtime.items()}
245-
continue
246275
break
276+
247277
else:
278+
attempt = 0
248279
while True:
280+
attempt += 1
249281
self._prune_runtime(selection_inputs.accounts)
250282
states, account_map = _build_states(
251283
accounts=selection_inputs.accounts,
@@ -304,17 +336,20 @@ async def load_selection_inputs() -> _SelectionInputs:
304336
)
305337
stale_account_ids = stale_account_ids or set()
306338
if selected_snapshot is not None and selected_snapshot.id in stale_account_ids:
339+
selected_snapshot = None
340+
error_message = None
341+
selected_states = []
342+
selected_account_map = {}
343+
if attempt >= _MAX_SELECTION_ATTEMPTS:
344+
break
307345
selection_inputs = await load_selection_inputs()
308346
if selection_inputs.error_code is not None and not selection_inputs.accounts:
309347
return AccountSelection(
310348
account=None,
311349
error_message=selection_inputs.error_message,
312350
error_code=selection_inputs.error_code,
313351
)
314-
selected_snapshot = None
315-
error_message = None
316-
selected_states = []
317-
selected_account_map = {}
352+
await asyncio.sleep(0)
318353
continue
319354
break
320355

app/modules/proxy/service.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,26 @@
109109

110110
logger = logging.getLogger(__name__)
111111

112+
_TASK_CANCEL_TIMEOUT_SECONDS = 1.0
113+
114+
115+
async def _await_cancelled_task(
116+
task: asyncio.Task[object] | asyncio.Task[None],
117+
*,
118+
timeout_seconds: float = _TASK_CANCEL_TIMEOUT_SECONDS,
119+
label: str,
120+
) -> bool:
121+
task.cancel()
122+
try:
123+
await asyncio.wait_for(task, timeout=timeout_seconds)
124+
except asyncio.CancelledError:
125+
return True
126+
except TimeoutError:
127+
logger.warning("Timed out waiting for %s cancellation", label)
128+
return False
129+
return True
130+
131+
112132
_TEXT_DELTA_EVENT_TYPES = frozenset({"response.output_text.delta", "response.refusal.delta"})
113133
_TEXT_DONE_CONTENT_PART_TYPES = frozenset({"output_text", "refusal"})
114134
_REQUEST_TRANSPORT_HTTP = "http"
@@ -1007,11 +1027,7 @@ async def proxy_responses_websocket(
10071027
response_create_gate=response_create_gate,
10081028
)
10091029
if upstream_reader is not None:
1010-
upstream_reader.cancel()
1011-
try:
1012-
await upstream_reader
1013-
except asyncio.CancelledError:
1014-
pass
1030+
await _await_cancelled_task(upstream_reader, label="proxy websocket upstream reader")
10151031
upstream_reader = None
10161032
upstream_control = None
10171033
if upstream is not None:
@@ -1024,11 +1040,7 @@ async def proxy_responses_websocket(
10241040
continue
10251041
finally:
10261042
if upstream_reader is not None:
1027-
upstream_reader.cancel()
1028-
try:
1029-
await upstream_reader
1030-
except asyncio.CancelledError:
1031-
pass
1043+
await _await_cancelled_task(upstream_reader, label="proxy websocket upstream reader")
10321044
if upstream is not None:
10331045
try:
10341046
await upstream.close()
@@ -1780,11 +1792,7 @@ async def _close_http_bridge_session(
17801792
else:
17811793
await self._unregister_http_bridge_turn_states(session)
17821794
if session.upstream_reader is not None:
1783-
session.upstream_reader.cancel()
1784-
try:
1785-
await session.upstream_reader
1786-
except asyncio.CancelledError:
1787-
pass
1795+
await _await_cancelled_task(session.upstream_reader, label="http bridge upstream reader")
17881796
try:
17891797
await session.upstream.close()
17901798
except Exception:
@@ -2319,12 +2327,17 @@ async def _reconnect_http_bridge_session(
23192327
old_upstream = session.upstream
23202328
old_reader = session.upstream_reader if restart_reader else None
23212329
if old_reader is not None:
2322-
old_reader.cancel()
23232330
if old_reader is not asyncio.current_task():
2324-
try:
2325-
await old_reader
2326-
except asyncio.CancelledError:
2327-
pass
2331+
cancelled = await _await_cancelled_task(old_reader, label="http bridge upstream reader")
2332+
if not cancelled:
2333+
session.closed = True
2334+
raise ProxyResponseError(
2335+
502,
2336+
openai_error(
2337+
"upstream_unavailable",
2338+
"HTTP responses session bridge reader did not shut down cleanly",
2339+
),
2340+
)
23282341
try:
23292342
await old_upstream.close()
23302343
except Exception:

0 commit comments

Comments
 (0)