12#include "stage/stage.h"
14#include "transforms/zigzag/zigzag.h"
15#include <cuda_runtime.h>
29template<typename TIn, typename TOut = typename std::make_unsigned<TIn>::type>
31 static_assert(std::is_integral<TIn>::value && std::is_signed<TIn>::value,
32 "ZigzagStage: TIn must be a signed integer type "
33 "(int8_t, int16_t, int32_t, or int64_t).");
34 static_assert(std::is_integral<TOut>::value && std::is_unsigned<TOut>::value,
35 "ZigzagStage: TOut must be an unsigned integer type.");
36 static_assert(
sizeof(TIn) ==
sizeof(TOut),
37 "ZigzagStage: TIn and TOut must have the same byte width.");
40 ZigzagStage() : is_inverse_(
false), actual_output_size_(0) {}
43 void setInverse(
bool inv)
override { is_inverse_ = inv; }
44 bool isInverse()
const override {
return is_inverse_; }
50 const std::vector<void*>& inputs,
51 const std::vector<void*>& outputs,
52 const std::vector<size_t>& sizes
56 std::string
getName()
const override {
return "Zigzag"; }
57 size_t getNumInputs()
const override {
return 1; }
58 size_t getNumOutputs()
const override {
return 1; }
61 const std::vector<size_t>& input_sizes
63 return {input_sizes[0]};
66 std::unordered_map<std::string, size_t>
68 return {{
"output", actual_output_size_}};
71 return (index == 0) ? actual_output_size_ : 0;
82 ?
static_cast<uint8_t
>(dataTypeOf<TIn>())
83 :
static_cast<uint8_t
>(dataTypeOf<TOut>());
89 ?
static_cast<uint8_t
>(dataTypeOf<TOut>())
90 :
static_cast<uint8_t
>(dataTypeOf<TIn>());
95 size_t output_index, uint8_t* buf,
size_t max_size
98 if (max_size < 2)
return 0;
99 buf[0] =
static_cast<uint8_t
>(dataTypeOf<TIn>());
100 buf[1] =
static_cast<uint8_t
>(dataTypeOf<TOut>());
105 (void)buf; (void)size;
114 size_t actual_output_size_;
117 static constexpr DataType dataTypeOf() {
118 if (std::is_same<U, int8_t>::value)
return DataType::INT8;
119 if (std::is_same<U, int16_t>::value)
return DataType::INT16;
120 if (std::is_same<U, int32_t>::value)
return DataType::INT32;
121 if (std::is_same<U, int64_t>::value)
return DataType::INT64;
122 if (std::is_same<U, uint8_t>::value)
return DataType::UINT8;
123 if (std::is_same<U, uint16_t>::value)
return DataType::UINT16;
124 if (std::is_same<U, uint32_t>::value)
return DataType::UINT32;
125 if (std::is_same<U, uint64_t>::value)
return DataType::UINT64;
126 return DataType::UINT8;
130extern template class ZigzagStage<int8_t, uint8_t>;
131extern template class ZigzagStage<int16_t, uint16_t>;
132extern template class ZigzagStage<int32_t, uint32_t>;
133extern template class ZigzagStage<int64_t, uint64_t>;
Definition zigzag_stage.h:30
uint8_t getInputDataType(size_t) const override
Definition zigzag_stage.h:86
std::unordered_map< std::string, size_t > getActualOutputSizesByName() const override
Definition zigzag_stage.h:67
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override
std::string getName() const override
Definition zigzag_stage.h:56
void setInverse(bool inv) override
Definition zigzag_stage.h:43
void deserializeHeader(const uint8_t *buf, size_t size) override
Definition zigzag_stage.h:104
uint16_t getStageTypeId() const override
Definition zigzag_stage.h:74
uint8_t getOutputDataType(size_t output_index) const override
Definition zigzag_stage.h:78
std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const override
Definition zigzag_stage.h:60
size_t serializeHeader(size_t output_index, uint8_t *buf, size_t max_size) const override
Definition zigzag_stage.h:94
size_t getMaxHeaderSize(size_t) const override
Definition zigzag_stage.h:110
size_t getActualOutputSize(int index) const override
Definition zigzag_stage.h:70
Definition fzm_format.h:25
@ ZIGZAG
ZigzagStage — zigzag encode/decode.
DataType
Element data type identifiers used in buffer and stage descriptors.
Definition fzm_format.h:103