FZGPUModules 1.0
GPU-accelerated modular compression pipeline
Loading...
Searching...
No Matches
stage.h
1#pragma once
2
3#include "fzm_format.h"
4#include <array>
5#include <cuda_runtime.h>
6#include <cstdint>
7#include <stdexcept>
8#include <string>
9#include <unordered_map>
10#include <vector>
11
12namespace fz {
13
14// Forward declaration — avoids requiring mempool.h in every stage header
15class MemoryPool;
16
28class Stage {
29public:
30 virtual ~Stage() = default;
31
33 virtual void execute(
34 cudaStream_t stream,
35 MemoryPool* pool,
36 const std::vector<void*>& inputs,
37 const std::vector<void*>& outputs,
38 const std::vector<size_t>& sizes
39 ) = 0;
40
42 virtual std::string getName() const = 0;
43
44 virtual size_t getNumInputs() const = 0;
45 virtual size_t getNumOutputs() const = 0;
46
53 virtual size_t getRequiredInputAlignment() const { return 1; }
54
59 virtual std::vector<std::string> getOutputNames() const {
60 return {"output"};
61 }
62
64 int getOutputIndex(const std::string& name) const {
65 auto names = getOutputNames();
66 for (size_t i = 0; i < names.size(); i++) {
67 if (names[i] == name) return static_cast<int>(i);
68 }
69 return -1;
70 }
71
77 virtual std::vector<size_t> estimateOutputSizes(
78 const std::vector<size_t>& input_sizes
79 ) const = 0;
80
82 virtual std::unordered_map<std::string, size_t> getActualOutputSizesByName() const = 0;
83
90 virtual size_t getActualOutputSize(int index) const {
91 auto names = getOutputNames();
92 if (index < 0 || index >= static_cast<int>(names.size())) return 0;
94 auto it = m.find(names[index]);
95 return (it != m.end()) ? it->second : 0;
96 }
97
102 virtual void setInverse(bool inverse) { (void)inverse; }
103 virtual bool isInverse() const { return false; }
104
106 virtual uint16_t getStageTypeId() const = 0;
107
109 virtual uint8_t getOutputDataType(size_t output_index) const = 0;
110
120 virtual uint8_t getInputDataType(size_t /*input_index*/) const {
121 return static_cast<uint8_t>(DataType::UNKNOWN);
122 }
123
128 virtual size_t serializeHeader(size_t output_index, uint8_t* header_buffer, size_t max_size) const {
129 (void)output_index; (void)header_buffer; (void)max_size;
130 return 0;
131 }
132
134 virtual void deserializeHeader(const uint8_t* header_buffer, size_t size) {
135 (void)header_buffer; (void)size;
136 }
137
144 virtual void saveState() {}
145 virtual void restoreState() {}
146
152 virtual void setDims(const std::array<size_t, 3>& dims) { (void)dims; }
153
160 virtual void postStreamSync(cudaStream_t stream) { (void)stream; }
161
163 virtual size_t getMaxHeaderSize(size_t output_index) const {
164 (void)output_index;
165 return 0;
166 }
167
180 virtual bool isGraphCompatible() const { return true; }
181
190 virtual size_t estimateScratchBytes(const std::vector<size_t>& input_sizes) const {
191 (void)input_sizes;
192 return 0;
193 }
194};
195
196} // namespace fz
Definition mempool.h:62
Definition stage.h:28
virtual uint8_t getInputDataType(size_t) const
Definition stage.h:120
virtual std::string getName() const =0
virtual std::vector< std::string > getOutputNames() const
Definition stage.h:59
virtual void saveState()
Definition stage.h:144
virtual bool isGraphCompatible() const
Definition stage.h:180
virtual void postStreamSync(cudaStream_t stream)
Definition stage.h:160
virtual size_t getActualOutputSize(int index) const
Definition stage.h:90
virtual std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const =0
virtual void setInverse(bool inverse)
Definition stage.h:102
virtual uint16_t getStageTypeId() const =0
virtual size_t serializeHeader(size_t output_index, uint8_t *header_buffer, size_t max_size) const
Definition stage.h:128
virtual uint8_t getOutputDataType(size_t output_index) const =0
virtual void deserializeHeader(const uint8_t *header_buffer, size_t size)
Definition stage.h:134
virtual size_t getRequiredInputAlignment() const
Definition stage.h:53
virtual std::unordered_map< std::string, size_t > getActualOutputSizesByName() const =0
virtual void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes)=0
int getOutputIndex(const std::string &name) const
Definition stage.h:64
virtual size_t estimateScratchBytes(const std::vector< size_t > &input_sizes) const
Definition stage.h:190
virtual size_t getMaxHeaderSize(size_t output_index) const
Definition stage.h:163
virtual void setDims(const std::array< size_t, 3 > &dims)
Definition stage.h:152
FZM binary file format definitions — structs, enums, and helpers.
Definition fzm_format.h:25
@ UNKNOWN
Byte-transparent stages: skip type checking at finalize()