53 std::is_same_v<T, TOut> ||
54 (std::is_integral_v<T> && std::is_signed_v<T> &&
55 std::is_integral_v<TOut> && std::is_unsigned_v<TOut> &&
56 sizeof(T) ==
sizeof(TOut)),
57 "DifferenceStage: TOut must equal T, or T must be a signed integer "
58 "and TOut its unsigned counterpart of the same width (negabinary fusion).");
60 DifferenceStage() : actual_output_size_(0), is_inverse_(
false), chunk_size_(0) {}
62 void setInverse(
bool inverse)
override { is_inverse_ = inverse; }
63 bool isInverse()
const override {
return is_inverse_; }
73 size_t getChunkSize()
const {
return chunk_size_; }
75 return chunk_size_ > 0 ? chunk_size_ : 1;
81 const std::vector<void*>& inputs,
82 const std::vector<void*>& outputs,
83 const std::vector<size_t>& sizes
87 std::string
getName()
const override {
return "Difference"; }
88 size_t getNumInputs()
const override {
return 1; }
89 size_t getNumOutputs()
const override {
return 1; }
92 const std::vector<size_t>& input_sizes
94 return {input_sizes[0]};
98 return {{
"output", actual_output_size_}};
101 return (index == 0) ? actual_output_size_ : 0;
105 return static_cast<uint16_t
>(StageType::DIFFERENCE);
110 return static_cast<uint8_t
>(getOutDataTypeEnum());
114 return static_cast<uint8_t
>(getInDataTypeEnum());
117 size_t serializeHeader(
size_t output_index, uint8_t* buf,
size_t max_size)
const override {
119 if (max_size < 6)
return 0;
120 buf[0] =
static_cast<uint8_t
>(getInDataTypeEnum());
121 buf[1] =
static_cast<uint8_t
>(getOutDataTypeEnum());
122 uint32_t cs =
static_cast<uint32_t
>(chunk_size_);
123 std::memcpy(buf + 2, &cs,
sizeof(uint32_t));
132 std::memcpy(&cs, buf + 2,
sizeof(uint32_t));
143 saved_chunk_size_ = chunk_size_;
144 saved_actual_output_size_ = actual_output_size_;
147 void restoreState()
override {
148 chunk_size_ = saved_chunk_size_;
149 actual_output_size_ = saved_actual_output_size_;
153 size_t actual_output_size_;
154 size_t saved_actual_output_size_ = 0;
157 size_t saved_chunk_size_ = 0;
160 DataType getInDataTypeEnum()
const {
161 if (std::is_same_v<T, uint8_t>)
return DataType::UINT8;
162 if (std::is_same_v<T, uint16_t>)
return DataType::UINT16;
163 if (std::is_same_v<T, uint32_t>)
return DataType::UINT32;
164 if (std::is_same_v<T, uint64_t>)
return DataType::UINT64;
165 if (std::is_same_v<T, int8_t>)
return DataType::INT8;
166 if (std::is_same_v<T, int16_t>)
return DataType::INT16;
167 if (std::is_same_v<T, int32_t>)
return DataType::INT32;
168 if (std::is_same_v<T, int64_t>)
return DataType::INT64;
169 if (std::is_same_v<T, float>)
return DataType::FLOAT32;
170 if (std::is_same_v<T, double>)
return DataType::FLOAT64;
171 return DataType::UINT8;
174 DataType getOutDataTypeEnum()
const {
175 if (std::is_same_v<TOut, uint8_t>)
return DataType::UINT8;
176 if (std::is_same_v<TOut, uint16_t>)
return DataType::UINT16;
177 if (std::is_same_v<TOut, uint32_t>)
return DataType::UINT32;
178 if (std::is_same_v<TOut, uint64_t>)
return DataType::UINT64;
179 if (std::is_same_v<TOut, int8_t>)
return DataType::INT8;
180 if (std::is_same_v<TOut, int16_t>)
return DataType::INT16;
181 if (std::is_same_v<TOut, int32_t>)
return DataType::INT32;
182 if (std::is_same_v<TOut, int64_t>)
return DataType::INT64;
183 if (std::is_same_v<TOut, float>)
return DataType::FLOAT32;
184 if (std::is_same_v<TOut, double>)
return DataType::FLOAT64;
185 return DataType::UINT8;
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override