Skip to content

Commit 443921a

Browse files
committed
small optimization for team checks
1 parent 3181f77 commit 443921a

2 files changed

Lines changed: 34 additions & 22 deletions

File tree

apps/sim/lib/billing/threshold-billing.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,10 @@ export async function checkAndBillOverageThreshold(userId: string): Promise<void
171171

172172
const stripe = requireStripeClient()
173173
const stripeSubscription = await stripe.subscriptions.retrieve(stripeSubscriptionId)
174-
const customerId = stripeSubscription.customer as string
174+
const customerId =
175+
typeof stripeSubscription.customer === 'string'
176+
? stripeSubscription.customer
177+
: stripeSubscription.customer.id
175178

176179
const periodEnd = userSubscription.periodEnd
177180
? Math.floor(userSubscription.periodEnd.getTime() / 1000)
@@ -337,7 +340,10 @@ export async function checkAndBillOrganizationOverageThreshold(
337340

338341
const stripe = requireStripeClient()
339342
const stripeSubscription = await stripe.subscriptions.retrieve(stripeSubscriptionId)
340-
const customerId = stripeSubscription.customer as string
343+
const customerId =
344+
typeof stripeSubscription.customer === 'string'
345+
? stripeSubscription.customer
346+
: stripeSubscription.customer.id
341347

342348
const periodEnd = orgSubscription.periodEnd
343349
? Math.floor(orgSubscription.periodEnd.getTime() / 1000)

apps/sim/lib/billing/webhooks/invoices.ts

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { db } from '@sim/db'
22
import { member, subscription as subscriptionTable, userStats } from '@sim/db/schema'
3-
import { eq } from 'drizzle-orm'
3+
import { eq, inArray } from 'drizzle-orm'
44
import type Stripe from 'stripe'
55
import { calculateSubscriptionOverage } from '@/lib/billing/core/billing'
66
import { requireStripeClient } from '@/lib/billing/stripe-client'
@@ -36,15 +36,19 @@ export async function getBilledOverageForSubscription(sub: {
3636
.from(member)
3737
.where(eq(member.organizationId, sub.referenceId))
3838

39-
for (const m of members) {
40-
const memberStats = await db
41-
.select({ billedOverageThisPeriod: userStats.billedOverageThisPeriod })
39+
const memberIds = members.map((m) => m.userId)
40+
41+
if (memberIds.length > 0) {
42+
const memberStatsRows = await db
43+
.select({
44+
userId: userStats.userId,
45+
billedOverageThisPeriod: userStats.billedOverageThisPeriod,
46+
})
4247
.from(userStats)
43-
.where(eq(userStats.userId, m.userId))
44-
.limit(1)
48+
.where(inArray(userStats.userId, memberIds))
4549

46-
if (memberStats.length > 0) {
47-
billedOverage += parseDecimal(memberStats[0].billedOverageThisPeriod)
50+
for (const stats of memberStatsRows) {
51+
billedOverage += parseDecimal(stats.billedOverageThisPeriod)
4852
}
4953
}
5054
} else {
@@ -147,16 +151,14 @@ export async function handleInvoicePaymentSucceeded(event: Stripe.Event) {
147151
.select({ userId: member.userId })
148152
.from(member)
149153
.where(eq(member.organizationId, sub.referenceId))
150-
for (const m of membersRows) {
151-
const row = await db
154+
const memberIds = membersRows.map((m) => m.userId)
155+
if (memberIds.length > 0) {
156+
const blockedRows = await db
152157
.select({ blocked: userStats.billingBlocked })
153158
.from(userStats)
154-
.where(eq(userStats.userId, m.userId))
155-
.limit(1)
156-
if (row.length > 0 && row[0].blocked) {
157-
wasBlocked = true
158-
break
159-
}
159+
.where(inArray(userStats.userId, memberIds))
160+
161+
wasBlocked = blockedRows.some((row) => !!row.blocked)
160162
}
161163
} else {
162164
const row = await db
@@ -172,11 +174,13 @@ export async function handleInvoicePaymentSucceeded(event: Stripe.Event) {
172174
.select({ userId: member.userId })
173175
.from(member)
174176
.where(eq(member.organizationId, sub.referenceId))
175-
for (const m of members) {
177+
const memberIds = members.map((m) => m.userId)
178+
179+
if (memberIds.length > 0) {
176180
await db
177181
.update(userStats)
178182
.set({ billingBlocked: false })
179-
.where(eq(userStats.userId, m.userId))
183+
.where(inArray(userStats.userId, memberIds))
180184
}
181185
} else {
182186
await db
@@ -263,11 +267,13 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) {
263267
.select({ userId: member.userId })
264268
.from(member)
265269
.where(eq(member.organizationId, sub.referenceId))
266-
for (const m of members) {
270+
const memberIds = members.map((m) => m.userId)
271+
272+
if (memberIds.length > 0) {
267273
await db
268274
.update(userStats)
269275
.set({ billingBlocked: true })
270-
.where(eq(userStats.userId, m.userId))
276+
.where(inArray(userStats.userId, memberIds))
271277
}
272278
logger.info('Blocked team/enterprise members due to payment failure', {
273279
organizationId: sub.referenceId,

0 commit comments

Comments
 (0)