其他
用OneFlow实现数据类型自动提升
1
问题引入
y1_tensor = torch.tensor(1, dtype=torch.float64)
out1 = torch.mul(x_tensor, y1_tensor)
y2_tensor = torch.tensor(1, dtype=torch.int64)
out2 = torch.mul(x_tensor, y2_tensor)
out3 = torch.mul(x_tensor, 1.0)
out4 = torch.mul(x_tensor, 2^63-1(the max value of int64))
out2.dtype: torch.int8
out3.dtype: torch.float32
out4.dtype: torch.int8
2
Python Array API标准
不同数据类型的提升遵循这个连接的规则 虚线表示python标量在溢出的时候未定义 bool int float之间没有连线,表示这种混合类型的提升未定义
int8
和uint8
,两者最终指向了int16
,表示两者运算后最终类型提升到了int16
unsigned int
系列和signed int
系列为例,列出的表格为:更多类型提升规则表格可参考前面提到的链接
i1 : 8-bit signed integer (i.e., int8 ) i2 : 16-bit signed integer (i.e., int16 ) i4 : 32-bit signed integer (i.e., int32 ) i8 : 64-bit signed integer (i.e., int64 )
同理于unsigned int
Python Array 和 Scalar 的类型提升
如果两者同属于一个数据类型系列(比如都是int系列,包含int8, int32, int64),则最终数据类型遵循数组的数据类型 如果两者同不属于一个数据类型系列(比如一个是int32,一个是float),则进行类型提升
out1 = x_tensor + 2 # out.dtype = torch.int16
out2 = x_tensor + 2.0 # out.dtype = torch.float32
y1_tensor = torch.tensor(2)
y2_tensor = torch.tensor(2.0)
out1 = x_tensor + y1_tensor # out.dtype = torch.int16
out2 = x_tensor + y2_tensor # out.dtype = torch.float32
x = np.ones((3, 3), dtype=np.int32)
out = x + (2**31-1) # dtype: np.int32
out = x + (2**31) # dtype: np.int64
我个人更倾向于在类型提升中,Scalar是单独一种行为,而Scalar Tensor和Tensor的行为一致
3
其他情况
要求两个输入的数据类型完全一致,如 torch.dot
输入存在一个最低数据类型,比如 torch.sum
,传任意int系列数据类型,最终输出结果均为torch.int64
。
4
PyTorch是怎么做类型提升的?
to
这个op,将输入tensor进行类型提升,再进入到Kernel进行实际的运算。下面我们会根据PyTorch的源码进行讲解:ScalarType.h
Activation.cpp
threshold
为例子const Tensor& result = maybe_get_output();
build(TensorIteratorConfig()
...
.promote_inputs_to_common_dtype(true)
}
build
函数,函数接受一个TensorIteratorConfig
,这个Config类是用于配制各种属性,可以看到这里调用promote_inputs_to_common_dtype
并设为true。TensorIterator.cpp
compute_type
函数compute_types(config);
...
TensorIterator
是一个容器类(Numpy里也有一个类似的容器NpyIter
),用于存储输出,输入tensor,里面用了多个for循环来推导得到一个common_dtype
。promote_inputs_to_common_dtype_
为true,当前Tensor不是输出Tensor,且输入的dtype不等于推导得到的common_dtype
,则做一个类型提升:if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
op.original_tensor = op.tensor;
op.tensor = c10::MaybeOwned<Tensor>::owned(op.tensor->to(common_dtype_));
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}
5
OneFlow的做法
TensorProcessor
类,接口设计如下:public:
TensorProcessor()
: common_dtype_(DType::InvalidDataType()), promote_inputs_to_common_dtype_(false){};
TensorProcessor& AddInputs(const TensorTuple& init_list);
TensorProcessor& AddInputs(const TensorTuple& init_list, Symbol<DType> tensor_lowest_dtype);
Maybe<void> Apply();
TensorProcessor& PromoteInputsToCommonDtype(bool is_promote);
Maybe<TensorTuple&> GetInputs() { return tensor_tuple_; };
private:
TensorTuple tensor_tuple_;
Symbol<DType> common_dtype_;
std::vector<Symbol<DType>> inputs_lowest_dtype_vec_;
bool promote_inputs_to_common_dtype_;
};
public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y) const {
TensorProcessor tensor_processor;
JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply());
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());
return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple);
...
}
...
};
PromoteInputsToCommonDtype 用于设置相关属性 AddInputs函数将需要参与类型提升的Tensor添加到容器中 Apply函数执行实际的类型提升等逻辑
tensor_processor.cpp
还有其他几个函数,这里简单介绍下功能:CheckHasDifferentInputDType 遍历输入Tensor,检查输入Tensor是否有不同的dtype ComputeCommonDType 根据输入dtype推导一个合理的提升过的dtype CastToSameType 给输入Tensor插入一个Cast操作
for (auto& tensor_ptr : tensor_tuple) {
if (tensor_ptr->dtype() != common_dtype) {
tensor_ptr = JUST(functional::Cast(tensor_ptr, common_dtype));
}
}
return Maybe<void>::Ok();
}
if (promote_inputs_to_common_dtype_) {
bool has_different_input_dtype = CheckHasDifferentInputDType(tensor_tuple_);
if (has_different_input_dtype) {
common_dtype_ = ComputeCommonDType(tensor_tuple_);
JUST(CastToSameType(tensor_tuple_, common_dtype_));
}
} else {
for (int i = 0; i < tensor_tuple_.size(); ++i) {
// Cast all the inputs to it's attribute `lowest_dtype` if the input tensor dtype is lower
// than attribute `lowest_dtype`.
Symbol<DType> base_dtype = inputs_lowest_dtype_vec_.at(i);
if (base_dtype->data_type()
&& DType::priority_order[base_dtype->data_type()]
> DType::priority_order[tensor_tuple_.at(i)->dtype()->data_type()]) {
tensor_tuple_.at(i) = JUST(one::functional::Cast(tensor_tuple_.at(i), base_dtype));
}
}
}
return Maybe<void>::Ok();
}
sum
算子,我们设定最低数据类型为int64是这么做的:public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,
const bool& keepdims) const {
...
TensorProcessor tensor_processor;
JUST(tensor_processor.AddInputs({x}, /*lowest_dtype=*/DType::Int64()).Apply());
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());
}
...
};