ident_ingest.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from langchain_community.llms import Ollama
  2. from langchain.prompts import ChatPromptTemplate
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_core.runnables import RunnablePassthrough
  5. from langchain_community.vectorstores import FAISS
  6. from langchain_community.embeddings import OllamaEmbeddings
  7. from langchain.text_splitter import CharacterTextSplitter
  8. from langchain.schema import Document
  9. # from langchain.chains.combine_documents import create_stuff_documents_chain
  10. # from langchain.chains import create_retrieval_chain
  11. from sqlalchemy import create_engine, MetaData, Table
  12. from sqlalchemy.orm import sessionmaker
  13. from dotenv import load_dotenv, find_dotenv
  14. import os
  15. MODEL = "mistral"
  16. class Ident:
  17. llm = None
  18. prompt = None
  19. retriever = None
  20. output_parser = None
  21. vectorstore = None
  22. chain = None
  23. engine = None
  24. def __init__(self):
  25. load_dotenv(find_dotenv())
  26. self.llm = Ollama(model=MODEL, temperature=1)
  27. self.prompt = ChatPromptTemplate.from_template(
  28. """
  29. You are a world class expert in vehicle appraisal.
  30. When answer to user:
  31. - If you don't know, just say that you don't know.
  32. - If you don't know when you are not sure, ask for clarification.
  33. Avoid mentioning that you obtained the information from the context.
  34. And answer according to the language of the user's question.
  35. Given the context information, answer the query.
  36. Query: {input}
  37. """
  38. )
  39. self.output_parser = StrOutputParser()
  40. self.chain = self.prompt | self.llm | self.output_parser
  41. db_host = os.getenv("DB_PRECIOS_HOST")
  42. db_user = os.getenv("DB_PRECIOS_USER")
  43. db_password = os.getenv("DB_PRECIOS_PASSWORD")
  44. db_schema = os.getenv("DB_PRECIOS_SCHEMA")
  45. self.engine = create_engine(f"mysql://{db_user}:{db_password}@{db_host}/{db_schema}")
  46. def ask(self, query: str):
  47. response = self.chain.invoke({
  48. "input": """From the following text: '[query]'
  49. Get year, brand, model
  50. Answer only with the format and nothing more:
  51. 'year: {year}, brand: {brand}, model: {model}'
  52. """.replace("[query]", query)
  53. })
  54. print(f"Response (1): {response}")
  55. documents = self.get_db_documents(response)
  56. text_splitter = CharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
  57. docs = text_splitter.split_documents(documents)
  58. print(documents)
  59. self.vectorstore = FAISS.from_documents(docs, embedding=OllamaEmbeddings())
  60. print(f"Response (2): {self.vectorstore.similarity_search_with_score(query)}")
  61. self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 1})
  62. self.llm = Ollama(model=MODEL)
  63. self.prompt = ChatPromptTemplate.from_template(
  64. """
  65. Con la siguiente lista de versiones de vehículos:
  66. {context}
  67. Entrega la lista ordena según similitud a la siguiente descripción: '{search}' y describe el criterio utlizado
  68. """
  69. )
  70. # document_chain = create_stuff_documents_chain(self.llm, self.prompt)
  71. # retrieval_chain = create_retrieval_chain(self.retriever, document_chain)
  72. # response = retrieval_chain.invoke({"search": query, "input": query})
  73. chain = (
  74. {"context": self.retriever, "search": RunnablePassthrough()}
  75. | self.prompt
  76. | self.llm
  77. | StrOutputParser()
  78. )
  79. response = chain.invoke(query)
  80. print(f"Response (3): {response}")
  81. document = self.retriever.get_relevant_documents(query)
  82. print(f"Retriever: {document}")
  83. return response + "\n\n\n" + document[0].page_content + ", model_id:" + str(document[0].metadata['model_id'])
  84. # return response["answer"]
  85. def get_db_documents(self, fields: str):
  86. fields = fields.upper().split(",")
  87. key_values = {}
  88. for field in fields:
  89. key, value = field.split(":")
  90. key_values[key.strip()] = value.strip()
  91. year = key_values.get("YEAR", 0)
  92. brand = key_values.get("BRAND", "")
  93. model = key_values.get("MODEL", "")
  94. print(f"Searching: '{year} {brand} {model}'")
  95. Session = sessionmaker(bind=self.engine)
  96. session = Session()
  97. connection = self.engine.connect()
  98. metadata = MetaData()
  99. # Carga la tabla desde la base de datos
  100. bm = Table('20171229_bm', metadata, autoload_with=self.engine)
  101. result = (session.query(bm)
  102. .where(bm.c.ano_auto == year)
  103. .where(bm.c.marca.like(f"%{brand}%"))
  104. .where(bm.c.modelo_comp.like(f"%{model}%"))
  105. .where(bm.c.eliminado == 0)
  106. .all())
  107. connection.close()
  108. documents = []
  109. for row in result:
  110. fuel = "BENCINA" if row.combustible == "BENC" else "DIESEL"
  111. page_content = f"{row.ano_auto} {row.marca} {row.modelo_comp} {row.version_comp} {fuel} {row.traccion} {row.tipo_carroceria}"
  112. documents.append(Document(page_content=page_content, metadata=dict(model_id=row.modelo_id)))
  113. return documents
  114. def clear(self):
  115. self.vectorstore = None
  116. self.retriever = None
  117. self.chain = None