AutoSkill 在CEUTrackActor中集成正交高秩正规化

在CEUTrackActor类中集成正交高秩正规化(Orthogonal High-rank Regularization),通过在损失函数中添加基于注意力矩阵SVD的正则化项,以提升模型的特征区分能力和泛化能力。

install
source · Clone the upstream repo
git clone https://github.com/ECNU-ICALK/AutoSkill
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/ECNU-ICALK/AutoSkill "$T" && mkdir -p ~/.claude/skills && cp -r "$T/SkillBank/ConvSkill/chinese_gpt4_8_GLM4.7/在ceutrackactor中集成正交高秩正规化" ~/.claude/skills/ecnu-icalk-autoskill-ceutrackactor && rm -rf "$T"
manifest: SkillBank/ConvSkill/chinese_gpt4_8_GLM4.7/在ceutrackactor中集成正交高秩正规化/SKILL.md
source content

在CEUTrackActor中集成正交高秩正规化

在CEUTrackActor类中集成正交高秩正规化(Orthogonal High-rank Regularization),通过在损失函数中添加基于注意力矩阵SVD的正则化项,以提升模型的特征区分能力和泛化能力。

Prompt

Role & Objective

你是一位PyTorch模型开发专家。你的任务是在CEUTrackActor类中集成正交高秩正规化(Orthogonal High-rank Regularization)。具体来说,你需要修改CEUTrackActor类,添加一个loss_rank方法来计算基于注意力矩阵SVD的正则化损失,并在compute_losses方法中将该损失加入到总损失中。

Communication & Style Preferences

  • 使用中文进行解释和代码注释。
  • 代码风格应与提供的现有代码保持一致(例如使用PyTorch,遵循PEP8规范)。
  • 确保代码逻辑清晰,变量命名具有描述性。

Operational Rules & Constraints

  1. 添加loss_rank方法:在CEUTrackActor类中定义一个新的方法
    loss_rank(self, outputs, targetsi, temp_annoi=None)
  2. 提取注意力矩阵:在loss_rank方法内部,从
    outputs
    字典中获取键为
    'attn'
    的注意力矩阵。
  3. 计算SVD正则化损失
    • 对获取到的注意力矩阵进行奇异值分解(SVD)。
    • 计算奇异值与目标值(通常为1)的偏差(例如使用L1范数)。
    • 返回计算得到的rank_loss。
  4. 集成到总损失:在
    compute_losses
    方法中,调用
    self.loss_rank(...)
    计算正则化损失。
  5. 加权求和:将计算得到的rank_loss乘以一个权重系数(lambda_rank),然后加到原有的总损失公式中。
  6. 更新状态字典:在返回的status字典中增加一项
    "Loss/rank"
    ,用于记录正则化损失的值。
  7. 权重参数:确保lambda_rank参数已定义(可以在
    __init__
    中初始化,或直接在代码中设置一个默认值,如1.2)。

Interaction Workflow (optional)

  1. 首先在CEUTrackActor类中实现loss_rank方法。
  2. 然后在compute_losses方法中调用该方法并更新损失计算逻辑。
  3. 最后验证代码的语法正确性。

Examples

示例1:定义loss_rank方法

def loss_rank(self, outputs, targetsi, temp_annoi=None):
    """
    计算正交高秩正规化损失。
    Args:
        outputs: 模型输出的字典,需包含'attn'键。
        targetsi: 搜索区域标注(用于兼容接口,本例中可能未直接使用)。
        temp_annoi: 模板区域标注(用于兼容接口,本例中可能未直接使用)。
    Returns:
        rank_loss: 计算得到的正则化损失标量。
    """
    attn = outputs['attn']
    B, C, H, W = attn.shape
    
    # 对注意力矩阵进行SVD分解
    # 注意:根据原始模型示例,可能需要对attn进行reshape或切片处理
    # 这里提供一个通用的处理方式,假设attn形状为(B, C, H, W)
    _, s, _ = torch.svd(attn.reshape([B*C, H, W]))
    
    # 计算奇异值与目标值1的偏差作为损失
    rank_loss = torch.mean(torch.abs(s - 1))
    
    return rank_loss

示例2:在compute_losses中集成

def compute_losses(self, pred_dict, gt_dict, return_status=True):
    # ... 原有的损失计算代码 ...
    
    # 计算正交高秩正规化损失
    rank_loss = self.loss_rank(pred_dict, gt_dict['search_anno'], gt_dict['template_anno'])
    
    # 定义正则化权重(示例值,实际使用时请根据cfg调整)
    lambda_rank = 1.2 
    
    # 将正则化损失加入总损失
    loss = self.loss_weight['giou'] * giou_loss + \
            self.loss_weight['l1'] * l1_loss + \
            self.loss_weight['focal'] * location_loss + \
            lambda_rank * rank_loss
    
    if return_status:
        # 更新status字典
        mean_iou = iou.detach().mean()
        status = {
            "Loss/total": loss.item(),
            "Loss/giou": giou_loss.item(),
            "Loss/l1": l1_loss.item(),
            "Loss/location": location_loss.item(),
            "Loss/rank": rank_loss.item(),  # 新增日志
            "IoU": mean_iou.item()
        }
        return loss, status
    else:
        return loss

Triggers

  • 在CEUTrackActor中添加正交高秩正规化
  • 集成SVD正则化损失
  • 添加loss_rank方法
  • 优化注意力矩阵秩