FZGPUModules 2.0
GPU-accelerated modular compression pipelines
Loading...
Searching...
No Matches
stage.h
Go to the documentation of this file.
1
5#pragma once
6
7#include "fzm_format.h"
8#include <array>
9#include <cuda_runtime.h>
10#include <cstdint>
11#include <stdexcept>
12#include <string>
13#include <unordered_map>
14#include <vector>
15
16namespace fz {
17
18// Forward declaration — avoids requiring mempool.h in every stage header
19class MemoryPool;
20
30class Stage {
31public:
32 virtual ~Stage() = default;
33
48 virtual void execute(
49 cudaStream_t stream,
50 MemoryPool* pool,
51 const std::vector<void*>& inputs,
52 const std::vector<void*>& outputs,
53 const std::vector<size_t>& sizes
54 ) = 0;
55
57 virtual std::string getName() const = 0;
58
59 virtual size_t getNumInputs() const = 0;
60 virtual size_t getNumOutputs() const = 0;
61
68 virtual size_t getRequiredInputAlignment() const { return 1; }
69
74 virtual std::vector<std::string> getOutputNames() const {
75 return {"output"};
76 }
77
79 int getOutputIndex(const std::string& name) const {
80 auto names = getOutputNames();
81 for (size_t i = 0; i < names.size(); i++) {
82 if (names[i] == name) return static_cast<int>(i);
83 }
84 return -1;
85 }
86
92 virtual std::vector<size_t> estimateOutputSizes(
93 const std::vector<size_t>& input_sizes
94 ) const = 0;
95
97 virtual std::unordered_map<std::string, size_t> getActualOutputSizesByName() const = 0;
98
105 virtual size_t getActualOutputSize(int index) const {
106 auto names = getOutputNames();
107 if (index < 0 || index >= static_cast<int>(names.size())) return 0;
109 auto it = m.find(names[index]);
110 return (it != m.end()) ? it->second : 0;
111 }
112
117 virtual void setInverse(bool inverse) { (void)inverse; }
118 virtual bool isInverse() const { return false; }
119
121 virtual uint16_t getStageTypeId() const = 0;
122
124 virtual uint8_t getOutputDataType(size_t output_index) const = 0;
125
135 virtual uint8_t getInputDataType(size_t /*input_index*/) const {
136 return static_cast<uint8_t>(DataType::UNKNOWN);
137 }
138
143 virtual size_t serializeHeader(size_t output_index, uint8_t* header_buffer, size_t max_size) const {
144 (void)output_index; (void)header_buffer; (void)max_size;
145 return 0;
146 }
147
149 virtual void deserializeHeader(const uint8_t* header_buffer, size_t size) {
150 (void)header_buffer; (void)size;
151 }
152
159 virtual void saveState() {}
160 virtual void restoreState() {}
161
167 virtual void setDims(const std::array<size_t, 3>& dims) { (void)dims; }
168
186 virtual void onFinalize(size_t /*estimated_inlen*/, MemoryPool* /*pool*/) {}
187
193 virtual size_t estimateDeviceFootprintBytes(size_t /*inlen*/) const { return 0; }
194
200 virtual size_t estimatePinnedFootprintBytes(size_t /*inlen*/) const { return 0; }
201
208 virtual void postStreamSync(cudaStream_t stream) { (void)stream; }
209
211 virtual size_t getMaxHeaderSize(size_t output_index) const {
212 (void)output_index;
213 return 0;
214 }
215
228 virtual bool isGraphCompatible() const { return true; }
229
238 virtual size_t estimateScratchBytes(const std::vector<size_t>& input_sizes) const {
239 (void)input_sizes;
240 return 0;
241 }
242};
243
244} // namespace fz
Definition mempool.h:82
Definition stage.h:30
virtual size_t estimateDeviceFootprintBytes(size_t) const
Definition stage.h:193
virtual uint8_t getInputDataType(size_t) const
Definition stage.h:135
virtual std::string getName() const =0
virtual std::vector< std::string > getOutputNames() const
Definition stage.h:74
virtual size_t estimatePinnedFootprintBytes(size_t) const
Definition stage.h:200
virtual void saveState()
Definition stage.h:159
virtual bool isGraphCompatible() const
Definition stage.h:228
virtual void postStreamSync(cudaStream_t stream)
Definition stage.h:208
virtual size_t getActualOutputSize(int index) const
Definition stage.h:105
virtual void onFinalize(size_t, MemoryPool *)
Definition stage.h:186
virtual std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const =0
virtual void setInverse(bool inverse)
Definition stage.h:117
virtual uint16_t getStageTypeId() const =0
virtual size_t serializeHeader(size_t output_index, uint8_t *header_buffer, size_t max_size) const
Definition stage.h:143
virtual uint8_t getOutputDataType(size_t output_index) const =0
virtual void deserializeHeader(const uint8_t *header_buffer, size_t size)
Definition stage.h:149
virtual size_t getRequiredInputAlignment() const
Definition stage.h:68
virtual std::unordered_map< std::string, size_t > getActualOutputSizesByName() const =0
virtual void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes)=0
int getOutputIndex(const std::string &name) const
Definition stage.h:79
virtual size_t estimateScratchBytes(const std::vector< size_t > &input_sizes) const
Definition stage.h:238
virtual size_t getMaxHeaderSize(size_t output_index) const
Definition stage.h:211
virtual void setDims(const std::array< size_t, 3 > &dims)
Definition stage.h:167
FZM binary file format definitions — structs, enums, and helpers.
Definition fzm_format.h:25
@ UNKNOWN
Byte-transparent stages: skip type checking at finalize()