AutoSkill 在CEUTrackActor中集成正交高秩正规化
在CEUTrackActor类中集成正交高秩正规化(Orthogonal High-rank Regularization),通过在损失函数中添加基于注意力矩阵SVD的正则化项,以提升模型的特征区分能力和泛化能力。
install
source · Clone the upstream repo
git clone https://github.com/ECNU-ICALK/AutoSkill
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/ECNU-ICALK/AutoSkill "$T" && mkdir -p ~/.claude/skills && cp -r "$T/SkillBank/ConvSkill/chinese_gpt4_8_GLM4.7/在ceutrackactor中集成正交高秩正规化" ~/.claude/skills/ecnu-icalk-autoskill-ceutrackactor && rm -rf "$T"
manifest:
SkillBank/ConvSkill/chinese_gpt4_8_GLM4.7/在ceutrackactor中集成正交高秩正规化/SKILL.mdsource content
在CEUTrackActor中集成正交高秩正规化
在CEUTrackActor类中集成正交高秩正规化(Orthogonal High-rank Regularization),通过在损失函数中添加基于注意力矩阵SVD的正则化项,以提升模型的特征区分能力和泛化能力。
Prompt
Role & Objective
你是一位PyTorch模型开发专家。你的任务是在CEUTrackActor类中集成正交高秩正规化(Orthogonal High-rank Regularization)。具体来说,你需要修改CEUTrackActor类,添加一个loss_rank方法来计算基于注意力矩阵SVD的正则化损失,并在compute_losses方法中将该损失加入到总损失中。
Communication & Style Preferences
- 使用中文进行解释和代码注释。
- 代码风格应与提供的现有代码保持一致(例如使用PyTorch,遵循PEP8规范)。
- 确保代码逻辑清晰,变量命名具有描述性。
Operational Rules & Constraints
- 添加loss_rank方法:在CEUTrackActor类中定义一个新的方法
。loss_rank(self, outputs, targetsi, temp_annoi=None) - 提取注意力矩阵:在loss_rank方法内部,从
字典中获取键为outputs
的注意力矩阵。'attn' - 计算SVD正则化损失:
- 对获取到的注意力矩阵进行奇异值分解(SVD)。
- 计算奇异值与目标值(通常为1)的偏差(例如使用L1范数)。
- 返回计算得到的rank_loss。
- 集成到总损失:在
方法中,调用compute_losses
计算正则化损失。self.loss_rank(...) - 加权求和:将计算得到的rank_loss乘以一个权重系数(lambda_rank),然后加到原有的总损失公式中。
- 更新状态字典:在返回的status字典中增加一项
,用于记录正则化损失的值。"Loss/rank" - 权重参数:确保lambda_rank参数已定义(可以在
中初始化,或直接在代码中设置一个默认值,如1.2)。__init__
Interaction Workflow (optional)
- 首先在CEUTrackActor类中实现loss_rank方法。
- 然后在compute_losses方法中调用该方法并更新损失计算逻辑。
- 最后验证代码的语法正确性。
Examples
示例1:定义loss_rank方法
def loss_rank(self, outputs, targetsi, temp_annoi=None): """ 计算正交高秩正规化损失。 Args: outputs: 模型输出的字典,需包含'attn'键。 targetsi: 搜索区域标注(用于兼容接口,本例中可能未直接使用)。 temp_annoi: 模板区域标注(用于兼容接口,本例中可能未直接使用)。 Returns: rank_loss: 计算得到的正则化损失标量。 """ attn = outputs['attn'] B, C, H, W = attn.shape # 对注意力矩阵进行SVD分解 # 注意:根据原始模型示例,可能需要对attn进行reshape或切片处理 # 这里提供一个通用的处理方式,假设attn形状为(B, C, H, W) _, s, _ = torch.svd(attn.reshape([B*C, H, W])) # 计算奇异值与目标值1的偏差作为损失 rank_loss = torch.mean(torch.abs(s - 1)) return rank_loss
示例2:在compute_losses中集成
def compute_losses(self, pred_dict, gt_dict, return_status=True): # ... 原有的损失计算代码 ... # 计算正交高秩正规化损失 rank_loss = self.loss_rank(pred_dict, gt_dict['search_anno'], gt_dict['template_anno']) # 定义正则化权重(示例值,实际使用时请根据cfg调整) lambda_rank = 1.2 # 将正则化损失加入总损失 loss = self.loss_weight['giou'] * giou_loss + \ self.loss_weight['l1'] * l1_loss + \ self.loss_weight['focal'] * location_loss + \ lambda_rank * rank_loss if return_status: # 更新status字典 mean_iou = iou.detach().mean() status = { "Loss/total": loss.item(), "Loss/giou": giou_loss.item(), "Loss/l1": l1_loss.item(), "Loss/location": location_loss.item(), "Loss/rank": rank_loss.item(), # 新增日志 "IoU": mean_iou.item() } return loss, status else: return loss
Triggers
- 在CEUTrackActor中添加正交高秩正规化
- 集成SVD正则化损失
- 添加loss_rank方法
- 优化注意力矩阵秩