Skip to content

Commit 028c50b

Browse files
authored
feat: IAM support for GDB (#618)
1 parent 58e445b commit 028c50b

20 files changed

Lines changed: 652 additions & 618 deletions

common/lib/authentication/aws_secrets_manager_plugin.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export class AwsSecretsManagerPlugin extends AbstractConnectionPlugin implements
3737
private static readonly TELEMETRY_UPDATE_SECRETS = "fetch credentials";
3838
private static readonly TELEMETRY_FETCH_CREDENTIALS_COUNTER = "secretsManager.fetchCredentials.count";
3939
private static SUBSCRIBED_METHODS: Set<string> = new Set<string>(["connect", "forceConnect"]);
40-
private static SECRETS_ARN_PATTERN: RegExp = new RegExp("^arn:aws:secretsmanager:(?<region>[^:\\n]*):[^:\\n]*:([^:/\\n]*[:/])?(.*)$");
40+
private static SECRETS_ARN_PATTERN: RegExp = new RegExp("^arn:aws[^:]*:secretsmanager:(?<region>[^:\\n]*):[^:\\n]*:([^:/\\n]*[:/])?(.*)$");
4141
private readonly pluginService: PluginService;
4242
private readonly fetchCredentialsCounter;
4343
private readonly expirationSec: number;

common/lib/authentication/iam_authentication_plugin.ts

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,26 @@ import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils";
2626
import { ClientWrapper } from "../client_wrapper";
2727
import { RegionUtils } from "../utils/region_utils";
2828
import { CanReleaseResources } from "../can_release_resources";
29+
import { RdsUrlType } from "../utils/rds_url_type";
30+
import { RdsUtils } from "../utils/rds_utils";
31+
import { GDBRegionUtils } from "../utils/gdb_region_utils";
2932

3033
export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements CanReleaseResources {
3134
private static readonly SUBSCRIBED_METHODS = new Set<string>(["connect", "forceConnect"]);
3235
protected static readonly tokenCache = new Map<string, TokenInfo>();
3336
private readonly telemetryFactory;
3437
private readonly fetchTokenCounter;
35-
private pluginService: PluginService;
38+
private readonly pluginService: PluginService;
39+
private readonly rdsUtils: RdsUtils = new RdsUtils();
40+
protected regionUtils: RegionUtils;
41+
protected readonly iamAuthUtils: IamAuthUtils;
3642

37-
constructor(pluginService: PluginService) {
43+
constructor(pluginService: PluginService, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) {
3844
super();
3945
this.pluginService = pluginService;
4046
this.telemetryFactory = this.pluginService.getTelemetryFactory();
4147
this.fetchTokenCounter = this.telemetryFactory.createCounter("iam.fetchTokenCount");
48+
this.iamAuthUtils = iamAuthUtils;
4249
}
4350

4451
getSubscribedMethods(): Set<string> {
@@ -74,14 +81,22 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements
7481
throw new AwsWrapperError(`${WrapperProperties.USER} is null or empty`);
7582
}
7683

77-
const host = IamAuthUtils.getIamHost(props, hostInfo);
78-
const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);
79-
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);
84+
const host = this.iamAuthUtils.getIamHost(props, hostInfo);
85+
const port = this.iamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);
86+
87+
const type: RdsUrlType = this.rdsUtils.identifyRdsType(host.host);
88+
this.regionUtils = type == RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER ? new GDBRegionUtils() : new RegionUtils();
89+
const region: string | null = await this.regionUtils.getRegion(WrapperProperties.IAM_REGION.name, host, props);
90+
91+
if (!region) {
92+
throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unableToDetermineRegion", WrapperProperties.IAM_REGION.name));
93+
}
94+
8095
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
8196
if (tokenExpirationSec < 0) {
8297
throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero"));
8398
}
84-
const cacheKey: string = IamAuthUtils.getCacheKey(port, user, host, region);
99+
const cacheKey: string = this.iamAuthUtils.getCacheKey(port, user, host.host, region);
85100

86101
const tokenInfo = IamAuthenticationPlugin.tokenCache.get(cacheKey);
87102
const isCachedToken: boolean = tokenInfo !== undefined && !tokenInfo.isExpired();
@@ -91,8 +106,8 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements
91106
WrapperProperties.PASSWORD.set(props, tokenInfo.token);
92107
} else {
93108
const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000;
94-
const token = await IamAuthUtils.generateAuthenticationToken(
95-
host,
109+
const token = await this.iamAuthUtils.generateAuthenticationToken(
110+
host.host,
96111
port,
97112
region,
98113
user,
@@ -118,8 +133,8 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements
118133
// Try to generate a new token and try to connect again
119134

120135
const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000;
121-
const token = await IamAuthUtils.generateAuthenticationToken(
122-
host,
136+
const token = await this.iamAuthUtils.generateAuthenticationToken(
137+
host.host,
123138
port,
124139
region,
125140
user,

common/lib/plugin_manager.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import { TelemetryTraceLevel } from "./utils/telemetry/telemetry_trace_level";
3232
import { ConnectionProvider } from "./connection_provider";
3333
import { ConnectionPluginFactory } from "./plugin_factory";
3434
import { ConfigurationProfile } from "./profile/configuration_profile";
35+
import { BaseSamlAuthPlugin } from "./plugins/federated_auth/saml_auth_plugin";
3536

3637
type PluginFunc<T> = (plugin: ConnectionPlugin, targetFunc: () => Promise<T>) => Promise<T>;
3738

common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ import { sleep } from "../../utils/utils";
3131
import { CustomEndpointMonitor, CustomEndpointMonitorImpl } from "./custom_endpoint_monitor_impl";
3232
import { SubscribedMethodHelper } from "../../utils/subscribed_method_helper";
3333
import { CanReleaseResources } from "../../can_release_resources";
34+
import { RdsUrlType } from "../../utils/rds_url_type";
35+
import { GDBRegionUtils } from "../../utils/gdb_region_utils";
3436

3537
export class CustomEndpointPlugin extends AbstractConnectionPlugin implements CanReleaseResources {
3638
private static readonly TELEMETRY_WAIT_FOR_INFO_COUNTER = "customEndpoint.waitForInfo.counter";
3739
private static SUBSCRIBED_METHODS: Set<string> = new Set<string>(SubscribedMethodHelper.NETWORK_BOUND_METHODS);
3840
private static readonly CACHE_CLEANUP_NANOS = BigInt(60_000_000_000);
41+
private static readonly regionUtils: RegionUtils = new RegionUtils();
3942

4043
private static readonly rdsUtils = new RdsUtils();
4144
protected static readonly monitors: SlidingExpirationCache<string, CustomEndpointMonitor> = new SlidingExpirationCache(
@@ -106,7 +109,7 @@ export class CustomEndpointPlugin extends AbstractConnectionPlugin implements Ca
106109
throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.errorParsingEndpointIdentifier", this.customEndpointHostInfo.host));
107110
}
108111

109-
this.region = RegionUtils.getRegion(props.get(WrapperProperties.CUSTOM_ENDPOINT_REGION.name), this.customEndpointHostInfo.host);
112+
this.region = await CustomEndpointPlugin.regionUtils.getRegion(WrapperProperties.CUSTOM_ENDPOINT_REGION.name, this.customEndpointHostInfo, props);
110113
if (!this.region) {
111114
throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.unableToDetermineRegion", WrapperProperties.CUSTOM_ENDPOINT_REGION.name));
112115
}

common/lib/plugins/federated_auth/credentials_provider_factory.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,9 @@
1717
import { AwsCredentialIdentity, AwsCredentialIdentityProvider } from "@smithy/types/dist-types/identity/awsCredentialIdentity";
1818

1919
export interface CredentialsProviderFactory {
20-
getAwsCredentialsProvider(host: string, region: string, props: Map<string, any>): Promise<AwsCredentialIdentity | AwsCredentialIdentityProvider>;
20+
getAwsCredentialsProvider(
21+
host: string,
22+
region: string | null,
23+
props: Map<string, any>
24+
): Promise<AwsCredentialIdentity | AwsCredentialIdentityProvider>;
2125
}

common/lib/plugins/federated_auth/federated_auth_plugin.ts

Lines changed: 5 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -14,118 +14,13 @@
1414
limitations under the License.
1515
*/
1616

17-
import { AbstractConnectionPlugin } from "../../abstract_connection_plugin";
1817
import { PluginService } from "../../plugin_service";
19-
import { RdsUtils } from "../../utils/rds_utils";
20-
import { HostInfo } from "../../host_info";
21-
import { IamAuthUtils, TokenInfo } from "../../utils/iam_auth_utils";
22-
import { WrapperProperties } from "../../wrapper_property";
23-
import { logger } from "../../../logutils";
24-
import { AwsWrapperError } from "../../utils/errors";
25-
import { Messages } from "../../utils/messages";
2618
import { CredentialsProviderFactory } from "./credentials_provider_factory";
27-
import { SamlUtils } from "../../utils/saml_utils";
28-
import { ClientWrapper } from "../../client_wrapper";
29-
import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter";
30-
import { RegionUtils } from "../../utils/region_utils";
31-
import { CanReleaseResources } from "../../can_release_resources";
19+
import { BaseSamlAuthPlugin } from "./saml_auth_plugin";
20+
import { IamAuthUtils } from "../../utils/iam_auth_utils";
3221

33-
export class FederatedAuthPlugin extends AbstractConnectionPlugin implements CanReleaseResources {
34-
protected static readonly tokenCache = new Map<string, TokenInfo>();
35-
protected rdsUtils: RdsUtils = new RdsUtils();
36-
protected pluginService: PluginService;
37-
private static readonly subscribedMethods = new Set<string>(["connect", "forceConnect"]);
38-
private readonly credentialsProviderFactory: CredentialsProviderFactory;
39-
private readonly fetchTokenCounter: TelemetryCounter;
40-
41-
public getSubscribedMethods(): Set<string> {
42-
return FederatedAuthPlugin.subscribedMethods;
43-
}
44-
45-
constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory) {
46-
super();
47-
this.credentialsProviderFactory = credentialsProviderFactory;
48-
this.pluginService = pluginService;
49-
this.fetchTokenCounter = this.pluginService.getTelemetryFactory().createCounter("federatedAuth.fetchToken.count");
50-
}
51-
52-
connect(
53-
hostInfo: HostInfo,
54-
props: Map<string, any>,
55-
isInitialConnection: boolean,
56-
connectFunc: () => Promise<ClientWrapper>
57-
): Promise<ClientWrapper> {
58-
return this.connectInternal(hostInfo, props, connectFunc);
59-
}
60-
61-
forceConnect(
62-
hostInfo: HostInfo,
63-
props: Map<string, any>,
64-
isInitialConnection: boolean,
65-
forceConnectFunc: () => Promise<ClientWrapper>
66-
): Promise<ClientWrapper> {
67-
return this.connectInternal(hostInfo, props, forceConnectFunc);
68-
}
69-
70-
async connectInternal(hostInfo: HostInfo, props: Map<string, any>, connectFunc: () => Promise<ClientWrapper>): Promise<ClientWrapper> {
71-
SamlUtils.checkIdpCredentialsWithFallback(props);
72-
73-
const host = IamAuthUtils.getIamHost(props, hostInfo);
74-
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort());
75-
const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);
76-
77-
const cacheKey = IamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host, region);
78-
const tokenInfo = FederatedAuthPlugin.tokenCache.get(cacheKey);
79-
80-
const isCachedToken: boolean = tokenInfo !== undefined && !tokenInfo.isExpired();
81-
82-
if (isCachedToken && tokenInfo) {
83-
logger.debug(Messages.get("AuthenticationToken.useCachedToken", tokenInfo.token));
84-
WrapperProperties.PASSWORD.set(props, tokenInfo.token);
85-
} else {
86-
await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host);
87-
}
88-
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props));
89-
this.pluginService.updateConfigWithProperties(props);
90-
91-
try {
92-
return await connectFunc();
93-
} catch (e) {
94-
if (!this.pluginService.isLoginError(e as Error) || !isCachedToken) {
95-
throw e;
96-
}
97-
try {
98-
await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host);
99-
return await connectFunc();
100-
} catch (e: any) {
101-
throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unhandledError", e.message));
102-
}
103-
}
104-
}
105-
106-
public async updateAuthenticationToken(hostInfo: HostInfo, props: Map<string, any>, region: string, cacheKey: string, iamHost: string) {
107-
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
108-
if (tokenExpirationSec < 0) {
109-
throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero"));
110-
}
111-
const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000;
112-
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort());
113-
const token = await IamAuthUtils.generateAuthenticationToken(
114-
iamHost,
115-
port,
116-
region,
117-
WrapperProperties.DB_USER.get(props),
118-
await this.credentialsProviderFactory.getAwsCredentialsProvider(hostInfo.host, region, props),
119-
this.pluginService
120-
);
121-
this.fetchTokenCounter.inc();
122-
logger.debug(Messages.get("AuthenticationToken.generatedNewToken", token));
123-
WrapperProperties.PASSWORD.set(props, token);
124-
FederatedAuthPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry));
125-
}
126-
127-
releaseResources(): Promise<void> {
128-
FederatedAuthPlugin.tokenCache.clear();
129-
return;
22+
export class FederatedAuthPlugin extends BaseSamlAuthPlugin {
23+
constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) {
24+
super(pluginService, credentialsProviderFactory, "federatedAuth.fetchToken.count", iamAuthUtils);
13025
}
13126
}

common/lib/plugins/federated_auth/okta_auth_plugin.ts

Lines changed: 5 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -14,120 +14,13 @@
1414
limitations under the License.
1515
*/
1616

17-
import { AbstractConnectionPlugin } from "../../abstract_connection_plugin";
18-
import { HostInfo } from "../../host_info";
19-
import { SamlUtils } from "../../utils/saml_utils";
20-
import { IamAuthUtils, TokenInfo } from "../../utils/iam_auth_utils";
2117
import { PluginService } from "../../plugin_service";
2218
import { CredentialsProviderFactory } from "./credentials_provider_factory";
23-
import { RdsUtils } from "../../utils/rds_utils";
24-
import { WrapperProperties } from "../../wrapper_property";
25-
import { logger } from "../../../logutils";
26-
import { Messages } from "../../utils/messages";
27-
import { AwsWrapperError } from "../../utils/errors";
28-
import { ClientWrapper } from "../../client_wrapper";
29-
import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter";
30-
import { RegionUtils } from "../../utils/region_utils";
31-
import { CanReleaseResources } from "../../can_release_resources";
19+
import { BaseSamlAuthPlugin } from "./saml_auth_plugin";
20+
import { IamAuthUtils } from "../../utils/iam_auth_utils";
3221

33-
export class OktaAuthPlugin extends AbstractConnectionPlugin implements CanReleaseResources {
34-
protected static readonly tokenCache = new Map<string, TokenInfo>();
35-
private static readonly subscribedMethods = new Set<string>(["connect", "forceConnect"]);
36-
protected pluginService: PluginService;
37-
protected rdsUtils = new RdsUtils();
38-
private readonly credentialsProviderFactory: CredentialsProviderFactory;
39-
private readonly fetchTokenCounter: TelemetryCounter;
40-
41-
constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory) {
42-
super();
43-
this.pluginService = pluginService;
44-
this.credentialsProviderFactory = credentialsProviderFactory;
45-
this.fetchTokenCounter = this.pluginService.getTelemetryFactory().createCounter("oktaAuth.fetchToken.count");
46-
}
47-
48-
public getSubscribedMethods(): Set<string> {
49-
return OktaAuthPlugin.subscribedMethods;
50-
}
51-
52-
connect(
53-
hostInfo: HostInfo,
54-
props: Map<string, any>,
55-
isInitialConnection: boolean,
56-
connectFunc: () => Promise<ClientWrapper>
57-
): Promise<ClientWrapper> {
58-
return this.connectInternal(hostInfo, props, connectFunc);
59-
}
60-
61-
forceConnect(
62-
hostInfo: HostInfo,
63-
props: Map<string, any>,
64-
isInitialConnection: boolean,
65-
connectFunc: () => Promise<ClientWrapper>
66-
): Promise<ClientWrapper> {
67-
return this.connectInternal(hostInfo, props, connectFunc);
68-
}
69-
70-
async connectInternal(hostInfo: HostInfo, props: Map<string, any>, connectFunc: () => Promise<ClientWrapper>): Promise<ClientWrapper> {
71-
SamlUtils.checkIdpCredentialsWithFallback(props);
72-
73-
const host = IamAuthUtils.getIamHost(props, hostInfo);
74-
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort());
75-
const region = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);
76-
77-
const cacheKey = IamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host, region);
78-
const tokenInfo = OktaAuthPlugin.tokenCache.get(cacheKey);
79-
80-
const isCachedToken = tokenInfo !== undefined && !tokenInfo.isExpired();
81-
82-
if (isCachedToken) {
83-
logger.debug(Messages.get("AuthenticationToken.useCachedToken", tokenInfo.token));
84-
WrapperProperties.PASSWORD.set(props, tokenInfo.token);
85-
} else {
86-
await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host);
87-
}
88-
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props));
89-
this.pluginService.updateConfigWithProperties(props);
90-
91-
try {
92-
return await connectFunc();
93-
} catch (e: any) {
94-
if (!this.pluginService.isLoginError(e as Error) || !isCachedToken) {
95-
logger.debug(Messages.get("Authentication.connectError", e.message));
96-
throw e;
97-
}
98-
try {
99-
await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host);
100-
return await connectFunc();
101-
} catch (e: any) {
102-
throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unhandledError", e.message));
103-
}
104-
}
105-
}
106-
107-
public async updateAuthenticationToken(hostInfo: HostInfo, props: Map<string, any>, region: string, cacheKey: string, iamHost): Promise<void> {
108-
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
109-
if (tokenExpirationSec < 0) {
110-
throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero"));
111-
}
112-
const tokenExpiry = Date.now() + tokenExpirationSec * 1000;
113-
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort());
114-
this.fetchTokenCounter.inc();
115-
const token = await IamAuthUtils.generateAuthenticationToken(
116-
iamHost,
117-
port,
118-
region,
119-
WrapperProperties.DB_USER.get(props),
120-
await this.credentialsProviderFactory.getAwsCredentialsProvider(hostInfo.host, region, props),
121-
this.pluginService
122-
);
123-
logger.debug(Messages.get("AuthenticationToken.generatedNewToken", token));
124-
WrapperProperties.PASSWORD.set(props, token);
125-
this.pluginService.updateConfigWithProperties(props);
126-
OktaAuthPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry));
127-
}
128-
129-
releaseResources(): Promise<void> {
130-
OktaAuthPlugin.tokenCache.clear();
131-
return;
22+
export class OktaAuthPlugin extends BaseSamlAuthPlugin {
23+
constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) {
24+
super(pluginService, credentialsProviderFactory, "oktaAuth.fetchToken.count", iamAuthUtils);
13225
}
13326
}

0 commit comments

Comments
 (0)