9#include <cuda_runtime.h>
13#include <unordered_map>
32 virtual ~Stage() =
default;
51 const std::vector<void*>& inputs,
52 const std::vector<void*>& outputs,
53 const std::vector<size_t>& sizes
59 virtual size_t getNumInputs()
const = 0;
60 virtual size_t getNumOutputs()
const = 0;
81 for (
size_t i = 0; i < names.size(); i++) {
82 if (names[i] == name)
return static_cast<int>(i);
93 const std::vector<size_t>& input_sizes
107 if (index < 0 || index >=
static_cast<int>(names.size()))
return 0;
109 auto it = m.find(names[index]);
110 return (it != m.end()) ? it->second : 0;
118 virtual bool isInverse()
const {
return false; }
143 virtual size_t serializeHeader(
size_t output_index, uint8_t* header_buffer,
size_t max_size)
const {
144 (void)output_index; (void)header_buffer; (void)max_size;
150 (void)header_buffer; (void)size;
160 virtual void restoreState() {}
167 virtual void setDims(
const std::array<size_t, 3>& dims) { (void)dims; }
virtual size_t estimateDeviceFootprintBytes(size_t) const
Definition stage.h:193
virtual uint8_t getInputDataType(size_t) const
Definition stage.h:135
virtual std::string getName() const =0
virtual std::vector< std::string > getOutputNames() const
Definition stage.h:74
virtual size_t estimatePinnedFootprintBytes(size_t) const
Definition stage.h:200
virtual void saveState()
Definition stage.h:159
virtual bool isGraphCompatible() const
Definition stage.h:228
virtual void postStreamSync(cudaStream_t stream)
Definition stage.h:208
virtual size_t getActualOutputSize(int index) const
Definition stage.h:105
virtual void onFinalize(size_t, MemoryPool *)
Definition stage.h:186
virtual std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const =0
virtual void setInverse(bool inverse)
Definition stage.h:117
virtual uint16_t getStageTypeId() const =0
virtual size_t serializeHeader(size_t output_index, uint8_t *header_buffer, size_t max_size) const
Definition stage.h:143
virtual uint8_t getOutputDataType(size_t output_index) const =0
virtual void deserializeHeader(const uint8_t *header_buffer, size_t size)
Definition stage.h:149
virtual size_t getRequiredInputAlignment() const
Definition stage.h:68
virtual std::unordered_map< std::string, size_t > getActualOutputSizesByName() const =0
virtual void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes)=0
int getOutputIndex(const std::string &name) const
Definition stage.h:79
virtual size_t estimateScratchBytes(const std::vector< size_t > &input_sizes) const
Definition stage.h:238
virtual size_t getMaxHeaderSize(size_t output_index) const
Definition stage.h:211
virtual void setDims(const std::array< size_t, 3 > &dims)
Definition stage.h:167
Definition fzm_format.h:25
@ UNKNOWN
Byte-transparent stages: skip type checking at finalize()