Commit 2d4b64c4 by zhengcheng.wang

Initial commit

parents
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/socket-asr-qianwen-tts.iml" filepath="$PROJECT_DIR$/.idea/socket-asr-qianwen-tts.iml" />
</modules>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
package main
import (
"fmt"
"github.com/sirupsen/logrus"
"net"
"socket-asr-qianwen-tts/inits"
"socket-asr-qianwen-tts/internal"
"time"
)
func main() {
var (
listener net.Listener
err error
)
if listener, err = net.Listen("tcp", inits.Config.ListenAddr); err != nil {
logrus.Fatalln(err)
}
logrus.Printf("监听套接字,创建成功, 服务器开始监听`%v`。。。。。。\n", inits.Config.ListenAddr)
for {
var (
conn net.Conn
e error
)
if conn, e = listener.Accept(); e != nil {
logrus.Errorln(e)
continue
}
go func(c net.Conn) {
if _, e = internal.NewServerConn(c, inits.Config.EngineUrl, inits.Config.ContextPath, inits.Config.QianwenUrl, inits.Config.QianwenKey, inits.Config.TtsUrl); e != nil {
e1 := e
logrus.Errorln(e)
if _, e = c.Write([]byte(fmt.Sprintf(`{"status":"fail","msg":"%v"}`, e1.Error()))); e != nil {
logrus.Errorln(e)
}
time.Sleep(2 * time.Second)
_ = c.Close()
}
}(conn)
}
}
package cnf
type Conf struct {
ListenAddr string `yaml:"listen_addr"`
Language string `yaml:"language"`
EngineUrl string `yaml:"engine_url"`
ContextPath string `yaml:"context_path"`
QianwenUrl string `yaml:"qianwen_url"`
QianwenKey string `yaml:"qianwen_key"`
TtsUrl string `yaml:"tts_url"`
AudioDir string `yaml:"audio_dir"`
}
listen_addr: 0.0.0.0:8888
language: 'ZH-CN'
engine_url: 'ws://192.168.0.29:20086'
itn_url: ''
context_path: ''
qianwen_url: 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation'
qianwen_key: 'sk-4c390e22d34447f9b5cb256a0f6066df'
tts_url: 'ws://172.16.5.188:8090'
audio_dir: ''
\ No newline at end of file
module socket-asr-qianwen-tts
go 1.20
require (
github.com/gofrs/uuid v4.4.0+incompatible
github.com/gorilla/websocket v1.5.0
github.com/sirupsen/logrus v1.9.3
gopkg.in/yaml.v3 v3.0.1
)
require golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA=
github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
package inits
import (
"flag"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
"os"
"socket-asr-qianwen-tts/cnf"
)
var (
Config = cnf.Conf{}
)
func init() {
var (
config = flag.String("f", "./cnf/server.yaml", "Service profiles")
confData []byte
err error
)
logrus.SetReportCaller(true)
flag.Parse()
if confData, err = os.ReadFile(*config); err != nil {
logrus.Fatalln(err)
}
if err = yaml.Unmarshal(confData, &Config); err != nil {
logrus.Fatalln(err)
}
}
package internal
import (
"encoding/json"
"fmt"
"github.com/gofrs/uuid"
"github.com/sirupsen/logrus"
"log"
"net"
"socket-asr-qianwen-tts/utils"
"time"
)
type ServerBaseConn struct {
socketConn net.Conn
engineUrl string
engineHotWordFile string
engineConn *utils.AsrClientConn
qWenUrl string
qWenKey string
qWenModel *utils.QWenBaseConn
ttsUrl string
ttsModel *utils.TtsBaseConn
resultChan chan []byte
closeChan chan struct{}
isClose bool
}
func NewServerConn(s net.Conn, engineUrl, engineHotWordFile string, qWenUrl, qWenKey string, ttsUrl string) (socket *ServerBaseConn, err error) {
//var (
// rdData = make([]byte, 10000)
// rdLen int
// rdStruck = struct {
// SecretId string `json:"secret_id"`
// Signature string `json:"signature"`
// Timestamp string `json:"timestamp"`
// }{}
//)
//校验合法性
//{
// if err = s.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
// return
// }
// if rdLen, err = s.Read(rdData); err != nil {
// return
// }
// if err = json.Unmarshal(rdData[:rdLen], &rdStruck); err != nil {
// err = errors.New("鉴权数据格式错误,请检查后重试")
// return
// }
// // 检验结果
// if err = utils.CheckAuth(rdStruck.SecretId, rdStruck.Signature, rdStruck.Timestamp); err != nil {
// return
// }
//}
socket = &ServerBaseConn{
socketConn: s,
engineUrl: engineUrl,
engineHotWordFile: engineHotWordFile,
qWenUrl: qWenUrl,
qWenKey: qWenKey,
ttsUrl: ttsUrl,
resultChan: make(chan []byte, 256),
closeChan: make(chan struct{}, 1),
}
if socket.engineConn, err = utils.NewAsrClientConn(engineUrl, engineHotWordFile); err != nil {
socket = nil
return
}
logrus.Println("11111111111111111111111111111")
if socket.qWenModel, err = utils.NewQWenConn(qWenUrl, qWenKey); err != nil {
socket = nil
return
}
logrus.Println("22222222222222222222222222222")
if socket.ttsModel, err = utils.NewTtsConn(ttsUrl); err != nil {
socket = nil
return
}
logrus.Println("33333333333333333333333333333")
go socket.readClientLoop()
go socket.readEngineLoop()
go socket.writeClientLoop()
go socket.readQwenLoop()
go socket.readTtsLoop()
return
}
func (socket *ServerBaseConn) readClientLoop() {
var (
err error
rdData = make([]byte, 10000)
rdLen int
)
for {
if !socket.isClose {
if rdLen, err = socket.socketConn.Read(rdData); err != nil {
goto ERR
}
log.Println("HHHHHHHHHHHHHHHHHHHHHH", rdLen)
switch string(rdData[:rdLen]) {
case "paused":
if err = socket.engineConn.WriteEngineMsg(1); err != nil {
goto ERR
}
default:
if err = socket.engineConn.WriteEngineMsg(0, rdData[:rdLen]); err != nil {
goto ERR
}
}
} else {
time.Sleep(250 * time.Millisecond)
}
}
ERR:
socket.close()
}
func (socket *ServerBaseConn) readEngineLoop() {
var (
err error
rdData []byte
)
for {
if !socket.isClose {
if rdData, err = socket.engineConn.ReadEngineMsg(); err != nil {
goto ERR
}
select {
case socket.resultChan <- rdData:
case <-socket.closeChan:
goto ERR
}
go func() {
var (
rdMap = make(map[string]interface{})
)
if err = json.Unmarshal(rdData, &rdMap); err != nil {
logrus.Errorln(err)
time.Sleep(100 * time.Millisecond)
return
}
if v, ok := rdMap["is_final"]; ok {
if v.(float64) == 1 {
logrus.Println(string(rdData))
if rdMap["results"].(string) != "" {
if err = socket.qWenModel.ReqQWen(rdMap["results"].(string), fmt.Sprintf("%v", fmt.Sprintf("%vt", time.Now().UnixMilli()))); err != nil {
return
}
}
}
}
return
}()
} else {
time.Sleep(250 * time.Millisecond)
goto ERR
}
}
ERR:
socket.close()
}
func (socket *ServerBaseConn) writeClientLoop() {
var (
err error
rdData []byte
isFist bool
)
for {
if !socket.isClose {
if !isFist {
isFist = true
meetingId, _ := uuid.NewV6()
rdData = []byte(fmt.Sprintf(`{"status":"ok","msg":"The connection was successful","meeting_id":"%v","type":"start"}`, meetingId.String()))
} else {
select {
case rdData = <-socket.resultChan:
case <-socket.closeChan:
goto ERR
}
}
if _, err = socket.socketConn.Write(rdData); err != nil {
goto ERR
}
} else {
time.Sleep(250 * time.Millisecond)
goto ERR
}
}
ERR:
socket.close()
}
func (socket *ServerBaseConn) readQwenLoop() {
var (
err error
rdMap utils.Result
rdData []byte
)
for {
select {
case rdMap = <-socket.qWenModel.ResultChan:
case <-socket.closeChan:
goto ERR
}
if rdData, err = json.Marshal(rdMap); err != nil {
logrus.Errorln(err)
continue
}
select {
case socket.resultChan <- rdData:
case <-socket.closeChan:
goto ERR
}
if err = socket.ttsModel.WriteMsg(rdMap.Answer, rdMap.AnswerId); err != nil {
logrus.Errorln(err)
continue
}
}
ERR:
socket.close()
}
func (socket *ServerBaseConn) readTtsLoop() {
var (
rdData []byte
)
for {
select {
case rdData = <-socket.ttsModel.ResultChan:
case <-socket.closeChan:
goto ERR
}
select {
case socket.resultChan <- rdData:
case <-socket.closeChan:
goto ERR
}
}
ERR:
socket.close()
}
func (socket *ServerBaseConn) close() {
var (
err error
rdData []byte
)
if !socket.isClose {
socket.isClose = true
_ = socket.engineConn.Close()
for i := 0; i < len(socket.resultChan); i++ {
select {
case rdData = <-socket.resultChan:
}
if _, err = socket.socketConn.Write(rdData); err != nil {
}
}
close(socket.closeChan)
}
_ = socket.ttsModel.Close()
_ = socket.socketConn.Close()
}
File added
File added
## *socket技术接入文档*
## *socket技术接入文档*
### 一、说明
> 说明1:此接口是`socket`接口,可以通过`host`进行连接。在使用过程中尽可能保持长连接状态以完成业务闭环。
> 说明2:此接口的请求的数据量理论上比响应的数据量少得多,为避免阻塞或读取数据不完全的情况,强烈要求使用读写异步处理的模式,也可以根据实际情况使用两个线程分别处理读和写这两个接口。
> 说明3:此接口根据实际情况已经包含了三个模块的功能,分别是`ASR`、`阿里千义通问大模型`和`TTS`。具体业务闭环如下:`ASR`正式结果 --》 `阿里千义通问大模型` --》 流式`TTS`。
### 二、处理步骤
> 说明1:此接口目前只接收收三种类型的数据,分别是`json`、`string`、`byte`。
> 说明2:`json`数据里面包含鉴权所需的信息,用于连接有效性鉴定;`byte`这种类型标识二进制音频数据,既可以实时采集的也可以录音文件读取的,而且要求音频的格式为`pcm`或`wav`,同时要求采样频率为`8000Hz`、单声道;`string`这种数据类型默认值`paused`,表示当前阶段业务数据已经传输结束,要求服务端及时返回语音识别的正式结果,以便完成业务闭环。
> 说明3:这三种数据类型中,第一帧数据必须是鉴权信息`json`,之后才可以是`byte`和`string`。
> 说明4:接口业务处理逻辑如下:根据实时语音转文字得到的正式结果充当问题向`阿里千义通问大模型`进行有上下文形式的提问;将得到`阿里千义通问大模型`的结果通过`TTS`流式响应给客户端。整个流程客户端将收到三种类型的数据:实时语音转文字结果(`type`=`real_asr`标识)、大模型问答结果(`type`=`model_answer`标识)和流式`tts`结果(`type`=`tts_info`标识).
#### 1、鉴权
| 参数 | 是否必要 | 说明 |
|:------|:---|:---|
| `secret_id` | `yes` | 鉴权码的`secret_id` |
| `signature` | `yes` | 根据规则生成的签名 |
| `timestamp` | `yes` | 参与签名生成的时间戳,精确到秒 |
##### (1)、`signature`生成规则
- 签名生成
i. 由参数`secretID``timestamp`拼接生成`baseString`,后面的`MD5`加密和`HMACSHA1`都会以这个字符串`basestring`为基础。
```
secretID={提供的secretID值}&timestamp={参数timestamp的值}
```
ii. 将上面生成的`baseString``MD5`加密生成小写的字符串`MD5String`
iii. 以授予的一组唯一性鉴权码中`api_key``HMACSHA1`算法的`key`对上面`MD5`加密生成的`MD5String`进行加密生成`signature`
- 生成签名和接口地址示例
**假如授予的一组唯一性鉴权码的`secret_id`=`raisound123456789`、`api_key`=`123456789`,当前时间戳`signature`=`1691139410`**
i. 生成`basestring`
```
secretID=raisound123456789&timestamp=1691139410
```
ii. `MD5`加密生成的`MD5String`
```
6feec50a597ba4ae269acc14bfd50fb6
```
iii. 以授予的一组唯一性鉴权码中`api_key``HMACSHA1`算法的`key`加密生成的`signature`
```
836ad483521ed97c3fd0cb5200a33130c22b00bf
```
##### (2)、例子
- 请求
```
{
"secret_id": "Q4NPLBLyidvRc4TAUByUhcVfRFM1kQBi",
"signature": "942e7441e80128de6ed36d986820fb102c7b6f58",
"timestamp": "1698721657"
}
```
- 响应
```
{
"status":"ok",
"msg":"The connection was successful",
"meeting_id":"1ee779af-f5cb-66f6-a7b6-a9e5331bb284",
"type":"start"
}
```
#### 2、业务数据
直接传业务数据,即二进制音频数据,每一帧的长度不大于6400。
- 响应例子
- 实时语音转文字(`type`=`real_asr`)---根据`is_final`进行区分临时结果(`is_final`=`0`)或正式结果(`is_final`=`1`
- 临时结果
```
{
"is_final":0,
"results":"你",
"status":"ok",
"timestamp":{"end":1260,"start":540},
"type":"real_asr"
}
```
- 正式结果
```
{
"is_final":1,
"results":"你好你好现在是北京时间",
"status":"ok",
"timestamp":{"end":5480,"start":540},
"type":"real_asr"
}
```
- 大模型问答(`type`=`model_answer`
> 说明:参数`answer_id`会与流式`TTS`的`answer_id`保持一致,以便数据保持统一性。
```
{
"type":"model_answer",
"status":"ok",
"answer_id":"1698721812377t",
"answer":"你好!现在北京时间是2023年2月24日 13:18。",
"question":"你好你好现在是北京时间"
}
```
- 流式`TTS``type`=`tts_info`
> 根据参数`index`进行确定开始(`index`=`0`)和结束(`index`=`1`)。以及顺序问题。TTS的音频数据`base64`加密后的。
- 开始
```
{
"answer_id":"1698721812377t",
"audio":"6P8OAPn/AQAJANn/0//u/wkACQDm/8z/1/8IAAgA/v/r/9v/z/+9/7P/qf+8/+X/7f/6/w8ADAAO......",
"index":0,
"type":"tts_info"
}
```
- 结束
```
{
"answer_id":"1698721812377t",
"audio":"",
"index":-1,
"type":"tts_info"
}
```
#### 3、中止标识
> 直接发送字符串`paused`。
package utils
import (
"encoding/json"
"errors"
"fmt"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"log"
"sync"
"time"
)
type AsrClientConn struct {
url string
hotWordFile string
engineConn *websocket.Conn
resultChan chan []byte
closeChan chan struct{}
resetChan chan int
isClose bool
isReset bool
resetTimes int
lock sync.Mutex
}
func NewAsrClientConn(url, hotWordFile string) (conn *AsrClientConn, err error) {
conn = &AsrClientConn{
url: url,
hotWordFile: hotWordFile,
resultChan: make(chan []byte, 256),
closeChan: make(chan struct{}, 1),
resetChan: make(chan int, 1),
}
if conn.engineConn, err = conn.__getConn(); err != nil {
conn = nil
return
}
go conn.readEngineLoop()
return
}
func (conn *AsrClientConn) readEngineLoop() {
var (
err error
rdData []byte
val interface{}
ok bool
)
for {
var (
rdMap = make(map[string]interface{})
sendData []byte
)
if !conn.isClose {
if _, rdData, err = conn.engineConn.ReadMessage(); err != nil {
logrus.Errorln(err)
if !conn.isClose {
if !conn.isReset {
if err = conn.__resetEngine(); err != nil {
goto ERR
}
}
time.Sleep(2 * time.Second)
continue
} else {
goto ERR
}
}
log.Println("engine", string(rdData))
if err = json.Unmarshal(rdData, &rdMap); err != nil {
logrus.Errorln(err)
time.Sleep(100 * time.Millisecond)
continue
}
if val, ok = rdMap["status"]; ok {
if val.(string) == "ok" {
if val, ok = rdMap["type"]; ok {
switch val.(string) {
case "wait", "server_ready", "reload_end":
continue
case "partial_result", "final_result":
bestFormat := rdMap["nbest"].([]interface{})
if len(bestFormat) > 0 {
rdMap["results"] = bestFormat[0].(map[string]interface{})["sentence"].(string)
} else {
rdMap["results"] = ""
}
if rdMap["results"].(string) == "" {
rdMap = nil
continue
}
rdMap["is_final"] = 0
bestFormat = nil
delete(rdMap, "nbest")
case "final_result_second_pass":
log.Println(rdMap)
bestFormat := rdMap["nbest"].([]interface{})
if len(bestFormat) > 0 {
rdMap["results"] = bestFormat[0].(map[string]interface{})["sentence"].(string)
} else {
rdMap["results"] = ""
}
if rdMap["results"].(string) == "" {
rdMap = nil
continue
}
rdMap["is_final"] = 1
bestFormat = nil
delete(rdMap, "nbest")
case "speech_end":
if !conn.isReset || conn.isClose {
goto ERR
}
_ = conn.engineConn.Close()
conn.isReset = false
conn.lock.Unlock()
if err = conn.__resetEngine(); err != nil {
goto ERR
}
rdMap = nil
continue
default:
logrus.Errorln(rdMap)
time.Sleep(100 * time.Millisecond)
rdMap = nil
continue
}
rdMap["type"] = "real_asr"
if sendData, err = json.Marshal(rdMap); err != nil {
logrus.Errorln(err)
rdMap = nil
return
}
select {
case conn.resultChan <- sendData:
case <-conn.closeChan:
goto ERR
}
sendData = nil
rdMap = nil
time.Sleep(50 * time.Millisecond)
} else {
logrus.Errorln(rdMap)
time.Sleep(100 * time.Millisecond)
rdMap = nil
continue
}
} else {
logrus.Errorln(rdMap)
time.Sleep(100 * time.Millisecond)
rdMap = nil
continue
}
}
} else {
goto ERR
}
}
ERR:
_ = conn.Close()
}
func (conn *AsrClientConn) ReadEngineMsg() (result []byte, err error) {
select {
case result = <-conn.resultChan:
case <-conn.closeChan:
err = errors.New("engine is close")
}
return
}
func (conn *AsrClientConn) WriteEngineMsg(dataType int, data ...[]byte) (err error) {
if conn.isReset {
conn.lock.Lock()
conn.lock.Unlock()
}
if !conn.isClose {
switch dataType {
case 0:
if err = conn.engineConn.WriteMessage(websocket.BinaryMessage, data[0]); err != nil {
if !conn.isClose {
if !conn.isReset {
if err = conn.__resetEngine(); err != nil {
return
}
return
}
} else {
err = errors.New("engine is closed")
return
}
}
case 1:
if err = conn.engineConn.WriteMessage(websocket.TextMessage, []byte(`{"signal":"end"}`)); err != nil {
if !conn.isClose {
if !conn.isReset {
if err = conn.__resetEngine(); err != nil {
return
}
return
}
} else {
err = errors.New("engine is closed")
return
}
}
conn.isReset = true
conn.lock.Lock()
}
} else {
err = errors.New("engine is closed")
}
return
}
func (conn *AsrClientConn) __getConn() (wsConn *websocket.Conn, err error) {
if wsConn, _, err = websocket.DefaultDialer.Dial(conn.url, nil); err != nil {
if conn.resetTimes > 10 {
return
}
time.Sleep(500 * time.Millisecond)
conn.resetTimes++
return conn.__getConn()
}
loadText := fmt.Sprintf(`{"signal": "start", "continuous_decoding": %v, "nbest": 10, "enable_vad": %v, "chunk_size": 12, "two_pass": %v, "context_update":%v,"context_score": 5.0, "context_path": "%v","input_sr" : 8000}`, true, true, true, true, conn.hotWordFile)
logrus.Println(loadText)
if err = wsConn.WriteMessage(websocket.TextMessage, []byte(loadText)); err != nil {
if conn.resetTimes > 10 {
return
}
time.Sleep(500 * time.Millisecond)
conn.resetTimes++
return conn.__getConn()
}
conn.resetTimes = 0
return
}
func (conn *AsrClientConn) __resetEngine() (err error) {
conn.lock.Lock()
if !conn.isClose {
conn.isReset = true
_ = conn.engineConn.Close()
conn.engineConn, err = conn.__getConn()
conn.isReset = false
} else {
err = errors.New("engine connection is closed")
}
conn.lock.Unlock()
return
}
func (conn *AsrClientConn) Close() (err error) {
conn.lock.Lock()
if !conn.isClose {
conn.isClose = true
_ = conn.engineConn.Close()
close(conn.closeChan)
}
conn.lock.Unlock()
return
}
package utils
import (
"bytes"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"sync"
)
type reqBody struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
History []historyData `json:"history"`
} `json:"input"`
}
type respBody struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
}
RequestId string `json:"request_id"`
}
type historyData struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type QWenBaseConn struct {
url string
key string
lock sync.Mutex
historySlice []historyData
ResultChan chan Result
closeChan chan struct{}
}
type Result struct {
Type string `json:"type"`
Status string `json:"status"`
AnswerId string `json:"answer_id"`
Answer string `json:"answer"`
Question string `json:"question"`
}
func NewQWenConn(reqUrl string, key string) (conn *QWenBaseConn, err error) {
conn = &QWenBaseConn{
url: reqUrl,
key: key,
ResultChan: make(chan Result, 32),
historySlice: []historyData{},
}
return
}
func (conn *QWenBaseConn) appendHistory(data historyData) {
conn.historySlice = append(conn.historySlice, data)
}
func (conn *QWenBaseConn) writer(data string, question string, anserId string) (err error) {
select {
case <-conn.closeChan:
err = errors.New("close")
case conn.ResultChan <- Result{Status: "ok", Type: "model_answer", AnswerId: anserId, Answer: data, Question: question}:
}
return
}
func (conn *QWenBaseConn) ReqQWen(text string, answerId string) (err error) {
conn.lock.Lock()
var (
reqBodyData = reqBody{Model: "qwen-v1", Input: struct {
Prompt string `json:"prompt"`
History []historyData `json:"history"`
}{Prompt: text, History: conn.historySlice}}
reqData []byte
req *http.Request
resp *http.Response
client = http.Client{}
respBodyData = respBody{}
respData []byte
)
if reqData, err = json.Marshal(reqBodyData); err != nil {
conn.lock.Unlock()
return
}
if req, err = http.NewRequest("POST", conn.url, bytes.NewBuffer(reqData)); err != nil {
conn.lock.Unlock()
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+conn.key)
if resp, err = client.Do(req); err != nil {
conn.lock.Unlock()
return
}
if respData, err = ioutil.ReadAll(resp.Body); err != nil {
conn.lock.Unlock()
return
}
defer func() {
_ = resp.Body.Close()
}()
if err = json.Unmarshal(respData, &respBodyData); err != nil {
conn.lock.Unlock()
return
}
history := historyData{
User: text,
Bot: respBodyData.Output.Text,
}
if history.Bot != "" {
_ = conn.writer(history.Bot, text, answerId)
conn.appendHistory(history)
} else {
_ = conn.writer("由于我我是AI大模型,我无法回答你的问题,请换个别的问题。", text, answerId)
}
conn.lock.Unlock()
return
}
package utils
import (
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"encoding/hex"
"errors"
"fmt"
"strconv"
"time"
)
const (
secret = "Q4NPLBLyidvRc4TAUByUhcVfRFM1kQBi"
api = "P7vNsYTv"
)
func CheckAuth(secretId, signature, timestamp string) (err error) {
var (
ts int
now = int(time.Now().Unix())
)
if secretId != secret {
err = errors.New("参数`secret_id`的值有误")
return
}
if ts, err = strconv.Atoi(timestamp); err != nil {
err = errors.New("参数`timestamp`格式有误")
return
}
if ts > now || now-ts > 600 {
err = errors.New("签名无效,请重新生成")
return
}
// TODO :MD5加密
m5 := md5.New()
m5.Write([]byte(fmt.Sprintf("secretID=%v&timestamp=%v", secretId, timestamp)))
stitchStr := hex.EncodeToString(m5.Sum(nil))
// TODO : 生成signature
mac := hmac.New(sha1.New, []byte(api))
mac.Write([]byte(stitchStr))
if signature != hex.EncodeToString(mac.Sum(nil)) {
err = errors.New("签名错误,请重新生成")
return
}
return
}
package utils
import (
"encoding/json"
"errors"
"fmt"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"sync"
"time"
)
type TtsBaseConn struct {
url string
wsConn *websocket.Conn
idx int
ResultChan chan []byte
answerIdChan chan string
answerId string
resetTimes int
isReset bool
lock sync.Mutex
isClose bool
closeChan chan struct{}
}
func NewTtsConn(url string) (tts *TtsBaseConn, err error) {
tts = &TtsBaseConn{
url: url,
ResultChan: make(chan []byte, 256),
answerIdChan: make(chan string, 128),
closeChan: make(chan struct{}, 1),
}
if tts.wsConn, err = tts.__geConn(); err != nil {
tts = nil
return
}
go tts.readMsgLoop()
return
}
func (tts *TtsBaseConn) __geConn() (conn *websocket.Conn, err error) {
var (
recvData []byte
recvMap = make(map[string]interface{})
)
if conn, _, err = websocket.DefaultDialer.Dial(tts.url, nil); err != nil {
if tts.resetTimes > 10 {
tts = nil
return
}
time.Sleep(500 * time.Millisecond)
tts.resetTimes++
return tts.__geConn()
}
if err = conn.WriteMessage(websocket.TextMessage, []byte(`{"singal": "start"}`)); err != nil {
if tts.resetTimes > 10 {
tts = nil
return
}
time.Sleep(500 * time.Millisecond)
tts.resetTimes++
return tts.__geConn()
}
if _, recvData, err = conn.ReadMessage(); err != nil {
if tts.resetTimes > 10 {
tts = nil
return
}
time.Sleep(500 * time.Millisecond)
tts.resetTimes++
return tts.__geConn()
}
if err = json.Unmarshal(recvData, &recvMap); err != nil {
if tts.resetTimes > 10 {
tts = nil
return
}
time.Sleep(500 * time.Millisecond)
tts.resetTimes++
return tts.__geConn()
}
if v, ok := recvMap["status"]; ok {
if v.(string) != "ok" {
if tts.resetTimes > 10 {
tts = nil
return
}
time.Sleep(500 * time.Millisecond)
tts.resetTimes++
return tts.__geConn()
}
}
tts.resetTimes = 0
return
}
func (tts *TtsBaseConn) __recall() (err error) {
tts.lock.Lock()
if !tts.isClose {
if !tts.isReset {
tts.isReset = true
tts.wsConn, err = tts.__geConn()
tts.isReset = false
}
} else {
err = errors.New("tts is closed")
}
tts.lock.Unlock()
return
}
func (tts *TtsBaseConn) WriteMsg(txt, answerId string) (err error) {
if !tts.isClose {
signal := fmt.Sprintf(`{"singal" : "text", "text" : "%v", "voice_name" : "xiaomu", "streaming": %v, "sample_rate": %v, "speed": %v}`, txt, true, 8000, float64(1))
if err = tts.wsConn.WriteMessage(websocket.TextMessage, []byte(signal)); err != nil {
if !tts.isClose {
if !tts.isReset {
if err = tts.__recall(); err != nil {
return
}
}
time.Sleep(100 * time.Millisecond)
return tts.WriteMsg(txt, answerId)
} else {
err = errors.New("tts is closed")
return
}
}
select {
case tts.answerIdChan <- answerId:
}
tts.answerId = answerId
} else {
err = errors.New("tts is close")
}
return
}
func (tts *TtsBaseConn) readMsg() (data []byte, err error) {
var (
recvData []byte
recvMap map[string]interface{}
result = make(map[string]interface{})
)
if !tts.isClose {
if _, recvData, err = tts.wsConn.ReadMessage(); err != nil {
if !tts.isClose {
if !tts.isReset {
if err = tts.__recall(); err != nil {
return
}
}
time.Sleep(100 * time.Millisecond)
return tts.readMsg()
} else {
err = errors.New("tts is closed")
return
}
}
if err = json.Unmarshal(recvData, &recvMap); err != nil {
logrus.Errorln(err)
err = nil
return
}
if v, ok := recvMap["status"]; ok {
switch v.(string) {
case "end":
tts.idx = 0
result["type"] = "tts_info"
result["index"] = -1
result["audio"] = ""
result["answer_id"] = tts.answerId
go func() {
select {
case tts.answerId = <-tts.answerIdChan:
}
}()
case "error":
tts.idx = 0
result["type"] = "tts_info"
result["index"] = -1
result["audio"] = ""
result["answer_id"] = tts.answerId
go func() {
select {
case tts.answerId = <-tts.answerIdChan:
}
}()
case "internal":
audio, _ := recvMap["audio"].(string)
result["type"] = "tts_info"
result["index"] = tts.idx
result["audio"] = audio
result["answer_id"] = tts.answerId
tts.idx++
default:
logrus.Errorln(recvMap)
return tts.readMsg()
}
} else {
logrus.Errorln(recvMap)
return tts.readMsg()
}
return json.Marshal(result)
} else {
err = errors.New("tts is close")
return
}
}
func (tts *TtsBaseConn) readMsgLoop() {
var (
recvData []byte
err error
)
for {
if recvData, err = tts.readMsg(); err != nil {
goto ERR
}
select {
case tts.ResultChan <- recvData:
}
}
ERR:
tts.Close()
}
func (tts *TtsBaseConn) Close() (err error) {
if !tts.isClose {
tts.isClose = true
if tts.wsConn != nil {
if err = tts.wsConn.WriteMessage(websocket.TextMessage, []byte(`{"singal" : "end"}`)); err != nil {
return
}
if err = tts.wsConn.Close(); err != nil {
return
}
}
}
return
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment