16#include "stage/stage.h"
18#include "transforms/negabinary/negabinary.h"
19#include <cuda_runtime.h>
33template<typename TIn, typename TOut = typename std::make_unsigned<TIn>::type>
35 static_assert(std::is_integral<TIn>::value && std::is_signed<TIn>::value,
36 "NegabinaryStage: TIn must be a signed integer type "
37 "(int8_t, int16_t, int32_t, or int64_t).");
38 static_assert(std::is_integral<TOut>::value && std::is_unsigned<TOut>::value,
39 "NegabinaryStage: TOut must be an unsigned integer type.");
40 static_assert(
sizeof(TIn) ==
sizeof(TOut),
41 "NegabinaryStage: TIn and TOut must have the same byte width.");
47 void setInverse(
bool inv)
override { is_inverse_ = inv; }
48 bool isInverse()
const override {
return is_inverse_; }
51 std::string
getName()
const override {
return "Negabinary"; }
52 size_t getNumInputs()
const override {
return 1; }
53 size_t getNumOutputs()
const override {
return 1; }
56 const std::vector<size_t>& input_sizes
58 return {input_sizes[0]};
62 return {{
"output", actual_output_size_}};
65 return (index == 0) ? actual_output_size_ : 0;
73 return static_cast<uint8_t
>(getTOutDataTypeEnum());
79 ?
static_cast<uint8_t
>(getTOutDataTypeEnum())
80 :
static_cast<uint8_t
>(getTInDataTypeEnum());
86 if (max_size < 2)
return 0;
87 buf[0] =
static_cast<uint8_t
>(getTInDataTypeEnum());
88 buf[1] =
static_cast<uint8_t
>(getTOutDataTypeEnum());
102 const std::vector<void*>& inputs,
103 const std::vector<void*>& outputs,
104 const std::vector<size_t>& sizes
109 size_t actual_output_size_;
111 DataType getTInDataTypeEnum()
const {
112 if (std::is_same_v<TIn, int8_t>)
return DataType::INT8;
113 if (std::is_same_v<TIn, int16_t>)
return DataType::INT16;
114 if (std::is_same_v<TIn, int32_t>)
return DataType::INT32;
115 if (std::is_same_v<TIn, int64_t>)
return DataType::INT64;
116 return DataType::INT32;
119 DataType getTOutDataTypeEnum()
const {
120 if (std::is_same_v<TOut, uint8_t>)
return DataType::UINT8;
121 if (std::is_same_v<TOut, uint16_t>)
return DataType::UINT16;
122 if (std::is_same_v<TOut, uint32_t>)
return DataType::UINT32;
123 if (std::is_same_v<TOut, uint64_t>)
return DataType::UINT64;
124 return DataType::UINT32;
128extern template class NegabinaryStage<int8_t, uint8_t>;
129extern template class NegabinaryStage<int16_t, uint16_t>;
130extern template class NegabinaryStage<int32_t, uint32_t>;
131extern template class NegabinaryStage<int64_t, uint64_t>;
Definition negabinary_stage.h:34
uint8_t getInputDataType(size_t) const override
Definition negabinary_stage.h:76
size_t getActualOutputSize(int index) const override
Definition negabinary_stage.h:64
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override
uint16_t getStageTypeId() const override
Definition negabinary_stage.h:68
size_t getMaxHeaderSize(size_t) const override
Definition negabinary_stage.h:96
void deserializeHeader(const uint8_t *, size_t) override
Definition negabinary_stage.h:92
std::unordered_map< std::string, size_t > getActualOutputSizesByName() const override
Definition negabinary_stage.h:61
uint8_t getOutputDataType(size_t) const override
Definition negabinary_stage.h:72
std::string getName() const override
Definition negabinary_stage.h:51
size_t serializeHeader(size_t, uint8_t *buf, size_t max_size) const override
Definition negabinary_stage.h:85
std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const override
Definition negabinary_stage.h:55
void setInverse(bool inv) override
Definition negabinary_stage.h:47
Definition fzm_format.h:25
@ NEGABINARY
NegabinaryStage — negabinary encode/decode.
DataType
Element data type identifiers used in buffer and stage descriptors.
Definition fzm_format.h:103