RAG with Structured and Unstructured Data

RAG with structured and unstructed data

This notebook demonstrates how to leverage multiple data sources, including structured data from a database and unstructed data (like pdf, txt, etc.), to answer questions using Retrieval Augmented Generation (RAG). Specifically, we'll show how to integrate a knowledge base and a database to retrieve relevant information and generate comprehensive natural language responses.

We'll set up a MultiRetrievalQAChain that can answer queries by retrieving information from an Amazon Bedrock knowledge base and a database (using Text-to-SQL as a retriever), and then generating responses using the Claude 3.0 Sonnet language model. The MultiRetrievalQAChain can intelligently determine the appropriate data source for a given question, fetch relevant information, maintain conversation context, and synthesize the retrieved data into a coherent natural language answer. Here's a diagram illustrating the workflow:

Multiple Retrievers in a QA Chain Image

Background

The MultiRetrievalQAChain is an advanced implementation of the Retrieval Augmented Generation (RAG) approach, which combines the strengths of retrieval-based and generation-based language models. By integrating multiple retrievers, each specialized in a different data source, the chain can leverage diverse information sources to generate comprehensive and accurate responses. In this notebook, we'll demonstrate how to use structured data from a database (retrieved using Text-to-SQL as a retriever) as well as a text-based knowledge base to power a RAG application.

Note: This notebook uses a custom module designed specifically for Amazon Athena, but you can easily adapt it for other databases like Amazon Redshift and Amazon RDS by using their respective data APIs.

Prerequisites

Note: This notebook assumes that you have 1. created a knowledge base for Amazon Bedrock using unstructred data 2. have data available for querying via SQL in Amazon Athena.

If you haven't met the prerequisite, please follow these steps:

  1. Create a knowledge base and ingest your documents by following this 01_create_ingest_documents_test_kb_multi_ds.ipynb.
  2. Note down the knowledge base ID, as you'll need it later in this notebook.
  3. If you need to use synthetic data for testing, refer to this link to get synthetic text data that you can use to create your knowledge base for Amazon bedrock.
  4. To create synthetic structured data, you can run 0-create-dummy-structured-data.ipynb notebook and then use 1_create_sql_dataset_optional.ipynb notebook to create a database and table in Amazon Athena.

Note: The custom_database_retriever.py file currently uses table schema for a retail order website generated by using 0-create-dummy-structured-data.ipynb notebook. If you choose to use a different dataset, please update the schema for tables and table information inside custom_database_retriever.py file.

Setup

Setp up the custom retriever for `Text-to-SQL`

The code for the custom module can be found in CustomDatabaseRetriever.py.

The provided code defines a retriever class called AmazonAthenaRetriever that retrieves relevant data from an Amazon Athena database using SQL queries generated by Amazon Bedrock. The retriever interacts with the Athena database through the AWS boto3 SDK, which allows running SQL queries on data stored in Amazon S3. It generates SQL queries based on natural language input, executes the queries on Athena, and returns the results as a list of documents formatted for a LangChain RetrievalQA chain.

import boto3
from custom_database_retriever import AmazonAthenaRetriever # import AmazonAthenaRetriever for Text-2-SQL

athena_client = boto3.client("athena")
bedrock_client = boto3.client('bedrock-runtime')
bedrock_agent_client = boto3.client("bedrock-agent-runtime")

Retrieve stored glue database information from 0_create_sql_dataset_optional.ipynb. You should comment this if you did not run 0_create_sql_dataset_optional.ipynb notebook to setup Amazon Athena database.

%store -r glue_database_name

Configure variables

<h2>define the model of your choice: defaults to Claude 3.0 Sonnet</h2>
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"

<h2>define Athena output location:</h2>
RESULT_OUTPUT_LOCATION = "s3://<Your-Bucket-Name>/"

<h2>define the database you are using in Amazon Athena:</h2>
<h2>database='<Your-database-name>'</h2>
database = 'default' if not glue_database_name else glue_database_name

<h2>define the knowledge base id that you have already prepared in Amazon Bedrock</h2>
kb_id = '<Your-Knowledgebase-Id>'

<h2>configure how many chunks do you want for model response generation. these chunks are retrieved from Knowledge base</h2>
numberOfResults = 3 # can be configured (its the number chunks in knowledge base)

Setup the custom retreiver AmazonAthenaRetriever that when invoked takes user input to write and execute a SQL query, and finally provide the data back as Langchain documents.

<h2>Configure SQL Retriever:</h2>
sql_retriever = AmazonAthenaRetriever(
    athena_client=athena_client,
    bedrock_client=bedrock_client,
    database=database,
    RESULT_OUTPUT_LOCATION=RESULT_OUTPUT_LOCATION,
    model_id=model_id
    )

Test the sql retriever

Note: The executed SQL is at the end of the Document inside the metadata. The last object's metadata is configured to always have the Execution ID (from Athena) and SQL query.

%%time
query = "Top 5 customers that spend most amount?"
response = sql_retriever.get_relevant_documents(query)

<h2>check response of SQL Retreiver</h2>
print(response)

Configure Knowledgebase Retriever

Now, lets use publicly available AmazonKnowledgeBasesRetriever from Langchain to use Knowledgebases for Amazon Bedrock. This implementation makes it easy to use your knowledge base when you are using langchain. The source for this module uses Retrieve API for Knowledgebase on Amazon bedrock using boto3 SDK.

<h2>Configure Knowledge Base Retriever:</h2>
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever

kb_retriever = AmazonKnowledgeBasesRetriever(
    client=bedrock_agent_client,
    knowledge_base_id=kb_id,
    retrieval_config={"vectorSearchConfiguration": 
                      {"numberOfResults": numberOfResults}}
    )

Test Knowledge base retriever

Note: You do have to use Langchain in order to use Knowledge base for Amazon Bedrock. This is one of the way you can use Knowledge Base when you are using langchain for your project/application.

%%time
query = "By what percentage did AWS revenue grow year-over-year in 2022?"
kb_retriever.get_relevant_documents(query)

Configure Multi QA Retriever Chain

The following code sets up a MultiRetrievalQAChain using the LangChain library. This chain can retrieve information from multiple data sources and employ a large language model (LLM) hosted on Amazon Bedrock to respond to user queries in a natural, contextual manner.

The MultiRetrievalQAChain defines two retrievers: one for a knowledge base and another for a database containing name and description to help the MultiRetrievalQAChain determine which retriever is most appropriate for answering a given query.

A default conversation chain manages the back-and-forth dialogue between the user and the system. It utilizes the LLM, a custom prompt, and a memory buffer to track the conversation's context.

The central component is the MultiRetrievalQAChain itself, which combines the knowledge base retriever, database retriever, and default conversation chain. Based on the user's query, this chain determines the appropriate retriever, retrieves relevant information from the corresponding data source, and then employs the LLM to generate a contextual response informed by the conversation history.

This sophisticated system can handle a diverse range of queries, from general questions to complex analytical database queries, providing natural language responses by synthesizing information from multiple sources.

<h2>from langchain.chains import RetrievalQA</h2>
from langchain_community.chat_models import BedrockChat
from langchain.chains.router.multi_retrieval_qa import MultiRetrievalQAChain
from langchain.chains import ConversationChain
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

boto3_bedrock = boto3.client('bedrock-runtime')
inference_params = {"max_tokens":4096, 
                    "temperature":0.01,
                    "top_k":250,
                    "top_p":0.01,
                    "stop_sequences": ["\n\nHuman"]
                   }
verbose = False

Let's define the two retrievers that will be used by the MultiRetrievalQAChain later on in this notebook. The retriever_infos component is crucial for correct intent classification as it contains the description of the retrievers. If for some reason you see an incorrect classification for your use case, this is probably the first place to begin your troubleshooting. You want the retriever description to be simple, concise, and clear.

Note: In this workshop, we are using the MultiRetrievalQAChain with only two retrievers but this chain can use more than two retrievers if needed.

<h2>Create retriever names and descriptions</h2>
retriever_names = ["kb_retriever", "sql_retriever"]

<h2>Create a list of retrievers</h2>
retrievers = [kb_retriever, sql_retriever]

<h2>Retriever Information used by Router to determine which retriever to use for the question:</h2>
retriever_infos = [
{
"name": "kb_retriever",
"description": 'Suitable for answering questions related to Amazon business, services, and latest launches based on shaeholder letter by CEO.',
"retriever": kb_retriever
},
{
"name": "sql_retriever",
"description": 'Designed for handling analytical queries and generating SQL code to retrieve and analyze data from databases about products purchased, payments made, refunds, customer reviews etc. This retriever is ideal for answering questions that require data retrieval, aggregation, filtering, or sorting based on specific criteria such as device/client status, usage statistics, counts, extremes (highest/lowest), and much more. It can return numerical or short string results or sets of relevant documents to show and answer users questions.',
"retriever": sql_retriever
}
]

Now, we will create a default chain that will be used by the MultiRetrievalQAChain later on in this notebook. Take a note at how you can customize this by using your own prompt. This is another area that you can optimize for if you are seeing any unexpected behavior specifically by the default chain.

<h2>define Bedrock Chat Model: Claude 3.0 is only supported in BedrockChat:</h2>
llm = BedrockChat(model_id = model_id,
                  model_kwargs=inference_params, 
                  streaming=True,
                  callbacks=[StreamingStdOutCallbackHandler()],
                  client = boto3_bedrock
                 )

<h2>Custom Prompt for the default chain:</h2>
default_chain_prompt = PromptTemplate(
    input_variables=["query", "history"],
    template="""You are a helpful assistant who answers user queries using the
    contexts provided. If the question cannot be answered using the information
    provided say "I don't know".
    {history}
    Question: {query}"""
)

<h2>Default Chain:</h2>
default_chain = ConversationChain(llm=llm, 
                                  verbose=verbose, 
                                  prompt=default_chain_prompt, 
                                  input_key='query',
                                  output_key='result')

Now, we will configure a memory that will be used by the MultiRetrievalQAChain later on in this notebook. This is used to provide historical context to MultiRetrievalQAChain.

<h2>add memory buffer:</h2>
memory = ConversationBufferMemory(memory_key="MultiRetrievers", 
                                  return_messages=False, 
                                  input_key="input",
                                  output_key="result", )

Finally, we will use all the resources created above to configure a MultiRetrievalQAChain. This will contains the retrievers and their information using retriever_infos, the default chain using default_chain, and memory buffer using memory. You can optionally choose to have a default retriever and a default prompt to further optimize the behavior of MultiRetrievalQAChain.

<h2>Create the multi-retriever chain</h2>
multi_retrieval_qa_chain = MultiRetrievalQAChain.from_retrievers(
    llm=llm,
    retriever_infos=retriever_infos,
    default_chain=default_chain,
    memory=memory,
    verbose=verbose
<h2>    default_retriever: optional</h2>
<h2>    default_prompt: optional # check below cell for more information on this</h2>
)

Before we move forward, it is important to understand how does Multiretriever QA chain decides which retriever to use to answer user's question?

It uses the prompt below along with the description of the retrievers i.e. retriever_infos list in above cell. So, if you are facing issues with incorrect routing, you may want to optimize and simplify the description you have for the retrievers.

<h2>Check the prompt used to route the question to relevant retriver:</h2>
from langchain.chains.router.multi_retrieval_prompt import (
    MULTI_RETRIEVAL_ROUTER_TEMPLATE,
)

print(MULTI_RETRIEVAL_ROUTER_TEMPLATE)

Test MultiRetrievalQAChain Chain

Using Knowledge Bases for Amazon Bedrock

Let's ask a question that we know can be answered by the associated knowledge base and see if the multiretrieval chain can route the question to the correct retriever.

%%time
<h2>Test the chain</h2>
query = "By what percentage did AWS revenue grow year-over-year in 2022?"
result = multi_retrieval_qa_chain({"input": query})

Using Custom SQL Retriever

You have explored Knowledgebase for Amazon bedrock quite a bit in this workshop, lets focus more on the concept of RAG here. We will ask questions that we know can be answered by using SQL but lets test if the multiretrieval QA chain can route the question to correct chain.

Notice how the total duration to complete the request is so quick and probably better than most solutions out there for Text-to-SQL.

%%time
query = "What is the total spending by all customers?"
result = multi_retrieval_qa_chain({"input": query})
<h2>sql_retriever.get_relevant_documents(query)</h2>
%%time
query = "What is the most ordered item?"
result = multi_retrieval_qa_chain({"input": query})
<h2>sql_retriever.get_relevant_documents(query)</h2>
%%time
query = "How many customers wrote a review about product?"
result = multi_retrieval_qa_chain({"input": query})
<h2>sql_retriever.get_relevant_documents(query)</h2>

Test default chain

Multiretrieval QA chains also have a default chain that uses LLM model's knowledge to answer question when the question cannot be determined to be answered by various retriever associated with the chain. You can also define your customized prompt to adjust the responses of the default chain.

Now, lets ask a question that will trigger default chain in our case.

%%time
<h2>This question is completely unrelated to any of the information provided in the retriever's configuration.</h2>
query = "What is going on in the world?"
result = multi_retrieval_qa_chain({"input": query})
<h2>sql_retriever.get_relevant_documents(query)</h2>

Finally, let's look at the current memory buffer

Here is complete context that the application currently has should a new question is asked

<h2>current memory buffer</h2>
multi_retrieval_qa_chain.memory.buffer

Cleanup

Please make sure to delete all the resources that were created as you will be incurred cost for storing documents in OSS index.

To delete Knowledge base resources, check clean up steps here

To delete the Glue database and table, run the cell below.

glue_client = boto3.client('glue')
glue_client.delete_database(Name=database)

Whats Next?

Now that you have a good understanding of custom retrievers, you may want to optimize the SQL prompts, Retriver Information and description, optimize default prompt to customize the model response to your business or project needes.

If you need even faster response, you can pre-prepare your data such that its more easily accessible and does not require join or complex conditions. You could also optimize the time delay to check and fetch the results after a query has completed successfully.