import os
from pathlib import Path
from typing import List
import bs4
import chromadb
from chromadb.api.models.Collection import Collection
from dotenv import load_dotenv
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from openai import OpenAI
# Load .env variables and ensure they are set correctly
env_path = Path(".env")
load_dotenv(dotenv_path=env_path)
MODEL = os.getenv("MODEL")
assert MODEL is not None, "MODEL missing from .env"
API_KEY = os.getenv("GALADRIEL_API_KEY")
assert API_KEY is not None, "GALADRIEL_API_KEY missing from .env"
API_URL = os.getenv("API_URL")
assert API_URL is not None, "API_URL missing from .env"
EMBEDDING_MODEL = "gte-large-en-v1.5"
# LLM prompt, where context will be added
PROMPT = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question}
Context: {context}
Answer:"""
QUERY = "What is task decomposition?"
def main():
client = OpenAI(
base_url=API_URL,
api_key=API_KEY,
)
print("Loading dataset...")
splits = _load_dataset()
print(f"Got dataset, with {len(splits)} documents.\n")
collection = _create_vector_db()
print("Creating embeddings...")
embeddings = client.embeddings.create(
model=EMBEDDING_MODEL,
input=[s.page_content for s in splits],
encoding_format="float"
)
print("Saving embeddings to the vector datastore...\n")
collection.add(
documents=[s.page_content for s in splits],
embeddings=[e.embedding for e in embeddings.data],
ids=[f"id{i}" for i, _ in enumerate(splits)],
)
print(f"Querying vector datastore with user query: {QUERY}\n")
# Create embeddings for the query
query_embeddings = client.embeddings.create(
model=EMBEDDING_MODEL,
input=["What are the approaches to Task Decomposition?"],
encoding_format="float"
)
# Query the vector datastore with the query embeddings
results = collection.query(
query_embeddings=query_embeddings.data[0].embedding,
n_results=2
)
# Format the query results
context = "\n".join([d for d in results["documents"][0]])
# Add the formatted query results to the prompt
formatted_prompt = PROMPT \
.replace("{question}", QUERY) \
.replace("{context}", context)
print("Calling LLM with formatted prompt:")
print(formatted_prompt)
print("\n")
response = client.chat.completions.create(
model=MODEL,
temperature=0,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": formatted_prompt},
],
)
print("==== Got LLM response: ====")
print(response.choices[0].message.content)
def _load_dataset() -> List[Document]:
loader = WebBaseLoader(
web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
bs_kwargs=dict(
parse_only=bs4.SoupStrainer(
class_=("post-content", "post-title", "post-header")
)
),
)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200)
return text_splitter.split_documents(docs)
def _create_vector_db() -> Collection:
chroma_client = chromadb.Client()
collection = chroma_client.create_collection(
name="my_collection")
return collection
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
if __name__ == '__main__':
main()