其他
面向隐私AI的TensorFlow深度定制化实践
TensorFlow 的核心概念
Tensor(张量)
Operation(算子)
Graph(计算图)
Session(会话)
TensorFlow 的 codebase 本身还是很复杂的,篇幅所限,难以在此对 TensorFlow 进行深入的介绍,感兴趣的读者可以参考 TensorFlow 社区中其他优秀文章以进一步学习。
TensorFlow 自定义算子库的扩展方法
有时候基于现有 TF 原生算子表达上层自定义逻辑很困难,而在后端实现则更灵活自由; 通过后端 Custom C++ op,可以以更加高效的方式实现自己的逻辑,可以在其中进行更底层的、面向编译器等的各种优化;
REGISTER_OP("RttMatmul")
.Input("x: string")
.Input("y: string")
.Output("res: string")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
.SetIsStateful();
template <typename Device>
class RttMatMulOp : public OpKernel {
public:
explicit RttMatMulOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
}
void Compute(OpKernelContext* context) override {
// Check if the dimensions of the two matrices are valid
const Tensor& x = context->input(0);
const Tensor& y = context->input(1);
// detailed implementation...
}
}
REGISTER_KERNEL_BUILDER(Name("RttMatmul").Device(DEVICE_CPU), RttMatMulOp);
# load librtt_ops.so
_rtt_ops_lib = os.path.dirname(__file__) + '/../../../librtt-ops.so'
rtt_ops = tf.load_op_library(_rtt_ops_lib)
# now, you can use the ops in this library as rtt_ops.rtt_matmul
RttOp 算子库与后端 MPC 隐私计算完全无关的辅助中间层,一系列的“浮标”置位算子,支持自定义 Tensor 类型。其内部默认的实现逻辑是和对应的 TF 原生算子一样的。SecureOp 算子库完整的前后端算子库,注册了对应的梯度函数;在内部实现中调用隐私协议层的抽象算子接口实现和 TF 的对接。
计算图的转换构建过程
引入 rosetta 库时
调用 Session.run 接口时
@ops.RegisterGradient("SecureMatmul")
def SecureMatMulGrad(op, grad):
"""The gradient for the Secure MatMul operator."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
a = op.inputs[0]
b = op.inputs[1]
if not t_a and not t_b:
grad_a = SecureMatMul(grad, b, transpose_b=True)
grad_b = SecureMatMul(a, grad, transpose_a=True)
elif not t_a and t_b:
grad_a = SecureMatMul(grad, b)
grad_b = SecureMatMul(grad, a, transpose_a=True)
elif t_a and not t_b:
grad_a = SecureMatMul(b, grad, transpose_b=True)
grad_b = SecureMatMul(a, grad)
elif t_a and t_b:
grad_a = SecureMatMul(b, grad, transpose_a=True, transpose_b=True)
grad_b = SecureMatMul(grad, a, transpose_a=True, transpose_b=True)
return grad_a, grad_b
补充说明
并非所有的算子都需要转换为SecureOp,这是因为如果一个局部子图中全部的输入都是本地的常量(公开的写定到代码中的数据,无需保护),那么就没有必要将这个子图转换为多方协作的隐私计算方式计算,这样可以减少不必要的计算时间。 转换时,由于此时知道了即将运行的完整子图的信息,比如 DAG 图上有多少了算子需要运行,所以可以在这里进行一些定制化的优化,比如优化底层协议中多方之间的并发通讯。
// call protocol ops
vector<string> outstr(m*n);
ProtocolManager::Instance()->GetProtocol()->GetOps(msg_id().str())->Matmul(in1, in2, outstr, &attrs_);