rasa会话数据存储

原文链接

背景

rasa 会话日志默认存储在内存中,但是如果想要支持多个实例机器人,则需要将rasa会话数据存储在数据库中,,而官方文档中支持的的存储方式为单实例的redis, 无法支持redis集群。

因此这里给出了使用redis集群进行会话日志存储的方法。此处连接的是redis哨兵集群。根据示例可扩展连接redis集群,修改相应代码部分即可。
此代码支持rasa3版本

一. 修改endpoints.yml配置文件

lock_store:
    type: 'sentinel_lock_store.RedisSentinelLockStore'
    master: mymaster
    host: 191.161.6.191 #示例127.0.0.1
    port1: 6379
    port2: 6380
    port3: 6381
    db: 0
    password: password  # 输入密码
    key_prefix: rasa
    socket_timeout: 0.5
    #设置为0.5秒, 所以阻塞0.5秒后会触发超时异常
    
tracker_store:
    type: 'sentinel_tracker_store.RedisSentinelTrackerStore'
    url: 191.161.6.191 #示例127.0.0.1
    master: mymaster
    port1: 6379
    port2: 6380
    port3: 6381
    db: 5
    password: password  # 输入密码
    key_prefix: rasa
    socket_timeout: 5
    record_exp: 300 #以秒为单位记录过期时间

二. 增加类文件

可以看到endpoints.yml配置文件中,增加了自定义的type, 因此需要增加类文件。

在endpoints.yml的相同目录下:增加三个python文件,文件内容如下:

# sentinel_lock_store.py 对应sentinel_lock_store.RedisSentinelLockStore
import asyncio
import json
import logging
import os

from async_generator import asynccontextmanager
from typing import Text, Union, Optional, AsyncGenerator

from rasa.shared.exceptions import RasaException, ConnectionException
import rasa.shared.utils.common
from rasa.core.constants import DEFAULT_LOCK_LIFETIME
from rasa.core.lock import TicketLock
from rasa.utils.endpoints import EndpointConfig

logger = logging.getLogger(__name__)

from redis.sentinel import Sentinel

def _get_lock_lifetime() -> int:
    return int(os.environ.get("TICKET_LOCK_LIFETIME", 0)) or DEFAULT_LOCK_LIFETIME


LOCK_LIFETIME = _get_lock_lifetime()
DEFAULT_SOCKET_TIMEOUT_IN_SECONDS = 10

DEFAULT_REDIS_LOCK_STORE_KEY_PREFIX = "lock:"
import redis
from redis.sentinel import Sentinel
from rasa.core.lock_store import LockStore
from redis_sentinel_helper import redisSentinelHelper


class RedisSentinelLockStore(LockStore):
    """Redis store for ticket locks."""
    def __init__(
        self,
        endpoint_config: EndpointConfig
    ) -> None:
        
        self.red = redisSentinelHelper(endpoint_config.kwargs)
        key_prefix=endpoint_config.kwargs['key_prefix']
        self.key_prefix = DEFAULT_REDIS_LOCK_STORE_KEY_PREFIX
        if key_prefix:
            logger.debug(f"Setting non-default redis key prefix: '{key_prefix}'.")
            self._set_key_prefix(key_prefix)
        super().__init__()

    def _set_key_prefix(self, key_prefix: Text) -> None:
        if isinstance(key_prefix, str) and key_prefix.isalnum():
            self.key_prefix = key_prefix + ":" + DEFAULT_REDIS_LOCK_STORE_KEY_PREFIX
        else:
            logger.warning(
                f"Omitting provided non-alphanumeric redis key prefix: '{key_prefix}'. "
                f"Using default '{self.key_prefix}' instead."
            )

    def get_lock(self, conversation_id: Text) -> Optional[TicketLock]:
        """Retrieves lock (see parent docstring for more information)."""
        serialised_lock = self.red.get_key(self.key_prefix + conversation_id)
        if serialised_lock:
            return TicketLock.from_dict(json.loads(serialised_lock))
        return None

    def delete_lock(self, conversation_id: Text) -> None:
        """Deletes lock for conversation ID."""
        deletion_successful = self.red.delete_key(self.key_prefix + conversation_id)
        self._log_deletion(conversation_id, deletion_successful)

    def save_lock(self, lock: TicketLock) -> None:
        self.red.set_key(self.key_prefix + lock.conversation_id, lock.dumps())
# sentinel_tracker_store.py 对应type  'sentinel_tracker_store.RedisSentinelTrackerStore'
import contextlib
import itertools
import json
import logging
import os

from time import sleep
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    Text,
    Union,
    TYPE_CHECKING,
    Generator,
)

from boto3.dynamodb.conditions import Key
from pymongo.collection import Collection

import rasa.core.utils as core_utils
import rasa.shared.utils.cli
import rasa.shared.utils.common
import rasa.shared.utils.io
from rasa.shared.core.constants import ACTION_LISTEN_NAME
from rasa.core.brokers.broker import EventBroker
from rasa.core.constants import (
    POSTGRESQL_SCHEMA,
    POSTGRESQL_MAX_OVERFLOW,
    POSTGRESQL_POOL_SIZE,
)
from rasa.shared.core.conversation import Dialogue
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import SessionStarted
from rasa.shared.core.trackers import (
    ActionExecuted,
    DialogueStateTracker,
    EventVerbosity,
)
from rasa.shared.exceptions import ConnectionException, RasaException
from rasa.shared.nlu.constants import INTENT_NAME_KEY
from rasa.utils.endpoints import EndpointConfig
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta

if TYPE_CHECKING:
    import boto3.resources.factory.dynamodb.Table
    from sqlalchemy.engine.url import URL
    from sqlalchemy.engine.base import Engine
    from sqlalchemy.orm import Session, Query
    from sqlalchemy import Sequence

logger = logging.getLogger(__name__)

DEFAULT_REDIS_TRACKER_STORE_KEY_PREFIX = "tracker:"
from rasa.core.tracker_store import TrackerStore
from redis.sentinel import Sentinel
from redis_sentinel_helper import redisSentinelHelper

class RedisSentinelTrackerStore(TrackerStore):
    """Stores conversation history in RedisSentinel"""
    def __init__(
        self,
        domain: Domain,
        host: Text = "localhost",
        master: Text= "mymaster",
        port1: int = 6379,
        port2: int = 6380,
        port3: int = 6381,
        db: int = 0,
        password: Optional[Text] = None,
        key_prefix: Optional[Text] = None,
        socket_timeout: Optional[float] = None,
        record_exp: Optional[float] = None,
        event_broker: Optional[EventBroker] = None,
        **kwargs: Dict[Text, Any],
        ):

        config=dict()
        config["host"]=host
        config["master"]=master
        config["port1"]=port1
        config["port2"]=port2
        config["port3"]=port3
        config["db"]=db
        config["password"]=password
        config["socket_timeout"]=socket_timeout
        self.red = redisSentinelHelper(config)

        self.record_exp = record_exp
        self.key_prefix = DEFAULT_REDIS_TRACKER_STORE_KEY_PREFIX
        if key_prefix:
            logger.debug(f"Setting non-default redis key prefix: '{key_prefix}'.")
            self._set_key_prefix(key_prefix)
        super().__init__(domain, event_broker, **kwargs)

    def _set_key_prefix(self, key_prefix: Text) -> None:
        if isinstance(key_prefix, str) and key_prefix.isalnum():
            self.key_prefix = key_prefix + ":" + DEFAULT_REDIS_TRACKER_STORE_KEY_PREFIX
        else:
            logger.warning(
                f"Omitting provided non-alphanumeric redis key prefix: '{key_prefix}'. "
                f"Using default '{self.key_prefix}' instead.")

    def _get_key_prefix(self) -> Text:
        return self.key_prefix

    def save(
        self, tracker: DialogueStateTracker, timeout: Optional[float] = None
    ) -> None:
        """Saves the current conversation state."""
        if self.event_broker:
            self.stream_events(tracker)

        if not timeout and self.record_exp:
            timeout = self.record_exp
        serialised_tracker = self.serialise_tracker(tracker)
        self.red.setex_key(self.key_prefix + tracker.sender_id, serialised_tracker, ex=timeout)
        #self.red.set_key(self.key_prefix + tracker.sender_id, serialised_tracker)

    def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
        """Retrieves tracker for the latest conversation session.
        The Redis key is formed by appending a prefix to sender_id.
        Args:
            sender_id: Conversation ID to fetch the tracker for.
        Returns:
            Tracker containing events from the latest conversation sessions.
        """
        stored = self.red.get_key(self.key_prefix + sender_id)
        if stored is not None:
            return self.deserialise_tracker(sender_id, stored)
        else:
            return None

    def keys(self) -> Iterable[Text]:
        """Returns keys of the Redis Tracker Store."""
        return self.red.get_pattern(self.key_prefix + "*")
# redis_sentinel_helper.py 连接redis 哨兵模式的工具类
import asyncio
import json
import logging
import os
import redis
from redis.sentinel import Sentinel

logger = logging.getLogger(__name__)


class redisSentinelHelper():
    def __init__(self,endpoint_config):
        self.url = endpoint_config['host']
        self.port1 = endpoint_config['port1']
        self.port2 = endpoint_config['port2']
        self.port3 = endpoint_config['port3']
        self.sentinel_list = []
        self.sentinel_list.append((self.url, self.port1))
        self.sentinel_list.append((self.url, self.port2))
        self.sentinel_list.append((self.url, self.port3))


        self.password= endpoint_config['password']
        self.socket_timeout= endpoint_config['socket_timeout']
        self.db= endpoint_config['db']
        self.service_name= endpoint_config['master']
        try:
            self.sentinel = Sentinel(self.sentinel_list, socket_timeout=self.socket_timeout, sentinel_kwargs={'password':self.password})
        
            self.master = self.sentinel.master_for(
            service_name=self.service_name,
            socket_timeout=self.socket_timeout,
            password=self.password,
            db=self.db)
        except redis.ConnectError as err:
            logger.debug(str(err))

    def get_master_redis(self):
        return self.sentinel.discover_master(self.service_name)
    
    def get_slave_redis(self):
        return self.sentinel.discover_slaves(self.service_name)
    
    def setex_key(self, key, value, ex):
        if self.master:
            return self.master.setex(key, ex, value)
        else:
            return None
    def set_key(self, key, value):
        if self.master:
            return self.master.set(key, value)
        else:
            return None
    
    def get_key(self, key):
        if self.master:
            return self.master.get(key)
        else:
            return None
    def delete_key(self, key):
        if self.master:
            return self.master.delete(key)
        else:
            return None

三. 运行rasa

rasa  run  --endpoints endpoints.yml

更多推荐

rasa会话数据存储 RedisTrackerStore 连接哨兵