123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- from langchain_community.llms import Ollama
- from langchain.prompts import ChatPromptTemplate
- from langchain_core.output_parsers import StrOutputParser
- from langchain_core.runnables import RunnablePassthrough
- from langchain_community.vectorstores import FAISS
- from langchain_community.embeddings import OllamaEmbeddings
- from langchain.text_splitter import CharacterTextSplitter
- from langchain.schema import Document
- # from langchain.chains.combine_documents import create_stuff_documents_chain
- # from langchain.chains import create_retrieval_chain
- from sqlalchemy import create_engine, MetaData, Table
- from sqlalchemy.orm import sessionmaker
- from dotenv import load_dotenv, find_dotenv
- import os
- MODEL = "mistral"
- class Ident:
- llm = None
- prompt = None
- retriever = None
- output_parser = None
- vectorstore = None
- chain = None
- engine = None
- def __init__(self):
- load_dotenv(find_dotenv())
- self.llm = Ollama(model=MODEL, temperature=1)
- self.prompt = ChatPromptTemplate.from_template(
- """
- You are a world class expert in vehicle appraisal.
- When answer to user:
- - If you don't know, just say that you don't know.
- - If you don't know when you are not sure, ask for clarification.
- Avoid mentioning that you obtained the information from the context.
- And answer according to the language of the user's question.
-
- Given the context information, answer the query.
-
- Query: {input}
- """
- )
- self.output_parser = StrOutputParser()
- self.chain = self.prompt | self.llm | self.output_parser
- db_host = os.getenv("DB_PRECIOS_HOST")
- db_user = os.getenv("DB_PRECIOS_USER")
- db_password = os.getenv("DB_PRECIOS_PASSWORD")
- db_schema = os.getenv("DB_PRECIOS_SCHEMA")
- self.engine = create_engine(f"mysql://{db_user}:{db_password}@{db_host}/{db_schema}")
- def ask(self, query: str):
- response = self.chain.invoke({
- "input": """From the following text: '[query]'
- Get year, brand, model
- Answer only with the format and nothing more:
- 'year: {year}, brand: {brand}, model: {model}'
- """.replace("[query]", query)
- })
- print(f"Response (1): {response}")
- documents = self.get_db_documents(response)
- text_splitter = CharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
- docs = text_splitter.split_documents(documents)
- print(documents)
- self.vectorstore = FAISS.from_documents(docs, embedding=OllamaEmbeddings())
- print(f"Response (2): {self.vectorstore.similarity_search_with_score(query)}")
- self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 1})
- self.llm = Ollama(model=MODEL)
- self.prompt = ChatPromptTemplate.from_template(
- """
- Con la siguiente lista de versiones de vehículos:
- {context}
-
- Entrega la lista ordena según similitud a la siguiente descripción: '{search}' y describe el criterio utlizado
- """
- )
- # document_chain = create_stuff_documents_chain(self.llm, self.prompt)
- # retrieval_chain = create_retrieval_chain(self.retriever, document_chain)
- # response = retrieval_chain.invoke({"search": query, "input": query})
- chain = (
- {"context": self.retriever, "search": RunnablePassthrough()}
- | self.prompt
- | self.llm
- | StrOutputParser()
- )
- response = chain.invoke(query)
- print(f"Response (3): {response}")
- document = self.retriever.get_relevant_documents(query)
- print(f"Retriever: {document}")
- return response + "\n\n\n" + document[0].page_content + ", model_id:" + str(document[0].metadata['model_id'])
- # return response["answer"]
- def get_db_documents(self, fields: str):
- fields = fields.upper().split(",")
- key_values = {}
- for field in fields:
- key, value = field.split(":")
- key_values[key.strip()] = value.strip()
- year = key_values.get("YEAR", 0)
- brand = key_values.get("BRAND", "")
- model = key_values.get("MODEL", "")
- print(f"Searching: '{year} {brand} {model}'")
- Session = sessionmaker(bind=self.engine)
- session = Session()
- connection = self.engine.connect()
- metadata = MetaData()
- # Carga la tabla desde la base de datos
- bm = Table('20171229_bm', metadata, autoload_with=self.engine)
- result = (session.query(bm)
- .where(bm.c.ano_auto == year)
- .where(bm.c.marca.like(f"%{brand}%"))
- .where(bm.c.modelo_comp.like(f"%{model}%"))
- .where(bm.c.eliminado == 0)
- .all())
- connection.close()
- documents = []
- for row in result:
- fuel = "BENCINA" if row.combustible == "BENC" else "DIESEL"
- page_content = f"{row.ano_auto} {row.marca} {row.modelo_comp} {row.version_comp} {fuel} {row.traccion} {row.tipo_carroceria}"
- documents.append(Document(page_content=page_content, metadata=dict(model_id=row.modelo_id)))
- return documents
- def clear(self):
- self.vectorstore = None
- self.retriever = None
- self.chain = None
|