How to create a RAG System
Recently, I have been learning a few things about LLMs, one of which is RAG applications. In this blog post, I will discuss some concepts about RAGs, how to implement them, and how to use Redis to cache the data.
What is a RAG?
RAG (Retrieval-Augmented Generation) is a technique to return a more accurate response to the user using an LLM. The RAG will understand the semantic meaning of the user's query and using an external knowledge base such as a vector database the LLM will generate an answer according to the information it has.
What are the benefits of a RAG?
We can combine the power of LLMs and an external knowledge base to create more accurate answers and reduce the number of hallucinations. In addition, a RAG will use semantic search that will allow us to compare unstructured data such as text, images or multimedia files, since they will all be represented as vectors. Also, the RAG approach is particularly effective for tasks that require a deep understanding of context and the ability to reference multiple sources of information.
How does a RAG work?
To understand the user's query it will embed the query using an embedding model such as mxbai-embed-large
. These models are trained with an amount of labeled data for a neural network so they will be able to understand the semantic meaning.
📎 The embedding model is basically a neural network with the last layer removed. Instead of getting a specific labeled value for an input, we get a vector embedding.
After doing the embeddings the system will search in its knowledge base (in this case a vector database) the results that are similar to the query. But, how does it do it?
Well, using vectors allows you to search for similar objects. A vector search algorithm only has to find two vectors that are close to each other in a vector database. Let's use coordinates as an example to understand this point If you want to go to the River Plate Monumental Stadium these are the coordinates {34° 32' 42.82“ S, 58° 26‘ 58.74” W} with this you can know if a new coordinate is near the stadium, for example, the River Plate Museum {34° 32' 47.40“ S, 58° 26’ 54.61” W} using the coordinates we can know that they are close to each other.
After doing this search, we will pass the information as a context to our LLM and it will generate the answer for the user.
context = f"Based on this info: {document[:100]}... Link: {link}"
prompt_for_ollama = context + f" Provide a brief response to: {prompt}
generation_response = ollama.generate(
model="llama3.2",
prompt=prompt_for_ollama,
options={"num_predict": 50},
)
To summarize
- Embed the user request (transformation into a vector based on semantic meaning).
- Search the storage using semantic search (e.g. vector database like chormaDB).
- Generate a response with a LLM.
Semantic Cache
There is a problem, and it is that the LLM's inference processes can be expensive. To solve this, we can use a caching technique called “Semantic Cache
.” We will store the meaning of a user's query and its answer so that if there is a new query with the same semantic meaning, we can answer it without making the LLM make a new inference. In addition, this will allow faster responses, obtaining two benefits: lower costs and reduced latency.
Redis and the Key-Value Mechanism
Redis is an in-memory database that allows us to store data as key-value
. What we will do is to store the embedding of the query and its response so that in future queries that are similar we do not need to generate a response with the LLM. Then what the system would do would be the following:
- Embedding of the user's query.
- Checks if it finds the cache for the query.
- If it finds the cache it returns the response, otherwise it searches the database and passes the data as context to the LLM to generate the response.
- Once generated, it saves the query and the answer in a cache.
Let's start with the APP
This will be a simple app where we will return a tech stack based on what we have in our database into which we will load this data.
To start we will create a python app and use FastApi to build our API, we will also need to have Ollama, redis and Chromadb.
First, install the necessary dependencies
pip install fastapi chromadb ollama pydantic redis
To use redis we will use docker use the following code and run docker-compose up
:
# docker-compose.yaml
version: "3.8"
services:
redis:
image: redis:latest # Use the latest official Redis image
container_name: redis-local
ports:
- "6379:6379" # Map port 6379 of the container to local port 6379
volumes:
- ./data:/data # Mount a volume for data persistence
command: ["redis-server", "--appendonly", "yes"] # Enables persistence on disk (optional)
Next, we create a function called lifespan
that will create the collection “tech_stacks”
that will be stored in the application state, then embed the data to insert it in the collection.
We will use a model called mxbai-embed-large to do the embeddings and add the results to the database.
import os
import ollama
import chromadb
from contextlib import asynccontextmanager
from info_list import info
@asynccontextmanager
async def lifespan(app):
client = chromadb.Client()
collection = client.create_collection("tech_stacks")
for i, item in enumerate(info):
description = item["description"]
links = " ".join(item["links"])
combined_text = f"{description} {links}"
response = ollama.embeddings(model="mxbai-embed-large", prompt=combined_text)
embedding = response["embedding"]
collection.add(
ids=[str(i)],
embeddings=[embedding],
documents=[combined_text],
metadatas=[{"description": description, "links": links}],
)
app.state.collection = collection
yield
In our main.py
file, we will use this function.
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(ollama_router, prefix="/api")
Cache response
We will create 2 functions. The function get_cached_embedding
will be in charge of searching the data stored in Redis and compare if any of them match with the embedding of the new query. Then cache_embedding
will be in charge of saving a new query in cache.
import redis
import numpy as np
import json
redis_client = redis.StrictRedis(host="localhost", port=6379, db=0)
def cosine_similarity(vecA, vecB):
dot_product = np.dot(vecA, vecB)
normA = np.linalg.norm(vecA)
normB = np.linalg.norm(vecB)
return dot_product / (normA * normB)
def get_cached_embedding(embedding_vector):
keys = redis_client.keys("embedding:*")
best_match = None
best_similarity = -1
for key in keys:
# Retrieve stored embedding
data = json.loads(redis_client.get(key))
stored_embedding = np.array(data["embedding"])
# Compare using cosine similarity
similarity = cosine_similarity(embedding_vector, stored_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_match = data
return best_match, best_similarity
def cache_embedding(prompt, embedding_vector, response):
# Store the new embedding and its response in Redis
key = f"embedding:{prompt}"
redis_data = {
"embedding": embedding_vector.tolist(),
"response": response,
}
redis_client.set(key, json.dumps(redis_data))
Creating the endpoint
Finally, we will create an endpoint that we will use to receive the user's query.
from fastapi import APIRouter, Request, HTTPException
from typing import Union
import ollama
import numpy as np
from redis_operations import get_cached_embedding, cache_embedding
ollama_router = APIRouter()
@ollama_router.get("/ollama")
async def ask_ollama(prompt: Union[str, None] = None, request: Request = None):
if prompt is None:
raise HTTPException(status_code=400, detail="Prompt not provided")
embed_response = ollama.embeddings(model="mxbai-embed-large", prompt=prompt)
embedding_vector = np.array(embed_response["embedding"])
collection = request.app.state.collection
results = collection.query(
query_embeddings=[embed_response["embedding"]],
n_results=3,
include=["metadatas", "documents"],
)
if not results["ids"] or not results["metadatas"] or not results["documents"]:
raise HTTPException(status_code=404, detail="No relevant results found")
best_match, best_similarity = get_cached_embedding(embedding_vector)
THRESHOLD = 0.7
if best_similarity > THRESHOLD:
return {"results": best_match["response"]}
responses = []
for metadata, document in zip(results["metadatas"][0], results["documents"][0]):
link = metadata["links"].split()[0]
context = f"Based on this info: {document[:100]}... Link: {link}"
prompt_for_ollama = context + f" Provide a brief response to: {prompt}"
generation_response = ollama.generate(
model="llama3.2",
prompt=prompt_for_ollama,
options={"num_predict": 50},
)
response_text = generation_response["response"].strip()
responses.append({"link": link, "short_response": response_text})
cache_embedding(prompt, embedding_vector, responses)
return {"results": responses}
What does this function have?
Initial Setup:
ollama_router = APIRouter()
This creates a FastAPI router (
ollama_router
) to define specific routes.The
ask_ollama
Function:@ollama_router.get("/ollama") async def ask_ollama(prompt: Union[str, None] = None, request: Request = None):
This function responds to GET requests at
/ollama
. It takes two parameters:prompt
: the input text provided by the user.request
: the request object, allowing access to FastAPIapp
attributes.
prompt
Validation:if prompt is None: raise HTTPException(status_code=400, detail="Prompt not provided")
If
prompt
is empty or not provided, the function returns a 400 error (Bad Request).Generate Prompt Embedding:
embed_response = ollama.embeddings(model="mxbai-embed-large", prompt=prompt) embedding_vector = np.array(embed_response["embedding"])
Using
ollama
, an embedding (numerical representation) of the prompt is created with themxbai-embed-large
model.Querying the Collection:
collection = request.app.state.collection results = collection.query( query_embeddings=[embed_response["embedding"]], n_results=3, include=["metadatas", "documents"], )
The collection, accessed from
request.app.state.collection
, is queried using the prompt’s embedding. The query retrieves the 3 most relevant documents, including their metadata and content.Results Verification:
if not results["ids"] or not results["metadatas"] or not results["documents"]: raise HTTPException(status_code=404, detail="No relevant results found")
If no relevant results are found, the function returns a 404 error.
Embedding Cache Check:
best_match, best_similarity = get_cached_embedding(embedding_vector) THRESHOLD = 0.7 if best_similarity > THRESHOLD: return {"results": best_match["response"]}
This part checks if a similar embedding exists in cache. If a match has a similarity above the threshold of 0.7, it returns the best cached result.
Generating and Caching Responses:
responses = [] for metadata, document in zip(results["metadatas"][0], results["documents"][0]): link = metadata["links"].split()[0] context = f"Based on this info: {document[:100]}... Link: {link}" prompt_for_ollama = context + f" Provide a brief response to: {prompt}" generation_response = ollama.generate( model="llama3.2", prompt=prompt_for_ollama, options={"num_predict": 50}, ) response_text = generation_response["response"].strip() responses.append({"link": link, "short_response": response_text}) cache_embedding(prompt, embedding_vector, responses)
For each document, it generates a response using the
llama3.2
model, creating a brief context from the document and adding thelink
. The generated response is stored inresponses
and cached with the embedding.Returning Results:
return {"results": responses}
Finally, the function returns the generated responses.
Demo
Finally, we will connect this API to an app in this case I create a simple search bar with Next.js.
You may notice that the first query is not cached so it takes longer but once it is done when you make the second query which is similar, the response is faster because it is cached.
Final
Thank you for reading I hope it has been useful. If you want to see the code here is the API code and here is the web code.