-
Notifications
You must be signed in to change notification settings - Fork 223
Custom/gradients dispatch #632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
| unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters; | ||
|
|
||
| // Cast helper (inspired by TF C-API) | ||
| template <typename T, typename U> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fix this diff to remove all the formatting changes so we can see just the functional changes to CustomGradFunc?
| return false; | ||
| } | ||
|
|
||
| bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the formatting to reduce the diff.
|
This looks like a fairly complicated fix to work around a bug in JavaCPP? Is it not better to fix it there? |
|
Thanks for the question — it’s a fair concern. This change is indeed a workaround for a limitation in JavaCPP (bytedeco/javacpp#1205), where multiple native callbacks of the same kind cannot be reliably registered and invoked. In practice, only the last registered gradient adapter survives, which makes it impossible to support more than one Java custom gradient per process. Fixing this directly in JavaCPP would be ideal in theory, but in practice it is not a viable short- or medium-term option for TensorFlow Java: The issue is deep in JavaCPP’s native callback and lifetime management. TensorFlow Java depends on JavaCPP as an external project, and cannot reasonably block feature development or correctness fixes on changes there. Even with a JavaCPP fix, TensorFlow Java would still need a stable, deterministic way to manage gradient dispatch per op type. For these reasons, this PR follows the same architectural pattern already used by TensorFlow itself. TensorFlow Python does not register one native callback per op. This PR mirrors that design on the Java side: A single native CustomGradFunc is registered with TensorFlow. That function dispatches to the appropriate Java gradient implementation based on op_type. This avoids the JavaCPP limitation entirely, while matching TensorFlow’s own gradient architecture. As a result, the solution is: robust and deterministic, consistent with TensorFlow’s Python design, backward-compatible, and does not require changes to JavaCPP or TensorFlow C++. In short: while the root cause is a JavaCPP limitation, centralizing gradient dispatch is not a hack — it is the same model TensorFlow already uses, adapted to the Java runtime constraints. |
You wanted to register multiple custom gradients in Java using
TensorFlow.registerCustomGradient(...).
Observed symptom:
After registering a few gradients (≈ 5–10),
TFJ_RegisterCustomGradient(opType, adapter) received adapter_ptr = 0 on the C++ side,
which resulted in:
either a refusal to register the gradient,
or a SIGSEGV later during backpropagation.
Key observation:
If the “important” gradient was registered first, it worked.
Subsequent ones failed → this was a cumulative issue, not related to the specific op.
It was not:
a JNI signature bug,
an InfoMap issue,
nor a casting or ABI problem.
👉 The real cause was a limitation in JavaCPP FunctionPointer callbacks:
each TFJ_GradFuncAdapter allocates a native thunk,
after a certain number of such allocations, JavaCPP silently passes a null pointer (0),
the TensorFlow C++ runtime then receives an invalid callback pointer.
👉 Conclusion:
Creating one native callback per gradient is not scalable.
Instead of:
1 gradient = 1 native callback
We switched to:
1 single native callback
with dispatching in Java based on opType
This is exactly how TensorFlow does it in Python on the C++ side.
A. A Single Native Callback (Singleton)
A single TFJ_GradFuncAdapter instance
Registered with TensorFlow C++ for all ops
As a result:
no more adapter_ptr = 0
no practical limit on the number of custom gradients
B. Java-side Dispatch by opType
A Java dispatcher selects the correct gradient during backpropagation:
TensorFlow C++
↓
CustomGradFunc (C++)
↓
TFJ_GradFuncAdapter.call(...)
↓
DispatchingGradientAdapter.apply(...)
↓
CustomGradient / RawCustomGradient for the corresponding op
Problem
NativeScope and Ops have package-private constructors
They are only accessible from org.tensorflow.op
Solution
DispatchingGradientAdapter is package-private and lives in org.tensorflow.op
A public GradientDispatch class acts as a bridge
TensorFlow.java only sees the public TFJ_GradFuncAdapter type
➡️ This strictly respects TensorFlow Java’s internal design, with no hacks.
Problem
Returning null on the Java side caused a NullPointerException
The native code did not correctly support TF_Output.oper == nullptr
Fixes
Java side (AbstractGradientAdapter):
null is now translated into:
TF_Output { oper = nullptr, index = 0 }
C++ side (CustomGradFunc):
out.oper == nullptr is interpreted as NoGradient
No dangerous dereference
No crashes / no SIGSEGV
Applied corrections:
Removed a double loop that was adding gradients twice
Consistent handling of NoGradient
Single, safe memory deallocation (free(outputs))
Preserved defensive hardening:
checks on num_outputs
outputs == nullptr
etc.
What now works
✔ Registering dozens (or hundreds) of custom gradients
✔ Registration order no longer matters
✔ No more adapter_ptr = 0
✔ No JNI crashes / no SIGSEGV
✔ Proper support for partial gradients (NoGradient)
✔ Architecture aligned with native TensorFlow
What was avoided
❌ Fragile JavaCPP patches
❌ Dependency on internal allocation details
❌ Workarounds based on registration order
We replaced a non-scalable architecture (“N gradients = N native callbacks”) with a scalable one (“1 native callback + Java dispatch”), while properly fixing NoGradient handling and strictly respecting TensorFlow Java’s internal constraints.