FZGPUModules 2.0
GPU-accelerated modular compression pipelines
Loading...
Searching...
No Matches
ans_stage.h
Go to the documentation of this file.
1#pragma once
2
25#include "stage/stage.h"
26#include "fzm_format.h"
27#include <cuda_runtime.h>
28#include <cstdint>
29#include <cstring>
30#include <stdexcept>
31#include <string>
32#include <unordered_map>
33#include <vector>
34
35namespace fz {
36
37class ANSStage : public Stage {
38public:
39 ANSStage() = default;
40 ~ANSStage() override = default;
41
42 // ── Configuration ─────────────────────────────────────────────────────────
43
49 void setProbBits(uint8_t pb) { prob_bits_ = pb; }
50 uint8_t getProbBits() const { return prob_bits_; }
51
52 // ── Stage control ─────────────────────────────────────────────────────────
53 void setInverse(bool inv) override { is_inverse_ = inv; }
54 bool isInverse() const override { return is_inverse_; }
55
56 // D2H copies occur in both encode (header readback) and decode (header peek).
57 bool isGraphCompatible() const override { return false; }
58
59 // dietGPU requires input aligned to 4 bytes (kANSRequiredAlignment).
60 size_t getRequiredInputAlignment() const override { return 4; }
61
62 // ── Pool lifecycle ────────────────────────────────────────────────────────
63
69 void onFinalize(size_t estimated_inlen, MemoryPool* pool) override;
70
71 size_t estimateDeviceFootprintBytes(size_t inlen) const override;
72
73 size_t estimateScratchBytes(const std::vector<size_t>& input_sizes) const override;
74
75 // ── Execution ─────────────────────────────────────────────────────────────
76 void execute(
77 cudaStream_t stream,
78 MemoryPool* pool,
79 const std::vector<void*>& inputs,
80 const std::vector<void*>& outputs,
81 const std::vector<size_t>& sizes
82 ) override;
83
84 // ── Metadata ──────────────────────────────────────────────────────────────
85 std::string getName() const override { return "ANS"; }
86 size_t getNumInputs() const override { return 1; }
87 size_t getNumOutputs() const override { return 1; }
88
89 std::vector<size_t> estimateOutputSizes(
90 const std::vector<size_t>& input_sizes
91 ) const override {
92 if (input_sizes.empty()) return {0};
93 if (!is_inverse_) {
94 // ANS worst-case expansion is ~1.25× per block plus fixed header overhead.
95 // 2× + 8 KiB is a conservative but safe bound for all realistic inputs.
96 return {input_sizes[0] * 2 + 8192};
97 }
98 // original_bytes_ is restored from the serialized FZM header before execute();
99 // fall back to input size if deserializeHeader() has not yet been called.
100 return {original_bytes_ > 0 ? original_bytes_ : input_sizes[0]};
101 }
102
103 std::unordered_map<std::string, size_t>
104 getActualOutputSizesByName() const override {
105 return {{"output", actual_output_size_}};
106 }
107
108 size_t getActualOutputSize(int index) const override {
109 return (index == 0) ? actual_output_size_ : 0;
110 }
111
112 // ── Type system ───────────────────────────────────────────────────────────
113 uint16_t getStageTypeId() const override {
114 return static_cast<uint16_t>(StageType::ANS);
115 }
116
117 // Byte-transparent: opt out of pipeline type-compatibility checking.
118 uint8_t getOutputDataType(size_t /*output_index*/) const override {
119 return static_cast<uint8_t>(DataType::UNKNOWN);
120 }
121 uint8_t getInputDataType(size_t /*input_index*/) const override {
122 return static_cast<uint8_t>(DataType::UNKNOWN);
123 }
124
125 // ── Serialization ─────────────────────────────────────────────────────────
126 size_t serializeHeader(
127 size_t /*output_index*/, uint8_t* buf, size_t max_size
128 ) const override {
129 if (max_size < 12) return 0;
130 buf[0] = prob_bits_;
131 buf[1] = buf[2] = buf[3] = 0;
132 std::memcpy(buf + 4, &original_bytes_, sizeof(uint64_t));
133 return 12;
134 }
135
136 void deserializeHeader(const uint8_t* buf, size_t size) override {
137 if (size >= 1)
138 prob_bits_ = buf[0];
139 if (size >= 12)
140 std::memcpy(&original_bytes_, buf + 4, sizeof(uint64_t));
141 }
142
143 size_t getMaxHeaderSize(size_t /*output_index*/) const override { return 12; }
144
145 void saveState() override {
146 saved_prob_bits_ = prob_bits_;
147 saved_original_bytes_ = original_bytes_;
148 saved_output_size_ = actual_output_size_;
149 }
150
151 void restoreState() override {
152 prob_bits_ = saved_prob_bits_;
153 original_bytes_ = saved_original_bytes_;
154 actual_output_size_ = saved_output_size_;
155 }
156
157private:
158 bool is_inverse_ = false;
159 uint8_t prob_bits_ = 10; // kANSDefaultProbBits
160 uint64_t original_bytes_ = 0; // set by forward execute; used by inverse estimateOutputSizes
161 size_t actual_output_size_ = 0;
162
163 // Capacity (input bytes) of current scratch allocation. Grow-only.
164 size_t cap_bytes_ = 0;
165
166 // Persistent scratch device pointers, sub-allocated from MemoryPool.
167 // All are null until initScratch() is called.
168 uint32_t* d_temp_histogram_ = nullptr; // uint32_t[256]
169 void* d_table_ = nullptr; // uint4[256] (encode table: pdf/cdf/mul/shift)
170 uint8_t* d_compressed_blocks_ = nullptr; // uint8_t[max_blocks * kUncoalescedStride]
171 uint32_t* d_compressed_words_ = nullptr; // uint32_t[max_blocks]
172 uint32_t* d_comp_words_prefix_ = nullptr; // uint32_t[max_blocks]
173 void* d_temp_prefix_sum_ = nullptr; // CUB temp storage (nullptr when blocks ≤ 512)
174 uint32_t* d_decode_table_ = nullptr; // uint32_t[1 << prob_bits_]
175
176 // D2H readback buffer for ANSCoalescedHeader after forward encode.
177 // Stored as raw bytes to avoid pulling the dietgpu headers into this header.
178 // sizeof(ANSCoalescedHeader) == 32.
179 uint8_t last_header_bytes_[32] = {};
180
181 // Histogram launch params — computed in initScratch(), reused every execute().
182 int hist_grid_dim_ = 0;
183 int hist_block_dim_ = 0;
184 int hist_shmem_use_ = 0;
185 int hist_r_per_block_ = 0;
186
187 // saveState / restoreState snapshots
188 uint8_t saved_prob_bits_ = 10;
189 uint64_t saved_original_bytes_ = 0;
190 size_t saved_output_size_ = 0;
191
192 // Allocates all 7 scratch buffers from pool and computes histogram launch
193 // params for the given input capacity. Replaces any previous allocation.
194 void initScratch(size_t inlen, MemoryPool* pool);
195
196 // Per-block scratch stride (bytes): ANSWarpState (128 B) + max raw compressed
197 // block (5120 B = roundUp(4096 + 4096/4, 16)). Derived from dietGPU constants.
198 static constexpr size_t kUncoalescedStride = 128 + 5120; // 5248
199};
200
201} // namespace fz
FZM binary file format definitions — structs, enums, and helpers.
Definition fzm_format.h:25
@ ANS
rANS entropy coder (GPU, via dietGPU)
@ UNKNOWN
Byte-transparent stages: skip type checking at finalize()
Base class interface for all compression stages.