修改 TensorFlow C API 的代码,使得固定的 saved_model.pb 文件名变为可设置的
我在 2021 年 5 月份发出了一个 SpleeterMsvcExe 开源项目,当时是支持 11kHz 和 16kHz 两种模型,年底的时候又添加了对 22kHz 模型的支持。这三种不同频率上限的模型, variables 目录中的文件是完全相同的,只有 saved_model.pb 不同。但 TensorFlow 源码中这个文件名在 constants.h 中被定义为了固定值:

  1. // SavedModel proto filename.
  2. constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
  3.  
  4. // SavedModel text format proto filename.
  5. constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";

用于加载模型的 ReadMetaGraphDefFromSavedModel() API 函数会调用 reader.cc 中的 ReadSavedModel() 函数:

  1. Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
  2.   LOG(INFO) << "Reading SavedModel from: " << export_dir;
  3.  
  4.   const string saved_model_pb_path =
  5.       io::JoinPath(export_dir, kSavedModelFilenamePb);
  6.   if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
  7.     return ReadBinaryProto(Env::Default(), saved_model_pb_path,
  8.                            saved_model_proto);
  9.   }
  10.   const string saved_model_pbtxt_path =
  11.       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
  12.   if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
  13.     return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
  14.                          saved_model_proto);
  15.   }
  16.   return Status(error::Code::NOT_FOUND,
  17.                 "Could not find SavedModel .pb or .pbtxt at supplied export "
  18.                 "directory path: " +
  19.                     export_dir);
  20. }

可以看到完全没有考虑让这个文件名可以被指定,而且一直到现在最新的 v2.12.0 版本都是这样的。

我之前曾经考虑过从模型下手,试图把三种种不同频率上限的模型合并为一个。搜索过几次 tensorflow saved model merge 之类的关键词,也借助 ChatGPT 修改过 checkpoint 到 saved model 的转换脚本,最终都没成功。而且从把 saved_model.pb 转换为 .pbtext 格式的结果看,整个 protobuf 文件中和频率上限相关的参数非常多,而且看文件里这些参数所在位置,也没法合并。

之前一直想尽量用官方提供的二进制版本,是考虑到这对于杀毒软件比较友好,自己编译的会有全新的 hash 值,有误报的风险。但现在考虑到 SpleeterMsvcExe, 即将发出的 WPF 版 Spleeter GUI 以及 BeatShow Player 程序的易用性,还是打算对 TensorFlow 的源码进行修改,自行编译一个版本来用了。

只是这一点需求,代码还是很好改的。直接在 ReadSavedModel() 中添加一段识别和处理环境变量的代码就可以了:

  1. Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
  2.   LOG(INFO) << "Reading SavedModel from: " << export_dir;
  3.  
  4.   const char* tf_alt_saved_model_pb = getenv("TF_ALT_SAVED_MODEL_PB");
  5.   if (tf_alt_saved_model_pb != nullptr) {
  6.     LOG(INFO) << "Environment variable TF_ALT_SAVED_MODEL_PB is set: " << tf_alt_saved_model_pb;
  7.     const string alt_saved_model_pb_path =
  8.         io::JoinPath(export_dir, tf_alt_saved_model_pb);
  9.     if (Env::Default()->FileExists(alt_saved_model_pb_path).ok()) {
  10.       LOG(INFO) << "Will use " << tf_alt_saved_model_pb << " instead of saved_model.pb";
  11.       return ReadBinaryProto(Env::Default(), alt_saved_model_pb_path,
  12.                              saved_model_proto);
  13.     } else {
  14.       return Status(error::Code::NOT_FOUND,
  15.                     "Could not find the specified .pb file: " + alt_saved_model_pb_path);
  16.     }
  17.   }
  18.  
  19.   const string saved_model_pb_path =
  20.       io::JoinPath(export_dir, kSavedModelFilenamePb);

不用改 API 接口的定义,兼容性和灵活性都比较好。程序中调用 TensorFlow C API 前,设置一下 TF_ALT_SAVED_MODEL_PB 环境变量的值就可以了。

对于 TensorFlow 的编译过程,可以参考上一篇文章:
TensorFlow C API 动态库 v1.15 版本的编译过程

2023-04-22 添加:

修改过的项目已经发到了 GitHub 上了: https://github.com/wudicgi/tensorflow-mod

实际的修改和之前贴的有差异,具体修改可以看 c5cfda2 这个提交。也可以直接下载 release 版本使用: https://github.com/wudicgi/tensorflow-mod/releases/tag/v1.15.5-mod.1
当前语言: 中文 (简体)
请大师动手指导,拒绝低俗  
您的大名
(必填)
电子邮件
(必填,不公开)
个人网站
(可选)
留言
可以使用类似维基标记的语法,点击这里查看说明