diff --git a/.github/workflows/sync-with-huggingface.yml b/.github/workflows/sync-with-huggingface.yml new file mode 100644 index 0000000..1fd91df --- /dev/null +++ b/.github/workflows/sync-with-huggingface.yml @@ -0,0 +1,50 @@ +name: Sync with Hugging Face + +on: + push: + branches: + - main + paths: + - .github/workflows/sync-with-huggingface.yml + - app/** + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Sync with Hugging Face + uses: nateraw/huggingface-sync-action@v0.0.5 + with: + # The github repo you are syncing from. Required. + github_repo_id: 'myscale/ChatData' + + # The Hugging Face repo id you want to sync to. (ex. 'username/reponame') + # A repo with this name will be created if it doesn't exist. Required. + huggingface_repo_id: 'myscale/ChatData' + + # Hugging Face token with write access. Required. + # Here, we provide a token that we called `HF_TOKEN` when we added the secret to our GitHub repo. + hf_token: ${{ secrets.HF_TOKEN }} + + # The type of repo you are syncing to: model, dataset, or space. + # Defaults to space. + repo_type: 'space' + + # If true and the Hugging Face repo doesn't already exist, it will be created + # as a private repo. + # + # Note: this param has no effect if the repo already exists. + private: false + + # If repo type is space, specify a space_sdk. One of: streamlit, gradio, or static + # + # This option is especially important if the repo has not been created yet. + # It won't really be used if the repo already exists. + space_sdk: 'streamlit' + + # If provided, subdirectory will determine which directory of the repo will be synced. + # By default, this action syncs the entire GitHub repo. + # + # An example using this option can be seen here: + # https://github.com/huggingface/fuego/blob/830ed98/.github/workflows/sync-with-huggingface.yml + subdirectory: app diff --git a/.streamlit/config.toml b/.streamlit/config.toml deleted file mode 100644 index 35fc398..0000000 --- a/.streamlit/config.toml +++ /dev/null @@ -1,6 +0,0 @@ -[theme] -primaryColor="#523EFD" -backgroundColor="#FFFFFF" -secondaryBackgroundColor="#D4CEFF" -textColor="#262730" -font="sans serif" \ No newline at end of file diff --git a/.streamlit/secrets.example.toml b/.streamlit/secrets.example.toml deleted file mode 100644 index 36f4b8a..0000000 --- a/.streamlit/secrets.example.toml +++ /dev/null @@ -1,6 +0,0 @@ -MYSCALE_HOST = "msc-1decbcc9.us-east-1.aws.staging.myscale.cloud" -MYSCALE_PORT = 443 -MYSCALE_USER = "chatdata" -MYSCALE_PASSWORD = "myscale_rocks" -OPENAI_API_BASE = "https://api.openai.com/v1" -OPENAI_API_KEY = "" diff --git a/app.py b/app/app.py similarity index 100% rename from app.py rename to app/app.py diff --git a/callbacks/arxiv_callbacks.py b/app/callbacks/arxiv_callbacks.py similarity index 100% rename from callbacks/arxiv_callbacks.py rename to app/callbacks/arxiv_callbacks.py diff --git a/chains/arxiv_chains.py b/app/chains/arxiv_chains.py similarity index 100% rename from chains/arxiv_chains.py rename to app/chains/arxiv_chains.py diff --git a/chat.py b/app/chat.py similarity index 100% rename from chat.py rename to app/chat.py diff --git a/login.py b/app/login.py similarity index 100% rename from login.py rename to app/login.py diff --git a/prompts/arxiv_prompt.py b/app/prompts/arxiv_prompt.py similarity index 100% rename from prompts/arxiv_prompt.py rename to app/prompts/arxiv_prompt.py diff --git a/requirements.txt b/app/requirements.txt similarity index 100% rename from requirements.txt rename to app/requirements.txt diff --git a/lib/helper.py b/lib/helper.py deleted file mode 100644 index 56ed1be..0000000 --- a/lib/helper.py +++ /dev/null @@ -1,559 +0,0 @@ - -import json -import time -import hashlib -from typing import Dict, Any, List, Tuple -import re -import pandas as pd -from os import environ -import streamlit as st -import datetime -from langchain.schema import BaseRetriever -from langchain.tools import Tool -from langchain.pydantic_v1 import BaseModel, Field - -from sqlalchemy import Column, Text, create_engine, MetaData -from langchain.agents import AgentExecutor -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker -from clickhouse_sqlalchemy import ( - Table, make_session, get_declarative_base, types, engines -) -from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain -from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever -from langchain.utilities.sql_database import SQLDatabase -from langchain.chains import LLMChain -from sqlalchemy import create_engine, MetaData -from langchain.prompts import PromptTemplate, ChatPromptTemplate, \ - SystemMessagePromptTemplate, HumanMessagePromptTemplate -from langchain.prompts.prompt import PromptTemplate -from langchain.chat_models import ChatOpenAI -from langchain.schema import BaseRetriever, Document -from langchain import OpenAI -from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName -from langchain.retrievers.self_query.base import SelfQueryRetriever -from langchain.retrievers.self_query.myscale import MyScaleTranslator -from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings -from langchain.vectorstores import MyScaleSettings -from chains.arxiv_chains import MyScaleWithoutMetadataJson -from langchain.prompts.prompt import PromptTemplate -from langchain.prompts.chat import MessagesPlaceholder -from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory -from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent -from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage,\ - SystemMessage, ChatMessage, ToolMessage -from langchain.memory import SQLChatMessageHistory -from langchain.memory.chat_message_histories.sql import \ - BaseMessageConverter, DefaultMessageConverter -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict -# from langchain.agents.agent_toolkits import create_retriever_tool -from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt -from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain -from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser -from .json_conv import CustomJSONEncoder - -environ['TOKENIZERS_PARALLELISM'] = 'true' -environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] - -# query_model_name = "gpt-3.5-turbo-instruct" -query_model_name = "gpt-3.5-turbo-instruct" -chat_model_name = "gpt-3.5-turbo-16k" - - -OPENAI_API_KEY = st.secrets['OPENAI_API_KEY'] -OPENAI_API_BASE = st.secrets['OPENAI_API_BASE'] -MYSCALE_USER = st.secrets['MYSCALE_USER'] -MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD'] -MYSCALE_HOST = st.secrets['MYSCALE_HOST'] -MYSCALE_PORT = st.secrets['MYSCALE_PORT'] -UNSTRUCTURED_API = st.secrets['UNSTRUCTURED_API'] - -COMBINE_PROMPT = ChatPromptTemplate.from_strings( - string_messages=[(SystemMessagePromptTemplate, combine_prompt_template), - (HumanMessagePromptTemplate, '{question}')]) -DEFAULT_SYSTEM_PROMPT = ( - "Do your best to answer the questions. " - "Feel free to use any tools available to look up " - "relevant information. Please keep all details in query " - "when calling search functions." -) - -def hint_arxiv(): - st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n" - "For example: \n\n" - "*If you want to search papers with complex filters*:\n\n" - "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n" - "*If you want to ask questions based on papers in database*:\n\n" - "- What is PageRank?\n" - "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n" - "- Introduce some applications of GANs published around 2019.\n" - "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n" - "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n" - "- Is it possible to synthesize room temperature super conductive material?") - - -def hint_sql_arxiv(): - st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') - st.markdown('''```sql -CREATE TABLE default.ChatArXiv ( - `abstract` String, - `id` String, - `vector` Array(Float32), - `metadata` Object('JSON'), - `pubdate` DateTime, - `title` String, - `categories` Array(String), - `authors` Array(String), - `comment` String, - `primary_category` String, - VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), - CONSTRAINT vec_len CHECK length(vector) = 768) -ENGINE = ReplacingMergeTree ORDER BY id -```''') - - -def hint_wiki(): - st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n" - "For example: \n\n" - "- Which company did Elon Musk found?\n" - "- What is Iron Gwazi?\n" - "- What is a Ring in mathematics?\n" - "- 苹果的发源地是那里?\n") - - -def hint_sql_wiki(): - st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') - st.markdown('''```sql -CREATE TABLE wiki.Wikipedia ( - `id` String, - `title` String, - `text` String, - `url` String, - `wiki_id` UInt64, - `views` Float32, - `paragraph_id` UInt64, - `langs` UInt32, - `emb` Array(Float32), - VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), - CONSTRAINT emb_len CHECK length(emb) = 768) -ENGINE = ReplacingMergeTree ORDER BY id -```''') - - -sel_map = { - 'Wikipedia': { - "database": "wiki", - "table": "Wikipedia", - "hint": hint_wiki, - "hint_sql": hint_sql_wiki, - "doc_prompt": PromptTemplate( - input_variables=["page_content", "url", "title", "ref_id", "views"], - template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"), - "metadata_cols": [ - AttributeInfo( - name="title", - description="title of the wikipedia page", - type="string", - ), - AttributeInfo( - name="text", - description="paragraph from this wiki page", - type="string", - ), - AttributeInfo( - name="views", - description="number of views", - type="float" - ), - ], - "must_have_cols": ['id', 'title', 'url', 'text', 'views'], - "vector_col": "emb", - "text_col": "text", - "metadata_col": "metadata", - "emb_model": lambda: SentenceTransformerEmbeddings( - model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',), - "tool_desc": ("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages"), - }, - 'ArXiv Papers': { - "database": "default", - "table": "ChatArXiv", - "hint": hint_arxiv, - "hint_sql": hint_sql_arxiv, - "doc_prompt": PromptTemplate( - input_variables=["page_content", "id", "title", "ref_id", - "authors", "pubdate", "categories"], - template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"), - "metadata_cols": [ - AttributeInfo( - name=VirtualColumnName(name="pubdate"), - description="The year the paper is published", - type="timestamp", - ), - AttributeInfo( - name="authors", - description="List of author names", - type="list[string]", - ), - AttributeInfo( - name="title", - description="Title of the paper", - type="string", - ), - AttributeInfo( - name="categories", - description="arxiv categories to this paper", - type="list[string]" - ), - AttributeInfo( - name="length(categories)", - description="length of arxiv categories to this paper", - type="int" - ), - ], - "must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'], - "vector_col": "vector", - "text_col": "abstract", - "metadata_col": "metadata", - "emb_model": lambda: HuggingFaceInstructEmbeddings( - model_name='hkunlp/instructor-xl', - embed_instruction="Represent the question for retrieving supporting scientific papers: "), - "tool_desc": ("search_among_scientific_papers", "Searches among scientific papers from ArXiv and returns research papers"), - } -} - -def build_embedding_model(_sel): - """Build embedding model - """ - with st.spinner("Loading Model..."): - embeddings = sel_map[_sel]["emb_model"]() - return embeddings - - -def build_chains_retrievers(_sel: str) -> Dict[str, Any]: - """build chains and retrievers - - :param _sel: selected knowledge base - :type _sel: str - :return: _description_ - :rtype: Dict[str, Any] - """ - metadata_field_info = sel_map[_sel]["metadata_cols"] - retriever = build_self_query(_sel) - chain = build_qa_chain(_sel, retriever, name="Self Query Retriever") - sql_retriever = build_vector_sql(_sel) - sql_chain = build_qa_chain(_sel, sql_retriever, name="Vector SQL") - - return { - "metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], - "retriever": retriever, - "chain": chain, - "sql_retriever": sql_retriever, - "sql_chain": sql_chain - } - -def build_self_query(_sel: str) -> SelfQueryRetriever: - """Build self querying retriever - - :param _sel: selected knowledge base - :type _sel: str - :return: retriever used by chains - :rtype: SelfQueryRetriever - """ - with st.spinner(f"Connecting DB for {_sel}..."): - myscale_connection = { - "host": MYSCALE_HOST, - "port": MYSCALE_PORT, - "username": MYSCALE_USER, - "password": MYSCALE_PASSWORD, - } - config = MyScaleSettings(**myscale_connection, - database=sel_map[_sel]["database"], - table=sel_map[_sel]["table"], - column_map={ - "id": "id", - "text": sel_map[_sel]["text_col"], - "vector": sel_map[_sel]["vector_col"], - "metadata": sel_map[_sel]["metadata_col"] - }) - doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config, - must_have_cols=sel_map[_sel]['must_have_cols']) - - with st.spinner(f"Building Self Query Retriever for {_sel}..."): - metadata_field_info = sel_map[_sel]["metadata_cols"] - retriever = SelfQueryRetriever.from_llm( - OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0), - doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info, - use_original_query=False, structured_query_translator=MyScaleTranslator()) - return retriever - -def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever: - """Build Vector SQL Database Retriever - - :param _sel: selected knowledge base - :type _sel: str - :return: retriever used by chains - :rtype: VectorSQLDatabaseChainRetriever - """ - with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'): - engine = create_engine( - f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https') - metadata = MetaData(bind=engine) - PROMPT = PromptTemplate( - input_variables=["input", "table_info", "top_k"], - template=_myscale_prompt, - ) - output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings( - model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"]) - sql_query_chain = VectorSQLDatabaseChain.from_llm( - llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0), - prompt=PROMPT, - top_k=10, - return_direct=True, - db=SQLDatabase(engine, None, metadata, max_string_length=1024), - sql_cmd_parser=output_parser, - native_format=True - ) - sql_retriever = VectorSQLDatabaseChainRetriever( - sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"]) - return sql_retriever - -def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str="Self-query") -> ArXivQAwithSourcesChain: - """_summary_ - - :param _sel: selected knowledge base - :type _sel: str - :param retriever: retriever used by chains - :type retriever: BaseRetriever - :param name: display name, defaults to "Self-query" - :type name: str, optional - :return: QA chain interacts with user - :rtype: ArXivQAwithSourcesChain - """ - with st.spinner(f'Building QA Chain with {name} for {_sel}...'): - chain = ArXivQAwithSourcesChain( - retriever=retriever, - combine_documents_chain=ArXivStuffDocumentChain( - llm_chain=LLMChain( - prompt=COMBINE_PROMPT, - llm=ChatOpenAI(model_name=chat_model_name, - openai_api_key=OPENAI_API_KEY, temperature=0.6), - ), - document_prompt=sel_map[_sel]["doc_prompt"], - document_variable_name="summaries", - - ), - return_source_documents=True, - max_tokens_limit=12000, - ) - return chain - -@st.cache_resource -def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]: - """build all resources - - :return: sel_map_obj - :rtype: Dict[str, Any] - """ - sel_map_obj = {} - embeddings = {} - for k in sel_map: - embeddings[k] = build_embedding_model(k) - st.session_state[f'emb_model_{k}'] = embeddings[k] - sel_map_obj[k] = build_chains_retrievers(k) - return sel_map_obj, embeddings - -def create_message_model(table_name, DynamicBase): # type: ignore - """ - Create a message model for a given table name. - - Args: - table_name: The name of the table to use. - DynamicBase: The base class to use for the model. - - Returns: - The model class. - - """ - - # Model decleared inside a function to have a dynamic table name - class Message(DynamicBase): - __tablename__ = table_name - id = Column(types.Float64) - session_id = Column(Text) - user_id = Column(Text) - msg_id = Column(Text, primary_key=True) - type = Column(Text) - addtionals = Column(Text) - message = Column(Text) - __table_args__ = ( - engines.ReplacingMergeTree( - partition_by='session_id', - order_by=('id', 'msg_id')), - {'comment': 'Store Chat History'} - ) - - return Message - -def _message_from_dict(message: dict) -> BaseMessage: - _type = message["type"] - if _type == "human": - return HumanMessage(**message["data"]) - elif _type == "ai": - return AIMessage(**message["data"]) - elif _type == "system": - return SystemMessage(**message["data"]) - elif _type == "chat": - return ChatMessage(**message["data"]) - elif _type == "function": - return FunctionMessage(**message["data"]) - elif _type == "tool": - return ToolMessage(**message["data"]) - elif _type == "AIMessageChunk": - message["data"]["type"] = "ai" - return AIMessage(**message["data"]) - else: - raise ValueError(f"Got unexpected message type: {_type}") - -class DefaultClickhouseMessageConverter(DefaultMessageConverter): - """The default message converter for SQLChatMessageHistory.""" - - def __init__(self, table_name: str): - self.model_class = create_message_model(table_name, declarative_base()) - - def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: - tstamp = time.time() - msg_id = hashlib.sha256(f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest() - user_id, _ = session_id.split("?") - return self.model_class( - id=tstamp, - msg_id=msg_id, - user_id=user_id, - session_id=session_id, - type=message.type, - addtionals=json.dumps(message.additional_kwargs), - message=json.dumps({ - "type": message.type, - "additional_kwargs": {"timestamp": tstamp}, - "data": message.dict()}) - ) - - def from_sql_model(self, sql_message: Any) -> BaseMessage: - msg_dump = json.loads(sql_message.message) - msg = _message_from_dict(msg_dump) - msg.additional_kwargs = msg_dump["additional_kwargs"] - return msg - - def get_sql_model_class(self) -> Any: - return self.model_class - - -def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs): - name = name.replace(" ", "_") - conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}' - chat_memory = SQLChatMessageHistory( - session_id, - connection_string=f'{conn_str}/chat?protocol=https', - custom_message_converter=DefaultClickhouseMessageConverter(name)) - memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory) - - _system_message = SystemMessage( - content=system_prompt - ) - prompt = OpenAIFunctionsAgent.create_prompt( - system_message=_system_message, - extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], - ) - agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) - return AgentExecutor( - agent=agent, - tools=tools, - memory=memory, - verbose=True, - return_intermediate_steps=True, - **kwargs - ) - -class RetrieverInput(BaseModel): - query: str = Field(description="query to look up in retriever") - -def create_retriever_tool( - retriever: BaseRetriever, name: str, description: str -) -> Tool: - """Create a tool to do retrieval of documents. - - Args: - retriever: The retriever to use for the retrieval - name: The name for the tool. This will be passed to the language model, - so should be unique and somewhat descriptive. - description: The description for the tool. This will be passed to the language - model, so should be descriptive. - - Returns: - Tool class to pass to an agent - """ - def wrap(func): - def wrapped_retrieve(*args, **kwargs): - docs: List[Document] = func(*args, **kwargs) - return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder) - return wrapped_retrieve - - return Tool( - name=name, - description=description, - func=wrap(retriever.get_relevant_documents), - coroutine=retriever.aget_relevant_documents, - args_schema=RetrieverInput, - ) - -@st.cache_resource -def build_tools(): - """build all resources - - :return: sel_map_obj - :rtype: Dict[str, Any] - """ - sel_map_obj = {} - for k in sel_map: - if f'emb_model_{k}' not in st.session_state: - st.session_state[f'emb_model_{k}'] = build_embedding_model(k) - if "sel_map_obj" not in st.session_state: - st.session_state["sel_map_obj"] = {} - if k not in st.session_state.sel_map_obj: - st.session_state["sel_map_obj"][k] = {} - if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]: - st.session_state.sel_map_obj[k].update(build_chains_retrievers(k)) - sel_map_obj.update({ - f"{k} + Self Querying": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],), - f"{k} + Vector SQL": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],), - }) - return sel_map_obj - -def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT): - chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature, - openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True, - ) - tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users - sel_tools = [tools[k] for k in tool_names] - agent = create_agent_executor( - "chat_memory", - session_id, - chat_llm, - tools=sel_tools, - system_prompt=system_prompt - ) - return agent - - -def display(dataframe, columns_=None, index=None): - if len(dataframe) > 0: - if index: - dataframe.set_index(index) - if columns_: - st.dataframe(dataframe[columns_]) - else: - st.dataframe(dataframe) - else: - st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True) \ No newline at end of file diff --git a/lib/json_conv.py b/lib/json_conv.py deleted file mode 100644 index 3d92479..0000000 --- a/lib/json_conv.py +++ /dev/null @@ -1,21 +0,0 @@ -import json -import datetime - -class CustomJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, datetime.datetime): - return datetime.datetime.isoformat(obj) - return json.JSONEncoder.default(self, obj) - -class CustomJSONDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) - - def object_hook(self, source): - for k, v in source.items(): - if isinstance(v, str): - try: - source[k] = datetime.datetime.fromisoformat(str(v)) - except: - pass - return source \ No newline at end of file diff --git a/lib/private_kb.py b/lib/private_kb.py deleted file mode 100644 index f2d0ed0..0000000 --- a/lib/private_kb.py +++ /dev/null @@ -1,212 +0,0 @@ -import pandas as pd -import hashlib -import requests -from typing import List, Optional -from datetime import datetime -from langchain.schema.embeddings import Embeddings -from streamlit.runtime.uploaded_file_manager import UploadedFile -from clickhouse_connect import get_client -from multiprocessing.pool import ThreadPool -from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings -from .helper import create_retriever_tool - -parser_url = "https://api.unstructured.io/general/v0/general" - - -def parse_files(api_key, user_id, files: List[UploadedFile]): - def parse_file(file: UploadedFile): - headers = { - "accept": "application/json", - "unstructured-api-key": api_key, - } - data = {"strategy": "auto", "ocr_languages": ["eng"]} - file_hash = hashlib.sha256(file.read()).hexdigest() - file_data = {"files": (file.name, file.getvalue(), file.type)} - response = requests.post( - parser_url, headers=headers, data=data, files=file_data - ) - json_response = response.json() - if response.status_code != 200: - raise ValueError(str(json_response)) - texts = [ - { - "text": t["text"], - "file_name": t["metadata"]["filename"], - "entity_id": hashlib.sha256( - (file_hash + t["text"]).encode() - ).hexdigest(), - "user_id": user_id, - "created_by": datetime.now(), - } - for t in json_response - if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 - ] - return texts - - with ThreadPool(8) as p: - rows = [] - for r in p.imap_unordered(parse_file, files): - rows.extend(r) - return rows - - -def extract_embedding(embeddings: Embeddings, texts): - if len(texts) > 0: - embs = embeddings.embed_documents([t["text"] for _, t in enumerate(texts)]) - for i, _ in enumerate(texts): - texts[i]["vector"] = embs[i] - return texts - raise ValueError("No texts extracted!") - - -class PrivateKnowledgeBase: - def __init__( - self, - host, - port, - username, - password, - embedding: Embeddings, - parser_api_key, - db="chat", - kb_table="private_kb", - tool_table="private_tool", - ) -> None: - super().__init__() - kb_schema_ = f""" - CREATE TABLE IF NOT EXISTS {db}.{kb_table}( - entity_id String, - file_name String, - text String, - user_id String, - created_by DateTime, - vector Array(Float32), - CONSTRAINT cons_vec_len CHECK length(vector) = 768, - VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') - ) ENGINE = ReplacingMergeTree ORDER BY entity_id - """ - tool_schema_ = f""" - CREATE TABLE IF NOT EXISTS {db}.{tool_table}( - tool_id String, - tool_name String, - file_names Array(String), - user_id String, - created_by DateTime, - tool_description String - ) ENGINE = ReplacingMergeTree ORDER BY tool_id - """ - self.kb_table = kb_table - self.tool_table = tool_table - config = MyScaleSettings( - host=host, - port=port, - username=username, - password=password, - database=db, - table=kb_table, - ) - client = get_client( - host=config.host, - port=config.port, - username=config.username, - password=config.password, - ) - client.command("SET allow_experimental_object_type=1") - client.command(kb_schema_) - client.command(tool_schema_) - self.parser_api_key = parser_api_key - self.vstore = MyScaleWithoutJSON( - embedding=embedding, - config=config, - must_have_cols=["file_name", "text", "created_by"], - ) - - def list_files(self, user_id, tool_name=None): - query = f""" - SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, - arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars - FROM {self.vstore.config.database}.{self.kb_table} - WHERE user_id = '{user_id}' GROUP BY file_name - """ - return [r for r in self.vstore.client.query(query).named_results()] - - def add_by_file( - self, user_id, files: List[UploadedFile], **kwargs - ): - data = parse_files(self.parser_api_key, user_id, files) - data = extract_embedding(self.vstore.embeddings, data) - self.vstore.client.insert_df( - self.kb_table, - pd.DataFrame(data), - database=self.vstore.config.database, - ) - - def clear(self, user_id): - self.vstore.client.command( - f"DELETE FROM {self.vstore.config.database}.{self.kb_table} " - f"WHERE user_id='{user_id}'" - ) - query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} - WHERE user_id = '{user_id}'""" - self.vstore.client.command(query) - - def create_tool( - self, user_id, tool_name, tool_description, files: Optional[List[str]] = None - ): - self.vstore.client.insert_df( - self.tool_table, - pd.DataFrame( - [ - { - "tool_id": hashlib.sha256( - (user_id + tool_name).encode("utf-8") - ).hexdigest(), - "tool_name": tool_name, - "file_names": files, - "user_id": user_id, - "created_by": datetime.now(), - "tool_description": tool_description, - } - ] - ), - database=self.vstore.config.database, - ) - - def list_tools(self, user_id, tool_name=None): - extended_where = f"AND tool_name = '{tool_name}'" if tool_name else "" - query = f""" - SELECT tool_name, tool_description, length(file_names) - FROM {self.vstore.config.database}.{self.tool_table} - WHERE user_id = '{user_id}' {extended_where} - """ - return [r for r in self.vstore.client.query(query).named_results()] - - def remove_tools(self, user_id, tool_names): - tool_names = ",".join([f"'{t}'" for t in tool_names]) - query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} - WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]""" - self.vstore.client.command(query) - - def as_tools(self, user_id, tool_name=None): - tools = self.list_tools(user_id=user_id, tool_name=tool_name) - retrievers = { - t["tool_name"]: create_retriever_tool( - self.vstore.as_retriever( - search_kwargs={ - "where_str": ( - f"user_id='{user_id}' " - f"""AND file_name IN ( - SELECT arrayJoin(file_names) FROM ( - SELECT file_names - FROM {self.vstore.config.database}.{self.tool_table} - WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}') - )""" - ) - }, - ), - name=t["tool_name"], - description=t["tool_description"], - ) - for t in tools - } - return retrievers diff --git a/lib/schemas.py b/lib/schemas.py deleted file mode 100644 index fefcd4c..0000000 --- a/lib/schemas.py +++ /dev/null @@ -1,52 +0,0 @@ -from sqlalchemy import Column, Text -from clickhouse_sqlalchemy import types, engines - - -def create_message_model(table_name, DynamicBase): # type: ignore - """ - Create a message model for a given table name. - - Args: - table_name: The name of the table to use. - DynamicBase: The base class to use for the model. - - Returns: - The model class. - - """ - - # Model decleared inside a function to have a dynamic table name - class Message(DynamicBase): - __tablename__ = table_name - id = Column(types.Float64) - session_id = Column(Text) - user_id = Column(Text) - msg_id = Column(Text, primary_key=True) - type = Column(Text) - addtionals = Column(Text) - message = Column(Text) - __table_args__ = ( - engines.ReplacingMergeTree( - partition_by='session_id', - order_by=('id', 'msg_id')), - {'comment': 'Store Chat History'} - ) - - return Message - - -def create_session_table(table_name, DynamicBase): # type: ignore - # Model decleared inside a function to have a dynamic table name - class Session(DynamicBase): - __tablename__ = table_name - user_id = Column(Text) - session_id = Column(Text, primary_key=True) - system_prompt = Column(Text) - create_by = Column(types.DateTime) - additionals = Column(Text) - __table_args__ = ( - engines.ReplacingMergeTree( - order_by=('session_id')), - {'comment': 'Store Session and Prompts'} - ) - return Session \ No newline at end of file diff --git a/lib/sessions.py b/lib/sessions.py deleted file mode 100644 index c888ee1..0000000 --- a/lib/sessions.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base -from langchain.schema import BaseChatMessageHistory -from datetime import datetime -from sqlalchemy import Column, Text, orm, create_engine -from clickhouse_sqlalchemy import types, engines -from .schemas import create_message_model, create_session_table - -def get_sessions(engine, model_class, user_id): - with orm.sessionmaker(engine)() as session: - result = ( - session.query(model_class) - .where( - model_class.session_id == user_id - ) - .order_by(model_class.create_by.desc()) - ) - return json.loads(result) - -class SessionManager: - def __init__(self, session_state, host, port, username, password, - db='chat', sess_table='sessions', msg_table='chat_memory') -> None: - conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https' - self.engine = create_engine(conn_str, echo=False) - self.sess_model_class = create_session_table(sess_table, declarative_base()) - self.sess_model_class.metadata.create_all(self.engine) - self.msg_model_class = create_message_model(msg_table, declarative_base()) - self.msg_model_class.metadata.create_all(self.engine) - self.Session = orm.sessionmaker(self.engine) - self.session_state = session_state - - def list_sessions(self, user_id): - with self.Session() as session: - result = ( - session.query(self.sess_model_class) - .where( - self.sess_model_class.user_id == user_id - ) - .order_by(self.sess_model_class.create_by.desc()) - ) - sessions = [] - for r in result: - sessions.append({ - "session_id": r.session_id.split("?")[-1], - "system_prompt": r.system_prompt, - }) - return sessions - - def modify_system_prompt(self, session_id, sys_prompt): - with self.Session() as session: - session.update(self.sess_model_class).where(self.sess_model_class==session_id).value(system_prompt=sys_prompt) - session.commit() - - def add_session(self, user_id, session_id, system_prompt, **kwargs): - with self.Session() as session: - elem = self.sess_model_class( - user_id=user_id, session_id=session_id, system_prompt=system_prompt, - create_by=datetime.now(), additionals=json.dumps(kwargs) - ) - session.add(elem) - session.commit() - - def remove_session(self, session_id): - with self.Session() as session: - session.query(self.sess_model_class).where(self.sess_model_class.session_id==session_id).delete() - # session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete() - if "agent" in self.session_state: - self.session_state.agent.memory.chat_memory.clear() - if "file_analyzer" in self.session_state: - self.session_state.file_analyzer.clear_files() - - \ No newline at end of file