Skip to content

Commit

Permalink
Merge pull request #5001 from Shopify/catlee/refresh_graphql
Browse files Browse the repository at this point in the history
Refresh session on 401 for graphQL requests
  • Loading branch information
catlee authored Dec 2, 2024
2 parents c202990 + 80ec997 commit 15b2a6f
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 3 deletions.
52 changes: 52 additions & 0 deletions packages/cli-kit/src/private/node/api.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ describe('retryAwareRequest', () => {
url: 'https://example.com',
},
undefined,
undefined,
{
defaultDelayMs: 500,
scheduleDelay: mockScheduleDelayFn,
Expand Down Expand Up @@ -100,6 +101,7 @@ describe('retryAwareRequest', () => {
url: 'https://example.com',
},
undefined,
undefined,
{
limitRetriesTo: 7,
scheduleDelay: mockScheduleDelayFn,
Expand All @@ -110,4 +112,54 @@ describe('retryAwareRequest', () => {
expect(mockRequestFn).toHaveBeenCalledTimes(8)
expect(mockScheduleDelayFn).toHaveBeenCalledTimes(7)
})

test('calls unauthorizedHandler when receiving 401', async () => {
const unauthorizedResponse = {
status: 401,
errors: [
{
extensions: {
code: '401',
},
} as any,
],
headers: new Headers(),
}

const mockRequestFn = vi
.fn()
.mockImplementationOnce(() => {
throw new ClientError(unauthorizedResponse, {query: ''})
})
.mockImplementationOnce(() => {
return Promise.resolve({
status: 200,
data: {hello: 'world!'},
headers: new Headers(),
})
})

const mockUnauthorizedHandler = vi.fn()

await expect(
retryAwareRequest(
{
request: mockRequestFn,
url: 'https://example.com',
},
undefined,
mockUnauthorizedHandler,
{
scheduleDelay: vi.fn((fn) => fn()),
},
),
).resolves.toEqual({
headers: expect.anything(),
status: 200,
data: {hello: 'world!'},
})

expect(mockRequestFn).toHaveBeenCalledTimes(2)
expect(mockUnauthorizedHandler).toHaveBeenCalledTimes(1)
})
})
26 changes: 26 additions & 0 deletions packages/cli-kit/src/private/node/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type VerboseResponse<T> = {
| {status: 'client-error'; clientError: ClientError}
| {status: 'unknown-error'; error: unknown}
| {status: 'can-retry'; clientError: ClientError; delayMs: number | undefined}
| {status: 'unauthorized'; clientError: ClientError; delayMs: number | undefined}
)

async function makeVerboseRequest<T extends {headers: Headers; status: number}>({
Expand Down Expand Up @@ -88,6 +89,16 @@ async function makeVerboseRequest<T extends {headers: Headers; status: number}>(
requestId: responseHeaders['x-request-id'],
delayMs,
}
} else if (err.response.status === 401) {
return {
status: 'unauthorized',
clientError: err,
duration,
sanitizedHeaders,
sanitizedUrl,
requestId: responseHeaders['x-request-id'],
delayMs: 500,
}
}

return {
Expand Down Expand Up @@ -169,12 +180,20 @@ ${result.sanitizedHeaders}
throw result.clientError
}
}
case 'unauthorized': {
if (errorHandler) {
throw errorHandler(result.clientError, result.requestId)
} else {
throw result.clientError
}
}
}
}

export async function retryAwareRequest<T extends {headers: Headers; status: number}>(
{request, url}: RequestOptions<T>,
errorHandler?: (error: unknown, requestId: string | undefined) => unknown,
unauthorizedHandler?: () => Promise<void>,
retryOptions: {
limitRetriesTo?: number
defaultDelayMs?: number
Expand Down Expand Up @@ -211,6 +230,13 @@ ${result.sanitizedHeaders}
} else {
throw result.error
}
} else if (result.status === 'unauthorized') {
if (unauthorizedHandler) {
// eslint-disable-next-line no-await-in-loop
await unauthorizedHandler()
} else {
throw result.clientError
}
}

if (limitRetriesTo <= retriesUsed) {
Expand Down
12 changes: 11 additions & 1 deletion packages/cli-kit/src/public/node/api/admin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,17 @@ export async function adminRequestDoc<TResult, TVariables extends Variables>(
token: session.token,
addedHeaders,
}
const result = graphqlRequestDoc<TResult, TVariables>({...opts, query, variables, responseOptions})
let unauthorizedHandler
if ('refresh' in session) {
unauthorizedHandler = session.refresh as () => Promise<void>
}
const result = graphqlRequestDoc<TResult, TVariables>({
...opts,
query,
variables,
responseOptions,
unauthorizedHandler,
})
return result
}

Expand Down
3 changes: 2 additions & 1 deletion packages/cli-kit/src/public/node/api/graphql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ describe('graphqlRequest', () => {
request: expect.any(Function),
url: mockedAddress,
}
expect(retryAwareRequest).toHaveBeenCalledWith(receivedObject, expect.any(Function))
expect(retryAwareRequest).toHaveBeenCalledWith(receivedObject, expect.any(Function), undefined)
})
})

Expand Down Expand Up @@ -95,6 +95,7 @@ describe('graphqlRequestDoc', () => {
url: mockedAddress,
},
expect.any(Function),
undefined,
)
expect(debugRequest.debugLogRequestInfo).toHaveBeenCalledWith(
'mockApi',
Expand Down
6 changes: 5 additions & 1 deletion packages/cli-kit/src/public/node/api/graphql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@ interface GraphQLRequestBaseOptions<TResult> {
type PerformGraphQLRequestOptions<TResult> = GraphQLRequestBaseOptions<TResult> & {
queryAsString: string
variables?: Variables
unauthorizedHandler?: () => Promise<void>
}

export type GraphQLRequestOptions<T> = GraphQLRequestBaseOptions<T> & {
query: RequestDocument
variables?: Variables
unauthorizedHandler?: () => Promise<void>
}

export type GraphQLRequestDocOptions<TResult, TVariables> = GraphQLRequestBaseOptions<TResult> & {
query: TypedDocumentNode<TResult, TVariables> | TypedDocumentNode<TResult, Exact<{[key: string]: never}>>
variables?: TVariables
unauthorizedHandler?: () => Promise<void>
}

export interface GraphQLResponseOptions<T> {
Expand All @@ -49,7 +52,7 @@ export interface GraphQLResponseOptions<T> {
* @param options - GraphQL request options.
*/
async function performGraphQLRequest<TResult>(options: PerformGraphQLRequestOptions<TResult>) {
const {token, addedHeaders, queryAsString, variables, api, url, responseOptions} = options
const {token, addedHeaders, queryAsString, variables, api, url, responseOptions, unauthorizedHandler} = options
const headers = {
...addedHeaders,
...buildHeaders(token),
Expand All @@ -63,6 +66,7 @@ async function performGraphQLRequest<TResult>(options: PerformGraphQLRequestOpti
const response = await retryAwareRequest(
{request: () => client.rawRequest<TResult>(queryAsString, variables), url},
responseOptions?.handleErrors === false ? undefined : errorHandler(api),
unauthorizedHandler,
)

if (responseOptions?.onResponse) {
Expand Down

0 comments on commit 15b2a6f

Please sign in to comment.