跳到主要内容

自定义中间件

前言

FastAPI内置的中间件未必满足所有业务场景需求,因此还需要自定义中间件。

基本示例

将中间件的代码放到单独的模块中,编辑middlewares/demo.py

from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
import time

class TimeCalculate(BaseHTTPMiddleware):
# dispatch必须实现
async def dispatch(self, request: Request, call_next):
print(">>> TimeCalculate Middleware <<<")
start_time = time.time()
resp = await call_next(request)
elapsed_time = round(time.time() - start_time, 4)
print(f"URL: {request.url}, Elapsed time: {elapsed_time}s")
return resp

引入自定义中间件

from middlewares import demo

app.add_middleware(demo.TimeCalculate)

日志追踪链路ID

通过链路追踪,可以在一个请求服务过程中把涉及多个的能源服务或其他第三方请求的日志都关联起来,这样可以快速进行问题定位及排错。要实现链路追踪,就需要为请求打标签。一般的做法是,每进来一个请求,就在当前请求上下文中生成一个链路ID,这个链路ID可以关联第三方请求或其他服务日志,并在整个请求链路上下文中进行传递,直到请求完成响应处理。

import contextvars
import uuid

# 定义一个上下文变量对象,主要用于存储当前请求的上下文信息
request_context = contextvars.ContextVar("request_context")

class TraceID(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_context.set(request) # 将当前请求传入当前request_context的上下文变量对象中
request.state.traceid = uuid.uuid4() # 生成traceid,并写入request.state中
resp = await call_next(request)
return resp


# 创建一个视图函数测试
@app.get("/")
async def get_response():
request: Request = request_context.get()
print(f"index-request: {request.state.traceid}")
return {"message": "Hello World"}

# 添加到中间件
app.add_middleware(TraceID)

自定义类实现中间件

除了通过集成BaseHTTPMiddleware类来自定义中间件外,还可以基于自定义类来实现。

from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.requests import HTTPConnection
import typing


# 自定义类实现黑名单IP中间件
class BlackIPMiddleware(BaseHTTPMiddleware):
# ASGIApp对象必须要有,其它参数根据实际需求设置
def __init__(self, app: ASGIApp, denied_ip: typing.Sequence[str] = ()):
self.app = app
self.denied_ip = denied_ip
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope['type'] in ("http", "websocket") and scope["scheme"] in ("http", "ws"):
conn = HTTPConnection(scope=scope)
if self.denied_ip and conn.client.host in self.denied_ip:
resp = JSONResponse({'message': 'IP Denied'}, status_code=403)
await resp(scope, receive, send)
return
await self.app(scope, receive, send)
else:
await self.app(scope, receive, send)

# 注册到app实例
app.add_middleware(BlackIPMiddleware, denied_ip = ["192.168.1.108"])

基于中间件获取响应内容

在某种场景下,需要在中间件中获取对应请求的响应报文,如常见的日志记录场景。

to be continue