50 std::is_same_v<T, TOut> ||
51 (std::is_integral_v<T> && std::is_signed_v<T> &&
52 std::is_integral_v<TOut> && std::is_unsigned_v<TOut> &&
53 sizeof(T) ==
sizeof(TOut)),
54 "DifferenceStage: TOut must equal T, or T must be a signed integer "
55 "and TOut its unsigned counterpart of the same width (negabinary fusion).");
57 DifferenceStage() : actual_output_size_(0), is_inverse_(
false), chunk_size_(0) {}
59 void setInverse(
bool inverse)
override { is_inverse_ = inverse; }
60 bool isInverse()
const override {
return is_inverse_; }
70 size_t getChunkSize()
const {
return chunk_size_; }
72 return chunk_size_ > 0 ? chunk_size_ : 1;
78 const std::vector<void*>& inputs,
79 const std::vector<void*>& outputs,
80 const std::vector<size_t>& sizes
84 std::string
getName()
const override {
return "Difference"; }
85 size_t getNumInputs()
const override {
return 1; }
86 size_t getNumOutputs()
const override {
return 1; }
89 const std::vector<size_t>& input_sizes
91 return {input_sizes[0]};
95 return {{
"output", actual_output_size_}};
98 return (index == 0) ? actual_output_size_ : 0;
107 return static_cast<uint8_t
>(getOutDataTypeEnum());
111 return static_cast<uint8_t
>(getInDataTypeEnum());
114 size_t serializeHeader(
size_t output_index, uint8_t* buf,
size_t max_size)
const override {
116 if (max_size < 6)
return 0;
117 buf[0] =
static_cast<uint8_t
>(getInDataTypeEnum());
118 buf[1] =
static_cast<uint8_t
>(getOutDataTypeEnum());
119 uint32_t cs =
static_cast<uint32_t
>(chunk_size_);
120 std::memcpy(buf + 2, &cs,
sizeof(uint32_t));
129 std::memcpy(&cs, buf + 2,
sizeof(uint32_t));
140 saved_chunk_size_ = chunk_size_;
141 saved_actual_output_size_ = actual_output_size_;
144 void restoreState()
override {
145 chunk_size_ = saved_chunk_size_;
146 actual_output_size_ = saved_actual_output_size_;
150 size_t actual_output_size_;
151 size_t saved_actual_output_size_ = 0;
154 size_t saved_chunk_size_ = 0;
157 DataType getInDataTypeEnum()
const {
158 if (std::is_same_v<T, uint8_t>)
return DataType::UINT8;
159 if (std::is_same_v<T, uint16_t>)
return DataType::UINT16;
160 if (std::is_same_v<T, uint32_t>)
return DataType::UINT32;
161 if (std::is_same_v<T, uint64_t>)
return DataType::UINT64;
162 if (std::is_same_v<T, int8_t>)
return DataType::INT8;
163 if (std::is_same_v<T, int16_t>)
return DataType::INT16;
164 if (std::is_same_v<T, int32_t>)
return DataType::INT32;
165 if (std::is_same_v<T, int64_t>)
return DataType::INT64;
166 if (std::is_same_v<T, float>)
return DataType::FLOAT32;
167 if (std::is_same_v<T, double>)
return DataType::FLOAT64;
168 return DataType::UINT8;
171 DataType getOutDataTypeEnum()
const {
172 if (std::is_same_v<TOut, uint8_t>)
return DataType::UINT8;
173 if (std::is_same_v<TOut, uint16_t>)
return DataType::UINT16;
174 if (std::is_same_v<TOut, uint32_t>)
return DataType::UINT32;
175 if (std::is_same_v<TOut, uint64_t>)
return DataType::UINT64;
176 if (std::is_same_v<TOut, int8_t>)
return DataType::INT8;
177 if (std::is_same_v<TOut, int16_t>)
return DataType::INT16;
178 if (std::is_same_v<TOut, int32_t>)
return DataType::INT32;
179 if (std::is_same_v<TOut, int64_t>)
return DataType::INT64;
180 if (std::is_same_v<TOut, float>)
return DataType::FLOAT32;
181 if (std::is_same_v<TOut, double>)
return DataType::FLOAT64;
182 return DataType::UINT8;
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override