FZGPUModules 2.0
GPU-accelerated modular compression pipelines
Loading...
Searching...
No Matches
merge_stage.h
Go to the documentation of this file.
1#pragma once
2
45#include "stage/stage.h"
46#include "fzm_format.h"
47#include <cuda_runtime.h>
48#include <cstdint>
49#include <cstring>
50#include <numeric>
51#include <stdexcept>
52#include <string>
53#include <unordered_map>
54#include <vector>
55
56namespace fz {
57
68class MergeStage : public Stage {
69public:
70 MergeStage() = default;
71 ~MergeStage() override = default;
72
73 // ── Configuration ───────────────────────────────────────────────────────
75 void setSegmentNames(const std::vector<std::string>& names) {
76 if (names.empty())
77 throw std::runtime_error("MergeStage: at least one segment name required");
78 if (names.size() > kMaxSegments)
79 throw std::runtime_error("MergeStage: too many segments (max "
80 + std::to_string(kMaxSegments) + ")");
81 segment_names_ = names;
82 segment_sizes_.assign(names.size(), 0);
83 }
84 const std::vector<std::string>& getSegmentNames() const { return segment_names_; }
85 size_t getNumSegments() const { return segment_names_.size(); }
86
87 // ── Stage control ───────────────────────────────────────────────────────
88 void setInverse(bool inv) override { is_inverse_ = inv; }
89 bool isInverse() const override { return is_inverse_; }
90
92 bool isGraphCompatible() const override { return true; }
93
94 // ── Port model ──────────────────────────────────────────────────────────
95 // Forward: N inputs → 1 output ("output").
96 // Inverse: 1 input → N outputs (segment_names_).
97 size_t getNumInputs() const override { return is_inverse_ ? 1 : segment_names_.size(); }
98 size_t getNumOutputs() const override { return is_inverse_ ? segment_names_.size() : 1; }
99 std::vector<std::string> getOutputNames() const override {
100 return is_inverse_ ? segment_names_ : std::vector<std::string>{"output"};
101 }
102
103 std::string getName() const override { return "Merge"; }
104
105 uint16_t getStageTypeId() const override {
106 return static_cast<uint16_t>(StageType::MERGE);
107 }
108
109 // Byte-transparent — opt out of finalize() type checking on every port.
110 uint8_t getOutputDataType(size_t) const override {
111 return static_cast<uint8_t>(DataType::UNKNOWN);
112 }
113 uint8_t getInputDataType(size_t) const override {
114 return static_cast<uint8_t>(DataType::UNKNOWN);
115 }
116
117 // ── Execution ───────────────────────────────────────────────────────────
119 cudaStream_t stream,
120 MemoryPool* pool,
121 const std::vector<void*>& inputs,
122 const std::vector<void*>& outputs,
123 const std::vector<size_t>& sizes
124 ) override;
125
126 // ── Size estimation ───────────────────────────────────────────────────
127 std::vector<size_t> estimateOutputSizes(
128 const std::vector<size_t>& input_sizes
129 ) const override {
130 if (is_inverse_) {
131 // 1 input (merged blob) → N segment outputs, sizes from the config
132 // header (restored via deserializeHeader) or cached from forward.
133 return segment_sizes_;
134 }
135 // N inputs → 1 output = pure concatenation (no in-stream header).
136 size_t total = 0;
137 for (size_t s : input_sizes) total += s;
138 return {total};
139 }
140
141 std::unordered_map<std::string, size_t>
142 getActualOutputSizesByName() const override {
143 if (is_inverse_) {
144 std::unordered_map<std::string, size_t> m;
145 for (size_t i = 0; i < segment_names_.size(); i++)
146 m[segment_names_[i]] = (i < segment_sizes_.size()) ? segment_sizes_[i] : 0;
147 return m;
148 }
149 return {{"output", merged_total_}};
150 }
151
152 size_t getActualOutputSize(int index) const override {
153 if (is_inverse_) {
154 if (index < 0 || index >= (int)segment_sizes_.size()) return 0;
155 return segment_sizes_[index];
156 }
157 return (index == 0) ? merged_total_ : 0;
158 }
159
160 // ── Serialization ─────────────────────────────────────────────────────
161 size_t serializeHeader(size_t, uint8_t* buf, size_t max_size) const override {
162 const size_t N = segment_names_.size();
163 size_t need = 1 + 4 * N;
164 for (const auto& nm : segment_names_) need += 1 + nm.size();
165 if (need > max_size) return 0;
166 size_t off = 0;
167 buf[off++] = static_cast<uint8_t>(N);
168 for (size_t i = 0; i < N; i++) {
169 uint32_t sz = (i < segment_sizes_.size()) ? static_cast<uint32_t>(segment_sizes_[i]) : 0u;
170 std::memcpy(buf + off, &sz, sizeof(uint32_t));
171 off += sizeof(uint32_t);
172 }
173 for (const auto& nm : segment_names_) {
174 buf[off++] = static_cast<uint8_t>(nm.size());
175 std::memcpy(buf + off, nm.data(), nm.size());
176 off += nm.size();
177 }
178 return off;
179 }
180
181 void deserializeHeader(const uint8_t* buf, size_t size) override {
182 if (size < 1) return;
183 size_t off = 0;
184 const size_t N = buf[off++];
185 segment_sizes_.assign(N, 0);
186 for (size_t i = 0; i < N && off + 4 <= size; i++) {
187 uint32_t sz = 0;
188 std::memcpy(&sz, buf + off, sizeof(uint32_t));
189 off += sizeof(uint32_t);
190 segment_sizes_[i] = sz;
191 }
192 segment_names_.assign(N, "");
193 for (size_t i = 0; i < N && off < size; i++) {
194 const size_t len = buf[off++];
195 if (off + len > size) break;
196 segment_names_[i].assign(reinterpret_cast<const char*>(buf + off), len);
197 off += len;
198 }
199 }
200
201 size_t getMaxHeaderSize(size_t) const override { return FZM_STAGE_CONFIG_SIZE; }
202
203 void saveState() override {
204 saved_names_ = segment_names_;
205 saved_sizes_ = segment_sizes_;
206 }
207 void restoreState() override {
208 segment_names_ = saved_names_;
209 segment_sizes_ = saved_sizes_;
210 }
211
212private:
213 static constexpr size_t kMaxSegments = 16;
214
215 bool is_inverse_ = false;
216 std::vector<std::string> segment_names_;
217 std::vector<size_t> segment_sizes_; // per-segment byte sizes (concat order)
218 size_t merged_total_ = 0;
219
220 std::vector<std::string> saved_names_;
221 std::vector<size_t> saved_sizes_;
222};
223
224} // namespace fz
Definition mempool.h:82
Definition merge_stage.h:68
size_t serializeHeader(size_t, uint8_t *buf, size_t max_size) const override
Definition merge_stage.h:161
void setSegmentNames(const std::vector< std::string > &names)
Define the N segments (concatenation order = inverse output order).
Definition merge_stage.h:75
bool isGraphCompatible() const override
Pure stream-ordered D2D memcpy in both directions — no host sync.
Definition merge_stage.h:92
void saveState() override
Definition merge_stage.h:203
size_t getActualOutputSize(int index) const override
Definition merge_stage.h:152
std::unordered_map< std::string, size_t > getActualOutputSizesByName() const override
Definition merge_stage.h:142
size_t getMaxHeaderSize(size_t) const override
Definition merge_stage.h:201
std::string getName() const override
Definition merge_stage.h:103
void deserializeHeader(const uint8_t *buf, size_t size) override
Definition merge_stage.h:181
uint16_t getStageTypeId() const override
Definition merge_stage.h:105
uint8_t getInputDataType(size_t) const override
Definition merge_stage.h:113
std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const override
Definition merge_stage.h:127
uint8_t getOutputDataType(size_t) const override
Definition merge_stage.h:110
std::vector< std::string > getOutputNames() const override
Definition merge_stage.h:99
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override
void setInverse(bool inv) override
Definition merge_stage.h:88
Definition stage.h:30
FZM binary file format definitions — structs, enums, and helpers.
Definition fzm_format.h:25
constexpr size_t FZM_STAGE_CONFIG_SIZE
Per-stage serialized config slot (bytes)
Definition fzm_format.h:65
@ UNKNOWN
Byte-transparent stages: skip type checking at finalize()
Base class interface for all compression stages.