FZGPUModules 2.0
GPU-accelerated modular compression pipelines
Loading...
Searching...
No Matches
ginterp_stage.h
Go to the documentation of this file.
1#pragma once
2
13#include "stage/stage.h"
14#include "fzm_format.h"
15#include "fused/lorenzo_quant/lorenzo_quant.h" // for ErrorBoundMode
16
17#include <cuda_runtime.h>
18#include <array>
19#include <cmath>
20#include <cstdint>
21#include <cstring>
22#include <string>
23#include <stdexcept>
24#include <type_traits>
25#include <unordered_map>
26#include <vector>
27
28// Forward-declared in the public header so callers don't pull in the
29// cuSZ-Hi type subset. The full definition lives in cusz_type_subset.h,
30// included only inside the stage TU.
32
33namespace fz {
34
40 // ── identity / dims ─────────────────────────────────────────────────────
41 double error_bound;
43 uint32_t quant_radius;
44 uint32_t num_elements;
45 uint32_t outlier_count;
48 uint8_t ndim;
49 uint8_t eb_mode;
50 uint32_t dim_x;
51 uint32_t dim_y;
52 uint32_t dim_z;
53 uint32_t anchor_dim_x;
54 uint32_t anchor_dim_y;
55 uint32_t anchor_dim_z;
56 float user_eb;
57 float value_base;
58
59 // ── resolved INTERPOLATION_PARAMS (phase 2 auto-tune output, 39 B) ──────
60 // Both encoder and decoder must use the exact same intp_param values, so
61 // the resolved params are written here on compress and consumed on decompress.
62 // Layout chosen to mirror `INTERPOLATION_PARAMS` field-for-field; we don't
63 // memcpy the struct directly because its padding is implementation-defined.
64 double intp_alpha;
65 double intp_beta;
66 uint8_t intp_use_md[6];
67 uint8_t intp_use_natural[6];
68 uint8_t intp_reverse[6];
70 uint8_t pad[8];
71
72 // ~96 bytes. Comfortable margin under FZM_STAGE_CONFIG_SIZE (128 B);
73 // the static_assert below is the source of truth.
74
77 input_type(DataType::FLOAT32), code_type(DataType::UINT16),
78 ndim(3), eb_mode(0),
79 dim_x(0), dim_y(1), dim_z(1),
81 user_eb(0.0f), value_base(0.0f),
82 intp_alpha(1.75), intp_beta(4.0),
83 intp_use_md{1, 1, 0, 0, 0, 0},
84 intp_use_natural{0, 0, 0, 0, 0, 0},
85 intp_reverse{0, 0, 0, 0, 0, 0},
86 auto_tuning_mode(0), pad{} {}
87};
88static_assert(sizeof(GInterpConfig) <= FZM_STAGE_CONFIG_SIZE,
89 "GInterpConfig must fit in FZM_STAGE_CONFIG_SIZE");
90
203template <typename TInput = float, typename TCode = uint16_t>
204class GInterpStage : public Stage {
205public:
206 struct Config {
207 float error_bound = 1e-3f;
219 int quant_radius = 0;
220 float outlier_capacity = 0.10f;
224 std::array<size_t, 3> dims = {0, 0, 0};
228 float precomputed_value_base = 0.0f;
248 uint8_t auto_tuning_mode = 0;
249
253 double manual_alpha = 0.0;
255 double manual_beta = 0.0;
256
257 Config() = default;
258 };
259
260 explicit GInterpStage(const Config& cfg = Config()) : config_(cfg) {
261 actual_output_sizes_.resize(4, 0);
262 }
263 ~GInterpStage() override;
264
265 // ── Stage interface ──────────────────────────────────────────────────────
267 cudaStream_t stream,
268 MemoryPool* pool,
269 const std::vector<void*>& inputs,
270 const std::vector<void*>& outputs,
271 const std::vector<size_t>& sizes
272 ) override;
273
274 void postStreamSync(cudaStream_t stream) override;
275
281 void onFinalize(size_t estimated_inlen, MemoryPool* pool) override;
282
283 size_t estimateDeviceFootprintBytes(size_t /*estimated_inlen*/) const override {
284 return needsProfilingScratch() ? kProfilingErrCount * sizeof(float) : 0;
285 }
286 size_t estimatePinnedFootprintBytes(size_t /*estimated_inlen*/) const override {
287 return needsProfilingScratch() ? kProfilingErrCount * sizeof(float) : 0;
288 }
289
290 std::string getName() const override { return "GInterp"; }
291 size_t getNumInputs() const override { return is_inverse_ ? 4 : 1; }
292 size_t getNumOutputs() const override { return is_inverse_ ? 1 : 4; }
293
294 std::vector<std::string> getOutputNames() const override {
295 return {"codes", "anchor", "outlier_vals", "outlier_idxs"};
296 }
297
298 std::vector<size_t> estimateOutputSizes(
299 const std::vector<size_t>& input_sizes
300 ) const override;
301
302 std::unordered_map<std::string, size_t> getActualOutputSizesByName() const override {
303 auto names = getOutputNames();
304 std::unordered_map<std::string, size_t> r;
305 for (size_t i = 0; i < names.size() && i < actual_output_sizes_.size(); i++)
306 r[names[i]] = actual_output_sizes_[i];
307 return r;
308 }
309 size_t getActualOutputSize(int index) const override {
310 return (index >= 0 && index < static_cast<int>(actual_output_sizes_.size()))
311 ? actual_output_sizes_[index] : 0;
312 }
313
314 void saveState() override { saved_output_sizes_ = actual_output_sizes_; }
315 void restoreState() override { actual_output_sizes_ = saved_output_sizes_; }
316
317 // ── Setters ──────────────────────────────────────────────────────────────
318 void setErrorBound(float eb) { config_.error_bound = eb; }
319 void setQuantRadius(int radius) { config_.quant_radius = radius; }
320 void setOutlierCapacity(float cap) { config_.outlier_capacity = cap; }
324 void setErrorBoundMode(ErrorBoundMode m) { config_.eb_mode = m; }
325 void setValueBase(float v) { config_.precomputed_value_base = v; }
330 void setAutoTuning(uint8_t mode) { config_.auto_tuning_mode = mode; }
335 void setManualAlphaBeta(double alpha, double beta) {
336 config_.manual_alpha = alpha;
337 config_.manual_beta = beta;
338 }
339 void setDims(const std::array<size_t, 3>& dims) override;
340 void setDims(size_t x, size_t y, size_t z) {
341 setDims(std::array<size_t, 3>{x, y, z});
342 }
343
344 float getErrorBound() const { return config_.error_bound; }
345 int getQuantRadius() const { return config_.quant_radius; }
346 float getOutlierCapacity() const { return config_.outlier_capacity; }
347 ErrorBoundMode getErrorBoundMode() const { return config_.eb_mode; }
348 float getValueBase() const { return config_.precomputed_value_base; }
349 uint8_t getAutoTuningMode() const { return config_.auto_tuning_mode; }
350 std::array<size_t, 3> getDims() const { return config_.dims; }
351
352 void setInverse(bool inv) override { is_inverse_ = inv; }
353 bool isInverse() const override { return is_inverse_; }
354
358 int ndim() const {
359 if (config_.dims[2] > 1) return 3;
360 if (config_.dims[1] > 1) return 2;
361 return 1;
362 }
363
364 // ── Type / Serialization ─────────────────────────────────────────────────
365 uint16_t getStageTypeId() const override {
366 return static_cast<uint16_t>(StageType::G_INTERP);
367 }
368
369 uint8_t getOutputDataType(size_t output_index) const override {
370 switch (output_index) {
371 case 0: return static_cast<uint8_t>(codeDataType()); // codes
372 case 1: return static_cast<uint8_t>(inputDataType()); // anchor
373 case 2: return static_cast<uint8_t>(inputDataType()); // outlier_vals
374 case 3: return static_cast<uint8_t>(DataType::UINT32); // outlier_idxs
375 default: return static_cast<uint8_t>(DataType::UINT8);
376 }
377 }
378 uint8_t getInputDataType(size_t /*input_index*/) const override {
379 return static_cast<uint8_t>(inputDataType());
380 }
381
382 size_t serializeHeader(size_t output_index, uint8_t* buf, size_t max_size) const override;
383 void deserializeHeader(const uint8_t* buf, size_t size) override;
384 size_t getMaxHeaderSize(size_t /*output_index*/) const override {
385 return sizeof(GInterpConfig);
386 }
387
404 bool isGraphCompatible() const override {
405 const bool tune_ok = (config_.auto_tuning_mode == 0 ||
406 config_.auto_tuning_mode == 5);
407 const bool radius_ok = config_.quant_radius > 0;
408 const bool eb_ok = (config_.eb_mode == ErrorBoundMode::ABS) ||
409 (config_.precomputed_value_base > 0.0f);
410 const bool mode5_alpha_ok = (config_.auto_tuning_mode != 5) ||
411 (config_.manual_alpha > 0.0);
412 return tune_ok && radius_ok && eb_ok && mode5_alpha_ok;
413 }
414
415private:
416 Config config_;
417 std::vector<size_t> actual_output_sizes_;
418 std::vector<size_t> saved_output_sizes_;
419
420 bool is_inverse_ = false;
421 size_t num_elements_ = 0;
422 uint32_t actual_outlier_count_ = 0;
429 uint32_t* d_outlier_count_scratch_ = nullptr;
430
433 TInput computed_abs_eb_ = 0;
436 float computed_value_base_ = 0.0f;
437
439 std::array<size_t, 3> anchor_dims_ = {0, 0, 0};
440
441 // ── Phase 2: auto-tuning state ──────────────────────────────────────────
444 static constexpr size_t kProfilingErrCount = 36;
445 float* d_profiling_errors_ = nullptr;
446 float* h_profiling_errors_ = nullptr;
450 MemoryPool* persistent_pool_ = nullptr;
451
456 double resolved_alpha_ = 1.75;
457 double resolved_beta_ = 4.0;
458 uint8_t resolved_use_md_[6] = {1, 1, 0, 0, 0, 0};
459 uint8_t resolved_use_natural_[6] = {0, 0, 0, 0, 0, 0};
460 uint8_t resolved_reverse_[6] = {0, 0, 0, 0, 0, 0};
461
464 bool needsProfilingScratch() const {
465 const uint8_t m = config_.auto_tuning_mode;
466 return m == 1 || m == 2 || m == 3 || m == 4;
467 }
468
471 void initProfilingScratch(MemoryPool* pool);
473 void initOutlierCountScratch(MemoryPool* pool);
477 INTERPOLATION_PARAMS buildIntpParam() const;
478
479 static DataType inputDataType() {
480 if (std::is_same<TInput, float>::value) return DataType::FLOAT32;
481 if (std::is_same<TInput, double>::value) return DataType::FLOAT64;
482 return DataType::FLOAT32;
483 }
484 static DataType codeDataType() {
485 if (std::is_same<TCode, uint8_t>::value) return DataType::UINT8;
486 if (std::is_same<TCode, uint16_t>::value) return DataType::UINT16;
487 if (std::is_same<TCode, uint32_t>::value) return DataType::UINT32;
488 return DataType::UINT16;
489 }
490 size_t getMaxOutlierCount(size_t n) const {
491 return static_cast<size_t>(std::ceil(n * config_.outlier_capacity));
492 }
493};
494
495extern template class GInterpStage<float, uint8_t>;
496extern template class GInterpStage<float, uint16_t>;
497extern template class GInterpStage<float, uint32_t>;
498extern template class GInterpStage<double, uint8_t>;
499extern template class GInterpStage<double, uint16_t>;
500extern template class GInterpStage<double, uint32_t>;
501
502} // namespace fz
Definition ginterp_stage.h:204
size_t serializeHeader(size_t output_index, uint8_t *buf, size_t max_size) const override
size_t getActualOutputSize(int index) const override
Definition ginterp_stage.h:309
void setDims(const std::array< size_t, 3 > &dims) override
uint8_t getInputDataType(size_t) const override
Definition ginterp_stage.h:378
uint8_t getOutputDataType(size_t output_index) const override
Definition ginterp_stage.h:369
std::unordered_map< std::string, size_t > getActualOutputSizesByName() const override
Definition ginterp_stage.h:302
size_t getMaxHeaderSize(size_t) const override
Definition ginterp_stage.h:384
void setAutoTuning(uint8_t mode)
Definition ginterp_stage.h:330
void onFinalize(size_t estimated_inlen, MemoryPool *pool) override
bool isGraphCompatible() const override
Definition ginterp_stage.h:404
void setManualAlphaBeta(double alpha, double beta)
Definition ginterp_stage.h:335
size_t estimateDeviceFootprintBytes(size_t) const override
Definition ginterp_stage.h:283
std::vector< std::string > getOutputNames() const override
Definition ginterp_stage.h:294
void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes) override
void postStreamSync(cudaStream_t stream) override
void saveState() override
Definition ginterp_stage.h:314
uint16_t getStageTypeId() const override
Definition ginterp_stage.h:365
int ndim() const
Definition ginterp_stage.h:358
void setErrorBoundMode(ErrorBoundMode m)
Definition ginterp_stage.h:324
void setInverse(bool inv) override
Definition ginterp_stage.h:352
std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const override
void deserializeHeader(const uint8_t *buf, size_t size) override
size_t estimatePinnedFootprintBytes(size_t) const override
Definition ginterp_stage.h:286
std::string getName() const override
Definition ginterp_stage.h:290
Definition mempool.h:82
Definition stage.h:30
FZM binary file format definitions — structs, enums, and helpers.
Fused Lorenzo predictor and quantizer stage.
Definition fzm_format.h:25
ErrorBoundMode
Definition lorenzo_quant.h:30
@ ABS
Absolute error bound.
constexpr size_t FZM_STAGE_CONFIG_SIZE
Per-stage serialized config slot (bytes)
Definition fzm_format.h:65
@ G_INTERP
Spline interpolation predictor + quantizer (cuSZ-Hi G-Interp)
DataType
Element data type identifiers used in buffer and stage descriptors.
Definition fzm_format.h:109
Base class interface for all compression stages.
Definition cusz_type_subset.h:32
Definition ginterp_stage.h:39
float user_eb
Original user-specified bound (before mode conversion).
Definition ginterp_stage.h:56
uint8_t eb_mode
ErrorBoundMode cast to uint8_t.
Definition ginterp_stage.h:49
uint8_t auto_tuning_mode
0=off, 1=cheap, 3=full, 4=full+alpha sweep, 5+=manual α/β.
Definition ginterp_stage.h:69
float value_base
value_range (NOA) / max(|data|) (REL) used in conversion.
Definition ginterp_stage.h:57
uint8_t intp_use_md[6]
Resolved use_md[level], booleans as u8.
Definition ginterp_stage.h:66
uint8_t pad[8]
Reserved for future fields (alignment also).
Definition ginterp_stage.h:70
uint32_t dim_x
X (fast) dimension.
Definition ginterp_stage.h:50
uint32_t anchor_dim_x
Anchor grid X extent.
Definition ginterp_stage.h:53
double error_bound
Definition ginterp_stage.h:41
uint32_t dim_z
Z dimension.
Definition ginterp_stage.h:52
double intp_alpha
Resolved alpha (auto-tuned from rel_eb or fixed).
Definition ginterp_stage.h:64
uint32_t num_elements
Total element count (= dim_x*dim_y*dim_z).
Definition ginterp_stage.h:44
uint32_t anchor_dim_y
Anchor grid Y extent.
Definition ginterp_stage.h:54
uint32_t outlier_count
Actual outlier count (post-execute).
Definition ginterp_stage.h:45
uint32_t anchor_dim_z
Anchor grid Z extent.
Definition ginterp_stage.h:55
double intp_beta
Resolved beta (default 4.0).
Definition ginterp_stage.h:65
uint32_t quant_radius
Quantization radius (codes lie in [0, 2*radius)).
Definition ginterp_stage.h:43
DataType code_type
Quant code type (1 B).
Definition ginterp_stage.h:47
uint8_t ndim
Spatial dimensionality (3 in MVP).
Definition ginterp_stage.h:48
uint32_t dim_y
Y dimension.
Definition ginterp_stage.h:51
DataType input_type
Float input type (1 B).
Definition ginterp_stage.h:46