疯狂交互学习的BM3推荐算法(论文复现)
文章目录
多模态推荐系统
示例
对比学习
什么是对比学习?
关键思想
优点
自监督学习
什么是自监督学习?
优点
实现自监督学习的方法
解决方案
框架图
损失函数(每一步都是自监督对比)
图重构损失
模态间对齐损失
相当于Item是标签,这些Text和Image是特征,相互学习的过程,把Text赋予标签信息,然后在Item里面增加更多的Text和Image的特征信息,同时由于Dropout可以保证学习的不崩溃
-
统一性和稳定性
- 项目(item)的嵌入表示相对于用户(user)的嵌入表示更为稳定和统一。用户的行为和兴趣可能会随时间和情境发生变化,而项目的特征相对固定,因此使用项目嵌入可以提供更稳定的对齐基础。
-
多视图一致性
- 多视图特征表示 $ h_m’ $ 是从不同模态(如文本、图像、音频等)中提取的。这些特征通常描述的是项目的不同方面,因此使用项目的嵌入来对齐多视图特征可以确保不同模态下的项目特征一致性。
-
提高泛化能力
- 使用项目嵌入来对齐多视图特征可以帮助模型更好地捕捉项目的多模态特性,从而提高模型在处理多模态推荐任务时的泛化能力。这意味着模型可以更好地理解和推荐多种类型的项目,即使在用户行为发生变化时,模型仍然能够提供有效的推荐。
模态内特征遮蔽损失
实验分析
环境部署
git clone https://github.com/enoche/BM3.git
环境配置
pip install -r requirements.txt
conda install --file requirements.txt
数据集配置
代码运行
cd .\src
python main.py -m BM3 -d baby
运行截图
代码分析
Loss分析
def calculate_loss(self, interactions):
# online network
u_online_ori, i_online_ori = self.forward()
t_feat_online, v_feat_online = None, None
if self.t_feat is not None:
t_feat_online = self.text_trs(self.text_embedding.weight)
if self.v_feat is not None:
v_feat_online = self.image_trs(self.image_embedding.weight)
with torch.no_grad(): # 停止梯度更新,这样在下面的操作中不会计算梯度,节省内存和计算资源
u_target, i_target = u_online_ori.clone(), i_online_ori.clone() # 复制在线用户和物品的原始特征向量
u_target.detach() # 分离用户目标特征向量,使其不参与梯度计算
i_target.detach() # 分离物品目标特征向量,使其不参与梯度计算
u_target = F.dropout(u_target, self.dropout) # 对用户目标特征向量应用Dropout,生成用户对比试图
i_target = F.dropout(i_target, self.dropout) # 对物品目标特征向量应用Dropout,生成物品对比试图
if self.t_feat is not None: # 检查时间特征是否存在
t_feat_target = t_feat_online.clone() # 复制时间特征向量
t_feat_target = F.dropout(t_feat_target, self.dropout) # 对时间特征向量应用Dropout,生成image对比试图
if self.v_feat is not None: # 检查image特征是否存在
v_feat_target = v_feat_online.clone() # 复制image特征
v_feat_target = F.dropout(v_feat_target, self.dropout) # 对image特征向量Dropout,生成text对比试图
# 预测用户和物品的在线特征向量
u_online, i_online = self.predictor(u_online_ori), self.predictor(i_online_ori)
# 获取交互数据中的用户和物品索引
users, items = interactions[0], interactions[1]
# 根据用户和物品索引提取相应的在线特征和目标特征
u_online = u_online[users, :] # 提取在线用户特征
i_online = i_online[items, :] # 提取在线物品特征
u_target = u_target[users, :] # 提取目标用户特征
i_target = i_target[items, :] # 提取目标物品特征
# 初始化各类损失为0
loss_t, loss_v, loss_tv, loss_vt = 0.0, 0.0, 0.0, 0.0
if self.t_feat is not None: # 检查时间特征是否存在
t_feat_online = self.predictor(t_feat_online) # 通过预测器更新在线时间特征
t_feat_online = t_feat_online[items, :] # 提取更新后的在线时间特征
t_feat_target = t_feat_target[items, :] # 提取目标时间特征
# 计算时间特征和物品目标特征的余弦相似度损失
loss_t = 1 - cosine_similarity(t_feat_online, i_target.detach(), dim=-1).mean()
# 计算时间特征和目标时间特征的余弦相似度损失
loss_tv = 1 - cosine_similarity(t_feat_online, t_feat_target.detach(), dim=-1).mean()
if self.v_feat is not None: # 检查视觉特征是否存在
v_feat_online = self.predictor(v_feat_online) # 通过预测器更新在线视觉特征
v_feat_online = v_feat_online[items, :] # 提取更新后的在线视觉特征
v_feat_target = v_feat_target[items, :] # 提取目标视觉特征
# 计算视觉特征和物品目标特征的余弦相似度损失
loss_v = 1 - cosine_similarity(v_feat_online, i_target.detach(), dim=-1).mean()
# 计算视觉特征和目标视觉特征的余弦相似度损失
loss_vt = 1 - cosine_similarity(v_feat_online, v_feat_target.detach(), dim=-1).mean()
# 计算用户在线特征和物品目标特征的余弦相似度损失
loss_ui = 1 - cosine_similarity(u_online, i_target.detach(), dim=-1).mean()
# 计算物品在线特征和用户目标特征的余弦相似度损失
loss_iu = 1 - cosine_similarity(i_online, u_target.detach(), dim=-1).mean()
# 返回总损失,包括余弦相似度损失、正则化损失和对比损失
return (loss_ui + loss_iu).mean() + self.reg_weight * self.reg_loss(u_online_ori, i_online_ori) + \
self.cl_weight * (loss_t + loss_v + loss_tv + loss_vt).mean()
参数代码分析
# 创建ArgumentParser对象用于解析命令行参数
parser = argparse.ArgumentParser()
# 添加命令行参数 --model 或 -m,用于指定模型名称,默认值为 'BM3'
parser.add_argument('--model', '-m', type=str, default='BM3', help='name of models')
# 添加命令行参数 --dataset 或 -d,用于指定数据集名称,默认值为 'baby'
parser.add_argument('--dataset', '-d', type=str, default='baby', help='name of datasets')
# 定义包含GPU配置信息的字典
config_dict = {
'gpu_id': 0,
}
# 解析命令行参数,将结果存储在 args 对象中
args, _ = parser.parse_known_args()
# 调用 quick_start 函数,传递模型名称、数据集名称、配置字典以及是否保存模型的标志
quick_start(model=args.model, dataset=args.dataset, config_dict=config_dict, save_model=True)
相关文件作用分析
BM3/
├── data/ # 数据目录
│ ├── baby/ # 婴儿数据目录
│ ├── clothing/ # 服装数据目录
│ └── sports/ # 运动数据目录
│
├── src/ # 源代码目录
│ ├── common/ # 公共模块目录
│ ├── configs/ # 配置目录
│ ├── log/ # 日志目录
│ ├── models/ # 模型目录
│ │ └── bm3.py # BM3模型代码
│ └── utils/ # 工具目录
│ └── main.py # 主程序代码
│
├── trained-models-logs/ # 训练模型日志目录
│
├── .gitignore # git忽略文件
├── LICENSE # 许可证文件
├── README.md # 项目说明文件
└── requirements.txt # 项目依赖文件
代码算法复现结果
Category | N@10 | N@20 | R@10 | R@20 |
---|---|---|---|---|
Baby | 0.0559 | 0.0880 | 0.0296 | 0.0383 |
Sports | 0.0646 | 0.0978 | 0.0345 | 0.0435 |
Electronics | 0.0434 | 0.0643 | 0.0247 | 0.0301 |