chat_huggingface 函数功能:

1.任务规划:使用ChatGPT分析用户的请求,了解他们的意图,并将其拆解成可能的可解决任务。

2.模型选择:为解决计划任务,ChatGPT 根据描述选择托管在 Hugging Face 上的专家模型。

3.任务执行:调用并执行每个选定的模型,并将结果返回给 ChatGPT。

4.生成结果: 最后使用ChatGPT整合所有模型的预测,生成response。

函数源码注释如下:

def chat_huggingface(messages):
    start = time.time()
    context = messages[:-1]             #messages列表数据赋值给context
    input = messages[-1]["content"]     #将用户最后一个输入的数据赋值给input
    logger.info("*"*80)
    logger.info(f"input: {input}")

    #1.解析任务:使用ChatGPT分析用户的请求,了解他们的意图,并将其拆解成可能的可解决任务
    task_str = parse_task(context, input).strip()

    #如果是用LLM回复空任务或者对话任务 记录日志,调用openai接口,返回结果
    if task_str == "[]" or "conversational" in task_str:  # using LLM response for empty task or conversational task 
        record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})  #记录log文件
        response = chitchat(messages)   #调用openai接口,返回结果
        return {"message": response}

    try:
        logger.info(task_str)
        tasks = json.loads(task_str)    #将josn格式的数据转换成python对象
    except Exception as e:
        logger.debug(e)
        response = chitchat(messages)   #调用openai接口,返回结果
        record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
        return {"message": response}

    results = {}
    processes = []      #存储启动的进程号
    tasks = tasks[:]    #复制整个tasks列表
    #3.启动多进程处理任务
    with  multiprocessing.Manager() as manager:
        d = manager.dict()  #用manager创建字典d,可以在其他进程中修改数据
        retry = 0
        while True:
            num_process = len(processes)    # num_process开始为0
            for task in tasks:
                #没有依赖前置资源的任务task["dep"]设置为-1
                if not resource_has_dep(task):
                    task["dep"] = [-1]
                dep = task["dep"]
                
                # 没有依赖前置资源的任务,开始执行该任务,并从任务列表中删除该任务
                if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:  #set(dep).intersection(d.keys()) 返回set(dep)和d.keys()两个集合的交集
                    tasks.remove(task)
                    #执行任务的进程
                    process = multiprocessing.Process(target=run_task, args=(input, task, d))   #此处进程会更新字典d的数据
                    process.start()
                    processes.append(process)   #processes是存储启动的进程列表
            if num_process == len(processes):   #这一步的意图不明确???
                time.sleep(0.5)
                retry += 1
            if retry > 160: #用户等待时间过长终止while循环
                logger.debug("User has waited too long, Loop break.")
                break
            if len(tasks) == 0: #任务列表没有任务时终止循环
                break
        #阻塞主进程,等待所有子进程执行完毕,再继续执行主进程
        for process in processes:
            process.join()  
        
        #此时字典d的结果都易经被所有子进程更新完毕,是最终结果了
        results = d.copy()
    logger.debug(results)
    #4.生成结果:最后使用ChatGPT整合所有模型的预测,生成response。
    response = response_results(input, results).strip()     

    #计算运行时间
    end = time.time()
    during = end - start

    #记录成功日志,返回结果
    answer = {"message": response}
    record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
    logger.info(f"response: {response}")
    return answer

1、调用 parse_task函数, 解析任务:使用ChatGPT分析用户的请求,了解他们的意图,并将其拆解成可能的可解决任务

parse_task函数源码注释如下:

def parse_task(context, input):

    demos_or_presteps = parse_task_demos_or_presteps    #看配置文件读取的是parse_task: demos/demo_parse_task.json
    messages = json.loads(demos_or_presteps)            #将demo_parse_task.json的数据转换为python类型存入messages
    messages.insert(0, {"role": "system", "content": parse_task_tprompt})   # messages首位插入数据

    # cut chat logs 切割聊天记录
    start = 0
    while start <= len(context):
        history = context[start:]
        #格式化一下提示词prompt和history数据
        prompt = replace_slot(parse_task_prompt, {
            "input": input,
            "context": history 
        })
        #将数据添加到messages列表中
        messages.append({"role": "user", "content": prompt})
        #格式化messages中prompt的内容
        history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages])
        #获取token的最大长度
        num = count_tokens(LLM_encoding, history_text)
        if get_max_context_length(LLM) - num > 800:
            break
        messages.pop()
        start += 2
    
    logger.debug(messages)
    data = {
        "model": LLM,
        "messages": messages,
        "temperature": 0,
        "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids}
    }
    #请求openai的接口
    return send_request(data)

2、调用 record_case 函数,将记录写入log文件,成功写入log_success.jsonl,失败写入log_fail.jsonl

record_case 函数,源码如下:

def record_case(success, **args):
    if success:
        f = open("log_success.jsonl", "a")
    else:
        f = open("log_fail.jsonl", "a")
    log = args
    f.write(json.dumps(log) + "\n")
    f.close()

3、调用replace_slot函数,源码如下:

def replace_slot(text, entries):
    #遍历entries字典,将history字段格式化成字符串
    for key, value in entries.items():
        if not isinstance(value, str):
            value = str(value)
        text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', ""))
    return text

4、调用 count_tokens 和 get_max_context_length 函数,源码如下: 

encodings = {
    "gpt-3.5-turbo": tiktoken.get_encoding("cl100k_base"),
    "gpt-3.5-turbo-0301": tiktoken.get_encoding("cl100k_base"),
    "text-davinci-003": tiktoken.get_encoding("p50k_base"),
    "text-davinci-002": tiktoken.get_encoding("p50k_base"),
    "text-davinci-001": tiktoken.get_encoding("r50k_base"),
    "text-curie-001": tiktoken.get_encoding("r50k_base"),
    "text-babbage-001": tiktoken.get_encoding("r50k_base"),
    "text-ada-001": tiktoken.get_encoding("r50k_base"),
    "davinci": tiktoken.get_encoding("r50k_base"),
    "curie": tiktoken.get_encoding("r50k_base"),
    "babbage": tiktoken.get_encoding("r50k_base"),
    "ada": tiktoken.get_encoding("r50k_base"),
}

max_length = {
    "gpt-3.5-turbo": 4096,
    "gpt-3.5-turbo-0301": 4096,
    "text-davinci-003": 4096,
    "text-davinci-002": 4096,
    "text-davinci-001": 2049,
    "text-curie-001": 2049,
    "text-babbage-001": 2049,
    "text-ada-001": 2049,
    "davinci": 2049,
    "curie": 2049,
    "babbage": 2049,
    "ada": 2049
}

def count_tokens(model_name, text):
    return len(encodings[model_name].encode(text))

def get_max_context_length(model_name):
    return max_length[model_name]

5、调用 chitchat 函数,此函数,构建参数数据,调用send_request函数给openai接口发请求

chitchat函数源码注释如下:

def chitchat(messages):
    data = {
        "model": LLM,
        "messages": messages
    }
    #调用openai接口
    return send_request(data)

6、调用 resource_has_dep 判断是否有依赖资源。方法:判断args数据里是否包含<GENERATED>字段,如果有返回值True,否则返回False

resource_has_dep 函数源码注释如下:

def resource_has_dep(command):
    args = command["args"]
    for _, v in args.items():
        if "<GENERATED>" in v:
            return True
    return False

内容太多,未完待续......

更多推荐

JARVIS项目源码分析 - awesome_chat.py代码分析2