Skip to content

Commit 846b6cf

Browse files
Craigacpnfeybesse
authored andcommitted
Minimising the custom gradient fix.
1 parent dc6ff5f commit 846b6cf

File tree

4 files changed

+0
-184
lines changed

4 files changed

+0
-184
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,9 @@
1717
package org.tensorflow.op;
1818

1919
import java.util.List;
20-
import org.bytedeco.javacpp.PointerPointer;
2120
import org.tensorflow.Operand;
2221
import org.tensorflow.Output;
2322
import org.tensorflow.TensorFlow;
24-
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
25-
import org.tensorflow.internal.c_api.TFJ_GraphId;
26-
import org.tensorflow.internal.c_api.TFJ_Scope;
27-
import org.tensorflow.internal.c_api.TF_Operation;
28-
import org.tensorflow.internal.c_api.TF_Output;
2923

3024
/**
3125
* A custom gradient for ops of type {@link T}. Should be registered using {@link
@@ -53,40 +47,4 @@ public interface CustomGradient<T extends RawOpInputs> {
5347
* @return the gradients of the op's inputs.
5448
*/
5549
List<Operand<?>> call(Ops tf, T op, List<Output<?>> gradInputs);
56-
57-
/**
58-
* Create an adapter for the custom gradient so that it can be used by native code.
59-
*
60-
* <p>You should not be calling this yourself, use {@link TensorFlow#registerCustomGradient(Class,
61-
* CustomGradient)}.
62-
*/
63-
static <T extends RawOpInputs<?>> TFJ_GradFuncAdapter adapter(
64-
CustomGradient<T> gradient, Class<T> opClass) {
65-
66-
final TypedGradientAdapter<T> impl = new TypedGradientAdapter<T>(gradient, opClass);
67-
68-
// IMPORTANT:
69-
// Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function
70-
// pointer thunk for the native side. Some call paths may pass NULL if we return a deeper
71-
// subclass.
72-
return new TFJ_GradFuncAdapter() {
73-
@Override
74-
public int call(
75-
TFJ_GraphId nativeGraphId,
76-
TFJ_Scope nativeScope,
77-
TF_Operation nativeOperation,
78-
TF_Output nativeGradInputs,
79-
int nativeGradInputsLength,
80-
PointerPointer nativeGradOutputsPtr) {
81-
82-
return impl.call(
83-
nativeGraphId,
84-
nativeScope,
85-
nativeOperation,
86-
nativeGradInputs,
87-
nativeGradInputsLength,
88-
nativeGradOutputsPtr);
89-
}
90-
};
91-
}
9250
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,10 @@
1717
package org.tensorflow.op;
1818

1919
import java.util.List;
20-
import org.bytedeco.javacpp.PointerPointer;
2120
import org.tensorflow.GraphOperation;
2221
import org.tensorflow.Operand;
2322
import org.tensorflow.Output;
2423
import org.tensorflow.TensorFlow;
25-
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
26-
import org.tensorflow.internal.c_api.TFJ_GraphId;
27-
import org.tensorflow.internal.c_api.TFJ_Scope;
28-
import org.tensorflow.internal.c_api.TF_Operation;
29-
import org.tensorflow.internal.c_api.TF_Output;
3024

3125
/**
3226
* A custom gradient for an op of unspecified type. Should be registered using {@link
@@ -51,38 +45,4 @@ public interface RawCustomGradient {
5145
* @return the gradients of the op's inputs.
5246
*/
5347
List<Operand<?>> call(Ops tf, GraphOperation op, List<Output<?>> gradInputs);
54-
55-
/**
56-
* Create an adapter for the custom gradient so that it can be used by native code.
57-
*
58-
* <p>You should not be calling this yourself, use {@link
59-
* TensorFlow#registerCustomGradient(String, RawCustomGradient)}.
60-
*/
61-
static TFJ_GradFuncAdapter adapter(RawCustomGradient gradient) {
62-
final RawGradientAdapter impl = new RawGradientAdapter(gradient);
63-
64-
// IMPORTANT:
65-
// Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function
66-
// pointer thunk for the native side. Some call paths may pass NULL if we return a deeper
67-
// subclass.
68-
return new TFJ_GradFuncAdapter() {
69-
@Override
70-
public int call(
71-
TFJ_GraphId nativeGraphId,
72-
TFJ_Scope nativeScope,
73-
TF_Operation nativeOperation,
74-
TF_Output nativeGradInputs,
75-
int nativeGradInputsLength,
76-
PointerPointer nativeGradOutputsPtr) {
77-
78-
return impl.call(
79-
nativeGraphId,
80-
nativeScope,
81-
nativeOperation,
82-
nativeGradInputs,
83-
nativeGradInputsLength,
84-
nativeGradOutputsPtr);
85-
}
86-
};
87-
}
8848
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java

Lines changed: 0 additions & 44 deletions
This file was deleted.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)