FZGPUModules 2.0
GPU-accelerated modular compression pipelines
Loading...
Searching...
No Matches
adm_stage.h
Go to the documentation of this file.
1#pragma once
2
27#include "stage/stage.h"
28#include "fzm_format.h"
29#include <cuda_runtime.h>
30#include <cstdint>
31#include <cstring>
32#include <stdexcept>
33#include <string>
34#include <unordered_map>
35#include <vector>
36
37namespace fz {
38
40enum class ADMDtype : uint8_t { U16 = 0, U32 = 1 };
41
42class ADMStage : public Stage {
43public:
44 ADMStage() = default;
45 ~ADMStage() override = default;
46
47 // ── Configuration ─────────────────────────────────────────────────────────
48
54 void setDtype(ADMDtype dt) { dtype_ = dt; }
55 ADMDtype getDtype() const { return dtype_; }
56
57 // ── Stage control ─────────────────────────────────────────────────────────
58 void setInverse(bool inv) override { is_inverse_ = inv; }
59 bool isInverse() const override { return is_inverse_; }
60
61 // D2H copies occur in compress (payload-size readback) and decompress (header peek).
62 bool isGraphCompatible() const override { return false; }
63
64 // ── Pool lifecycle ────────────────────────────────────────────────────────
65 void onFinalize(size_t estimated_inlen, MemoryPool* pool) override;
66 size_t estimateDeviceFootprintBytes(size_t inlen) const override;
67 size_t estimateScratchBytes(const std::vector<size_t>& input_sizes) const override;
68
69 // ── Execution ─────────────────────────────────────────────────────────────
70 void execute(
71 cudaStream_t stream,
72 MemoryPool* pool,
73 const std::vector<void*>& inputs,
74 const std::vector<void*>& outputs,
75 const std::vector<size_t>& sizes
76 ) override;
77
78 // ── Metadata ──────────────────────────────────────────────────────────────
79 std::string getName() const override { return "ADM"; }
80 size_t getNumInputs() const override { return 1; }
81 size_t getNumOutputs() const override { return 1; }
82
83 std::vector<size_t> estimateOutputSizes(
84 const std::vector<size_t>& input_sizes
85 ) const override;
86
87 std::unordered_map<std::string, size_t>
88 getActualOutputSizesByName() const override {
89 return {{"output", actual_output_size_}};
90 }
91
92 size_t getActualOutputSize(int index) const override {
93 return (index == 0) ? actual_output_size_ : 0;
94 }
95
96 // ── Type system ───────────────────────────────────────────────────────────
97 uint16_t getStageTypeId() const override {
98 return static_cast<uint16_t>(StageType::ADM);
99 }
100
101 // Input type: U16 or U32 (used by pipeline finalize() for type checking).
102 uint8_t getInputDataType(size_t /*idx*/) const override {
103 return dtype_ == ADMDtype::U16
104 ? static_cast<uint8_t>(DataType::UINT16)
105 : static_cast<uint8_t>(DataType::UINT32);
106 }
107
108 // Output is an opaque ADM payload — opt out of downstream type checking.
109 uint8_t getOutputDataType(size_t /*idx*/) const override {
110 return static_cast<uint8_t>(DataType::UNKNOWN);
111 }
112
113 // ── Serialization ─────────────────────────────────────────────────────────
114 size_t serializeHeader(
115 size_t /*output_index*/, uint8_t* buf, size_t max_size
116 ) const override {
117 if (max_size < 12) return 0;
118 buf[0] = static_cast<uint8_t>(dtype_);
119 buf[1] = buf[2] = buf[3] = 0;
120 std::memcpy(buf + 4, &num_elements_, sizeof(uint64_t));
121 return 12;
122 }
123
124 void deserializeHeader(const uint8_t* buf, size_t size) override {
125 if (size >= 1) dtype_ = static_cast<ADMDtype>(buf[0]);
126 if (size >= 12) std::memcpy(&num_elements_, buf + 4, sizeof(uint64_t));
127 }
128
129 size_t getMaxHeaderSize(size_t /*output_index*/) const override { return 12; }
130
131 void saveState() override {
132 saved_dtype_ = dtype_;
133 saved_num_elements_ = num_elements_;
134 saved_output_size_ = actual_output_size_;
135 }
136
137 void restoreState() override {
138 dtype_ = saved_dtype_;
139 num_elements_ = saved_num_elements_;
140 actual_output_size_ = saved_output_size_;
141 }
142
143private:
144 ADMDtype dtype_ = ADMDtype::U16;
145 bool is_inverse_ = false;
146 uint64_t num_elements_ = 0; // set by forward execute; restored by deserializeHeader
147 size_t actual_output_size_ = 0;
148
149 // Capacity (elements) of current scratch allocation. Grow-only.
150 size_t cap_elements_ = 0;
151
152 // Persistent scratch device pointers.
153 int* d_signal_length_ = nullptr; // gsize × sizeof(int)
154 int* d_output_lengths_ = nullptr; // (gsize+1) × sizeof(int)
155 void* d_centers_ = nullptr; // gsize × sizeof(dtype)
156 uint32_t* d_block_flags_ = nullptr; // flags_words × sizeof(uint32_t)
157 uint8_t* d_codes_ = nullptr; // num_elements × 1
158 uint8_t* d_concat_signals_ = nullptr; // num_elements × kMaxSignalBytes
159 uint8_t* d_bit_signals_ = nullptr; // num_elements × kMaxSignalBytes (thrust path)
160 int* d_loc_offset_ = nullptr; // (gsize+1) × sizeof(int)
161 int* d_prefix_state_ = nullptr; // (gsize+1) × sizeof(int)
162 unsigned int* d_overflow_flag_ = nullptr; // 1 word; checked after kernels in debug builds
163
164 // saveState / restoreState snapshots.
165 ADMDtype saved_dtype_ = ADMDtype::U16;
166 uint64_t saved_num_elements_ = 0;
167 size_t saved_output_size_ = 0;
168
169 // Allocates / reallocates all 9 scratch buffers from pool.
170 void initScratch(size_t num_elements, MemoryPool* pool);
171
172 // Returns the number of signal bytes per element for the current dtype.
173 int maxSignalBytes() const;
174 // Returns sizeof of the center element for the current dtype.
175 size_t centerElemBytes() const;
176};
177
178} // namespace fz
FZM binary file format definitions — structs, enums, and helpers.
Definition fzm_format.h:25
ADMDtype
Definition adm_stage.h:40
@ ADM
Adaptive Data Mapping transform (MANS)
DataType
Element data type identifiers used in buffer and stage descriptors.
Definition fzm_format.h:104
@ UNKNOWN
Byte-transparent stages: skip type checking at finalize()
Base class interface for all compression stages.