27#include <cuda_runtime.h>
33#include <unordered_map>
49 std::is_same_v<T, uint8_t> ||
50 std::is_same_v<T, uint16_t> ||
51 std::is_same_v<T, uint32_t>,
52 "BitpackStage: T must be uint8_t, uint16_t, or uint32_t.");
58 void setInverse(
bool inv)
override { is_inverse_ = inv; }
59 bool isInverse()
const override {
return is_inverse_; }
75 if (nbits == 0 || nbits > 8 *
sizeof(T) || (nbits & (nbits - 1)) != 0)
76 throw std::invalid_argument(
77 "BitpackStage::setNBits: nbits must be a power of two "
78 "in [1, " + std::to_string(8 *
sizeof(T)) +
"], got "
79 + std::to_string(nbits));
82 uint8_t getNBits()
const {
return nbits_; }
98 bool isAutoDetect()
const {
return auto_detect_; }
104 const std::vector<void*>& inputs,
105 const std::vector<void*>& outputs,
106 const std::vector<size_t>& sizes
110 std::string
getName()
const override {
return "Bitpack"; }
111 size_t getNumInputs()
const override {
return 1; }
112 size_t getNumOutputs()
const override {
return 1; }
115 const std::vector<size_t>& input_sizes
117 if (input_sizes.empty())
return {0};
122 return {input_sizes[0]};
125 const size_t n = input_sizes[0] /
sizeof(T);
126 return {(n * nbits_ + 7) / 8};
130 const size_t max_elems = (input_sizes[0] * 8 + nbits_ - 1) / nbits_;
131 return {max_elems *
sizeof(T)};
135 std::unordered_map<std::string, size_t>
137 return {{
"output", actual_output_size_}};
141 return (index == 0) ? actual_output_size_ : 0;
147 return static_cast<uint16_t
>(StageType::BITPACK);
152 return static_cast<uint8_t
>(DataType::UNKNOWN);
155 return static_cast<uint8_t
>(DataType::UNKNOWN);
161 size_t , uint8_t* buf,
size_t max_size
163 if (max_size < 10)
return 0;
164 buf[0] =
static_cast<uint8_t
>(dataTypeOf<T>());
166 std::memcpy(buf + 2, &num_elements_,
sizeof(uint64_t));
173 if (size >= 2) nbits_ = buf[1];
174 if (size >= 10) std::memcpy(&num_elements_, buf + 2,
sizeof(uint64_t));
183 saved_nbits_ = nbits_;
184 saved_num_elements_ = num_elements_;
185 saved_output_size_ = actual_output_size_;
188 void restoreState()
override {
189 nbits_ = saved_nbits_;
190 num_elements_ = saved_num_elements_;
191 actual_output_size_ = saved_output_size_;
199 bool is_inverse_ =
false;
200 bool auto_detect_ =
false;
201 uint8_t nbits_ = 8 *
sizeof(T);
202 uint64_t num_elements_ = 0;
203 size_t actual_output_size_ = 0;
206 uint8_t saved_nbits_ = 8 *
sizeof(T);
207 uint64_t saved_num_elements_ = 0;
208 size_t saved_output_size_ = 0;
211 static constexpr DataType dataTypeOf() {
212 if (std::is_same_v<U, uint8_t>)
return DataType::UINT8;
213 if (std::is_same_v<U, uint16_t>)
return DataType::UINT16;
214 if (std::is_same_v<U, uint32_t>)
return DataType::UINT32;
215 return DataType::UINT8;
219extern template class BitpackStage<uint8_t>;
220extern template class BitpackStage<uint16_t>;
221extern template class BitpackStage<uint32_t>;
Definition bitpack_stage.h:47
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override
std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const override
Definition bitpack_stage.h:114
void saveState() override
Definition bitpack_stage.h:182
bool isGraphCompatible() const override
Definition bitpack_stage.h:196
size_t getActualOutputSize(int index) const override
Definition bitpack_stage.h:140
uint8_t getInputDataType(size_t) const override
Definition bitpack_stage.h:154
std::string getName() const override
Definition bitpack_stage.h:110
void setAutoDetect(bool enable)
Definition bitpack_stage.h:97
uint16_t getStageTypeId() const override
Definition bitpack_stage.h:146
std::unordered_map< std::string, size_t > getActualOutputSizesByName() const override
Definition bitpack_stage.h:136
size_t getMaxHeaderSize(size_t) const override
Definition bitpack_stage.h:177
void setInverse(bool inv) override
Definition bitpack_stage.h:58
uint8_t getOutputDataType(size_t) const override
Definition bitpack_stage.h:151
void deserializeHeader(const uint8_t *buf, size_t size) override
Definition bitpack_stage.h:170
size_t serializeHeader(size_t, uint8_t *buf, size_t max_size) const override
Definition bitpack_stage.h:160
void setNBits(uint8_t nbits)
Definition bitpack_stage.h:74
Base class interface for all compression stages.