创建自定义 Spring Cloud Gateway 过滤器

工程 | Fredrich Ombico | 2022 年 8 月 27 日 | ...

在本文中,我们将深入了解如何为 Spring Cloud Gateway 编写自定义扩展。在开始之前,让我们先了解 Spring Cloud Gateway 的工作原理。

Spring Cloud Gateway diagram

  1. 首先,客户端向网关发出网络请求。
  2. 网关定义了许多路由,每个路由都带有谓词,用于将请求与路由匹配。例如,您可以根据 URL 的路径段或请求的 HTTP 方法进行匹配。
  3. 匹配后,网关会对应用于路由的每个过滤器执行请求前逻辑。例如,您可能希望向请求添加查询参数。
  4. 代理过滤器将请求路由到代理服务。
  5. 服务执行并返回响应。
  6. 网关接收响应并在返回响应之前对每个过滤器执行请求后逻辑。例如,您可以在返回给客户端之前删除不需要的响应头。

我们的扩展将对请求体进行哈希处理,并将值作为名为 X-Hash 的请求头添加。这对应于上图中的步骤 3。注意:因为我们正在读取请求体,所以网关的内存将受到限制。

首先,我们在 start.spring.io 上创建一个包含 Gateway 依赖项的项目。在本例中,我们将使用 Java 中的 Gradle 项目,JDK 17 和 Spring Boot 2.7.3。下载、解压缩并在您喜欢的 IDE 中打开项目,并运行它以确保您已准备好进行本地开发。

接下来,让我们创建 GatewayFilter Factory,它是一个作用域限定于特定路由的过滤器,允许我们以某种方式修改传入的 HTTP 请求或传出的 HTTP 响应。在我们的例子中,我们将使用额外的标头修改传入的 HTTP 请求。

package com.example.demo;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.List;

import org.bouncycastle.util.encoders.Hex;
import reactor.core.publisher.Mono;

import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR;

/**
 * This filter hashes the request body, placing the value in the X-Hash header.
 * Note: This causes the gateway to be memory constrained.
 * Sample usage: RequestHashing=SHA-256
 */
@Component
public class RequestHashingGatewayFilterFactory extends
        AbstractGatewayFilterFactory<RequestHashingGatewayFilterFactory.Config> {

    private static final String HASH_ATTR = "hash";
    private static final String HASH_HEADER = "X-Hash";
    private final List<HttpMessageReader<?>> messageReaders =
            HandlerStrategies.withDefaults().messageReaders();

    public RequestHashingGatewayFilterFactory() {
        super(Config.class);
    }

    @Override
    public GatewayFilter apply(Config config) {
        MessageDigest digest = config.getMessageDigest();
        return (exchange, chain) -> ServerWebExchangeUtils
                .cacheRequestBodyAndRequest(exchange, (httpRequest) -> ServerRequest
                    .create(exchange.mutate().request(httpRequest).build(),
                            messageReaders)
                    .bodyToMono(String.class)
                    .doOnNext(requestPayload -> exchange
                            .getAttributes()
                            .put(HASH_ATTR, computeHash(digest, requestPayload)))
                    .then(Mono.defer(() -> {
                        ServerHttpRequest cachedRequest = exchange.getAttribute(
                                CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);
                        Assert.notNull(cachedRequest, 
                                "cache request shouldn't be null");
                        exchange.getAttributes()
                                .remove(CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);

                        String hash = exchange.getAttribute(HASH_ATTR);
                        cachedRequest = cachedRequest.mutate()
                                .header(HASH_HEADER, hash)
                                .build();
                        return chain.filter(exchange.mutate()
                                .request(cachedRequest)
                                .build());
                    })));
    }

    @Override
    public List<String> shortcutFieldOrder() {
        return Collections.singletonList("algorithm");
    }

    private String computeHash(MessageDigest messageDigest, String requestPayload) {
        return Hex.toHexString(messageDigest.digest(requestPayload.getBytes()));
    }

    static class Config {

        private MessageDigest messageDigest;

        public MessageDigest getMessageDigest() {
            return messageDigest;
        }

        public void setAlgorithm(String algorithm) throws NoSuchAlgorithmException {
            messageDigest = MessageDigest.getInstance(algorithm);
        }
    }
}

让我们更详细地看一下代码。

  • 我们在类中添加了 @Component 注解。Spring Cloud Gateway 需要能够检测到此类才能使用它。或者,我们可以使用 @Bean 定义一个实例。
  • 在我们的类名中,我们使用 GatewayFilterFactory 作为后缀。在 application.yaml 中添加此过滤器时,我们不包括后缀,只使用 RequestHashing。这是 Spring Cloud Gateway 过滤器命名约定。
  • 我们的类还扩展了 AbstractGatewayFilterFactory,类似于所有其他 Spring Cloud Gateway 过滤器。我们还指定了一个类来配置我们的过滤器,一个名为 Config 的嵌套静态类有助于保持简单。配置类允许我们设置要使用的哈希算法。
  • 覆盖的 apply 方法是所有工作发生的地方。在参数中,我们获得了配置类的实例,在其中我们可以访问用于哈希的 MessageDigest 实例。接下来,我们看到 (exchange, chain),它是返回的 GatewayFilter 接口类的 lambda 表达式。exchange 是 ServerWebExchange 的一个实例,它为 Gateway 过滤器提供了对 HTTP 请求和响应的访问权限。对于我们的情况,我们希望修改 HTTP 请求,这需要我们修改 exchange。
  • 我们需要读取请求体才能生成哈希,但是,由于主体存储在字节缓冲区中,因此在过滤器中只能读取一次。通过使用 ServerWebExchangeUtils,我们将请求缓存为 exchange 中的一个属性。属性提供了一种在整个过滤器链中为特定请求共享数据的方式。我们还将存储计算出的请求体哈希值。
  • 我们使用 exchange 属性获取缓存的请求和计算出的哈希值。然后,我们通过在最终发送到链中的下一个过滤器之前添加哈希头来修改 exchange。
  • shortcutFieldOrder 方法有助于将参数的数量和顺序映射到过滤器。algorithm 字符串与 Config 类中的 setter 匹配。

为了测试代码,我们将使用 WireMock。将依赖项添加到您的 build.gradle 文件中。

testImplementation 'com.github.tomakehurst:wiremock:2.27.2'

这里我们有一个测试检查标头的存在和值,另一个测试检查如果没有请求体,则标头不存在。

package com.example.demo;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
import org.bouncycastle.jcajce.provider.digest.SHA512;
import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpStatus;
import org.springframework.test.web.reactive.server.WebTestClient;

import static com.example.demo.RequestHashingGatewayFilterFactory.*;
import static com.example.demo.RequestHashingGatewayFilterFactoryTest.*;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;

@SpringBootTest(
        webEnvironment = RANDOM_PORT,
        classes = RequestHashingFilterTestConfig.class)
@AutoConfigureWebTestClient
class RequestHashingGatewayFilterFactoryTest {

    @TestConfiguration
    static class RequestHashingFilterTestConfig {

        @Autowired
        RequestHashingGatewayFilterFactory requestHashingGatewayFilter;

        @Bean(destroyMethod = "stop")
        WireMockServer wireMockServer() {
            WireMockConfiguration options = wireMockConfig().dynamicPort();
            WireMockServer wireMock = new WireMockServer(options);
            wireMock.start();
            return wireMock;
        }

        @Bean
        RouteLocator testRoutes(RouteLocatorBuilder builder, WireMockServer wireMock)
                throws NoSuchAlgorithmException {
            Config config = new Config();
            config.setAlgorithm("SHA-512");

            GatewayFilter gatewayFilter = requestHashingGatewayFilter.apply(config);
            return builder
                    .routes()
                    .route(predicateSpec -> predicateSpec
                            .path("/post")
                            .filters(spec -> spec.filter(gatewayFilter))
                            .uri(wireMock.baseUrl()))
                    .build();
        }
    }

    @Autowired
    WebTestClient webTestClient;

    @Autowired
    WireMockServer wireMockServer;

    @AfterEach
    void afterEach() {
        wireMockServer.resetAll();
    }

    @Test
    void shouldAddHeaderWithComputedHash() {
        MessageDigest messageDigest = new SHA512.Digest();
        String body = "hello world";
        String expectedHash = Hex.toHexString(messageDigest.digest(body.getBytes()));

        wireMockServer.stubFor(WireMock.post("/post").willReturn(WireMock.ok()));

        webTestClient.post().uri("/post")
                .bodyValue(body)
                .exchange()
                .expectStatus()
                .isEqualTo(HttpStatus.OK);

        wireMockServer.verify(postRequestedFor(urlEqualTo("/post"))
                .withHeader("X-Hash", equalTo(expectedHash)));
    }

    @Test
    void shouldNotAddHeaderIfNoBody() {
        wireMockServer.stubFor(WireMock.post("/post").willReturn(WireMock.ok()));

        webTestClient.post().uri("/post")
                .exchange()
                .expectStatus()
                .isEqualTo(HttpStatus.OK);

        wireMockServer.verify(postRequestedFor(urlEqualTo("/post"))
                .withoutHeader("X-Hash"));
    }
}

为了在我们的网关中使用过滤器,我们在 application.yaml 中的路由中添加 RequestHashing 过滤器,使用 SHA-256 作为算法。

spring:
  cloud:
    gateway:
      routes:
        - id: demo
          uri: https://httpbin.org
          predicates:
            - Path=/post/**
          filters:
            - RequestHashing=SHA-256

我们使用 https://httpbin.org,因为它在返回的响应中显示了我们的请求头。运行应用程序并发出 curl 请求以查看结果。

$> curl --request POST 'https://127.0.0.1:8080/post' \
--header 'Content-Type: application/json' \
--data-raw '{
    "data": {
        "hello": "world"
    }
}'

{
  ...
  "data": "{\n    \"data\": {\n        \"hello\": \"world\"\n    }\n}",
  "headers": {
        "Accept": "*/*",
        "Accept-Encoding": "gzip, deflate, br",
        "Content-Length": "48",
        "Content-Type": "application/json",
        "Forwarded": "proto=http;host=\"localhost:8080\";for=\"[0:0:0:0:0:0:0:1]:55647\"",
        "Host": "httpbin.org",
        "User-Agent": "PostmanRuntime/7.29.0",
        "X-Forwarded-Host": "localhost:8080",
        "X-Hash": "1bd93d38735501b5aec7a822f8bc8136d9f1f71a30c2020511bdd5df379772b8"
    },
  ...
}

总之,我们了解了如何为 Spring Cloud Gateway 编写自定义扩展。我们的过滤器读取请求体以生成哈希,我们将其作为请求头添加。我们还使用 WireMock 为过滤器编写了测试,以检查标头值。最后,我们运行了一个带有过滤器的网关来验证结果。

如果您计划在 Kubernetes 集群上部署 Spring Cloud Gateway,请务必查看 VMware Spring Cloud Gateway for Kubernetes。除了支持开源 Spring Cloud Gateway 过滤器和自定义过滤器(例如我们上面编写的过滤器)之外,它还提供了 更多内置过滤器 来操作您的请求和响应。Spring Cloud Gateway for Kubernetes 代表 API 开发团队处理横切关注点,例如:单点登录 (SSO)、访问控制、速率限制、弹性、安全等。

获取 Spring 新闻通讯

与 Spring 新闻通讯保持联系

订阅

领先一步

VMware 提供培训和认证,以加速您的进步。

了解更多

获得支持

Tanzu Spring 在一个简单的订阅中提供对 OpenJDK™、Spring 和 Apache Tomcat® 的支持和二进制文件。

了解更多

即将举行的活动

查看 Spring 社区中所有即将举行的活动。

查看全部