2023-04-20 22:32:10 · Author: 五帝 · Tagged with: Hack, Computer Software, C/C++, Open Source | Comments
我在 2021 年 5 月份发出了一个 SpleeterMsvcExe 开源项目,当时是支持 11kHz 和 16kHz 两种模型,年底的时候又添加了对 22kHz 模型的支持。这三种不同频率上限的模型, variables 目录中的文件是完全相同的,只有 saved_model.pb 不同。但 TensorFlow 源码中这个文件名在 constants.h 中被定义为了固定值:
用于加载模型的 ReadMetaGraphDefFromSavedModel() API 函数会调用 reader.cc 中的 ReadSavedModel() 函数:
可以看到完全没有考虑让这个文件名可以被指定,而且一直到现在最新的 v2.12.0 版本都是这样的。
我之前曾经考虑过从模型下手,试图把三种种不同频率上限的模型合并为一个。搜索过几次 tensorflow saved model merge 之类的关键词,也借助 ChatGPT 修改过 checkpoint 到 saved model 的转换脚本,最终都没成功。而且从把 saved_model.pb 转换为 .pbtext 格式的结果看,整个 protobuf 文件中和频率上限相关的参数非常多,而且看文件里这些参数所在位置,也没法合并。
之前一直想尽量用官方提供的二进制版本,是考虑到这对于杀毒软件比较友好,自己编译的会有全新的 hash 值,有误报的风险。但现在考虑到 SpleeterMsvcExe, 即将发出的 WPF 版 Spleeter GUI 以及 BeatShow Player 程序的易用性,还是打算对 TensorFlow 的源码进行修改,自行编译一个版本来用了。
只是这一点需求,代码还是很好改的。直接在 ReadSavedModel() 中添加一段识别和处理环境变量的代码就可以了:
不用改 API 接口的定义,兼容性和灵活性都比较好。程序中调用 TensorFlow C API 前,设置一下 TF_ALT_SAVED_MODEL_PB 环境变量的值就可以了。
对于 TensorFlow 的编译过程,可以参考上一篇文章:
《TensorFlow C API 动态库 v1.15 版本的编译过程》
2023-04-22 添加:
修改过的项目已经发到了 GitHub 上了: https://github.com/wudicgi/tensorflow-mod
实际的修改和之前贴的有差异,具体修改可以看 c5cfda2 这个提交。也可以直接下载 release 版本使用: https://github.com/wudicgi/tensorflow-mod/releases/tag/v1.15.5-mod.1
- // SavedModel proto filename.
- constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
- // SavedModel text format proto filename.
- constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
用于加载模型的 ReadMetaGraphDefFromSavedModel() API 函数会调用 reader.cc 中的 ReadSavedModel() 函数:
- Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
- LOG(INFO) << "Reading SavedModel from: " << export_dir;
- const string saved_model_pb_path =
- io::JoinPath(export_dir, kSavedModelFilenamePb);
- if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
- return ReadBinaryProto(Env::Default(), saved_model_pb_path,
- saved_model_proto);
- }
- const string saved_model_pbtxt_path =
- io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
- if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
- return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
- saved_model_proto);
- }
- return Status(error::Code::NOT_FOUND,
- "Could not find SavedModel .pb or .pbtxt at supplied export "
- "directory path: " +
- export_dir);
- }
可以看到完全没有考虑让这个文件名可以被指定,而且一直到现在最新的 v2.12.0 版本都是这样的。
我之前曾经考虑过从模型下手,试图把三种种不同频率上限的模型合并为一个。搜索过几次 tensorflow saved model merge 之类的关键词,也借助 ChatGPT 修改过 checkpoint 到 saved model 的转换脚本,最终都没成功。而且从把 saved_model.pb 转换为 .pbtext 格式的结果看,整个 protobuf 文件中和频率上限相关的参数非常多,而且看文件里这些参数所在位置,也没法合并。
之前一直想尽量用官方提供的二进制版本,是考虑到这对于杀毒软件比较友好,自己编译的会有全新的 hash 值,有误报的风险。但现在考虑到 SpleeterMsvcExe, 即将发出的 WPF 版 Spleeter GUI 以及 BeatShow Player 程序的易用性,还是打算对 TensorFlow 的源码进行修改,自行编译一个版本来用了。
只是这一点需求,代码还是很好改的。直接在 ReadSavedModel() 中添加一段识别和处理环境变量的代码就可以了:
- Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
- LOG(INFO) << "Reading SavedModel from: " << export_dir;
- const char* tf_alt_saved_model_pb = getenv("TF_ALT_SAVED_MODEL_PB");
- if (tf_alt_saved_model_pb != nullptr) {
- LOG(INFO) << "Environment variable TF_ALT_SAVED_MODEL_PB is set: " << tf_alt_saved_model_pb;
- const string alt_saved_model_pb_path =
- io::JoinPath(export_dir, tf_alt_saved_model_pb);
- if (Env::Default()->FileExists(alt_saved_model_pb_path).ok()) {
- LOG(INFO) << "Will use " << tf_alt_saved_model_pb << " instead of saved_model.pb";
- return ReadBinaryProto(Env::Default(), alt_saved_model_pb_path,
- saved_model_proto);
- } else {
- return Status(error::Code::NOT_FOUND,
- "Could not find the specified .pb file: " + alt_saved_model_pb_path);
- }
- }
- const string saved_model_pb_path =
- io::JoinPath(export_dir, kSavedModelFilenamePb);
不用改 API 接口的定义,兼容性和灵活性都比较好。程序中调用 TensorFlow C API 前,设置一下 TF_ALT_SAVED_MODEL_PB 环境变量的值就可以了。
对于 TensorFlow 的编译过程,可以参考上一篇文章:
《TensorFlow C API 动态库 v1.15 版本的编译过程》
2023-04-22 添加:
修改过的项目已经发到了 GitHub 上了: https://github.com/wudicgi/tensorflow-mod
实际的修改和之前贴的有差异,具体修改可以看 c5cfda2 这个提交。也可以直接下载 release 版本使用: https://github.com/wudicgi/tensorflow-mod/releases/tag/v1.15.5-mod.1
Current language: 中文 (简体)