FZGPUModules 2.0
GPU-accelerated modular compression pipelines
Loading...
Searching...
No Matches
bitpack_stage.h
Go to the documentation of this file.
1#pragma once
2
25#include "stage/stage.h"
26#include "fzm_format.h"
27#include <cuda_runtime.h>
28#include <cstdint>
29#include <cstring>
30#include <stdexcept>
31#include <string>
32#include <type_traits>
33#include <unordered_map>
34#include <vector>
35
36namespace fz {
37
46template<typename T>
47class BitpackStage : public Stage {
48 static_assert(
49 std::is_same_v<T, uint8_t> ||
50 std::is_same_v<T, uint16_t> ||
51 std::is_same_v<T, uint32_t>,
52 "BitpackStage: T must be uint8_t, uint16_t, or uint32_t.");
53
54public:
55 BitpackStage() = default;
56
57 // ── Stage control ──────────────────────────────────────────────────────────
58 void setInverse(bool inv) override { is_inverse_ = inv; }
59 bool isInverse() const override { return is_inverse_; }
60
61 // ── Configuration ──────────────────────────────────────────────────────────
62
74 void setNBits(uint8_t nbits) {
75 if (nbits == 0 || nbits > 8 * sizeof(T) || (nbits & (nbits - 1)) != 0)
76 throw std::invalid_argument(
77 "BitpackStage::setNBits: nbits must be a power of two "
78 "in [1, " + std::to_string(8 * sizeof(T)) + "], got "
79 + std::to_string(nbits));
80 nbits_ = nbits;
81 }
82 uint8_t getNBits() const { return nbits_; }
83
97 void setAutoDetect(bool enable) { auto_detect_ = enable; }
98 bool isAutoDetect() const { return auto_detect_; }
99
100 // ── Execution ──────────────────────────────────────────────────────────────
102 cudaStream_t stream,
103 MemoryPool* pool,
104 const std::vector<void*>& inputs,
105 const std::vector<void*>& outputs,
106 const std::vector<size_t>& sizes
107 ) override;
108
109 // ── Metadata ───────────────────────────────────────────────────────────────
110 std::string getName() const override { return "Bitpack"; }
111 size_t getNumInputs() const override { return 1; }
112 size_t getNumOutputs() const override { return 1; }
113
114 std::vector<size_t> estimateOutputSizes(
115 const std::vector<size_t>& input_sizes
116 ) const override {
117 if (input_sizes.empty()) return {0};
118 if (!is_inverse_) {
119 if (auto_detect_) {
120 // nbits is unknown until execute() scans the data; return worst
121 // case (full-width, no compression) so PREALLOCATE has enough room.
122 return {input_sizes[0]};
123 }
124 // Forward: packed output is ceil(n * nbits / 8) bytes.
125 const size_t n = input_sizes[0] / sizeof(T);
126 return {(n * nbits_ + 7) / 8};
127 } else {
128 // Inverse: worst case — every packed bit expands to a full element.
129 // input_sizes[0] is the packed byte count; max elements = bytes * (8/nbits).
130 const size_t max_elems = (input_sizes[0] * 8 + nbits_ - 1) / nbits_;
131 return {max_elems * sizeof(T)};
132 }
133 }
134
135 std::unordered_map<std::string, size_t>
136 getActualOutputSizesByName() const override {
137 return {{"output", actual_output_size_}};
138 }
139
140 size_t getActualOutputSize(int index) const override {
141 return (index == 0) ? actual_output_size_ : 0;
142 }
143
144 // ── Type system ────────────────────────────────────────────────────────────
145
146 uint16_t getStageTypeId() const override {
147 return static_cast<uint16_t>(StageType::BITPACK);
148 }
149
150 // Packed byte stream has no meaningful element type; opt out of type checking.
151 uint8_t getOutputDataType(size_t /*output_index*/) const override {
152 return static_cast<uint8_t>(DataType::UNKNOWN);
153 }
154 uint8_t getInputDataType(size_t /*input_index*/) const override {
155 return static_cast<uint8_t>(DataType::UNKNOWN);
156 }
157
158 // ── Serialization ──────────────────────────────────────────────────────────
159
161 size_t /*output_index*/, uint8_t* buf, size_t max_size
162 ) const override {
163 if (max_size < 10) return 0;
164 buf[0] = static_cast<uint8_t>(dataTypeOf<T>());
165 buf[1] = nbits_;
166 std::memcpy(buf + 2, &num_elements_, sizeof(uint64_t));
167 return 10;
168 }
169
170 void deserializeHeader(const uint8_t* buf, size_t size) override {
171 // buf[0] (DataType) is used by the factory to pick the right instantiation.
172 // We only need nbits and num_elements here.
173 if (size >= 2) nbits_ = buf[1];
174 if (size >= 10) std::memcpy(&num_elements_, buf + 2, sizeof(uint64_t));
175 }
176
177 size_t getMaxHeaderSize(size_t /*output_index*/) const override { return 10; }
178
179 // saveState/restoreState: deserializeHeader (called during decompression
180 // setup) overwrites num_elements with the value from the file header.
181 // Save the forward-pass values so they can be restored afterward.
182 void saveState() override {
183 saved_nbits_ = nbits_;
184 saved_num_elements_ = num_elements_;
185 saved_output_size_ = actual_output_size_;
186 }
187
188 void restoreState() override {
189 nbits_ = saved_nbits_;
190 num_elements_ = saved_num_elements_;
191 actual_output_size_ = saved_output_size_;
192 }
193
194 // Auto-detect requires a D2H sync to read the scanned max, so it cannot
195 // be recorded inside a CUDA Graph.
196 bool isGraphCompatible() const override { return !auto_detect_; }
197
198private:
199 bool is_inverse_ = false;
200 bool auto_detect_ = false;
201 uint8_t nbits_ = 8 * sizeof(T); // default: keep all bits (identity)
202 uint64_t num_elements_ = 0; // set by forward execute; used by inverse
203 size_t actual_output_size_ = 0;
204
205 // saveState snapshots
206 uint8_t saved_nbits_ = 8 * sizeof(T);
207 uint64_t saved_num_elements_ = 0;
208 size_t saved_output_size_ = 0;
209
210 template<typename U>
211 static constexpr DataType dataTypeOf() {
212 if (std::is_same_v<U, uint8_t>) return DataType::UINT8;
213 if (std::is_same_v<U, uint16_t>) return DataType::UINT16;
214 if (std::is_same_v<U, uint32_t>) return DataType::UINT32;
215 return DataType::UINT8; // unreachable
216 }
217};
218
219extern template class BitpackStage<uint8_t>;
220extern template class BitpackStage<uint16_t>;
221extern template class BitpackStage<uint32_t>;
222
223} // namespace fz
Definition bitpack_stage.h:47
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override
std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const override
Definition bitpack_stage.h:114
void saveState() override
Definition bitpack_stage.h:182
bool isGraphCompatible() const override
Definition bitpack_stage.h:196
size_t getActualOutputSize(int index) const override
Definition bitpack_stage.h:140
uint8_t getInputDataType(size_t) const override
Definition bitpack_stage.h:154
std::string getName() const override
Definition bitpack_stage.h:110
void setAutoDetect(bool enable)
Definition bitpack_stage.h:97
uint16_t getStageTypeId() const override
Definition bitpack_stage.h:146
std::unordered_map< std::string, size_t > getActualOutputSizesByName() const override
Definition bitpack_stage.h:136
size_t getMaxHeaderSize(size_t) const override
Definition bitpack_stage.h:177
void setInverse(bool inv) override
Definition bitpack_stage.h:58
uint8_t getOutputDataType(size_t) const override
Definition bitpack_stage.h:151
void deserializeHeader(const uint8_t *buf, size_t size) override
Definition bitpack_stage.h:170
size_t serializeHeader(size_t, uint8_t *buf, size_t max_size) const override
Definition bitpack_stage.h:160
void setNBits(uint8_t nbits)
Definition bitpack_stage.h:74
Definition mempool.h:66
Definition stage.h:30
FZM binary file format definitions — structs, enums, and helpers.
DataType
Element data type identifiers used in buffer and stage descriptors.
Definition fzm_format.h:102
Base class interface for all compression stages.