有效
一种基于联邦学习的推荐方法、装置、设备和存储介质
祝咏升、王伟、陈国荣、蔡伯根、原笑含、王万齐、陈政、刘敬楷、郝玉蓉、杜飞
中国铁道科学研究院集团有限公司电子计算技术研究所
祝
祝咏升机构 暂无
技术领域 暂无
王
王伟机构 暂无
技术领域 暂无
陈
陈国荣机构 暂无
技术领域 暂无
蔡
蔡伯根机构 暂无
技术领域 暂无
原
原笑含机构 暂无
技术领域 暂无
王
王万齐机构 暂无
技术领域 暂无
陈
陈政机构 暂无
技术领域 暂无
刘
刘敬楷机构 暂无
技术领域 暂无
郝
郝玉蓉机构 暂无
技术领域 暂无
杜
杜飞机构 暂无
技术领域 暂无
摘要
本申请公开了一种基于联邦学习的推荐方法、装置、设备和存储介质,用于提高推荐系统的公平性。本申请接收目标对象上传的目标属性;基于目标属性对目标对象进行分组,得到群体;针对每个群体基于目标属性的属性值对目标对象进行分组,得到子群体;针对每个子群体,将子群体对应的模型参数集发送给子群体中的每个目标对象;以使目标对象根据接收到的模型参数集本地模型进行训练;本地模型用于执行推荐操作。用户可以选择自己期望的敏感属性,根据不同的敏感属性构建不同的群体,并根据群体中的不同取值来构建子群体,针对每个子群体均设置了对应的模型参数集,进而可以保证训练得到的本地模型更加的准确,可以保证公平性的同时提供较高的推荐性能。
1.一种基于联邦学习的推荐方法,其特征在于,应用于服务器,所述方法包括:接收用户基于目标对象上传的目标属性;所述目标属性包括:年龄、性别;所述目标对象为终端设备;基于所述目标属性对所述终端设备进行分组,得到至少一个群体;针对每个群体:基于目标属性的属性值对所述终端设备进行分组,得到至少一个子群体;其中,在所述目标属性为性别时,所述性别对应的属性值为男性对应的数值以及女性对应的数值,在所述目标属性为年龄时,所述年龄对应的属性值为用户输入的年龄对应的数值;针对每个子群体,将所述子群体对应的模型参数集发送给所述子群体中的每个终端设备;以使所述终端设备中的购物平台根据接收到的模型参数集对对应的本地模型进行训练;所述本地模型用于使所述购物平台对用户推荐物品;所述模型参数集中还包括:上一轮更新权重,所述方法还包括:接收每个终端设备发送的公平度量值;基于每个终端设备发送的所述公平度量值,得到公平度量均值;针对每个子群体:基于所述子群体中每个终端设备发送的公平度量值得到群体度量值;基于所述群体度量值以及所述公平度量均值,得到所述子群体对应的目标更新权重;采用所述目标更新权重更新所述上一轮更新权重;其中,所述终端设备是根据以下方法对本地模型进行训练的:所述终端设备根据所述上一轮更新权重以及预设轮次初始值,得到目标训练轮次;采用所述终端设备对应的训练数据集、所述目标训练轮次以及优化目标值对所述本地模型进行训练,得到训练后的本地模型以及所述本地模型的当前模型参数;所述优化目标值是所述终端设备根据所述模型参数集来确定。
2.根据权利要求1所述的方法,其特征在于,所述模型参数集中包括:上一轮的全局模型更新参数,所述方法还包括:接收每个终端设备上传的第一本地参数;接收每个终端设备发送的模型差值;基于每个终端设备对应的第一本地参数得到第二本地参数;对所述第二本地参数、每个所述终端设备对应的第一本地参数、以及每个所述终端设备对应的模型差值进行聚合处理,得到全局模型更新参数;采用所述全局模型更新参数更新存储的上一轮的全局模型更新参数。
3.根据权利要求2所述的方法,其特征在于,所述模型参数集中包括:上一轮的子群体模型更新参数,所述接收每个终端设备发送的模型差值之后,所述方法还包括:针对每个子群体;基于所述子群体中每个终端设备对应的本地数据构建子群体参数;对所述第二本地参数、所述子群体参数、以及所述模型差值进行聚合处理,得到子群体模型更新参数;采用所述子群体模型更新参数更新存储的上一轮的子群体模型更新参数。
4.根据权利要求2所述的方法,其特征在于,所述模型参数集中还包括:上一轮的正则项强度,所述方法还包括:接收每个所述终端设备发送的本地预测误差;针对每个子群体,基于所述子群体中每个终端设备对应的本地数据构建子群体参数;对所述第二本地参数、所述子群体参数、以及所述本地预测误差进行聚合处理,得到目标正则项强度;采用所述目标正则项强度更新存储的上一轮的正则项强度。
5.一种基于联邦学习的推荐方法,其特征在于,应用于终端设备,所述方法包括:从服务器中获取模型参数集;所述模型参数集是所述终端设备所属的子群体对应的模型参数集;所述终端设备所属的子群体是根据目标对象上传的目标属性确定的;所述目标属性包括:年龄、性别;所述目标对象为所述终端设备;根据所述模型参数集来确定优化目标值;根据所述优化目标值来对购物平台对应的本地模型进行训练,得到训练后的本地模型;采用所述训练后的本地模型对用户推荐物品;其中,所述模型参数集中包括:上一轮的全局模型更新参数、上一轮的子群体模型更新参数、上一轮的正则项强度、上一轮更新权重;所述上一轮的正则项强度用于确定目标对象的本地模型是否收敛;所述上一轮更新权重用于确定所述目标对象的本地模型的训练轮次;所述目标对象为终端设备,所述终端设备中设置有购物平台;所述上一轮更新权重是所述服务器根据以下方法得到的:根据接收到的每个终端设备发送的公平度量值;基于每个终端设备发送的所述公平度量值,得到公平度量均值;针对每个子群体:基于所述子群体中每个终端设备发送的公平度量值得到群体度量值;基于所述群体度量值以及所述公平度量均值,得到所述子群体对应的目标更新权重;采用所述目标更新权重更新所述上一轮更新权重;所述根据所述优化目标值来对购物平台对应的本地模型进行训练,得到训练后的本地模型,包括:根据所述上一轮更新权重以及预设轮次初始值,得到目标训练轮次;采用所述终端设备对应的训练数据集、所述目标训练轮次以及所述优化目标值对所述本地模型进行训练,得到训练后的本地模型以及所述本地模型的当前模型参数;获取验证数据集;基于所述训练数据集确定所述终端设备对应的交互记录;基于所述交互记录、所述当前模型参数、以及预设的准确度函数得到准确度;基于所述准确度以及所述验证数据集,得到公平度量值。
6.根据权利要求5所述的方法,其特征在于,所述根据所述模型参数集来确定优化目标值,包括:获取上一轮的初始模型参数;基于所述上一轮的初始模型参数以及所述上一轮的子群体模型更新参数,得到所述本地模型的本地模型参数;基于所述上一轮的初始模型参数以及所述上一轮的全局模型更新参数,得到所述终端设备的全局模型的全局模型参数;基于所述终端设备对应的训练数据集、所述本地模型参数、所述全局模型参数,得到优化目标值。
7.根据权利要求6所述的方法,其特征在于,所述基于所述终端设备对应的训练数据集、所述本地模型参数、所述全局模型参数,得到优化目标值,包括:基于所述训练数据集确定所述终端设备对应的交互记录以及交互数据量;基于所述交互记录、所述交互数据量、所述上一轮的正则项强度、所述本地模型参数、所述全局模型参数以及预设的损失函数,得到所述优化目标值。
8.根据权利要求6所述的方法,其特征在于,所述得到优化目标值之后,所述方法还包括:基于所述训练数据集确定所述终端设备对应的交互记录;基于所述交互记录、损失函数以及所述本地模型参数得到本地损失值;基于所述交互记录、所述损失函数以及所述全局模型参数,得到全局损失值;获取验证数据集;基于所述本地损失值、所述全局损失值以及所述验证数据集,得到本地预测误差;将所述本地预测误差上传至服务器,以使所述服务器根据所述本地预测误差更新上一轮的正则项强度。
9.根据权利要求6所述的方法,其特征在于,所述得到优化目标值之后,所述方法还包括:根据所述当前模型参数以及所述本地模型参数得到模型差值;将所述模型差值上传至服务器,以使所述服务器根据所述模型差值更新模型参数中的更新存储的所述上一轮的全局模型更新参数。
10.一种基于联邦学习的推荐装置,其特征在于,应用于服务器,所述装置包括:接收模块,用于接收用户基于目标对象上传的目标属性;所述目标属性包括:年龄、性别;所述目标对象为终端设备;第一分组模块,用于基于所述目标属性对所述终端设备进行分组,得到至少一个群体;第二分组模块,用于针对每个群体:基于目标属性的属性值对所述终端设备进行分组,得到至少一个子群体;其中,在所述目标属性为性别时,所述性别对应的属性值为男性对应的数值以及女性对应的数值,在所述目标属性为年龄时,所述年龄对应的属性值为用户输入的年龄对应的数值;参数下发模块,用于针对每个子群体,将所述子群体对应的模型参数集发送给所述子群体中的每个终端设备;以使所述终端设备根据接收到的模型参数集对本地模型进行训练;所述本地模型用于使购物平台对用户推荐物品;所述模型参数集中包括:上一轮更新权重,所述参数下发模块还用于:接收每个终端设备发送的公平度量值;基于每个终端设备发送的所述公平度量值,得到公平度量均值;针对每个子群体:基于所述子群体中每个终端设备发送的公平度量值得到群体度量值;基于所述群体度量值以及所述公平度量均值,得到所述子群体对应的目标更新权重;采用所述目标更新权重更新所述上一轮更新权重;其中,所述终端设备是根据以下方法对本地模型进行训练的:所述终端设备根据所述上一轮更新权重以及预设轮次初始值,得到目标训练轮次;采用所述终端设备对应的训练数据集、所述目标训练轮次以及优化目标值对所述本地模型进行训练,得到训练后的本地模型以及所述本地模型的当前模型参数;所述优化目标值是所述终端设备根据所述模型参数集来确定。
11.根据权利要求10所述的装置,其特征在于,所述模型参数集中包括:上一轮的全局模型更新参数,所述参数下发模块还用于:接收每个终端设备上传的第一本地参数;接收每个终端设备发送的模型差值;基于每个终端设备对应的第一本地参数得到第二本地参数;对所述第二本地参数、每个所述终端设备对应的第一本地参数、以及每个所述终端设备对应的模型差值进行聚合处理,得到全局模型更新参数;采用所述全局模型更新参数更新存储的上一轮的全局模型更新参数。
12.根据权利要求11所述的装置,其特征在于,所述模型参数集中包括:上一轮的子群体模型更新参数,所述参数下发模块还用于:针对每个子群体;基于所述子群体中每个终端设备对应的本地数据构建子群体参数;对所述第二本地参数、所述子群体参数、以及所述模型差值进行聚合处理,得到子群体模型更新参数;采用所述子群体模型更新参数更新存储的上一轮的子群体模型更新参数。
13.根据权利要求11所述的装置,其特征在于,所述模型参数集中包括:上一轮的正则项强度,所述参数下发模块还用于:接收每个所述终端设备发送的本地预测误差;针对每个子群体,基于所述子群体中每个终端设备对应的本地数据构建子群体参数;对所述第二本地参数、所述子群体参数、以及所述本地预测误差进行聚合处理,得到目标正则项强度;采用所述目标正则项强度更新存储的上一轮的正则项强度。
14.一种基于联邦学习的推荐装置,其特征在于,应用于终端设备,所述装置包括:参数接收模块,用于从服务器中获取模型参数集;所述模型参数集是所述终端设备所属的子群体对应的模型参数集;所述终端设备所属的子群体是根据目标对象上传的目标属性确定的;所述目标属性包括:年龄、性别;所述目标对象为所述终端设备;优化目标值确定模块,用于根据所述模型参数集来确定优化目标值;训练模块,用于根据所述优化目标值来对购物平台对应的本地模型进行训练,得到训练后的本地模型;推荐模块,用于采用所述训练后的本地模型对用户推荐物品;其中,所述模型参数集中包括:上一轮的全局模型更新参数、上一轮的子群体模型更新参数、上一轮的正则项强度、上一轮更新权重;所述上一轮的正则项强度用于确定目标对象的本地模型是否收敛;所述上一轮更新权重用于确定所述目标对象的本地模型的训练轮次;所述目标对象为终端设备,所述终端设备中设置有购物平台;所述上一轮更新权重是所述服务器根据以下方法得到的:根据接收到的每个终端设备发送的公平度量值;基于每个终端设备发送的所述公平度量值,得到公平度量均值;针对每个子群体:基于所述子群体中每个终端设备发送的公平度量值得到群体度量值;基于所述群体度量值以及所述公平度量均值,得到所述子群体对应的目标更新权重;采用所述目标更新权重更新所述上一轮更新权重;所述训练模块,具体用于:根据所述上一轮更新权重以及预设轮次初始值,得到目标训练轮次;采用所述终端设备对应的训练数据集、所述目标训练轮次以及所述优化目标值对所述本地模型进行训练,得到训练后的本地模型以及所述本地模型的当前模型参数;优化目标值确定模块,还用于:获取验证数据集;基于所述训练数据集确定所述终端设备对应的交互记录;基于所述交互记录、所述当前模型参数、以及预设的准确度函数得到准确度。
15.根据权利要求14所述的装置,其特征在于,所述优化目标值确定模块,用于:获取上一轮的初始模型参数;基于所述上一轮的初始模型参数以及所述上一轮的子群体模型更新参数,得到所述本地模型的本地模型参数;基于所述上一轮的初始模型参数以及所述上一轮的全局模型更新参数,得到所述终端设备的全局模型的全局模型参数;基于所述终端设备对应的训练数据集、所述本地模型参数、所述全局模型参数,得到优化目标值。
16.根据权利要求15所述的装置,其特征在于,所述优化目标值确定模块,用于:基于所述训练数据集确定所述终端设备对应的交互记录以及交互数据量;基于所述交互记录、所述交互数据量、所述上一轮的正则项强度、所述本地模型参数、所述全局模型参数以及预设的损失函数,得到所述优化目标值。
17.根据权利要求16所述的装置,其特征在于,所述优化目标值确定模块还用于:基于所述训练数据集确定所述终端设备对应的交互记录;基于所述交互记录、损失函数以及所述本地模型参数得到本地损失值;基于所述交互记录、所述损失函数以及所述全局模型参数,得到全局损失值;获取验证数据集;基于所述本地损失值、所述全局损失值以及所述验证数据集,得到本地预测误差;将所述本地预测误差上传至服务器,以使所述服务器根据所述本地预测误差更新上一轮的正则项强度。
18.根据权利要求16所述的装置,其特征在于,所述优化目标值确定模块还用于:根据所述当前模型参数以及所述本地模型参数得到模型差值;将所述模型差值上传至服务器,以使所述服务器根据所述模型差值更新模型参数中的更新存储的所述上一轮的全局模型更新参数。
19.一种电子设备,其特征在于,包括用于存储计算机程序指令的存储器和用于执行程序指令的处理器,其中,当该计算机程序指令被所述处理器执行时,触发所述电子设备执行权利要求1-4、5-9中任一项所述的方法。
20.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质包括存储的程序,其中,在所述程序运行时控制所述计算机可读存储介质所在设备执行权利要求1-4、5-9中任意一项所述的方法。



