refactor: 补齐草稿与SSE收口

This commit is contained in:
2026-06-04 06:26:09 +08:00
parent c93b8fb570
commit bbb9269bab
8 changed files with 433 additions and 76 deletions

View File

@@ -0,0 +1,51 @@
import { afterEach, describe, expect, it, vi } from 'vitest';
import { streamPlainTextCompletion } from './llmClient';
function createSseResponse(body: string) {
const encoder = new TextEncoder();
const stream = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(encoder.encode(body));
controller.close();
},
});
return new Response(stream, {
headers: {
'Content-Type': 'text/event-stream; charset=utf-8',
},
});
}
describe('llmClient streamPlainTextCompletion', () => {
afterEach(() => {
vi.unstubAllGlobals();
vi.restoreAllMocks();
});
it('reads OpenAI compatible SSE through the shared stream reader', async () => {
const onUpdate = vi.fn();
const fetchMock = vi.fn().mockResolvedValue(
createSseResponse(
[
'data: {"choices":[{"delta":{"content":"溪上"}}]}\r\n\r\n',
'data: not-json\r\n\r\n',
'data: {"choices":[{"delta":{"content":"春风"}}]}\r\n\r\n',
'data: [DONE]\r\n\r\n',
'data: {"choices":[{"delta":{"content":"不应读取"}}]}\r\n\r\n',
].join(''),
),
);
vi.stubGlobal('fetch', fetchMock);
const result = await streamPlainTextCompletion('system', 'user', {
onUpdate,
});
expect(result).toBe('溪上春风');
expect(onUpdate).toHaveBeenNthCalledWith(1, '溪上');
expect(onUpdate).toHaveBeenNthCalledWith(2, '溪上春风');
expect(onUpdate).toHaveBeenCalledTimes(2);
});
});

View File

@@ -1,5 +1,6 @@
import type {TextStreamOptions} from './aiTypes';
import { fetchWithApiAuth } from './apiClient';
import { parseSseJsonObject, readSseStream } from './sseStream';
const ENV: Partial<ImportMetaEnv> = import.meta.env ?? {};
@@ -44,6 +45,26 @@ function resolveHeaders(headers?: HeadersInit) {
return nextHeaders;
}
function readLlmStreamDeltaContent(parsed: Record<string, unknown>) {
const choices = parsed.choices;
if (!Array.isArray(choices)) {
return null;
}
const [firstChoice] = choices;
if (typeof firstChoice !== 'object' || firstChoice === null) {
return null;
}
const delta = (firstChoice as {delta?: unknown}).delta;
if (typeof delta !== 'object' || delta === null) {
return null;
}
const content = (delta as {content?: unknown}).content;
return typeof content === 'string' && content.length > 0 ? content : null;
}
const NODE_ENV = getNodeEnv();
const IS_SERVER_RUNTIME = typeof window === 'undefined';
const SERVER_API_KEY =
@@ -291,48 +312,20 @@ export async function streamPlainTextCompletion(
return fallbackText;
}
const reader = response.body.getReader();
const decoder = new TextDecoder('utf-8');
let buffer = '';
let accumulatedText = '';
for (;;) {
const {done, value} = await reader.read();
if (done) {
break;
await readSseStream(response, ({ data }) => {
if (data === '[DONE]') {
return false;
}
buffer += decoder.decode(value, {stream: true});
while (buffer.includes('\n\n')) {
const boundary = buffer.indexOf('\n\n');
const eventBlock = buffer.slice(0, boundary);
buffer = buffer.slice(boundary + 2);
for (const rawLine of eventBlock.split(/\r?\n/u)) {
const line = rawLine.trim();
if (!line.startsWith('data:')) {
continue;
}
const data = line.slice(5).trim();
if (!data || data === '[DONE]') {
continue;
}
try {
const parsed = JSON.parse(data);
const delta = parsed?.choices?.[0]?.delta?.content;
if (typeof delta === 'string' && delta.length > 0) {
accumulatedText += delta;
options.onUpdate?.(accumulatedText);
}
} catch {
// Ignore malformed SSE frames and continue consuming the stream.
}
}
const parsed = parseSseJsonObject(data);
const delta = parsed ? readLlmStreamDeltaContent(parsed) : null;
if (delta) {
accumulatedText += delta;
options.onUpdate?.(accumulatedText);
}
}
});
return accumulatedText.trim();
} catch (error) {