End-to-End ACL with Knowledge Base
Knowledge Bases for Amazon Bedrock
Access Control Filtering - End to end notebook
This notebook will guide the users on creating access controls for Knowledge Bases on Amazon Bedrock.
To demonstrate the access control capabilities enabled by metadata filtering in Knowledge Bases, let's consider a use case where a healthcare provider has a Knowledge Base containing conversation transcripts between doctors and patients. In this scenario, it is crucial to ensure that each doctor can only access and leverage transcripts from their own patient interactions during the search, and not have access to transcripts from other doctors' patient interactions.
To complete this notebook you should have a role with access to the following services: Amazon S3, AWS STS, AWS Lambda, AWS CloudFormation, Amazon Bedrock, Amazon Cognito and Amazon Opensearch Serverless.
This notebook contains the following sections:
- Base Infrastructure Deployment: In this section you will deploy an Amazon Cloudformation Template which will create and configure some of the services used for the solution.
- Amazon Cognito: You are going to populate an Amazon Cognito pool with two doctors and three patients. We will use the unique identifiers generated by Cognito for each user to associate transcripts with the respective patients.
- Doctor-patient association in Amazon DynamoDB: You will populate an Amazon DynamoDB table which will store doctor-patient associations.
- Dataset download: For this notebook you will use user-patient transcripts located in the following repository.
- Metadata association: You will use the doctor identifiers generated by Cognito to create metadata files associated to each transcript file.
- Upload the dataset to Amazon S3: You will create an Amazon S3 bucket and upload the dataset and metadata files.
- Create a Knowledge Base for Amazon Bedrock: You will create and sync the Knowledge Base with the transcripts and associated metadata.
- Update AWS Lambda: Until Boto3/Lambda is updated -- Create a Lambda Layer to include the latest SDK.
- Create and run a Streamlit Application: You will create a simple interface to showcase access control with metadata filtering using a Streamlit application
- Clean up: Delete all the resources created during this notebook to avoid unnecessary costs.
!pip install -qU opensearch-py streamlit streamlit-cognito-auth retrying boto3 botocore
Let's import necessary Python modules and libraries, and initialize AWS service clients required for the notebook.
import os
import json
import time
import uuid
import boto3
import requests
import random
from utils import create_base_infrastructure, create_kb_infrastructure, updateDataAccessPolicy, createAOSSIndex, replace_vars
from botocore.exceptions import ClientError
s3_client = boto3.client('s3')
sts_client = boto3.client('sts')
session = boto3.session.Session()
region = session.region_name
lambda_client = boto3.client('lambda')
dynamodb_resource = boto3.resource('dynamodb')
cloudformation = boto3.client('cloudformation')
bedrock_agent_client = boto3.client('bedrock-agent')
bedrock = boto3.client("bedrock",region_name=region)
account_id = sts_client.get_caller_identity()["Account"]
cognito_client = boto3.client('cognito-idp', region_name=region)
identity_arn = session.client('sts').get_caller_identity()['Arn']
bedrock_agent_runtime_client = boto3.client('bedrock-agent-runtime')
0. Base Infrastructure Deployment
We have created for you an Amazon CloudFormation template which will automatically set up some of the services needed for this notebook.
This template will automatically create: - Amazon Cognito User Pool and App Client. (user_pool_id, cognito_arn, client_id, client_secret) - Amazon DynamoDB Table - Amazon S3 Bucket - AWS Lambda Function
def short_uuid():
    uuid_str = str(uuid.uuid4())
    return uuid_str[:8]
solution_id = 'KBS{}'.format(short_uuid()).lower()
user_pool_id, user_pool_arn, cognito_arn, client_id, client_secret, dynamo_table, s3_bucket, lambda_function_arn, collection_id = create_base_infrastructure(solution_id)
1. Amazon Cognito User Pool: Doctors and patients
Create doctors and patients into the user pool
We will create doctors and patients to test out the use case. User ids are stored for later use when retrieving information. For the notebook to work you will need to replace the placeholder for 2 doctors and 3 patients. This users will be created in the Amazon Cognito user pool and you will later need them to log into the web application. While this is a dummy user creation for test purposes, in production use cases you will need to follow you organization best practices and guidelines to create users.
For this example, the first doctor will have associated the first two patients, and the second doctor will have associated the third patient.
Password minimum length:8 character(s)
Password requirements
Contains at least 1 number
Contains at least 1 special character
Contains at least 1 uppercase letter
Contains at least 1 lowercase letter
doctors = [
    {
        'name': 'INSERT_DOCTOR_1_NAME',
        'email': 'INSERT_DOCTOR_1_EMAIL',
        'password': 'INSERT_DOCTOR_1_PASSWORD'
    },
    {
        'name': 'INSERT_DOCTOR_2_NAME',
        'email': 'INSERT_DOCTOR_2_EMAIL',
        'password': 'INSERT_DOCTOR_2_PASSWORD'
    }
]
patients = [
    {
        'name': 'INSERT_PATIENT_1_NAME',
        'email': 'INSERT_PATIENT_1_EMAIL',
        'password': 'INSERT_PATIENT_1_PASSWORD'
    },
    {
        'name': 'INSERT_PATIENT_2_NAME',
        'email': 'INSERT_PATIENT_2_EMAIL',
        'password': 'INSERT_PATIENT_2_PASSWORD'
    },
    {
        'name': 'INSERT_PATIENT_3_NAME',
        'email': 'INSERT_PATIENT_3_EMAIL',
        'password': 'INSERT_PATIENT_3_PASSWORD'
    }
]
doctor_ids = []
patient_ids = []
def create_user(user_data, user_type):
    user_ids = []
    for user in user_data:
        response = cognito_client.admin_create_user(
            UserPoolId=user_pool_id,
            Username=user['email'],
            UserAttributes=[
                {'Name': 'name', 'Value': user['name']},
                {'Name': 'email', 'Value': user['email']},
                {'Name': 'email_verified', 'Value': 'true'}
            ],
            ForceAliasCreation=False,
            MessageAction='SUPPRESS'
        )
        cognito_client.admin_set_user_password(
            UserPoolId=user_pool_id,
            Username=user['email'],
            Password=user['password'],
            Permanent=True
        )
        print(f"{user_type.capitalize()} created:", response['User']['Username'])
        print(f"{user_type.capitalize()} id:", response['User']['Attributes'][3]['Value'])
        user_ids.append(response['User']['Attributes'][3]['Value'])
    return user_ids
doctor_ids = create_user(doctors, 'doctor')
patient_ids = create_user(patients, 'patient')
print("Doctor IDs:", doctor_ids)
print("Patient IDs:", patient_ids)
2. Doctor-patient association in DynamoDB
In this section we will populate the already created DynamoDB table with the doctor-patient associations. This will be useful later on to retrieve the list of patient ids a doctor is allowed to filter by. *
table = dynamodb_resource.Table(dynamo_table)
with table.batch_writer() as batch:
    batch.put_item(
        Item={
            'doctor_id': doctor_ids[0],
            'patient_id_list': patient_ids[:2]  # Assign the first two patients to the first doctor
        }
    )
    batch.put_item(
        Item={
            'doctor_id': doctor_ids[1],
            'patient_id_list': [patient_ids[2]]  # Assign the third patient to the second doctor
        }
    )
print('Data inserted successfully!')
3. Dataset download
The dataset that we will be using can be found here. It consists of PDF format transcriptions of synthetic conversations. We will download three specific documents of conversations which will be later associated to its respective patient.
dataset_folder = "source_transcripts"
if not os.path.exists(dataset_folder):
    os.makedirs(dataset_folder)
abs_path = os.path.abspath(dataset_folder)
repo_url = 'https://api.github.com/repos/nazmulkazi/dataset_automated_medical_transcription/contents/transcripts/source'
headers = {'Accept': 'application/vnd.github.v3+json'}
response = requests.get(repo_url, headers=headers, timeout=20)
json_data = response.json()
files_to_download = ['D0421-S1-T01.pdf', 'D0420-S1-T02.pdf', 'D0420-S1-T04.pdf']
list_of_pdfs = [item for item in json_data if item['type'] == 'file' and item['name'] in files_to_download]
query_parameters = {"downloadformat": "pdf"}
transcripts = [pdf_dict['name'] for pdf_dict in list_of_pdfs]
for pdf_dict in list_of_pdfs:
    pdf_name = pdf_dict['name']
    file_url = pdf_dict['download_url']
    r = requests.get(file_url, params=query_parameters, timeout=20)
    with open(os.path.join(dataset_folder, pdf_name), 'wb') as pdf_file:
        pdf_file.write(r.content)
4. Metadata association
These files will need to be uploaded to an Amazon S3 bucket for processing. To use metadata filtering, we need to create a separate metadata JSON file for each transcript file. The metadata file should share the same name as the corresponding PDF file (including the extension). For instance, if the transcript file is named transcript_001.pdf, the metadata file should be named transcript_001.pdf.metadata.json. This nomenclature is crucial for the Knowledge Base to identify the metadata for specific files during the ingestion process.
The metadata JSON file will contain key-value pairs representing the relevant metadata fields associated with the transcript. In our healthcare provider use case, the most important metadata field is patient_id, which will be used to implement access control. We will assign each transcript to a specific patient by including their unique identifier from the Amazon Cognito User Pool in the patient_id field of the metadata file.
import os
import json
file_patient_mapping = {
    'D0421-S1-T01.pdf': patient_ids[0],
    'D0420-S1-T02.pdf': patient_ids[1],
    'D0420-S1-T04.pdf': patient_ids[2]
}
files = os.listdir(dataset_folder)
for file_name in files:
    if file_name in file_patient_mapping:
        patient_id = file_patient_mapping[file_name]
        metadata = json.dumps({"metadataAttributes": {"patient_id": patient_id}})
        with open(os.path.join(dataset_folder, f"{file_name}.metadata.json"), "w") as outfile:
            outfile.write(metadata)
    else:
        print(f"No patient ID assigned for {file_name}")
print("Done!")
5. Upload to Amazon S3
Knowledge Bases for Amazon Bedrock, currently require data to reside in an Amazon S3 bucket. We will upload both files and metadata files.
files = [f.name for f in os.scandir(abs_path) if f.is_file()]
for file in files:
    s3_client.upload_file(f'{abs_path}/{file}', s3_bucket, f'{file}')
6. Create a Knowledge Base for Amazon Bedrock
In this section we will go through all the steps to create and test a Knowledge Base.
indexName = "kb-acl-index-" + solution_id
print("Index name:",indexName)
updateDataAccessPolicy(solution_id) # Adding the current role to the collection's data access policy
time.sleep(60) # Changes to the data access policy might take a bit to update
createAOSSIndex(indexName, region, collection_id) # Create the AOSS index
Create the Knowledge Base
In this section you will create the Knowledge Base. Before creating a new KB we need to define which embeddings model we want it to use. In this case we will be using Amazon Titan Embeddings V2.
embeddingModelArn = "arn:aws:bedrock:{}::foundation-model/amazon.titan-embed-text-v2:0".format(region)
Now we can create our Knowledge Base for Amazon Bedrock. We have created an Amazon CloudFormation template which takes care of the configuration needed.
kb_id, datasource_id = create_kb_infrastructure(solution_id, s3_bucket, embeddingModelArn, indexName, region, account_id, collection_id)
Sync the Knowledge Base
As we have created and associated the data source to the Knowledge Base, we can proceed to Sync the data.
Each time you add, modify, or remove files from the S3 bucket for a data source, you must sync the data source so that it is re-indexed to the knowledge base. Syncing is incremental, so Amazon Bedrock only processes the objects in your S3 bucket that have been added, modified, or deleted since the last sync.
ingestion_job_response = bedrock_agent_client.start_ingestion_job(
    knowledgeBaseId=kb_id,
    dataSourceId=datasource_id,
    description='Initial Ingestion'
)
status = bedrock_agent_client.get_ingestion_job(
    knowledgeBaseId=ingestion_job_response["ingestionJob"]["knowledgeBaseId"],
    dataSourceId=ingestion_job_response["ingestionJob"]["dataSourceId"],
    ingestionJobId=ingestion_job_response["ingestionJob"]["ingestionJobId"]
)["ingestionJob"]["status"]
print(status)
while status not in ["COMPLETE", "FAILED", "STOPPED"]:
    status = bedrock_agent_client.get_ingestion_job(
        knowledgeBaseId=ingestion_job_response["ingestionJob"]["knowledgeBaseId"],
        dataSourceId=ingestion_job_response["ingestionJob"]["dataSourceId"],
        ingestionJobId=ingestion_job_response["ingestionJob"]["ingestionJobId"]
    )["ingestionJob"]["status"]
    print(status)
    time.sleep(30)
print("Waiting for changes to take place in the vector database")
time.sleep(30) # Wait for all changes to take place
Test the Knowledge Base
Now the Knowlegde Base is available we can test it out using the retrieve and retrieve_and_generate APIs.
Let's examine a test case with patient 0's transcript, where they mention a cat named Kelly. We'll query the knowledge base using the metadata filter for patient 0 to retrieve information about Kelly. Changing the patient_id will prevent the model from responding accurately. Read through the PDFs for other questions you might want to ask.
In this first example we are going to use the retrieve and generate API. This API queries a knowledge base and generates responses based on the retrieved results, using an LLM.
<h2>retrieve and generate API</h2>
response = bedrock_agent_runtime_client.retrieve_and_generate(
    input={
        "text": "Who is Kelly?"
    },
    retrieveAndGenerateConfiguration={
        "type": "KNOWLEDGE_BASE",
        "knowledgeBaseConfiguration": {
            'knowledgeBaseId': kb_id,
            "modelArn": "arn:aws:bedrock:{}::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0".format(region),
            "retrievalConfiguration": {
                "vectorSearchConfiguration": {
                    "numberOfResults":5,
                    "filter": {
                        "equals": {
                            "key": "patient_id",
                            "value": patient_ids[0]
                        }
                    }
                } 
            }
        }
    }
)
print(response['output']['text'],end='\n'*2)
In this second example we are going to use the retrieve API. This API queries the knowledge base and retrieves relavant information from it, it does not generate the response.
response_ret = bedrock_agent_runtime_client.retrieve(
    knowledgeBaseId=kb_id, 
    nextToken='string',
    retrievalConfiguration={
        "vectorSearchConfiguration": {
            "numberOfResults":3,
            "filter": {
                 "equals": {
                    "key": "patient_id",
                    "value": patient_ids[0]
                        }
                    }
                } 
            },
    retrievalQuery={
        'text': 'Who is Kelly?'
        }
)
def response_print(retrieve_resp):
#structure 'retrievalResults': list of contents
<h2>each list has content,location,score,metadata</h2>
    for num,chunk in enumerate(response_ret['retrievalResults'],1):
        print(f'Chunk {num}: ',chunk['content']['text'],end='\n'*2)
        print(f'Chunk {num} Location: ',chunk['location'],end='\n'*2)
        print(f'Chunk {num} Score: ',chunk['score'],end='\n'*2)
        print(f'Chunk {num} Metadata: ',chunk['metadata'],end='\n'*2)
response_print(response_ret)
7. Add Lambda Layer
At the time of developing this notebook, the latest Boto3 version available in Lambda with Python 3.12 does not include metadata filtering capabilities. To solve this, we will create and attach an AWS Lambda Layer with the latest Boto3 version.
For this section to run you will need the zip package to by installed at the system level.
You can check if zip is installed running the following command: !zip
If it is not installed you will need to install it using the appropriate package manager (apt-get for Debian-based systems or yum for RHEL-based systems for example).
!zip
#!sudo apt-get install zip -y # Debian-based systems 
#!sudo yum install zip -y # RHEL-based systems
!mkdir latest-sdk-layer
%cd latest-sdk-layer
!pip install -qU boto3 botocore -t python/lib/python3.12/site-packages/
!zip -rq latest-sdk-layer.zip .
%cd ..
def publish_lambda_layer(layer_name, description, zip_file_path, compatible_runtimes):
    with open(zip_file_path, 'rb') as f:
        response = lambda_client.publish_layer_version(
            LayerName=layer_name,
            Description=description,
            Content={
                'ZipFile': f.read(),
            },
            CompatibleRuntimes=compatible_runtimes
        )
    return response['LayerVersionArn']
layer_name = 'latest-sdk-layer'
description = 'Layer with the latest boto3 version.'
zip_file_path = 'latest-sdk-layer/latest-sdk-layer.zip'
compatible_runtimes = ['python3.12']
layer_version_arn = publish_lambda_layer(layer_name, description, zip_file_path, compatible_runtimes)
print("Layer version ARN:", layer_version_arn)
try:
    # Add the layer to the Lambda function
    lambda_client.update_function_configuration(
        FunctionName=lambda_function_arn,
        Layers=[layer_version_arn]
    )
    print("Layer added to the Lambda function successfully.")
except ClientError as e:
    print(f"Error adding layer to Lambda function: {e.response['Error']['Message']}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")
8. Create Streamlit Application
To showcase the interaction between doctors and the Knowledge Bases, we can develop a user-friendly web application using Streamlit for testing purposes, a popular open-source Python library for building interactive data apps. Streamlit provides a simple and intuitive way to create custom interfaces that can seamlessly integrate with the various AWS services involved in this solution.
Here is the application, don't modify the placeholders, we will replace them in the next cell.
%%writefile app.py
import os
import boto3
import json
import requests
import streamlit as st
from streamlit_cognito_auth import CognitoAuthenticator
pool_id = "<<replace_pool_id>>"
app_client_id = "<<replace_app_client_id>>"
app_client_secret = "<<replace_app_client_secret>>"
kb_id = "<<replace_kb_id>>"
lambda_function_arn = '<<replace_lambda_function_arn>>'
dynamo_table = '<<replace_dynamo_table_name>>'
authenticator = CognitoAuthenticator(
    pool_id=pool_id,
    app_client_id=app_client_id,
    app_client_secret= app_client_secret,
    use_cookies=False
)
is_logged_in = authenticator.login()
if not is_logged_in:
    st.stop()
def logout():
    authenticator.logout()
def get_user_sub(user_pool_id, username):
    cognito_client = boto3.client('cognito-idp')
    try:
        response = cognito_client.admin_get_user(
            UserPoolId=pool_id,
            Username=authenticator.get_username()
        )
        sub = None
        for attr in response['UserAttributes']:
            if attr['Name'] == 'sub':
                sub = attr['Value']
                break
        return sub
    except cognito_client.exceptions.UserNotFoundException:
        print("User not found.")
        return None
def get_patient_ids(doctor_id):
    dynamodb = boto3.client('dynamodb')
    response = dynamodb.query(
        TableName=dynamo_table,
        KeyConditionExpression='doctor_id = :doctor_id',
        ExpressionAttributeValues={
            ':doctor_id': {'S': doctor_id}
        }
    )
    print(response)
    patient_id_list = []  # Initialize the list
    for item in response['Items']:
        patient_ids = item.get('patient_id_list', {}).get('L', [])
        patient_id_list.extend([patient_id['S'] for patient_id in patient_ids])
    return patient_id_list
def search_transcript(doctor_id, kb_id, text, patient_ids):
    # Initialize the Lambda client
    lambda_client = boto3.client('lambda')
    # Payload for the Lambda function
    payload = json.dumps({
        "doctorId": sub,
        "knowledgeBaseId": kb_id,
        "text": text, 
        "patientIds": patient_ids
    }).encode('utf-8')
    try:
        # Invoke the Lambda function
        response = lambda_client.invoke(
            FunctionName=lambda_function_arn,
            InvocationType='RequestResponse',
            Payload=payload
        )
        # Process the response
        if response['StatusCode'] == 200:
            response_payload = json.loads(response['Payload'].read().decode('utf-8'))
            return response_payload
        else:
            # Handle error response
            return {'error': 'Failed to fetch data'}
    except Exception as e:
        # Handle exception
        return {'error': str(e)}
sub = get_user_sub(pool_id, authenticator.get_username())
print(sub)
patient_ids = get_patient_ids(sub)
print(patient_ids)
<h2>Application Front</h2>
with st.sidebar:
    st.header("User Information")
    st.markdown("## Doctor")
    st.text(authenticator.get_username())
    st.markdown("## Doctor Id")
    st.text(sub)
    selected_patient = st.selectbox("Select a patient (or 'All' for all patients)", ['All'] + patient_ids)
    st.button("Logout", "logout_btn", on_click=logout)
st.header("Transcript Search Tool")
<h2>Text input for the search query</h2>
query = st.text_input("Enter your search query:")
if st.button("Search"):
    if query:
        # Perform search
        patient_ids_filter = [selected_patient] if selected_patient != 'All' else patient_ids
        results = search_transcript(sub, kb_id, query, patient_ids_filter)
        print(results)
        if results:
            st.subheader("Search Results:")
            st.markdown(results["body"], unsafe_allow_html=True)
        else:
            st.write("No matching results found.")
    else:
        st.write("Please enter a search query.")
replace_vars("app.py", user_pool_id, client_id, client_secret, kb_id, lambda_function_arn, dynamo_table)
Execute the streamlit locally
Execute the cell below to run the Streamlit application.
Use the email and password of the doctors you defined at the top of the notebook to access the application.
Once you have logged in, you can filter by specific patients you have assigned (dropdown in the left panel), or all to query the knowledge base.
!streamlit run app.py
If you are executing this notebook on SageMaker Studio you can access the Streamlit application in the following url.
https://<<STUDIOID>>.studio.<<REGION>>.sagemaker.aws/jupyterlab/default/proxy/8501/
If you are executing this notebook on a SageMaker Notebook you can access the Streamlit application in the following url.
https://<<NOTEBOOKID>>.notebook.<<REGION>>.sagemaker.aws/proxy/8501/
9. Clean up
Before running this cell you will need to stop the cell above where the app is runnning!
Run the following cell to delete the created resources and avoid unnecesary costs. This should take about 2-3 minutes to complete.
<h2>Delete all objects in the bucket</h2>
try:
    response = s3_client.list_objects_v2(Bucket=s3_bucket)
    if 'Contents' in response:
        for obj in response['Contents']:
            s3_client.delete_object(Bucket=s3_bucket, Key=obj['Key'])
        print(f"All objects in {s3_bucket} have been deleted.")
except Exception as e:
    print(f"Error deleting objects from {s3_bucket}: {e}")
<h2>Define the stack names to delete</h2>
stack_names = ["KB-E2E-KB-{}".format(solution_id),"KB-E2E-Base-{}".format(solution_id)]
<h2>Iterate over the stack names and delete each stack</h2>
for stack_name in stack_names:
    try:
        # Retrieve the stack information
        stack_info = cloudformation.describe_stacks(StackName=stack_name)
        stack_status = stack_info['Stacks'][0]['StackStatus']
        # Check if the stack exists and is in a deletable state
        if stack_status != 'DELETE_COMPLETE':
            # Delete the stack
            cloudformation.delete_stack(StackName=stack_name)
            print(f'Deleting stack: {stack_name}')
            # Wait for the stack deletion to complete
            waiter = cloudformation.get_waiter('stack_delete_complete')
            waiter.wait(StackName=stack_name)
            print(f'Stack {stack_name} deleted successfully.')
        else:
            print(f'Stack {stack_name} does not exist or has already been deleted.')
    except cloudformation.exceptions.ClientError as e:
        print(f'Error deleting stack {stack_name}: {e.response["Error"]["Message"]}')