FZGPUModules 2.0
GPU-accelerated modular compression pipelines
Loading...
Searching...
No Matches
rre_stage.h
Go to the documentation of this file.
1#pragma once
2
28#include "stage/stage.h"
29#include "fzm_format.h"
30#include <cuda_runtime.h>
31#include <cstdint>
32#include <cstring>
33#include <stdexcept>
34#include <string>
35#include <unordered_map>
36#include <vector>
37
38namespace fz {
39
54class RREStage : public Stage {
55public:
56 RREStage()
57 : is_inverse_(false)
58 , chunk_size_(16384)
59 , word_size_(1)
60 , actual_output_size_(0)
61 , cached_orig_bytes_(0)
62 , d_scratch_(nullptr)
63 , d_sizes_dev_(nullptr)
64 , d_clean_dev_(nullptr)
65 , d_dst_off_dev_(nullptr)
66 , scratch_capacity_(0)
67 {}
68
69 ~RREStage() override;
70
71 // ── Stage control ──────────────────────────────────────────────────────
72 void setInverse(bool inv) override { is_inverse_ = inv; }
73 bool isInverse() const override { return is_inverse_; }
74
78 bool isGraphCompatible() const override { return !is_inverse_; }
79
80 void setChunkSize(size_t bytes) { chunk_size_ = static_cast<uint32_t>(bytes); }
81 void setWordSize(size_t bytes) { word_size_ = static_cast<uint8_t>(bytes); }
82
83 size_t getChunkSize() const { return chunk_size_; }
84 size_t getRequiredInputAlignment() const override { return chunk_size_; }
85 int getWordSize() const { return static_cast<int>(word_size_); }
86 uint32_t getCachedOrigBytes() const { return cached_orig_bytes_; }
87
88 // ── Execution ──────────────────────────────────────────────────────────
89 void execute(
90 cudaStream_t stream,
91 MemoryPool* pool,
92 const std::vector<void*>& inputs,
93 const std::vector<void*>& outputs,
94 const std::vector<size_t>& sizes
95 ) override;
96 void postStreamSync(cudaStream_t stream) override;
97
98 // ── Metadata ───────────────────────────────────────────────────────────
99 std::string getName() const override { return "RRE"; }
100 size_t getNumInputs() const override { return 1; }
101 size_t getNumOutputs() const override { return 1; }
102
103 std::vector<size_t> estimateOutputSizes(
104 const std::vector<size_t>& input_sizes
105 ) const override {
106 if (is_inverse_) {
107 if (cached_orig_bytes_ > 0)
108 return {static_cast<size_t>(cached_orig_bytes_)};
109 return {input_sizes.empty() ? 0 : input_sizes[0]};
110 }
111 // Forward: worst case = original data + stream header.
112 const size_t n_bytes = input_sizes.empty() ? 0 : input_sizes[0];
113 const size_t n_chunks = (n_bytes + chunk_size_ - 1) / chunk_size_;
114 const size_t hdr = 4 + 4 + 4 * n_chunks;
115 return {n_bytes + hdr};
116 }
117
118 std::unordered_map<std::string, size_t>
120 size_t getActualOutputSize(int index) const override;
121
131 const std::vector<size_t>& input_sizes
132 ) const override {
133 if (is_inverse_ || input_sizes.empty()) return 0;
134 const size_t in_bytes = input_sizes[0];
135 const size_t n_chunks = (in_bytes + chunk_size_ - 1) / chunk_size_;
136 return n_chunks * (static_cast<size_t>(chunk_size_) + 3 * sizeof(uint32_t));
137 }
138
139 uint16_t getStageTypeId() const override {
140 return static_cast<uint16_t>(StageType::RRE);
141 }
142
143 uint8_t getOutputDataType(size_t) const override {
144 return static_cast<uint8_t>(DataType::UINT8);
145 }
146
147 // ── Serialization ──────────────────────────────────────────────────────
149 size_t output_index, uint8_t* buf, size_t max_size
150 ) const override {
151 (void)output_index;
152 if (max_size < 9) return 0;
153 std::memcpy(buf, &chunk_size_, sizeof(uint32_t));
154 buf[4] = word_size_;
155 std::memcpy(buf + 5, &cached_orig_bytes_, sizeof(uint32_t));
156 return 9;
157 }
158
159 void deserializeHeader(const uint8_t* buf, size_t size) override {
160 if (size >= 4) std::memcpy(&chunk_size_, buf, sizeof(uint32_t));
161 if (size >= 5) word_size_ = buf[4];
162 if (size >= 9) std::memcpy(&cached_orig_bytes_, buf + 5, sizeof(uint32_t));
163 }
164
165 size_t getMaxHeaderSize(size_t) const override { return 9; }
166
167 void saveState() override {
168 saved_chunk_size_ = chunk_size_;
169 saved_word_size_ = word_size_;
170 saved_cached_orig_bytes_ = cached_orig_bytes_;
171 }
172
173 void restoreState() override {
174 chunk_size_ = saved_chunk_size_;
175 word_size_ = saved_word_size_;
176 cached_orig_bytes_ = saved_cached_orig_bytes_;
177 }
178
179private:
180 bool is_inverse_;
181 uint32_t chunk_size_;
182 uint32_t saved_chunk_size_ = 0;
183 uint8_t word_size_;
184 uint8_t saved_word_size_ = 0;
185 size_t actual_output_size_;
186 uint32_t cached_orig_bytes_ = 0;
187 uint32_t saved_cached_orig_bytes_ = 0;
188
189 // ── Persistent forward scratch buffers ───────────────────────────────────
190 uint8_t* d_scratch_;
191 uint32_t* d_sizes_dev_;
192 uint32_t* d_clean_dev_;
193 uint32_t* d_dst_off_dev_;
194 mutable bool tail_readback_pending_ = false;
195 mutable cudaStream_t tail_readback_stream_ = nullptr;
196 mutable uint32_t tail_last_index_ = 0;
197 mutable uint8_t* tail_output_ptr_ = nullptr;
198 size_t scratch_capacity_;
199 MemoryPool* scratch_pool_owner_ = nullptr;
200 bool scratch_from_pool_ = false;
201};
202
203} // namespace fz
Definition mempool.h:82
Definition rre_stage.h:54
void setInverse(bool inv) override
Definition rre_stage.h:72
void postStreamSync(cudaStream_t stream) override
size_t getMaxHeaderSize(size_t) const override
Definition rre_stage.h:165
size_t estimateScratchBytes(const std::vector< size_t > &input_sizes) const override
Definition rre_stage.h:130
size_t getActualOutputSize(int index) const override
size_t getRequiredInputAlignment() const override
Definition rre_stage.h:84
uint8_t getOutputDataType(size_t) const override
Definition rre_stage.h:143
bool isGraphCompatible() const override
Definition rre_stage.h:78
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override
std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const override
Definition rre_stage.h:103
void deserializeHeader(const uint8_t *buf, size_t size) override
Definition rre_stage.h:159
std::string getName() const override
Definition rre_stage.h:99
std::unordered_map< std::string, size_t > getActualOutputSizesByName() const override
size_t serializeHeader(size_t output_index, uint8_t *buf, size_t max_size) const override
Definition rre_stage.h:148
uint16_t getStageTypeId() const override
Definition rre_stage.h:139
void saveState() override
Definition rre_stage.h:167
Definition stage.h:30
FZM binary file format definitions — structs, enums, and helpers.
Definition fzm_format.h:25
@ RRE
Repetition-Reduction Encoding (LC framework lossless component)
Base class interface for all compression stages.