@@ -4377,7 +4377,6 @@ function defaultFactory (origin, opts) {
43774377
43784378class Agent extends DispatcherBase {
43794379 constructor ({ factory = defaultFactory, maxRedirections = 0, connect, ...options } = {}) {
4380- super()
43814380
43824381 if (typeof factory !== 'function') {
43834382 throw new InvalidArgumentError('factory must be a function.')
@@ -4391,6 +4390,8 @@ class Agent extends DispatcherBase {
43914390 throw new InvalidArgumentError('maxRedirections must be a positive number')
43924391 }
43934392
4393+ super(options)
4394+
43944395 if (connect && typeof connect !== 'function') {
43954396 connect = { ...connect }
43964397 }
@@ -6939,9 +6940,10 @@ class Client extends DispatcherBase {
69396940 autoSelectFamilyAttemptTimeout,
69406941 // h2
69416942 maxConcurrentStreams,
6942- allowH2
6943+ allowH2,
6944+ webSocket
69436945 } = {}) {
6944- super()
6946+ super({ webSocket } )
69456947
69466948 if (keepAlive !== undefined) {
69476949 throw new InvalidArgumentError('unsupported keepAlive, use pipelining=0 instead')
@@ -7473,15 +7475,23 @@ const { kDestroy, kClose, kClosed, kDestroyed, kDispatch, kInterceptors } = __nc
74737475const kOnDestroyed = Symbol('onDestroyed')
74747476const kOnClosed = Symbol('onClosed')
74757477const kInterceptedDispatch = Symbol('Intercepted Dispatch')
7478+ const kWebSocketOptions = Symbol('webSocketOptions')
74767479
74777480class DispatcherBase extends Dispatcher {
7478- constructor () {
7481+ constructor (opts ) {
74797482 super()
74807483
74817484 this[kDestroyed] = false
74827485 this[kOnDestroyed] = null
74837486 this[kClosed] = false
74847487 this[kOnClosed] = []
7488+ this[kWebSocketOptions] = opts?.webSocket ?? {}
7489+ }
7490+
7491+ get webSocketOptions () {
7492+ return {
7493+ maxPayloadSize: this[kWebSocketOptions].maxPayloadSize ?? 128 * 1024 * 1024
7494+ }
74857495 }
74867496
74877497 get destroyed () {
@@ -8041,8 +8051,8 @@ const kRemoveClient = Symbol('remove client')
80418051const kStats = Symbol('stats')
80428052
80438053class PoolBase extends DispatcherBase {
8044- constructor () {
8045- super()
8054+ constructor (opts ) {
8055+ super(opts )
80468056
80478057 this[kQueue] = new FixedQueue()
80488058 this[kClients] = []
@@ -8301,8 +8311,6 @@ class Pool extends PoolBase {
83018311 allowH2,
83028312 ...options
83038313 } = {}) {
8304- super()
8305-
83068314 if (connections != null && (!Number.isFinite(connections) || connections < 0)) {
83078315 throw new InvalidArgumentError('invalid connections')
83088316 }
@@ -8327,6 +8335,8 @@ class Pool extends PoolBase {
83278335 })
83288336 }
83298337
8338+ super(options)
8339+
83308340 this[kInterceptors] = options.interceptors?.Pool && Array.isArray(options.interceptors.Pool)
83318341 ? options.interceptors.Pool
83328342 : []
@@ -26081,40 +26091,35 @@ const tail = Buffer.from([0x00, 0x00, 0xff, 0xff])
2608126091const kBuffer = Symbol('kBuffer')
2608226092const kLength = Symbol('kLength')
2608326093
26084- // Default maximum decompressed message size: 4 MB
26085- const kDefaultMaxDecompressedSize = 4 * 1024 * 1024
26086-
2608726094class PerMessageDeflate {
2608826095 /** @type {import('node:zlib').InflateRaw} */
2608926096 #inflate
2609026097
2609126098 #options = {}
2609226099
26093- /** @type {boolean} */
26094- #aborted = false
26095-
26096- /** @type {Function|null} */
26097- #currentCallback = null
26100+ #maxPayloadSize = 0
2609826101
2609926102 /**
2610026103 * @param {Map<string, string>} extensions
2610126104 */
26102- constructor (extensions) {
26105+ constructor (extensions, options ) {
2610326106 this.#options.serverNoContextTakeover = extensions.has('server_no_context_takeover')
2610426107 this.#options.serverMaxWindowBits = extensions.get('server_max_window_bits')
26108+
26109+ this.#maxPayloadSize = options.maxPayloadSize
2610526110 }
2610626111
26112+ /**
26113+ * Decompress a compressed payload.
26114+ * @param {Buffer} chunk Compressed data
26115+ * @param {boolean} fin Final fragment flag
26116+ * @param {Function} callback Callback function
26117+ */
2610726118 decompress (chunk, fin, callback) {
2610826119 // An endpoint uses the following algorithm to decompress a message.
2610926120 // 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
2611026121 // payload of the message.
2611126122 // 2. Decompress the resulting data using DEFLATE.
26112-
26113- if (this.#aborted) {
26114- callback(new MessageSizeExceededError())
26115- return
26116- }
26117-
2611826123 if (!this.#inflate) {
2611926124 let windowBits = Z_DEFAULT_WINDOWBITS
2612026125
@@ -26137,23 +26142,12 @@ class PerMessageDeflate {
2613726142 this.#inflate[kLength] = 0
2613826143
2613926144 this.#inflate.on('data', (data) => {
26140- if (this.#aborted) {
26141- return
26142- }
26143-
2614426145 this.#inflate[kLength] += data.length
2614526146
26146- if (this.#inflate[kLength] > kDefaultMaxDecompressedSize ) {
26147- this.#aborted = true
26147+ if (this.#maxPayloadSize > 0 && this.# inflate[kLength] > this.#maxPayloadSize ) {
26148+ callback(new MessageSizeExceededError())
2614826149 this.#inflate.removeAllListeners()
26149- this.#inflate.destroy()
2615026150 this.#inflate = null
26151-
26152- if (this.#currentCallback) {
26153- const cb = this.#currentCallback
26154- this.#currentCallback = null
26155- cb(new MessageSizeExceededError())
26156- }
2615726151 return
2615826152 }
2615926153
@@ -26166,22 +26160,20 @@ class PerMessageDeflate {
2616626160 })
2616726161 }
2616826162
26169- this.#currentCallback = callback
2617026163 this.#inflate.write(chunk)
2617126164 if (fin) {
2617226165 this.#inflate.write(tail)
2617326166 }
2617426167
2617526168 this.#inflate.flush(() => {
26176- if (this.#aborted || !this.#inflate) {
26169+ if (!this.#inflate) {
2617726170 return
2617826171 }
2617926172
2618026173 const full = Buffer.concat(this.#inflate[kBuffer], this.#inflate[kLength])
2618126174
2618226175 this.#inflate[kBuffer].length = 0
2618326176 this.#inflate[kLength] = 0
26184- this.#currentCallback = null
2618526177
2618626178 callback(null, full)
2618726179 })
@@ -26216,6 +26208,7 @@ const {
2621626208const { WebsocketFrameSend } = __nccwpck_require__(3264)
2621726209const { closeWebSocketConnection } = __nccwpck_require__(6897)
2621826210const { PerMessageDeflate } = __nccwpck_require__(9469)
26211+ const { MessageSizeExceededError } = __nccwpck_require__(8707)
2621926212
2622026213// This code was influenced by ws released under the MIT license.
2622126214// Copyright (c) 2011 Einar Otto Stangvik <einaros@gmail.com>
@@ -26224,6 +26217,7 @@ const { PerMessageDeflate } = __nccwpck_require__(9469)
2622426217
2622526218class ByteParser extends Writable {
2622626219 #buffers = []
26220+ #fragmentsBytes = 0
2622726221 #byteOffset = 0
2622826222 #loop = false
2622926223
@@ -26235,18 +26229,23 @@ class ByteParser extends Writable {
2623526229 /** @type {Map<string, PerMessageDeflate>} */
2623626230 #extensions
2623726231
26232+ /** @type {number} */
26233+ #maxPayloadSize
26234+
2623826235 /**
2623926236 * @param {import('./websocket').WebSocket} ws
2624026237 * @param {Map<string, string>|null} extensions
26238+ * @param {{ maxPayloadSize?: number }} [options]
2624126239 */
26242- constructor (ws, extensions) {
26240+ constructor (ws, extensions, options = {} ) {
2624326241 super()
2624426242
2624526243 this.ws = ws
2624626244 this.#extensions = extensions == null ? new Map() : extensions
26245+ this.#maxPayloadSize = options.maxPayloadSize ?? 0
2624726246
2624826247 if (this.#extensions.has('permessage-deflate')) {
26249- this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions))
26248+ this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions, options ))
2625026249 }
2625126250 }
2625226251
@@ -26262,6 +26261,19 @@ class ByteParser extends Writable {
2626226261 this.run(callback)
2626326262 }
2626426263
26264+ #validatePayloadLength () {
26265+ if (
26266+ this.#maxPayloadSize > 0 &&
26267+ !isControlFrame(this.#info.opcode) &&
26268+ this.#info.payloadLength > this.#maxPayloadSize
26269+ ) {
26270+ failWebsocketConnection(this.ws, 'Payload size exceeds maximum allowed size')
26271+ return false
26272+ }
26273+
26274+ return true
26275+ }
26276+
2626526277 /**
2626626278 * Runs whenever a new chunk is received.
2626726279 * Callback is called whenever there are no more chunks buffering,
@@ -26350,6 +26362,10 @@ class ByteParser extends Writable {
2635026362 if (payloadLength <= 125) {
2635126363 this.#info.payloadLength = payloadLength
2635226364 this.#state = parserStates.READ_DATA
26365+
26366+ if (!this.#validatePayloadLength()) {
26367+ return
26368+ }
2635326369 } else if (payloadLength === 126) {
2635426370 this.#state = parserStates.PAYLOADLENGTH_16
2635526371 } else if (payloadLength === 127) {
@@ -26374,6 +26390,10 @@ class ByteParser extends Writable {
2637426390
2637526391 this.#info.payloadLength = buffer.readUInt16BE(0)
2637626392 this.#state = parserStates.READ_DATA
26393+
26394+ if (!this.#validatePayloadLength()) {
26395+ return
26396+ }
2637726397 } else if (this.#state === parserStates.PAYLOADLENGTH_64) {
2637826398 if (this.#byteOffset < 8) {
2637926399 return callback()
@@ -26396,6 +26416,10 @@ class ByteParser extends Writable {
2639626416
2639726417 this.#info.payloadLength = lower
2639826418 this.#state = parserStates.READ_DATA
26419+
26420+ if (!this.#validatePayloadLength()) {
26421+ return
26422+ }
2639926423 } else if (this.#state === parserStates.READ_DATA) {
2640026424 if (this.#byteOffset < this.#info.payloadLength) {
2640126425 return callback()
@@ -26408,42 +26432,53 @@ class ByteParser extends Writable {
2640826432 this.#state = parserStates.INFO
2640926433 } else {
2641026434 if (!this.#info.compressed) {
26411- this.#fragments.push(body)
26435+ this.writeFragments(body)
26436+
26437+ if (this.#maxPayloadSize > 0 && this.#fragmentsBytes > this.#maxPayloadSize) {
26438+ failWebsocketConnection(this.ws, new MessageSizeExceededError().message)
26439+ return
26440+ }
2641226441
2641326442 // If the frame is not fragmented, a message has been received.
2641426443 // If the frame is fragmented, it will terminate with a fin bit set
2641526444 // and an opcode of 0 (continuation), therefore we handle that when
2641626445 // parsing continuation frames, not here.
2641726446 if (!this.#info.fragmented && this.#info.fin) {
26418- const fullMessage = Buffer.concat(this.#fragments)
26419- websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage)
26420- this.#fragments.length = 0
26447+ websocketMessageReceived(this.ws, this.#info.binaryType, this.consumeFragments())
2642126448 }
2642226449
2642326450 this.#state = parserStates.INFO
2642426451 } else {
26425- this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => {
26426- if (error) {
26427- failWebsocketConnection(this.ws, error.message)
26428- return
26429- }
26452+ this.#extensions.get('permessage-deflate').decompress(
26453+ body,
26454+ this.#info.fin,
26455+ (error, data) => {
26456+ if (error) {
26457+ failWebsocketConnection(this.ws, error.message)
26458+ return
26459+ }
2643026460
26431- this.#fragments.push(data)
26461+ this.writeFragments(data)
26462+
26463+ if (this.#maxPayloadSize > 0 && this.#fragmentsBytes > this.#maxPayloadSize) {
26464+ failWebsocketConnection(this.ws, new MessageSizeExceededError().message)
26465+ return
26466+ }
26467+
26468+ if (!this.#info.fin) {
26469+ this.#state = parserStates.INFO
26470+ this.#loop = true
26471+ this.run(callback)
26472+ return
26473+ }
26474+
26475+ websocketMessageReceived(this.ws, this.#info.binaryType, this.consumeFragments())
2643226476
26433- if (!this.#info.fin) {
26434- this.#state = parserStates.INFO
2643526477 this.#loop = true
26478+ this.#state = parserStates.INFO
2643626479 this.run(callback)
26437- return
2643826480 }
26439-
26440- websocketMessageReceived(this.ws, this.#info.binaryType, Buffer.concat(this.#fragments))
26441-
26442- this.#loop = true
26443- this.#state = parserStates.INFO
26444- this.#fragments.length = 0
26445- this.run(callback)
26446- })
26481+ )
2644726482
2644826483 this.#loop = false
2644926484 break
@@ -26495,6 +26530,26 @@ class ByteParser extends Writable {
2649526530 return buffer
2649626531 }
2649726532
26533+ writeFragments (fragment) {
26534+ this.#fragmentsBytes += fragment.length
26535+ this.#fragments.push(fragment)
26536+ }
26537+
26538+ consumeFragments () {
26539+ const fragments = this.#fragments
26540+
26541+ if (fragments.length === 1) {
26542+ this.#fragmentsBytes = 0
26543+ return fragments.shift()
26544+ }
26545+
26546+ const output = Buffer.concat(fragments, this.#fragmentsBytes)
26547+ this.#fragments = []
26548+ this.#fragmentsBytes = 0
26549+
26550+ return output
26551+ }
26552+
2649826553 parseCloseBody (data) {
2649926554 assert(data.length !== 1)
2650026555
@@ -27526,7 +27581,11 @@ class WebSocket extends EventTarget {
2752627581 // once this happens, the connection is open
2752727582 this[kResponse] = response
2752827583
27529- const parser = new ByteParser(this, parsedExtensions)
27584+ const maxPayloadSize = this[kController]?.dispatcher?.webSocketOptions?.maxPayloadSize
27585+
27586+ const parser = new ByteParser(this, parsedExtensions, {
27587+ maxPayloadSize
27588+ })
2753027589 parser.on('drain', onParserDrain)
2753127590 parser.on('error', onParserError.bind(this))
2753227591
0 commit comments