Cool poc
This commit is contained in:
86
script.py
Normal file
86
script.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import chromadb
|
||||
import requests
|
||||
import argparse
|
||||
|
||||
OLLAMA_URL = "http://ollama:11434/api"
|
||||
CHROMA_COLLECTION_NAME = "rag_documents"
|
||||
|
||||
# Connect to ChromaDB client
|
||||
client = chromadb.HttpClient(host="chroma", port=8000)
|
||||
collection = client.get_or_create_collection(name=CHROMA_COLLECTION_NAME)
|
||||
|
||||
|
||||
def get_embedding(text, embedding_model):
|
||||
"""Generate an embedding using Ollama."""
|
||||
response = requests.post(f"{OLLAMA_URL}/embeddings", json={"model": embedding_model, "prompt": text})
|
||||
if response.status_code == 200:
|
||||
return response.json().get("embedding")
|
||||
else:
|
||||
raise Exception(f"Failed to get embedding: {response.text}")
|
||||
|
||||
def seed_database(embedding_model):
|
||||
"""Seed the ChromaDB with example documents."""
|
||||
documents = [
|
||||
"Artificial Intelligence is transforming industries.",
|
||||
"Machine learning helps computers learn from data.",
|
||||
"Deep learning uses neural networks to process information.",
|
||||
"RAG enhances language models with retrieved knowledge."
|
||||
]
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
embedding = get_embedding(doc, embedding_model)
|
||||
collection.add(ids=[str(i)], embeddings=[embedding], metadatas=[{"text": doc}])
|
||||
print(f"Added document {i}: {doc}")
|
||||
|
||||
print("Database seeding complete.")
|
||||
|
||||
def search(query, embedding_model, llm_model, top_k=2):
|
||||
"""Retrieve similar documents and generate an answer using an LLM."""
|
||||
query_embedding = get_embedding(query, embedding_model)
|
||||
results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
|
||||
|
||||
retrieved_docs = [doc["text"] for doc_list in results["metadatas"] for doc in doc_list]
|
||||
|
||||
if not retrieved_docs:
|
||||
print("No relevant documents found.")
|
||||
return
|
||||
|
||||
# Construct the LLM prompt
|
||||
prompt = f"Use the following documents to answer the question:\n\n"
|
||||
for doc in retrieved_docs:
|
||||
prompt += f"- {doc}\n"
|
||||
prompt += f"\nQuestion: {query}\nAnswer:"
|
||||
|
||||
response = requests.post(f"{OLLAMA_URL}/generate", json={"model": llm_model, "prompt": prompt, "stream": False})
|
||||
|
||||
print("RAW RESPONSE:", response.text)
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
data = response.json()
|
||||
answer = data.get("response", "No response field found.")
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
answer = response.text # Fallback to raw text
|
||||
print("\nSearch Results:\n")
|
||||
for doc in retrieved_docs:
|
||||
print(f"Retrieved: {doc}")
|
||||
print("\nGenerated Answer:\n", answer)
|
||||
else:
|
||||
print(f"Failed to generate response: {response.status_code} - {response.text}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("command", choices=["seed", "search"], help="Command to run")
|
||||
parser.add_argument("--query", type=str, help="Query text for searching")
|
||||
parser.add_argument("--embedding_model", type=str, default="mxbai-embed-large", help="Embedding model")
|
||||
parser.add_argument("--llm_model", type=str, default="gemma3", help="LLM model for generating responses")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "seed":
|
||||
seed_database(args.embedding_model)
|
||||
elif args.command == "search":
|
||||
if not args.query:
|
||||
print("Please provide a query with --query")
|
||||
else:
|
||||
search(args.query, args.embedding_model, args.llm_model)
|
||||
Reference in New Issue
Block a user