Introduction
In this project, I implemented a retrieval-augmented generation model to assist with PCI DSS compliance in Python.
The Payment Card Industry Data Security Standard outlines requirements for organizations that store payment card information. These standards contain hundreds of pages, making it difficult to determine relevant sections when addressing specific compliance questions. While querying a general large language model can result in vague responses, fine-tuned models can provide more precise answers, but businesses usually lack the properly formatted data required for training. If you read my Tax Law GPT project, one of the biggest challenges I had was converting nearly 10,000 sections of the Internal Revenue Code into trainable data.
This is where RAG comes in. Unlike fine-tuned models, RAG allows documents the model wasn’t originally trained on to be retrieved and incorporated into its output. Here’s a rundown of how it works:
- Documents are converted into numerical representations (embeddings) and stored in a vector database.
- When a user submits a query, it’s also converted into a vector.
- Since both the query and the documents exist in the same high-dimensional vector space, one can determine which documents are most relevant.
- These relevant documents are appended to the prompt before being fed into the LLM, enhancing the response.
This approach enables LLMs to answer questions about large, complex documents, like the PCI DSS, even if they weren’t trained on them.
Implementation in Python
import os
from dotenv import load_dotenv
import chromadb
from openai import OpenAI
from chromadb.utils import embedding_functions
# Load API keys and environment variables
load_dotenv()
OPENAI_KEY = os.getenv("OPENAI_KEY")
client = OpenAI(api_key=OPENAI_KEY)Import the required libraries and initialize the OpenAI API client.
# Initialize OpenAI embedding function
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=OPENAI_KEY, model_name="text-embedding-3-large")Use OpenAI’s text-embedding-3-large model for embeddings.
# Set up ChromaDB collection
chroma_client = chromadb.PersistentClient(path="chroma_persistent_storage")
collection = chroma_client.get_or_create_collection(
name="PCI_collection", embedding_function=openai_ef
)Create a ChromaDB collection as the vector database.
# Load text documents from a directory
def load_documents_from_directory(directory_path):
documents = []
for filename in os.listdir(directory_path):
if filename.endswith(".txt"):
with open(os.path.join(directory_path, filename), "r", encoding="utf-8") as file:
documents.append({"id": filename, "text": file.read()})
return documents
# Split text into overlapping chunks
def split_text(text, chunk_size=1000, chunk_overlap=20):
chunks, start = [], 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start = end - chunk_overlap
return chunks
# Load and preprocess documents
directory_path = "path/to/documents/folder"
documents = load_documents_from_directory(directory_path)
chunked_documents = []
for doc in documents:
chunks = split_text(doc["text"])
for i, chunk in enumerate(chunks):
chunked_documents.append({"id": f"{doc['id']}_chunk{i+1}", "text": chunk})
# Generate OpenAI embeddings
def get_openai_embedding(text):
response = client.embeddings.create(input=text, model="text-embedding-3-large")
return response.data[0].embedding
# Embed and store document chunks in ChromaDB
for doc in chunked_documents:
doc["embedding"] = get_openai_embedding(doc["text"])
collection.upsert(
ids=[doc["id"]], documents=[doc["text"]], embeddings=[doc["embedding"]]
)Load external documents, split them into chunks, convert them into vectors, and store them in the database. In this project, the documents include information on the rules and requirements of the PCI DSS.
# Query ChromaDB for relevant document chunks
def query_documents(question, n_results=2):
results = collection.query(query_texts=question, n_results=n_results)
return [doc for sublist in results["documents"] for doc in sublist]
# Generate response using retrieved context
def generate_response(question, relevant_chunks):
context = "\n\n".join(relevant_chunks)
prompt = (
"You are a question-answering assistant. Use only the provided context to generate a concise response in three sentences or fewer. If the answer is not in the context, respond with 'I don't know.'"
"\n\nContext:\n" + context + "\n\nQuestion:\n" + question
)
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": question},
],
)
return response.choices[0].message
# Example query and response
document_question = "What are network diagrams required to contain, according to PCI DSS Version 4.0.1?"
relevant_chunks = query_documents(document_question)
answer = generate_response(document_question, relevant_chunks)
print(answer)When a query is made, the LLM generates a response based on both the query and the retrieved documents.
We get the following output:
Network diagrams are required to show all connections between the Cardholder Data Environment (CDE) and other networks, including any wireless networks.
Interestingly, even without external documents, GPT-4o still produces fairly specific PCI DSS responses.