ident.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from langchain_community.llms import Ollama
  2. from langchain.prompts import ChatPromptTemplate
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_core.runnables import RunnablePassthrough
  5. from langchain_community.vectorstores import Chroma
  6. from langchain_community.embeddings import OllamaEmbeddings
  7. from langchain.text_splitter import CharacterTextSplitter
  8. from langchain.schema import Document
  9. from langchain.chains.combine_documents import create_stuff_documents_chain
  10. from langchain.chains import create_retrieval_chain
  11. from sqlalchemy import create_engine, MetaData, Table
  12. from sqlalchemy.orm import sessionmaker
  13. from dotenv import load_dotenv, find_dotenv
  14. import os
  15. MODEL = "mistral"
  16. class Ident:
  17. llm = None
  18. prompt = None
  19. retriever = None
  20. output_parser = None
  21. vectorstore = None
  22. chain = None
  23. engine = None
  24. def __init__(self):
  25. load_dotenv(find_dotenv())
  26. self.llm = Ollama(model=MODEL, temperature=1)
  27. self.prompt = ChatPromptTemplate.from_template(
  28. """
  29. You are a world class expert vehicle appraiser
  30. When answer to user:
  31. - If you don't know, just say that you don't know.
  32. - If you don't know when you are not sure, ask for clarification.
  33. Avoid mentioning that you obtained the information from the context.
  34. And answer according to the language of the user's question.
  35. Given the context information, answer the query.
  36. Query: {input}
  37. """
  38. )
  39. self.output_parser = StrOutputParser()
  40. self.chain = self.prompt | self.llm | self.output_parser
  41. db_host = os.getenv("DB_PRECIOS_HOST")
  42. db_user = os.getenv("DB_PRECIOS_USER")
  43. db_password = os.getenv("DB_PRECIOS_PASSWORD")
  44. db_schema = os.getenv("DB_PRECIOS_SCHEMA")
  45. self.engine = create_engine(f"mysql://{db_user}:{db_password}@{db_host}/{db_schema}")
  46. def ask(self, query: str):
  47. identified_fields = self.chain.invoke({
  48. "input": """From the following text: '[query]'
  49. Try get:
  50. - year
  51. - make/brand
  52. - model
  53. - transmission (mt, manual, at, automatic, cvt)
  54. - variant
  55. - engine size (in CC or LT. Ex. 1.2 or 2.0)
  56. - fuel (bencina, diesel, hybrid)
  57. - power train (4x2, 4x4, AWD, FWD)
  58. - doors no.
  59. - price
  60. Answer only with the format and nothing more:
  61. 'year: {year}, brand: {brand}, model: {model}, transmission: {transmission}, variant: {variant}, submodel: {submodel}, engine size: {engine}, fuel: {fuel}, power train: {traction}, doors: {doors}, price: {price}'
  62. If one field is not to be found assing 'N/A' as default value
  63. """.replace("[query]", query)
  64. })
  65. print(f"Response (1): {identified_fields}")
  66. documents = self.get_db_documents(identified_fields)
  67. identified_fields = identified_fields.replace(", ", "\n")
  68. print(f"New query: {identified_fields}")
  69. text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)
  70. docs = text_splitter.split_documents(documents)
  71. # print(documents)
  72. self.vectorstore = Chroma.from_documents(docs, embedding=OllamaEmbeddings(model="nomic-embed-text"))
  73. # print(f"Response (2): {self.vectorstore.similarity_search_with_score(identified_fields)}")
  74. self.retriever = self.vectorstore.as_retriever()
  75. self.prompt = ChatPromptTemplate.from_template(
  76. """
  77. You are a world class expert vehicle appraiser
  78. Consider the following vehicle versions:
  79. {context}
  80. Utiliza los siguientes criterios de comparación (ordenandos por relevancia):
  81. - Year
  82. - Brand
  83. - Model
  84. - Transmission
  85. - Variant
  86. - Engine size
  87. - Fuel
  88. - Power train
  89. - Body type
  90. - Doors
  91. - Price (range 10%)
  92. Si un criterio no puede ser evaluado continuar con el siguiente
  93. Compara la siguiente descripción: '{search}'
  94. Omite características como ubicación y kilometraje
  95. Responde con las características del mejor match
  96. """
  97. )
  98. # document_chain = create_stuff_documents_chain(self.llm, self.prompt)
  99. # retrieval_chain = create_retrieval_chain(self.retriever, document_chain)
  100. # response = retrieval_chain.invoke({"search": query, "input": query})
  101. """
  102. self.chain = (
  103. {"context": self.retriever | self.format_docs, "search": RunnablePassthrough()}
  104. | self.prompt
  105. | self.llm
  106. | StrOutputParser()
  107. )
  108. response = self.chain.invoke(identified_fields)
  109. """
  110. # print(f"Response (3): {response}")
  111. relevant_documents = self.retriever.get_relevant_documents(identified_fields)
  112. print(f"Retriever: {relevant_documents}")
  113. # return response
  114. return relevant_documents[0].page_content + "\nmodel_id: " + str(relevant_documents[0].metadata['modelo_id'])
  115. def get_db_documents(self, fields: str):
  116. fields = fields.upper().split(",")
  117. key_values = {}
  118. for field in fields:
  119. try:
  120. key, value = field.split(":")
  121. except Exception as e:
  122. print(f"Error: {e}")
  123. key = ''
  124. value = ''
  125. key_values[key.strip()] = value.strip()
  126. year = key_values.get("YEAR", 0)
  127. brand = key_values.get("BRAND", "")
  128. model = key_values.get("MODEL", "")
  129. print(f"Searching: '{year} {brand} {model}'")
  130. Session = sessionmaker(bind=self.engine)
  131. session = Session()
  132. connection = self.engine.connect()
  133. metadata = MetaData()
  134. # Carga la tabla desde la base de datos
  135. bm = Table('20171229_bm', metadata, autoload_with=self.engine)
  136. result = (session.query(bm)
  137. .where(bm.c.ano_auto == year)
  138. .where(bm.c.marca.like(f"%{brand}%"))
  139. .where(bm.c.modelo_comp.like(f"{model}%"))
  140. .where(bm.c.eliminado == 0)
  141. .all())
  142. connection.close()
  143. documents = []
  144. for row in result:
  145. if row.combustible == "BENC":
  146. fuel = "BENCINA"
  147. elif row.combustible == "DIES":
  148. fuel = "DIESEL"
  149. elif row.combustible == "HIB":
  150. fuel = "HIBRIDO"
  151. elif row.combustible == "ELEC":
  152. fuel = "ELECTRICO"
  153. page_content = f"""
  154. year: {row.ano_auto}
  155. brand: {row.marca}
  156. model: {row.modelo}
  157. transmission: {row.transmision}
  158. variant: {row.version_m}
  159. engine size: {row.motor} | {row.cilindrada}
  160. fuel: {fuel}
  161. power train: {row.traccion}
  162. body type: {row.tipo_carroceria}
  163. doors: {row.puertas}
  164. price: {row.tasacion}
  165. """
  166. documents.append(Document(page_content=page_content, metadata=dict(modelo_id=row.modelo_id)))
  167. return documents
  168. def format_docs(self, docs):
  169. return "\n".join(doc.page_content for doc in docs)
  170. def clear(self):
  171. self.vectorstore = None
  172. self.retriever = None
  173. self.chain = None