All files / src/mcp mcp-server.service.ts

97.39% Statements 112/115
94.73% Branches 54/57
92.85% Functions 13/14
98.21% Lines 110/112

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 2896x       6x 6x   6x   6x 6x 6x                       6x 33x 33x     33x 33x       33x     33x 33x 33x 33x 33x 33x   2x       2x         33x 63x 63x   63x 63x   61x   61x               61x   2x             2x 1x             7x 12x 12x   12x 12x   11x   10x 10x 50x 50x 6x 6x         2x               7x       100x       1x               15x 15x 1x 1x         14x 14x     14x 14x 14x   14x   14x     14x 14x     14x 28x 28x   14x 14x 14x               23x 23x 1x 1x         22x     22x   22x 22x 13x 13x       9x 4x     5x     1x         9x     9x 7x   2x           2x 1x                   39x 39x 1x 1x       38x   38x                 38x 33x 33x 22x                     22x   22x 22x 22x 22x 20x 2x     2x   22x   11x                 5x         16x 16x        
import { randomUUID } from 'crypto';
 
import { Request, Response } from 'express';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
 
import { Injectable, Logger, OnModuleDestroy, OnModuleInit } from '@nestjs/common';
 
import { ConfigService } from '../config/config.service';
import { RedactionService } from '../redaction/redaction.service';
import { MCPClientWrapper } from './mcp-client-wrapper';
 
export interface IMCPServerInstance {
  name: string;
  server: Server;
  clientWrapper: MCPClientWrapper;
  transports: Map<string, SSEServerTransport | StreamableHTTPServerTransport>;
  serverType: 'sse' | 'streamable-http' | 'stdio';
}
export type MCPServerInstance = IMCPServerInstance;
 
@Injectable()
export class MCPServerService implements OnModuleInit, OnModuleDestroy {
  private readonly logger = new Logger(MCPServerService.name);
  private servers: Map<string, MCPServerInstance> = new Map();
 
  constructor(
    private configService: ConfigService,
    private redactionService: RedactionService
  ) {}
 
  async onModuleInit() {
    const config = this.configService.getConfig();
 
    // In stdio mode, initialize only the selected downstream target to avoid slow startup
    const proxyType = config.mcpProxy.type || 'sse';
    const argv = process.argv;
    const targetFlagIdx = argv.indexOf('--stdio-target');
    const targetFromArg = targetFlagIdx >= 0 ? argv[targetFlagIdx + 1] : undefined;
    const allEntries = Object.entries(config.mcpServers);
    const entriesToInit = proxyType === 'stdio' ?
      (() => {
        Iif (targetFromArg) {
          return allEntries.filter(([ n ]) => n === targetFromArg);
        }
        // If no explicit target, initialize all. main.ts enforces correctness when running in stdio mode.
        return allEntries;
      })() :
      allEntries;
 
    // Initialize MCP clients and create proxy servers
    for (const [ name, clientConfig ] of entriesToInit) {
      try {
        this.logger.log(`<${ name }> Initializing client...`);
 
        const clientWrapper = new MCPClientWrapper(name, clientConfig, this.redactionService);
        await clientWrapper.initialize();
 
        const server = await clientWrapper.getServer();
 
        this.servers.set(name, {
          name,
          server,
          clientWrapper,
          transports: new Map(),
          serverType: config.mcpProxy.type || 'sse'
        });
 
        this.logger.log(`<${ name }> Client initialized successfully`);
      } catch (error) {
        this.logger.error([
          '<',
          name,
          '> Failed to initialize client: ',
          String(error)
        ].join(''));
 
        if (clientConfig.options?.panicIfInvalid) {
          throw error;
        }
      }
    }
  }
 
  async onModuleDestroy() {
    for (const [ name, instance ] of this.servers) {
      try {
        this.logger.log([ '<', name, '> Shutting down...' ].join(''));
        // Close all transports
        for (const transport of instance.transports.values()) {
          await transport.close();
        }
        await instance.clientWrapper.close();
        // Close sdk Server if it supports lifecycle methods
        const srv = instance.server as unknown as Record<string, unknown>;
        for (const method of [ 'close', 'shutdown', 'stop', 'dispose', 'terminate' ]) {
          const candidate = srv?.[method];
          if (typeof candidate === 'function') {
            try {
              await (candidate as () => Promise<void> | void)();
            } catch {}
          }
        }
      } catch (error) {
        this.logger.error([
          '<',
          name,
          '> Error during shutdown: ',
          String(error)
        ].join(''));
      }
    }
    this.servers.clear();
  }
 
  getServer(name: string): MCPServerInstance | undefined {
    return this.servers.get(name);
  }
 
  getAllServers(): Map<string, MCPServerInstance> {
    return this.servers;
  }
 
  async handleSSERequest(
    name: string,
    req: Request,
    res: Response
  ): Promise<void> {
    const instance = this.getServer(name);
    if (!instance) {
      res.status(404).json({ error: 'MCP server not found' });
      return;
    }
 
    // Derive baseURL from the incoming request instead of config
    // This ensures it works correctly in test environments with dynamic ports
    const protocol = (req as any).protocol || 'http';
    const hostHeader = typeof (req as any).get === 'function' ?
      (req as any).get('host') :
      (req.headers?.host as string | undefined);
    const host = hostHeader || 'localhost';
    const baseURL = `${ protocol }://${ host }`;
    const endpoint = `${ baseURL }/${ name }/message`;
 
    const transport = new SSEServerTransport(endpoint, res);
    // Do not call start() here; Server.connect() will start the transport
    await instance.server.connect(transport);
 
    // Register the transport after connect so the sessionId matches the one used by the SDK
    const sessionId = transport.sessionId;
    instance.transports.set(sessionId, transport);
 
    // Handle cleanup
    const cleanup = () => {
      instance.transports.delete(sessionId);
      transport.close();
    };
    req.on('close', cleanup);
    (res as any).on?.('close', cleanup);
    (res as any).on?.('finish', cleanup);
  }
 
  async handleStreamableHTTPRequest(
    name: string,
    req: Request,
    res: Response
  ): Promise<void> {
    const instance = this.getServer(name);
    if (!instance) {
      res.status(404).json({ error: 'MCP server not found' });
      return;
    }
 
    // Stateful handling: reuse transport by sessionId when provided
    const sessionIdHeader =
      (req.headers['mcp-session-id'] as string) ||
      (req.headers['mcp-session-id'.toLowerCase()] as string) ||
      (req.query.sessionId as string);
    const existing = sessionIdHeader ? instance.transports.get(sessionIdHeader) : undefined;
 
    try {
      if (existing && existing instanceof StreamableHTTPServerTransport) {
        await existing.handleRequest(req as any, res, req.body);
        return;
      }
 
      // Create a new stateful Streamable HTTP server transport
      const transport = new StreamableHTTPServerTransport({
        sessionIdGenerator: () => randomUUID(),
        enableJsonResponse: false,
        onsessioninitialized: async (sid: string) => {
          instance.transports.set(sid, transport);
        },
        onsessionclosed: async (sid: string) => {
          instance.transports.delete(sid);
        }
      });
 
      // Connect this transport to the shared server instance
      await instance.server.connect(transport);
 
      // Handle the request (GET/POST/DELETE)
      await transport.handleRequest(req as any, res, req.body);
      return;
    } catch (error) {
      this.logger.error([
        'Error in streamable HTTP handler for ',
        name,
        ': ',
        String(error)
      ].join(''));
      if (!res.headersSent) {
        res.status(500).json({ error: 'Internal server error' });
      }
    }
  }
 
  async handlePostMessage(
    name: string,
    req: Request,
    res: Response
  ): Promise<void> {
    const instance = this.getServer(name);
    if (!instance) {
      res.status(404).json({ error: 'MCP server not found' });
      return;
    }
 
    // Get session ID from header or query
    const sessionId = req.headers['mcp-session-id'] as string || req.query.sessionId as string;
 
    this.logger.debug([
      '<',
      name,
      '> handlePostMessage - sessionId: ',
      String(sessionId),
      ', transports: ',
      Array.from(instance.transports.keys()).join(', ')
    ].join(''));
 
    if (sessionId) {
      const transport = instance.transports.get(sessionId);
      if (transport && 'handlePostMessage' in transport) {
        this.logger.debug([
          '<',
          name,
          '> Found SSE transport for session ',
          String(sessionId),
          ', delegating to handlePostMessage'
        ].join(''));
        // Forward the original request/response to the SSE transport. If the request stream
        // is not readable due to body parsing, prefer passing parsed body for JSON requests;
        // otherwise pass rawBody (when enabled), else fall back to streaming.
        const contentType =
          (req.headers['content-type'] as string | undefined) ||
          (req.headers['Content-Type'] as unknown as string | undefined);
        const isJson = typeof contentType === 'string' && contentType.includes('application/json');
        const hasParsedBody = typeof (req as any).body !== 'undefined' && (req as any).body !== null;
        const rawBody = (req as any).rawBody;
        if (isJson && hasParsedBody) {
          await (transport as SSEServerTransport).handlePostMessage(req as any, res, (req as any).body);
        } else Iif (typeof rawBody !== 'undefined') {
          await (transport as SSEServerTransport).handlePostMessage(req as any, res, rawBody);
        } else {
          await (transport as SSEServerTransport).handlePostMessage(req as any, res);
        }
        return;
      } else {
        this.logger.warn([
          '<',
          name,
          '> SessionId ',
          String(sessionId),
          ' provided but no matching transport found or transport doesn\'t support handlePostMessage'
        ].join(''));
      }
    } else {
      this.logger.warn([ '<', name, '> No sessionId provided in POST message request' ].join(''));
    }
 
    // If no session ID or transport doesn't support handlePostMessage,
    // fall back to streamable HTTP handler
    this.logger.debug([ '<', name, '> Falling back to streamable HTTP handler' ].join(''));
    await this.handleStreamableHTTPRequest(name, req, res);
  }
}