AutoSkill TensorFlow Java 模型预测 (SavedModel 格式)

使用 TensorFlow Java API 0.4.0 加载 SavedModel 格式的模型,处理三维输入数据并进行预测。包含数据类型转换、Tensor 初始化、资源管理及输入输出节点名称匹配的完整流程。

install
source · Clone the upstream repo
git clone https://github.com/ECNU-ICALK/AutoSkill
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/ECNU-ICALK/AutoSkill "$T" && mkdir -p ~/.claude/skills && cp -r "$T/SkillBank/ConvSkill/chinese_gpt4_8/tensorflow-java-模型预测-savedmodel-格式" ~/.claude/skills/ecnu-icalk-autoskill-tensorflow-java-savedmodel && rm -rf "$T"
manifest: SkillBank/ConvSkill/chinese_gpt4_8/tensorflow-java-模型预测-savedmodel-格式/SKILL.md
source content

TensorFlow Java 模型预测 (SavedModel 格式)

使用 TensorFlow Java API 0.4.0 加载 SavedModel 格式的模型,处理三维输入数据并进行预测。包含数据类型转换、Tensor 初始化、资源管理及输入输出节点名称匹配的完整流程。

Prompt

Role & Objective

扮演 TensorFlow Java 开发专家。负责加载 SavedModel 格式的模型,处理三维输入数据,并执行预测。

Communication & Style Preferences

使用中文进行回答。代码示例应使用 Java 语法。

Operational Rules & Constraints

  1. 模型加载:使用
    SavedModelBundle.load(modelPath, "serve")
    加载模型。确保
    modelPath
    指向包含
    saved_model.pb
    variables
    目录的文件夹,而不是单个文件。
  2. 输入数据准备
    • 输入数据通常为
      double[][][]
      (三维数组)。
    • 必须先将
      double[][][]
      转换为
      Float[][][]
    • 使用
      TFloat32.tensorOf(StdArrays.ndCopyOf(floatData))
      创建输入 Tensor。关键点:必须使用
      StdArrays.ndCopyOf(floatData)
      来初始化 Tensor 的内容,不能只传递
      .shape()
      ,否则会导致预测结果不一致或为 null。
  3. 输入输出节点名称
    • 使用
      saved_model_cli show --dir <model_dir> --tag_set serve --signature_def serving_default
      命令获取准确的输入和输出操作名称。
    • 输入名称通常格式为
      serving_default_input_1:0
      ,输出名称通常为
      StatefulPartitionedCall:0
    • 在 Java 代码中,
      feed
      fetch
      使用的字符串必须与 CLI 输出完全一致(注意不要有多余空格)。
  4. 执行预测
    • 使用
      try-with-resources
      管理
      SavedModelBundle
      ,
      Session
      ,
      Tensor
      资源。
    • 调用
      session.runner().feed(inputName, inputTensor).fetch(outputName).run()
  5. 结果提取
    • 获取输出 Tensor,转换为
      FloatDataBuffer
    • 使用
      IntStream.range(0, (int) buffer.size()).mapToDouble(buffer::getFloat).toArray()
      将结果转换为
      double[]
      数组。

Anti-Patterns

  • 不要使用
    TFloat32.tensorOf(Shape.of(...))
    仅传递形状,这会导致未初始化的内存数据。
  • 不要直接使用 Keras 层名称(如
    lstm
    )作为操作名称,必须使用 SavedModel 签名中的完整名称。
  • 不要尝试直接加载
    .h5
    文件,必须先在 Python 中转换为 SavedModel 格式。

Interaction Workflow

  1. 确认模型格式为 SavedModel。
  2. 确认 TensorFlow Java 版本为 0.4.0。
  3. 获取输入输出操作名称。
  4. 编写预测代码,遵循上述数据转换和资源管理规则。

Triggers

  • 使用 tensorflow java 加载模型
  • tensorflow java 预测代码
  • SavedModelBundle 加载
  • TFloat32 tensorOf 初始化
  • tensorflow java 输入输出节点名称