Skip to content

Commit dc6ff5f

Browse files
committed
Use single native dispatch adapter for custom gradients
Replace the native per-op unordered_map of TFJ_GradFuncAdapter with a single global dispatch adapter. The native layer now registers CustomGradFunc per op type in the GradOpRegistry, but always calls the same TFJ_GradFuncAdapter instance. All opType-based routing is handled on the Java side by DispatchingGradientAdapter. This aligns the native implementation with the intended design: there is only one native function pointer registered, and dispatch logic lives entirely in Java. Also fixes unsafe casting of Scope* to TFJ_Scope* by constructing a temporary TFJ_Scope wrapper instead.
1 parent 767726e commit dc6ff5f

File tree

1 file changed

+21
-15
lines changed
  • tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api

1 file changed

+21
-15
lines changed

tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ limitations under the License.
1919
#include <stdlib.h>
2020

2121
#include <string>
22-
#include <unordered_map>
2322
#include <vector>
2423

2524
#include "tfj_graph.h"
25+
#include "tfj_scope.h"
2626
#include "tsl/platform/errors.h"
2727
#include "tensorflow/c/c_api.h"
2828
#include "tensorflow/c/tf_status.h"
@@ -33,7 +33,8 @@ namespace tensorflow {
3333
using namespace tsl;
3434
using namespace std;
3535

36-
unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters;
36+
// Single Java-side dispatcher entry point (no native per-op map).
37+
static TFJ_GradFuncAdapter g_dispatch_adapter = NULL;
3738

3839
/// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit
3940
///
@@ -53,14 +54,10 @@ namespace tensorflow {
5354
vector<Output>* grad_outputs)
5455
{
5556
const string& op_type = op.node()->type_string();
56-
auto found_adapter = g_grad_func_adapters.find(op_type);
57-
if (found_adapter == g_grad_func_adapters.end()) {
58-
return errors::NotFound("No gradient adapter found for operation ", op_type);
59-
}
6057

61-
TFJ_GradFuncAdapter adapter = found_adapter->second;
58+
TFJ_GradFuncAdapter adapter = g_dispatch_adapter;
6259
if (adapter == NULL) {
63-
return errors::Unknown("Null Java gradient adapter for operation ", op_type);
60+
return errors::Unknown("Null Java dispatch gradient adapter for operation ", op_type);
6461
}
6562

6663
int num_inputs = grad_inputs.size();
@@ -81,9 +78,13 @@ namespace tensorflow {
8178

8279
TF_Output* outputs = NULL;
8380
LOG(INFO) << "Calling Java gradient function for operation of type " << op_type;
81+
82+
// IMPORTANT: TFJ_Scope is a wrapper struct (see tfj_scope_impl.cc). Do not cast Scope* to TFJ_Scope*.
83+
TFJ_Scope tfj_scope{scope};
84+
8485
int num_outputs = adapter(
8586
static_cast<TFJ_GraphId>(scope.graph()),
86-
struct_cast<TFJ_Scope>(const_cast<Scope*>(&scope)),
87+
&tfj_scope,
8788
struct_cast<TF_Operation>(op.node()),
8889
inputs,
8990
num_inputs,
@@ -142,16 +143,21 @@ bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_fu
142143
return false;
143144
}
144145

146+
// Only a single native function pointer is used: it must always be the same dispatch adapter.
147+
if (g_dispatch_adapter == NULL) {
148+
g_dispatch_adapter = grad_func_adapter;
149+
} else if (g_dispatch_adapter != grad_func_adapter) {
150+
LOG(ERROR) << "Refusing to register a different Java dispatch gradient adapter";
151+
return false;
152+
}
153+
145154
if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash
146155
LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type
147156
<< ", which has already a registered function";
148157
return false;
149158
}
150-
bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc);
151-
if (registered) {
152-
g_grad_func_adapters.insert({op_type, grad_func_adapter});
153-
}
154-
return registered;
159+
160+
return GradOpRegistry::Global()->Register(op_type, CustomGradFunc);
155161
}
156162

157163
#else // #ifndef _WIN32
@@ -161,4 +167,4 @@ bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_fu
161167
bool TFJ_HasGradient(const char* op_type) { return true; }
162168
bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { return false; }
163169

164-
#endif // #ifndef _WIN32
170+
#endif // #ifndef _WIN32

0 commit comments

Comments
 (0)