Open WebUI Enterprise Search Pipe

(compatible with Retrieval Suite Version 1.14.5 onwards)

"""
title: Enterprise Search
author: RheinInsights
author_url: https://www.rheininsights.com
version: 1.0.1
required_open_webui_version: 0.3.30

"""

import asyncio
from typing import List, Union, AsyncGenerator, Iterator, Callable, Awaitable, Optional
from pydantic import BaseModel, Field
import requests
import time
import urllib3
from bs4 import BeautifulSoup

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


class Pipe:
    class Valves(BaseModel):
        RheininsightsUrl: str = Field(default="https://localhost")
        AuthenticationToken: str = Field(default="secret")
        QueryPipelineId: int = Field(default=0)

    def __init__(self):
        self.type = "manifold"
        self.id = "engine_search"
        self.name = "engines/"
        self.valves = self.Valves()

    def pipes(self) -> List[dict]:
        enabled_pipes = []
        enabled_pipes.append({"id": "enterprisesearch", "name": "Enterprise Search"})
        return enabled_pipes

    async def pipe(
        self,
        __metadata__: dict,
        body: dict,
        __user__: dict,
        __event_emitter__: Callable[[dict], Awaitable[dict]] = None,
        __event_call__: Callable[[dict], Awaitable[dict]] = None,
    ) -> AsyncGenerator[str, None]:

        output = ""

        user_input = self._extract_user_input(body)
        if not user_input:
            yield "No search query provided"

        occ = user_input.find("### Task:")

        doNotPostUpdates = False
        if occ >= 0:
            yield ""
        else:
            output = await self._search_knowledge(
                user_input,
                body.get("messages", []),
                __user__["email"],
                __event_emitter__,
                doNotPostUpdates,
            )

            await self.postFinalMessage(__event_emitter__, doNotPostUpdates)

            yield output

    def _extract_user_input(self, body: dict) -> str:
        messages = body.get("messages", [])
        if messages:
            last_message = messages[-1]
            if isinstance(last_message.get("content"), list):
                for item in last_message["content"]:
                    if item["type"] == "text":
                        return item["text"]
            else:
                return last_message.get("content", "")
        return ""

    async def _search_knowledge(
        self,
        query: str,
        context: object,
        mail: str,
        __event_emitter__: Callable[[dict], Awaitable[dict]],
        doNotPostUpdates: bool,
    ) -> Optional[str]:

        data = self.submitQuery(query, context, mail)

        if not data:
            return f"No results found for: {query}"

        if not data["threadId"]:
            return f"No results found for: {query}"

        result = await self.waitForFinalization(
            data["threadId"], mail, __event_emitter__, doNotPostUpdates
        )

        if not result:
            return f"No results found for: {query}"

        if result["isFailed"]:
            return f"Something went wrong, please try again or contact your administrator {query}"

        resultResponse = result["result"]

        return self.handleResults(resultResponse)

    async def waitForFinalization(
        self,
        threadId: str,
        mail: str,
        __event_emitter__: Callable[[dict], Awaitable[dict]],
        doNotPostUpdates: bool,
    ) -> any:

        deadline = time.time() + 600
        interval = 1
        lastUpdateAt = 0
        await self.emitEvent("Processing", __event_emitter__, doNotPostUpdates)

        while time.time() < deadline:

            await asyncio.sleep(interval)
            data = self.getStatus(threadId, mail)

            if not data:
                return

            if data["status"] == "UNKNOWN":
                print(f"Is unknown")
                return

            if data["status"] == "FINISHED":
                print(f"Is finished")
                return data

            if not data["message"]:
                continue

            msg = data["message"]

            if msg["lastUpdate"] == lastUpdateAt:
                continue

            await self.emitEvent(msg["message"], __event_emitter__, doNotPostUpdates)
            lastUpdateAt = msg["lastUpdate"]

    def submitQuery(self, query: str, context: object, mail: str) -> any:
        url = f"{self.valves.RheininsightsUrl}/api/v1/querypipelines/search/async?queryPipelineId={self.valves.QueryPipelineId}"
        headers = {"Authorization": self.valves.AuthenticationToken}

        queryObject = {
            "query": query,
            "userPrincipalName": mail,
            "context": context,
        }

        response = requests.post(url, headers=headers, json=queryObject, verify=False)

        response.raise_for_status()
        data = response.json()
        return data

    def getStatus(self, threadId: str, mail: str) -> any:
        url = (
            f"{self.valves.RheininsightsUrl}/api/v1/querypipelines/search/async/status"
        )
        headers = {"Authorization": self.valves.AuthenticationToken}

        queryObject = {"threadId": threadId, "userPrincipalName": mail}

        response = requests.post(url, headers=headers, json=queryObject, verify=False)

        response.raise_for_status()
        data = response.json()
        return data

    async def postFinalMessage(
        self,
        __event_emitter__: Callable[[dict], Awaitable[dict]],
        doNotPostUpdates: bool,
    ):

        if not __event_emitter__:
            return

        if doNotPostUpdates is True:
            return

        await __event_emitter__(
            {
                "type": "status",
                "data": {"description": "Completed successfully", "done": True},
            }
        )

    async def emitEvent(
        self,
        msg: str,
        __event_emitter__: Callable[[dict], Awaitable[dict]],
        doNotPostUpdates: bool,
    ):

        if not __event_emitter__:
            return

        if doNotPostUpdates is True:
            return

        print(f"emitting {msg}")

        await __event_emitter__(
            {
                "type": "status",
                "data": {
                    "status": "in_progress",
                    "description": msg,
                    "done": False,
                },
            }
        )

    def handleResults(self, data) -> str:
        try:
            formatted_results = ""
            for i, result in enumerate(data["results"]):

                url = result["url"]
                metadata = result["metadata"]
                answer = result["teaser"]
                title = metadata["title"]

                titleStr = "Untitled"
                if title:
                    titleStr = title[0]

                img = ""
                svgStr = metadata["iconSvg"]
                if svgStr:
                    img = "![Icon](data:image/svg+xml;base64," + svgStr[0] + ")"

                print(img)
                answer = self.replaceTokens(answer)
                titleStr = self.replaceTokens(titleStr)

                formatted_results += f"{img}[{titleStr}]({url})\n{answer}\n\n"

            return formatted_results
        except Exception as e:
            print(e)

    def replaceTokens(self, text: str) -> str:
        text = text.replace("<em>", "_")
        text = text.replace("</em>", "_")

        soup = BeautifulSoup(text)
        text = soup.get_text()

        return text