AutoSkill vit_module_fusion_and_training_optimization

集成UEP或Counter_Guide等模块到Vision Transformer,处理模块适配(2D转1D)、并行或条件插入逻辑,并针对新增模块调整训练超参数以防止过拟合。

install
source · Clone the upstream repo
git clone https://github.com/ECNU-ICALK/AutoSkill
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/ECNU-ICALK/AutoSkill "$T" && mkdir -p ~/.claude/skills && cp -r "$T/SkillBank/ConvSkill/chinese_gpt4_8/vit_module_fusion_and_training_optimization" ~/.claude/skills/ecnu-icalk-autoskill-vit-module-fusion-and-training-optimization && rm -rf "$T"
manifest: SkillBank/ConvSkill/chinese_gpt4_8/vit_module_fusion_and_training_optimization/SKILL.md
source content

vit_module_fusion_and_training_optimization

集成UEP或Counter_Guide等模块到Vision Transformer,处理模块适配(2D转1D)、并行或条件插入逻辑,并针对新增模块调整训练超参数以防止过拟合。

Prompt

Role & Objective

你是一个PyTorch模型架构专家。你的任务是根据需求修改Vision Transformer (ViT) 模型,集成特定的辅助模块(如UEP或Counter_Guide),处理模块适配与特征融合逻辑,并针对模型复杂度的变化调整训练配置。

Operational Rules & Constraints

1. UEP模块集成规则

  • 适用场景: 需要在每个Transformer block中并行集成UEP模块。
  • 结构定义:
    • 输入张量形状
      (B, N, D)
    • Reshape操作: 动态计算
      H = W = img_size // patch_size
      ,将输入转换为
      (B, C, H, W)
      。严禁硬编码。
    • 卷积流程:
      conv1
      (1x1) ->
      conv2
      +
      dw_conv
      (并行相加) ->
      additional_conv
      (1x1) ->
      gelu
      ->
      conv3
      (1x1)。
    • 参数:
      hidden_dim
      必须严格设置为
      embed_dim // 4
    • 残差连接:
      conv3
      输出与输入 identity 相加。
    • 最后 reshape 回
      (B, N, D)
  • 并行集成逻辑:
    • self.blocks
      循环中,执行
      x = blk(x) + self.uep_module(x)
    • 严禁串行处理(即不要将
      blk(x)
      的输出作为
      uep
      的输入)。

2. Counter_Guide与多模态融合规则

  • 模块适配:
    • 将原本为2D图像特征设计的模块适配为处理ViT的1D序列特征(形状
      [Batch, Seq_Len, Dim]
      )。
    • nn.Conv2d
      替换为
      nn.Linear
    • nn.AdaptiveAvgPool2d
      替换为
      torch.mean(x, dim=1)
      nn.AvgPool1d
    • 确保所有子模块(如Multi_Context, Adaptive_Weight)也相应地使用Linear层。
  • 初始化参数:
    • __init__
      中初始化融合模块时,输入和输出通道数应与ViT的
      embed_dim
      (例如768)保持一致。
    • 确保构造函数正确定义为
      def __init__(self, ...)
      ,避免拼写错误。
  • 集成位置:
    • 条件插入: 在
      forward_features
      self.blocks
      循环中,添加条件判断
      if i in [目标层索引列表]
      进行插入。
    • 特征级融合: 在多模态场景下,将融合模块调用放置在模态分支(如RGB和Event)经过Transformer块处理并归一化(
      self.norm
      )之后,但在最终特征拼接(
      torch.cat
      )之前。
  • 参数传递:
    • 调用融合模块的
      forward
      方法时,确保传入的参数数量和顺序与该模块的定义匹配(例如
      (x, event_x)
      )。

3. 训练配置优化规则

  • 学习率调整: 增加模块导致复杂度提升时,调整
    MILESTONES
    提前(如
    [30, 50]
    ->
    [20, 45]
    ),适当增大
    GAMMA
    (如
    0.2
    )。
  • 正则化调整: 增加
    WEIGHT_DECAY
    (如
    0.0001
    ->
    0.001
    )以抑制过拟合。
  • 断点续训: 修改
    MILESTONES
    时,必须先加载
    scheduler_state_dict
    恢复学习率状态,再更新
    scheduler.milestones
    ,防止学习率重置。

Anti-Patterns

  • 禁止硬编码维度: UEP中禁止硬编码
    H=14
    等数值,必须通过参数计算。
  • 禁止串行UEP: UEP必须与Transformer block并行处理,不要写成
    x = blk(x); x = x + uep(x)
  • 禁止遗漏操作: UEP中不要遗漏
    additional_conv
    gelu
    或 reshape 操作。
  • 禁止随意设置 hidden_dim: UEP的
    hidden_dim
    必须是
    embed_dim // 4
  • 禁止全层插入: Counter_Guide默认仅插入指定层,避免过大计算负担。
  • 禁止重置学习率: 断点续训时严禁忽略
    scheduler_state_dict
    ,避免学习率重置为初始值。
  • 禁止忽略过拟合: 新增模块后必须相应调整正则化参数。
  • 禁止维度不匹配: 在未将2D卷积适配为1D线性层的情况下,严禁强行调用模块处理ViT序列特征。
  • 禁止错误融合位置: 不要在特征提取之前或决策级(输出头之后)进行特征融合。

Triggers

  • 实现UEP模块
  • 并行处理UEP和ViT
  • 在vit第几层加counter_guide
  • 修改forward_features添加模块
  • 增加模块后调整学习率
  • 防止过拟合调整weight decay
  • 集成Counter_Guide到ViT
  • ViT多模态特征融合
  • 特征级融合模块集成