27#include <cuda_runtime.h>
32#include <unordered_map>
37class ANSStage :
public Stage {
40 ~ANSStage()
override =
default;
49 void setProbBits(uint8_t pb) { prob_bits_ = pb; }
50 uint8_t getProbBits()
const {
return prob_bits_; }
53 void setInverse(
bool inv)
override { is_inverse_ = inv; }
54 bool isInverse()
const override {
return is_inverse_; }
57 bool isGraphCompatible()
const override {
return false; }
60 size_t getRequiredInputAlignment()
const override {
return 4; }
69 void onFinalize(
size_t estimated_inlen, MemoryPool* pool)
override;
71 size_t estimateDeviceFootprintBytes(
size_t inlen)
const override;
73 size_t estimateScratchBytes(
const std::vector<size_t>& input_sizes)
const override;
79 const std::vector<void*>& inputs,
80 const std::vector<void*>& outputs,
81 const std::vector<size_t>& sizes
85 std::string getName()
const override {
return "ANS"; }
86 size_t getNumInputs()
const override {
return 1; }
87 size_t getNumOutputs()
const override {
return 1; }
89 std::vector<size_t> estimateOutputSizes(
90 const std::vector<size_t>& input_sizes
92 if (input_sizes.empty())
return {0};
96 return {input_sizes[0] * 2 + 8192};
100 return {original_bytes_ > 0 ? original_bytes_ : input_sizes[0]};
103 std::unordered_map<std::string, size_t>
104 getActualOutputSizesByName()
const override {
105 return {{
"output", actual_output_size_}};
108 size_t getActualOutputSize(
int index)
const override {
109 return (index == 0) ? actual_output_size_ : 0;
113 uint16_t getStageTypeId()
const override {
118 uint8_t getOutputDataType(
size_t )
const override {
121 uint8_t getInputDataType(
size_t )
const override {
126 size_t serializeHeader(
127 size_t , uint8_t* buf,
size_t max_size
129 if (max_size < 12)
return 0;
131 buf[1] = buf[2] = buf[3] = 0;
132 std::memcpy(buf + 4, &original_bytes_,
sizeof(uint64_t));
136 void deserializeHeader(
const uint8_t* buf,
size_t size)
override {
140 std::memcpy(&original_bytes_, buf + 4,
sizeof(uint64_t));
143 size_t getMaxHeaderSize(
size_t )
const override {
return 12; }
145 void saveState()
override {
146 saved_prob_bits_ = prob_bits_;
147 saved_original_bytes_ = original_bytes_;
148 saved_output_size_ = actual_output_size_;
151 void restoreState()
override {
152 prob_bits_ = saved_prob_bits_;
153 original_bytes_ = saved_original_bytes_;
154 actual_output_size_ = saved_output_size_;
158 bool is_inverse_ =
false;
159 uint8_t prob_bits_ = 10;
160 uint64_t original_bytes_ = 0;
161 size_t actual_output_size_ = 0;
164 size_t cap_bytes_ = 0;
168 uint32_t* d_temp_histogram_ =
nullptr;
169 void* d_table_ =
nullptr;
170 uint8_t* d_compressed_blocks_ =
nullptr;
171 uint32_t* d_compressed_words_ =
nullptr;
172 uint32_t* d_comp_words_prefix_ =
nullptr;
173 void* d_temp_prefix_sum_ =
nullptr;
174 uint32_t* d_decode_table_ =
nullptr;
179 uint8_t last_header_bytes_[32] = {};
182 int hist_grid_dim_ = 0;
183 int hist_block_dim_ = 0;
184 int hist_shmem_use_ = 0;
185 int hist_r_per_block_ = 0;
188 uint8_t saved_prob_bits_ = 10;
189 uint64_t saved_original_bytes_ = 0;
190 size_t saved_output_size_ = 0;
194 void initScratch(
size_t inlen, MemoryPool* pool);
198 static constexpr size_t kUncoalescedStride = 128 + 5120;
Definition fzm_format.h:25
@ ANS
rANS entropy coder (GPU, via dietGPU)
@ UNKNOWN
Byte-transparent stages: skip type checking at finalize()
Base class interface for all compression stages.