Press n or j to go to the next uncovered block, b, p or k for the previous block.
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 | 6x 6x 6x 6x 6x 6x 6x 6x 33x 33x 33x 33x 33x 33x 33x 33x 33x 33x 33x 2x 2x 33x 63x 63x 63x 63x 61x 61x 61x 2x 2x 1x 7x 12x 12x 12x 12x 11x 10x 10x 50x 50x 6x 6x 2x 7x 100x 1x 15x 15x 1x 1x 14x 14x 14x 14x 14x 14x 14x 14x 14x 14x 28x 28x 14x 14x 14x 23x 23x 1x 1x 22x 22x 22x 22x 13x 13x 9x 4x 5x 1x 9x 9x 7x 2x 2x 1x 39x 39x 1x 1x 38x 38x 38x 33x 33x 22x 22x 22x 22x 22x 22x 20x 2x 2x 22x 11x 5x 16x 16x | import { randomUUID } from 'crypto';
import { Request, Response } from 'express';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { Injectable, Logger, OnModuleDestroy, OnModuleInit } from '@nestjs/common';
import { ConfigService } from '../config/config.service';
import { RedactionService } from '../redaction/redaction.service';
import { MCPClientWrapper } from './mcp-client-wrapper';
export interface IMCPServerInstance {
name: string;
server: Server;
clientWrapper: MCPClientWrapper;
transports: Map<string, SSEServerTransport | StreamableHTTPServerTransport>;
serverType: 'sse' | 'streamable-http' | 'stdio';
}
export type MCPServerInstance = IMCPServerInstance;
@Injectable()
export class MCPServerService implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(MCPServerService.name);
private servers: Map<string, MCPServerInstance> = new Map();
constructor(
private configService: ConfigService,
private redactionService: RedactionService
) {}
async onModuleInit() {
const config = this.configService.getConfig();
// In stdio mode, initialize only the selected downstream target to avoid slow startup
const proxyType = config.mcpProxy.type || 'sse';
const argv = process.argv;
const targetFlagIdx = argv.indexOf('--stdio-target');
const targetFromArg = targetFlagIdx >= 0 ? argv[targetFlagIdx + 1] : undefined;
const allEntries = Object.entries(config.mcpServers);
const entriesToInit = proxyType === 'stdio' ?
(() => {
Iif (targetFromArg) {
return allEntries.filter(([ n ]) => n === targetFromArg);
}
// If no explicit target, initialize all. main.ts enforces correctness when running in stdio mode.
return allEntries;
})() :
allEntries;
// Initialize MCP clients and create proxy servers
for (const [ name, clientConfig ] of entriesToInit) {
try {
this.logger.log(`<${ name }> Initializing client...`);
const clientWrapper = new MCPClientWrapper(name, clientConfig, this.redactionService);
await clientWrapper.initialize();
const server = await clientWrapper.getServer();
this.servers.set(name, {
name,
server,
clientWrapper,
transports: new Map(),
serverType: config.mcpProxy.type || 'sse'
});
this.logger.log(`<${ name }> Client initialized successfully`);
} catch (error) {
this.logger.error([
'<',
name,
'> Failed to initialize client: ',
String(error)
].join(''));
if (clientConfig.options?.panicIfInvalid) {
throw error;
}
}
}
}
async onModuleDestroy() {
for (const [ name, instance ] of this.servers) {
try {
this.logger.log([ '<', name, '> Shutting down...' ].join(''));
// Close all transports
for (const transport of instance.transports.values()) {
await transport.close();
}
await instance.clientWrapper.close();
// Close sdk Server if it supports lifecycle methods
const srv = instance.server as unknown as Record<string, unknown>;
for (const method of [ 'close', 'shutdown', 'stop', 'dispose', 'terminate' ]) {
const candidate = srv?.[method];
if (typeof candidate === 'function') {
try {
await (candidate as () => Promise<void> | void)();
} catch {}
}
}
} catch (error) {
this.logger.error([
'<',
name,
'> Error during shutdown: ',
String(error)
].join(''));
}
}
this.servers.clear();
}
getServer(name: string): MCPServerInstance | undefined {
return this.servers.get(name);
}
getAllServers(): Map<string, MCPServerInstance> {
return this.servers;
}
async handleSSERequest(
name: string,
req: Request,
res: Response
): Promise<void> {
const instance = this.getServer(name);
if (!instance) {
res.status(404).json({ error: 'MCP server not found' });
return;
}
// Derive baseURL from the incoming request instead of config
// This ensures it works correctly in test environments with dynamic ports
const protocol = (req as any).protocol || 'http';
const hostHeader = typeof (req as any).get === 'function' ?
(req as any).get('host') :
(req.headers?.host as string | undefined);
const host = hostHeader || 'localhost';
const baseURL = `${ protocol }://${ host }`;
const endpoint = `${ baseURL }/${ name }/message`;
const transport = new SSEServerTransport(endpoint, res);
// Do not call start() here; Server.connect() will start the transport
await instance.server.connect(transport);
// Register the transport after connect so the sessionId matches the one used by the SDK
const sessionId = transport.sessionId;
instance.transports.set(sessionId, transport);
// Handle cleanup
const cleanup = () => {
instance.transports.delete(sessionId);
transport.close();
};
req.on('close', cleanup);
(res as any).on?.('close', cleanup);
(res as any).on?.('finish', cleanup);
}
async handleStreamableHTTPRequest(
name: string,
req: Request,
res: Response
): Promise<void> {
const instance = this.getServer(name);
if (!instance) {
res.status(404).json({ error: 'MCP server not found' });
return;
}
// Stateful handling: reuse transport by sessionId when provided
const sessionIdHeader =
(req.headers['mcp-session-id'] as string) ||
(req.headers['mcp-session-id'.toLowerCase()] as string) ||
(req.query.sessionId as string);
const existing = sessionIdHeader ? instance.transports.get(sessionIdHeader) : undefined;
try {
if (existing && existing instanceof StreamableHTTPServerTransport) {
await existing.handleRequest(req as any, res, req.body);
return;
}
// Create a new stateful Streamable HTTP server transport
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
enableJsonResponse: false,
onsessioninitialized: async (sid: string) => {
instance.transports.set(sid, transport);
},
onsessionclosed: async (sid: string) => {
instance.transports.delete(sid);
}
});
// Connect this transport to the shared server instance
await instance.server.connect(transport);
// Handle the request (GET/POST/DELETE)
await transport.handleRequest(req as any, res, req.body);
return;
} catch (error) {
this.logger.error([
'Error in streamable HTTP handler for ',
name,
': ',
String(error)
].join(''));
if (!res.headersSent) {
res.status(500).json({ error: 'Internal server error' });
}
}
}
async handlePostMessage(
name: string,
req: Request,
res: Response
): Promise<void> {
const instance = this.getServer(name);
if (!instance) {
res.status(404).json({ error: 'MCP server not found' });
return;
}
// Get session ID from header or query
const sessionId = req.headers['mcp-session-id'] as string || req.query.sessionId as string;
this.logger.debug([
'<',
name,
'> handlePostMessage - sessionId: ',
String(sessionId),
', transports: ',
Array.from(instance.transports.keys()).join(', ')
].join(''));
if (sessionId) {
const transport = instance.transports.get(sessionId);
if (transport && 'handlePostMessage' in transport) {
this.logger.debug([
'<',
name,
'> Found SSE transport for session ',
String(sessionId),
', delegating to handlePostMessage'
].join(''));
// Forward the original request/response to the SSE transport. If the request stream
// is not readable due to body parsing, prefer passing parsed body for JSON requests;
// otherwise pass rawBody (when enabled), else fall back to streaming.
const contentType =
(req.headers['content-type'] as string | undefined) ||
(req.headers['Content-Type'] as unknown as string | undefined);
const isJson = typeof contentType === 'string' && contentType.includes('application/json');
const hasParsedBody = typeof (req as any).body !== 'undefined' && (req as any).body !== null;
const rawBody = (req as any).rawBody;
if (isJson && hasParsedBody) {
await (transport as SSEServerTransport).handlePostMessage(req as any, res, (req as any).body);
} else Iif (typeof rawBody !== 'undefined') {
await (transport as SSEServerTransport).handlePostMessage(req as any, res, rawBody);
} else {
await (transport as SSEServerTransport).handlePostMessage(req as any, res);
}
return;
} else {
this.logger.warn([
'<',
name,
'> SessionId ',
String(sessionId),
' provided but no matching transport found or transport doesn\'t support handlePostMessage'
].join(''));
}
} else {
this.logger.warn([ '<', name, '> No sessionId provided in POST message request' ].join(''));
}
// If no session ID or transport doesn't support handlePostMessage,
// fall back to streamable HTTP handler
this.logger.debug([ '<', name, '> Falling back to streamable HTTP handler' ].join(''));
await this.handleStreamableHTTPRequest(name, req, res);
}
}
|