import os from dotenv import load_dotenv from langchain_community.document_loaders import DirectoryLoader, TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_community.vectorstores import Chroma from langchain.chains import RetrievalQA
load_dotenv()
class RAGSystem: def __init__(self, data_dir, persist_dir="./chroma_db"): self.data_dir = data_dir self.persist_dir = persist_dir self.embeddings = OpenAIEmbeddings() self.llm = ChatOpenAI(model="gpt-4o", temperature=0) self.vectorstore = None self.qa_chain = None
def load_documents(self): """문서 로드""" loader = DirectoryLoader( self.data_dir, glob="**/*.txt", loader_cls=TextLoader ) documents = loader.load() print(f"로드된 문서 수: {len(documents)}") return documents
def split_documents(self, documents): """문서 분할""" text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200 ) chunks = text_splitter.split_documents(documents) print(f"생성된 청크 수: {len(chunks)}") return chunks
def create_vectorstore(self, chunks): """벡터 스토어 생성""" self.vectorstore = Chroma.from_documents( documents=chunks, embedding=self.embeddings, persist_directory=self.persist_dir ) self.vectorstore.persist() print("벡터 스토어 생성 완료")
def load_vectorstore(self): """기존 벡터 스토어 로드""" self.vectorstore = Chroma( persist_directory=self.persist_dir, embedding_function=self.embeddings ) print("벡터 스토어 로드 완료")
def setup_qa_chain(self): """QA 체인 설정""" retriever = self.vectorstore.as_retriever( search_kwargs={"k": 4} ) self.qa_chain = RetrievalQA.from_chain_type( llm=self.llm, chain_type="stuff", retriever=retriever, return_source_documents=True ) print("QA 체인 설정 완료")
def query(self, question): """질문하기""" result = self.qa_chain({"query": question}) return result
def build(self): """전체 시스템 구축""" documents = self.load_documents() chunks = self.split_documents(documents) self.create_vectorstore(chunks) self.setup_qa_chain() print("RAG 시스템 구축 완료!")
if __name__ == "__main__": rag = RAGSystem(data_dir="./data") rag.build()
result = rag.query("RAG 시스템의 구성 요소는 무엇인가요?") print("\n답변:", result["result"]) print("\n출처:") for i, doc in enumerate(result["source_documents"], 1): print(f"{i}. {doc.metadata.get('source', 'Unknown')}")
|