更新mini支持讯飞语音识别
This commit is contained in:
87
boards/default_src/micropython/origin/build/lib/hmac.py
Normal file
87
boards/default_src/micropython/origin/build/lib/hmac.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
# Implements the hmac module from the Python standard library.
|
||||||
|
|
||||||
|
|
||||||
|
class HMAC:
|
||||||
|
def __init__(self, key, msg=None, digestmod=None):
|
||||||
|
if not isinstance(key, (bytes, bytearray)):
|
||||||
|
raise TypeError("key: expected bytes/bytearray")
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
if digestmod is None:
|
||||||
|
# TODO: Default hash algorithm is now deprecated.
|
||||||
|
digestmod = hashlib.md5
|
||||||
|
|
||||||
|
if callable(digestmod):
|
||||||
|
# A hashlib constructor returning a new hash object.
|
||||||
|
make_hash = digestmod # A
|
||||||
|
elif isinstance(digestmod, str):
|
||||||
|
# A hash name suitable for hashlib.new().
|
||||||
|
make_hash = lambda d=b"": getattr(hashlib, digestmod)(d)
|
||||||
|
else:
|
||||||
|
# A module supporting PEP 247.
|
||||||
|
make_hash = digestmod.new # C
|
||||||
|
|
||||||
|
self._outer = make_hash()
|
||||||
|
self._inner = make_hash()
|
||||||
|
|
||||||
|
self.digest_size = getattr(self._inner, "digest_size", None)
|
||||||
|
# If the provided hash doesn't support block_size (e.g. built-in
|
||||||
|
# hashlib), 64 is the correct default for all built-in hash
|
||||||
|
# functions (md5, sha1, sha256).
|
||||||
|
self.block_size = getattr(self._inner, "block_size", 64)
|
||||||
|
|
||||||
|
# Truncate to digest_size if greater than block_size.
|
||||||
|
if len(key) > self.block_size:
|
||||||
|
key = make_hash(key).digest()
|
||||||
|
|
||||||
|
# Pad to block size.
|
||||||
|
key = key + bytes(self.block_size - len(key))
|
||||||
|
|
||||||
|
self._outer.update(bytes(x ^ 0x5C for x in key))
|
||||||
|
self._inner.update(bytes(x ^ 0x36 for x in key))
|
||||||
|
|
||||||
|
if msg is not None:
|
||||||
|
self.update(msg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return "hmac-" + getattr(self._inner, "name", type(self._inner).__name__)
|
||||||
|
|
||||||
|
def update(self, msg):
|
||||||
|
self._inner.update(msg)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
if not hasattr(self._inner, "copy"):
|
||||||
|
# Not supported for built-in hash functions.
|
||||||
|
raise NotImplementedError()
|
||||||
|
# Call __new__ directly to avoid the expensive __init__.
|
||||||
|
other = self.__class__.__new__(self.__class__)
|
||||||
|
other.block_size = self.block_size
|
||||||
|
other.digest_size = self.digest_size
|
||||||
|
other._inner = self._inner.copy()
|
||||||
|
other._outer = self._outer.copy()
|
||||||
|
return other
|
||||||
|
|
||||||
|
def _current(self):
|
||||||
|
h = self._outer
|
||||||
|
if hasattr(h, "copy"):
|
||||||
|
# built-in hash functions don't support this, and as a result,
|
||||||
|
# digest() will finalise the hmac and further calls to
|
||||||
|
# update/digest will fail.
|
||||||
|
h = h.copy()
|
||||||
|
h.update(self._inner.digest())
|
||||||
|
return h
|
||||||
|
|
||||||
|
def digest(self):
|
||||||
|
h = self._current()
|
||||||
|
return h.digest()
|
||||||
|
|
||||||
|
def hexdigest(self):
|
||||||
|
import binascii
|
||||||
|
|
||||||
|
return str(binascii.hexlify(self.digest()), "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def new(key, msg=None, digestmod=None):
|
||||||
|
return HMAC(key, msg, digestmod)
|
||||||
32
boards/default_src/micropython/origin/build/lib/urllib.py
Normal file
32
boards/default_src/micropython/origin/build/lib/urllib.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from ucollections import namedtuple
|
||||||
|
|
||||||
|
URI = namedtuple('URI', ('cheme', 'netloc', 'path', 'params', 'query', 'fragment'))
|
||||||
|
|
||||||
|
def quote(string, safe='_.-~+'):
|
||||||
|
""" A simple implementation of URL quoting"""
|
||||||
|
string = string.replace(' ', '+')
|
||||||
|
result = ""
|
||||||
|
for char in string:
|
||||||
|
if ('a' <= char <= 'z') or ('A' <= char <= 'Z') or ('0' <= char<= '9') or (char in safe):
|
||||||
|
result += char
|
||||||
|
else:
|
||||||
|
result += "%{:02X}".format(ord(char))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def urlencode(query, safe='_.-~+'):
|
||||||
|
"""A simple urlencode function"""
|
||||||
|
return '&'.join('{}={}'.format(quote(k, safe), quote(v, safe)) for k, v in query.items())
|
||||||
|
|
||||||
|
def urlparse(url):
|
||||||
|
"""A simple urlparse (cheme, netloc, path, params, query, fragment)"""
|
||||||
|
parts = [''] * 6
|
||||||
|
for i, sep in enumerate(['://', '#', '?', ';']):
|
||||||
|
if sep in url:
|
||||||
|
left, right = url.split(sep, 1)
|
||||||
|
parts[i], url = left, right
|
||||||
|
if '/' in url:
|
||||||
|
parts[1], parts[2] = url.split('/', 1)
|
||||||
|
parts[2] = '/' + parts[2]
|
||||||
|
else:
|
||||||
|
parts[1] = url
|
||||||
|
return URI(*parts)
|
||||||
229
boards/default_src/micropython/origin/build/lib/websocket.py
Normal file
229
boards/default_src/micropython/origin/build/lib/websocket.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
''''''
|
||||||
|
import usocket as socket
|
||||||
|
import ubinascii as binascii
|
||||||
|
import urandom as random
|
||||||
|
import ustruct as struct
|
||||||
|
import urandom as random
|
||||||
|
from ucollections import namedtuple
|
||||||
|
|
||||||
|
# Opcodes
|
||||||
|
OP_CONT = const(0x0)
|
||||||
|
OP_TEXT = const(0x1)
|
||||||
|
OP_BYTES = const(0x2)
|
||||||
|
OP_CLOSE = const(0x8)
|
||||||
|
OP_PING = const(0x9)
|
||||||
|
OP_PONG = const(0xA)
|
||||||
|
|
||||||
|
# Close codes
|
||||||
|
CLOSE_OK = const(1000)
|
||||||
|
CLOSE_GOING_AWAY = const(1001)
|
||||||
|
CLOSE_PROTOCOL_ERROR = const(1002)
|
||||||
|
CLOSE_DATA_NOT_SUPPORTED = const(1003)
|
||||||
|
CLOSE_BAD_DATA = const(1007)
|
||||||
|
CLOSE_POLICY_VIOLATION = const(1008)
|
||||||
|
CLOSE_TOO_BIG = const(1009)
|
||||||
|
CLOSE_MISSING_EXTN = const(1010)
|
||||||
|
CLOSE_BAD_CONDITION = const(1011)
|
||||||
|
|
||||||
|
URI = namedtuple('URI', ('protocol', 'hostname', 'port', 'path'))
|
||||||
|
|
||||||
|
class NoDataException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ConnectionClosed(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def urlparse(uri):
|
||||||
|
# Split protocol and the rest
|
||||||
|
protocol, rest = uri.split('://', 1)
|
||||||
|
if '/' in rest:
|
||||||
|
hostname_port, path = rest.split('/', 1)
|
||||||
|
path = '/' + path
|
||||||
|
else:
|
||||||
|
hostname_port, path = rest, ''
|
||||||
|
if ':' in hostname_port:
|
||||||
|
hostname, port = hostname_port.rsplit(':', 1)
|
||||||
|
else:
|
||||||
|
hostname, port = hostname_port, None
|
||||||
|
if port is None:
|
||||||
|
port = 443 if protocol == 'wss' else 80
|
||||||
|
return URI(protocol, hostname, port, path)
|
||||||
|
|
||||||
|
class Websocket:
|
||||||
|
"""Basis of the Websocket protocol."""
|
||||||
|
|
||||||
|
is_client = False
|
||||||
|
|
||||||
|
def __init__(self, sock):
|
||||||
|
self.sock = sock
|
||||||
|
self.open = True
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def settimeout(self, timeout):
|
||||||
|
self.sock.settimeout(timeout)
|
||||||
|
|
||||||
|
def read_frame(self, max_size=None):
|
||||||
|
"""Read a frame from the socket"""
|
||||||
|
# Frame header
|
||||||
|
two_bytes = self.sock.read(2)
|
||||||
|
if not two_bytes:
|
||||||
|
raise NoDataException
|
||||||
|
byte1, byte2 = struct.unpack('!BB', two_bytes)
|
||||||
|
# Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4)
|
||||||
|
fin = bool(byte1 & 0x80)
|
||||||
|
opcode = byte1 & 0x0F
|
||||||
|
# Byte 2: MASK(1) LENGTH(7)
|
||||||
|
mask = bool(byte2 & (1 << 7))
|
||||||
|
length = byte2 & 0x7F
|
||||||
|
|
||||||
|
if length == 126: # Magic number, length header is 2 bytes
|
||||||
|
(length,) = struct.unpack('!H', self.sock.read(2))
|
||||||
|
elif length == 127: # Magic number, length header is 8 bytes
|
||||||
|
(length,) = struct.unpack('!Q', self.sock.read(8))
|
||||||
|
if mask: # Mask is 4 bytes
|
||||||
|
mask_bits = self.sock.read(4)
|
||||||
|
try:
|
||||||
|
data = self.sock.read(length)
|
||||||
|
except MemoryError:
|
||||||
|
# We can't receive this many bytes, close the socket
|
||||||
|
self.close(code=CLOSE_TOO_BIG)
|
||||||
|
return True, OP_CLOSE, None
|
||||||
|
if mask:
|
||||||
|
data = bytes(b ^ mask_bits[i % 4] for i, b in enumerate(data))
|
||||||
|
return fin, opcode, data
|
||||||
|
|
||||||
|
def write_frame(self, opcode, data=b''):
|
||||||
|
"""Write a frame to the socket"""
|
||||||
|
fin = True
|
||||||
|
mask = self.is_client # messages sent by client are masked
|
||||||
|
|
||||||
|
length = len(data)
|
||||||
|
# Frame header
|
||||||
|
# Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4)
|
||||||
|
byte1 = 0x80 if fin else 0
|
||||||
|
byte1 |= opcode
|
||||||
|
# Byte 2: MASK(1) LENGTH(7)
|
||||||
|
byte2 = 0x80 if mask else 0
|
||||||
|
if length < 126: # 126 is magic value to use 2-byte length header
|
||||||
|
byte2 |= length
|
||||||
|
self.sock.write(struct.pack('!BB', byte1, byte2))
|
||||||
|
elif length < (1 << 16): # Length fits in 2-bytes
|
||||||
|
byte2 |= 126 # Magic code
|
||||||
|
self.sock.write(struct.pack('!BBH', byte1, byte2, length))
|
||||||
|
elif length < (1 << 64):
|
||||||
|
byte2 |= 127 # Magic code
|
||||||
|
self.sock.write(struct.pack('!BBQ', byte1, byte2, length))
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
if mask: # Mask is 4 bytes
|
||||||
|
mask_bits = struct.pack('!I', random.getrandbits(32))
|
||||||
|
self.sock.write(mask_bits)
|
||||||
|
data = bytes(b ^ mask_bits[i % 4] for i, b in enumerate(data))
|
||||||
|
self.sock.write(data)
|
||||||
|
|
||||||
|
def recv(self):
|
||||||
|
"""Receive data from the websocket"""
|
||||||
|
assert self.open
|
||||||
|
|
||||||
|
while self.open:
|
||||||
|
try:
|
||||||
|
fin, opcode, data = self.read_frame()
|
||||||
|
except NoDataException:
|
||||||
|
return ''
|
||||||
|
except ValueError:
|
||||||
|
self._close()
|
||||||
|
raise ConnectionClosed()
|
||||||
|
|
||||||
|
if not fin:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if opcode == OP_TEXT:
|
||||||
|
return data.decode('utf-8')
|
||||||
|
elif opcode == OP_BYTES:
|
||||||
|
return data
|
||||||
|
elif opcode == OP_CLOSE:
|
||||||
|
self._close()
|
||||||
|
return
|
||||||
|
elif opcode == OP_PONG:
|
||||||
|
# Ignore this frame, keep waiting for a data frame
|
||||||
|
continue
|
||||||
|
elif opcode == OP_PING:
|
||||||
|
# We need to send a pong frame
|
||||||
|
self.write_frame(OP_PONG, data)
|
||||||
|
# And then wait to receive
|
||||||
|
continue
|
||||||
|
elif opcode == OP_CONT:
|
||||||
|
# This is a continuation of a previous frame
|
||||||
|
raise NotImplementedError(opcode)
|
||||||
|
else:
|
||||||
|
raise ValueError(opcode)
|
||||||
|
|
||||||
|
def send(self, buf):
|
||||||
|
"""Send data to the websocket."""
|
||||||
|
assert self.open
|
||||||
|
if isinstance(buf, str):
|
||||||
|
opcode = OP_TEXT
|
||||||
|
buf = buf.encode('utf-8')
|
||||||
|
elif isinstance(buf, bytes):
|
||||||
|
opcode = OP_BYTES
|
||||||
|
else:
|
||||||
|
raise TypeError()
|
||||||
|
self.write_frame(opcode, buf)
|
||||||
|
|
||||||
|
def close(self, code=CLOSE_OK, reason=''):
|
||||||
|
"""Close the websocket."""
|
||||||
|
if not self.open:
|
||||||
|
return
|
||||||
|
buf = struct.pack('!H', code) + reason.encode('utf-8')
|
||||||
|
self.write_frame(OP_CLOSE, buf)
|
||||||
|
self._close()
|
||||||
|
|
||||||
|
def _close(self):
|
||||||
|
self.open = False
|
||||||
|
self.sock.close()
|
||||||
|
|
||||||
|
class WebsocketClient(Websocket):
|
||||||
|
is_client = True
|
||||||
|
|
||||||
|
def connect(uri, headers=None):
|
||||||
|
"""Connect a websocket."""
|
||||||
|
uri = urlparse(uri)
|
||||||
|
assert uri
|
||||||
|
sock = socket.socket()
|
||||||
|
addr = socket.getaddrinfo(uri.hostname, uri.port)
|
||||||
|
sock.connect(addr[0][4])
|
||||||
|
if uri.protocol == 'wss':
|
||||||
|
import ssl as ussl
|
||||||
|
sock = ussl.wrap_socket(sock, server_hostname=uri.hostname)
|
||||||
|
|
||||||
|
def send_header(header, *args):
|
||||||
|
sock.write(header % args + '\r\n')
|
||||||
|
|
||||||
|
# Sec-WebSocket-Key is 16 bytes of random base64 encoded
|
||||||
|
key = binascii.b2a_base64(bytes(random.getrandbits(8) for _ in range(16)))[:-1]
|
||||||
|
|
||||||
|
send_header(b'GET %s HTTP/1.1', uri.path or '/')
|
||||||
|
send_header(b'Host: %s:%s', uri.hostname, uri.port)
|
||||||
|
send_header(b'Connection: Upgrade')
|
||||||
|
send_header(b'Upgrade: websocket')
|
||||||
|
send_header(b'Sec-WebSocket-Key: %s', key)
|
||||||
|
send_header(b'Sec-WebSocket-Version: 13')
|
||||||
|
send_header(b'Origin: http://{hostname}:{port}'.format(hostname=uri.hostname, port=uri.port))
|
||||||
|
if headers: # 注入自定义头
|
||||||
|
for k, v in headers.items():
|
||||||
|
send_header((k + ": " + v).encode())
|
||||||
|
send_header(b'')
|
||||||
|
|
||||||
|
header = sock.readline()[:-2]
|
||||||
|
assert header.startswith(b'HTTP/1.1 101 '), header
|
||||||
|
# We don't (currently) need these headers
|
||||||
|
# FIXME: should we check the return key?
|
||||||
|
while header:
|
||||||
|
header = sock.readline()[:-2]
|
||||||
|
|
||||||
|
return WebsocketClient(sock)
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
MINI_XUNFEI
|
||||||
|
|
||||||
|
Micropython library for the MINI_XUNFEI(ASR, LLM)
|
||||||
|
=======================================================
|
||||||
|
@dahanzimin From the Mixly Team
|
||||||
|
|
||||||
|
语音听写(流式版) WebAPI 文档 https://www.xfyun.cn/doc/asr/voicedictation/API.html
|
||||||
|
大模型(Spark4.0 Ultra)WebAPI 文档 https://www.xfyun.cn/doc/spark/Web.html
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import rtctime
|
||||||
|
import websocket
|
||||||
|
from mixgo_mini import onboard_bot
|
||||||
|
from base64 import b64decode, b64encode
|
||||||
|
from urllib import urlencode, urlparse
|
||||||
|
|
||||||
|
class Ws_Param:
|
||||||
|
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
||||||
|
self.APPID = APPID
|
||||||
|
self.APIKey = APIKey
|
||||||
|
self.APISecret = APISecret
|
||||||
|
self.url = Spark_url
|
||||||
|
self.urlparse = urlparse(Spark_url)
|
||||||
|
|
||||||
|
def create_url(self):
|
||||||
|
date = rtctime.rfc1123_time()
|
||||||
|
signature_origin = "host: " + self.urlparse.netloc + "\n"
|
||||||
|
signature_origin += "date: " + date + "\n"
|
||||||
|
signature_origin += "GET " + self.urlparse.path + " HTTP/1.1"
|
||||||
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest()
|
||||||
|
signature_base64 = b64encode(signature_sha).decode('utf-8')
|
||||||
|
authorization_origin = ('api_key="{}", algorithm="hmac-sha256", headers="host date request-line", signature="{}"'.format(self.APIKey, signature_base64))
|
||||||
|
authorization = b64encode(authorization_origin.encode('utf-8')).decode('utf-8')
|
||||||
|
headers = {"authorization": authorization, "date": date, "host": self.urlparse.netloc}
|
||||||
|
return self.url + '?' + urlencode(headers)
|
||||||
|
|
||||||
|
#语音听写
|
||||||
|
class ASR_WebSocket(Ws_Param):
|
||||||
|
def __init__(self, APPID, APIKey, APISecret, url='ws://iat-api.xfyun.cn/v2/iat'):
|
||||||
|
super().__init__(APPID, APIKey, APISecret, url)
|
||||||
|
self.ws = None
|
||||||
|
self.business = {
|
||||||
|
"domain": "iat",
|
||||||
|
"language": "zh_cn",
|
||||||
|
"accent": "mandarin",
|
||||||
|
"vinfo": 1,
|
||||||
|
"vad_eos": 1000,
|
||||||
|
"nbest": 1,
|
||||||
|
"wbest": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.ws = websocket.connect(self.create_url())
|
||||||
|
self.ws.settimeout(2000)
|
||||||
|
|
||||||
|
def _frame(self, status, buf):
|
||||||
|
return {"status": status, "format": "audio/L16;rate=8000", "audio": str(b64encode(buf), 'utf-8'), "encoding": "raw"}
|
||||||
|
|
||||||
|
def on_message(self, message):
|
||||||
|
result = ""
|
||||||
|
msg = json.loads(message)
|
||||||
|
code = msg["code"]
|
||||||
|
if code != 0:
|
||||||
|
raise AttributeError("On message sid:%s call error:%s code is:%s" % (msg["sid"], msg["message"], code))
|
||||||
|
else:
|
||||||
|
data = msg["data"]["result"]["ws"]
|
||||||
|
for i in data:
|
||||||
|
for w in i["cw"]:
|
||||||
|
result += w["w"]
|
||||||
|
if msg["data"]["status"]== 2:
|
||||||
|
return result, False
|
||||||
|
return result, True
|
||||||
|
|
||||||
|
def receive_messages(self):
|
||||||
|
msg = ""
|
||||||
|
while True:
|
||||||
|
t = self.on_message(self.ws.recv())
|
||||||
|
msg += t[0]
|
||||||
|
if not t[1]:
|
||||||
|
break
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def run(self, seconds=3, ibuf=1600, timeout=2000):
|
||||||
|
try:
|
||||||
|
_state = 0
|
||||||
|
self.connect()
|
||||||
|
_star = time.ticks_ms()
|
||||||
|
_size = int(ibuf * seconds * 10) #100ms/次
|
||||||
|
onboard_bot.pcm_en(True) #PCM开启
|
||||||
|
while _size > 0:
|
||||||
|
if onboard_bot.pcm_any():
|
||||||
|
_size -= ibuf
|
||||||
|
_star = time.ticks_ms()
|
||||||
|
buf = onboard_bot.pcm_read(ibuf)
|
||||||
|
# 第一帧处理
|
||||||
|
if _state == 0:
|
||||||
|
d = {"common": {"app_id": self.APPID}, "business": self.business, "data": self._frame(_state, buf)}
|
||||||
|
_state = 1
|
||||||
|
# 中间帧处理
|
||||||
|
else:
|
||||||
|
d = {"data": self._frame(_state, buf)}
|
||||||
|
self.ws.send(json.dumps(d))
|
||||||
|
#print("------",len(buf), time.ticks_diff(time.ticks_ms(), _star))
|
||||||
|
if time.ticks_diff(time.ticks_ms(), _star) > timeout:
|
||||||
|
raise OSError("Timeout pcm read error")
|
||||||
|
# 最后一帧处理
|
||||||
|
d = {"data": self._frame(2, b'\x00')}
|
||||||
|
self.ws.send(json.dumps(d))
|
||||||
|
onboard_bot.pcm_en(False) #PCM关闭
|
||||||
|
msg = self.receive_messages()
|
||||||
|
self.ws.close()
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
onboard_bot.pcm_en(False) #PCM关闭
|
||||||
|
print("run:%s" % (e))
|
||||||
|
|
||||||
|
#大模型
|
||||||
|
class LLM_WebSocket(Ws_Param):
|
||||||
|
def __init__(self, APPID, APIKey, APISecret, answers=50, url='ws://spark-api.xf-yun.com/v4.0/chat'):
|
||||||
|
super().__init__(APPID, APIKey, APISecret, url)
|
||||||
|
self.ws = None
|
||||||
|
self.answers =answers
|
||||||
|
self._messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是知识渊博的助理,习惯简短表达"
|
||||||
|
}]
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.ws = websocket.connect(self.create_url())
|
||||||
|
self.ws.settimeout(1000)
|
||||||
|
|
||||||
|
def _params(self, domain):
|
||||||
|
d = {
|
||||||
|
"header": {"app_id": self.APPID},
|
||||||
|
"parameter": {
|
||||||
|
"chat": {
|
||||||
|
"domain": domain,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"top_k": 5,
|
||||||
|
"auditing": "default"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"message": {
|
||||||
|
"text": self._messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.ws.send(json.dumps(d))
|
||||||
|
|
||||||
|
|
||||||
|
def empty_history(self):
|
||||||
|
self._messages = []
|
||||||
|
|
||||||
|
def add_history(self, role, content):
|
||||||
|
self._messages.append({
|
||||||
|
"role": role,
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
|
||||||
|
def on_message(self, message):
|
||||||
|
result = ""
|
||||||
|
msg = json.loads(message)
|
||||||
|
code = msg['header']['code']
|
||||||
|
if code != 0:
|
||||||
|
raise AttributeError("On message sid:%s code is:%s" % (msg["header"]["sid"], code))
|
||||||
|
else:
|
||||||
|
choices = msg["payload"]["choices"]
|
||||||
|
result += choices["text"][0]["content"]
|
||||||
|
if choices["status"] == 2:
|
||||||
|
return result, False
|
||||||
|
return result, True
|
||||||
|
|
||||||
|
def receive_messages(self):
|
||||||
|
msg = ""
|
||||||
|
while True:
|
||||||
|
t = self.on_message(self.ws.recv())
|
||||||
|
msg += t[0]
|
||||||
|
if not t[1]:
|
||||||
|
break
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def run(self, question, domain="4.0Ultra"):
|
||||||
|
try:
|
||||||
|
self.connect()
|
||||||
|
self.add_history("user", question)
|
||||||
|
self._params(domain)
|
||||||
|
while self.answers < len(self._messages):
|
||||||
|
del self._messages[0]
|
||||||
|
msg = self.receive_messages()
|
||||||
|
self.ws.close()
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
print("run:%s" % (e))
|
||||||
Reference in New Issue
Block a user