9#include <cuda_runtime.h>
13#include <unordered_map>
32 virtual ~Stage() =
default;
38 const std::vector<void*>& inputs,
39 const std::vector<void*>& outputs,
40 const std::vector<size_t>& sizes
46 virtual size_t getNumInputs()
const = 0;
47 virtual size_t getNumOutputs()
const = 0;
68 for (
size_t i = 0; i < names.size(); i++) {
69 if (names[i] == name)
return static_cast<int>(i);
80 const std::vector<size_t>& input_sizes
94 if (index < 0 || index >=
static_cast<int>(names.size()))
return 0;
96 auto it = m.find(names[index]);
97 return (it != m.end()) ? it->second : 0;
105 virtual bool isInverse()
const {
return false; }
123 return static_cast<uint8_t
>(DataType::UNKNOWN);
130 virtual size_t serializeHeader(
size_t output_index, uint8_t* header_buffer,
size_t max_size)
const {
131 (void)output_index; (void)header_buffer; (void)max_size;
137 (void)header_buffer; (void)size;
147 virtual void restoreState() {}
154 virtual void setDims(
const std::array<size_t, 3>& dims) { (void)dims; }
virtual uint8_t getInputDataType(size_t) const
Definition stage.h:122
virtual std::string getName() const =0
virtual std::vector< std::string > getOutputNames() const
Definition stage.h:61
virtual void saveState()
Definition stage.h:146
virtual bool isGraphCompatible() const
Definition stage.h:182
virtual void postStreamSync(cudaStream_t stream)
Definition stage.h:162
virtual size_t getActualOutputSize(int index) const
Definition stage.h:92
virtual std::vector< size_t > estimateOutputSizes(const std::vector< size_t > &input_sizes) const =0
virtual void setInverse(bool inverse)
Definition stage.h:104
virtual uint16_t getStageTypeId() const =0
virtual size_t serializeHeader(size_t output_index, uint8_t *header_buffer, size_t max_size) const
Definition stage.h:130
virtual uint8_t getOutputDataType(size_t output_index) const =0
virtual void deserializeHeader(const uint8_t *header_buffer, size_t size)
Definition stage.h:136
virtual size_t getRequiredInputAlignment() const
Definition stage.h:55
virtual std::unordered_map< std::string, size_t > getActualOutputSizesByName() const =0
virtual void execute(cudaStream_t stream, MemoryPool *pool, const std::vector< void * > &inputs, const std::vector< void * > &outputs, const std::vector< size_t > &sizes)=0
int getOutputIndex(const std::string &name) const
Definition stage.h:66
virtual size_t estimateScratchBytes(const std::vector< size_t > &input_sizes) const
Definition stage.h:192
virtual size_t getMaxHeaderSize(size_t output_index) const
Definition stage.h:165
virtual void setDims(const std::array< size_t, 3 > &dims)
Definition stage.h:154