本文将简要介绍如何通过外部知识库 API 将 Dify 平台与 AWS Bedrock 知识库相连接,使得 Dify 平台内的 AI 应用能够直接获取存储在 AWS Bedrock 知识库中的内容,扩展新的信息来源渠道。
from flask import request
from flask_restful import Resource, reqparse
from bedrock.knowledge_service import ExternalDatasetService
class BedrockRetrievalApi(Resource):
# url : <your-endpoint>/retrieval
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument("query", nullable=False, required=True, type=str,)
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args()
# Authorization check
auth_header = request.headers.get("Authorization")
if " " not in auth_header:
return {
"error_code": 1001,
"error_msg": "Invalid Authorization header format. Expected 'Bearer <api-key>' format."
}, 403
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
return {
"error_code": 1001,
"error_msg": "Invalid Authorization header format. Expected 'Bearer <api-key>' format."
}, 403
if auth_token:
# process your authorization logic here
pass
# Call the knowledge retrieval service
result = ExternalDatasetService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200
import boto3
class ExternalDatasetService:
@staticmethod
def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
# get bedrock client
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key="AWS_SECRET_ACCESS_KEY",
aws_access_key_id="AWS_ACCESS_KEY_ID",
# example: us-east-1
region_name="AWS_REGION_NAME",
)
# fetch external knowledge retrieval
response = client.retrieve(
knowledgeBaseId=knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {"numberOfResults": retrieval_setting.get("top_k"), "overrideSearchType": "HYBRID"}
},
retrievalQuery={"text": query},
)
# parse response
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
# filter out results with score less than threshold
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", .0):
continue
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return {
"records": results
}