29#include <cuda_runtime.h>
34#include <unordered_map>
40enum class ADMDtype : uint8_t { U16 = 0, U32 = 1 };
42class ADMStage :
public Stage {
45 ~ADMStage()
override =
default;
54 void setDtype(
ADMDtype dt) { dtype_ = dt; }
55 ADMDtype getDtype()
const {
return dtype_; }
58 void setInverse(
bool inv)
override { is_inverse_ = inv; }
59 bool isInverse()
const override {
return is_inverse_; }
62 bool isGraphCompatible()
const override {
return false; }
65 void onFinalize(
size_t estimated_inlen, MemoryPool* pool)
override;
66 size_t estimateDeviceFootprintBytes(
size_t inlen)
const override;
67 size_t estimateScratchBytes(
const std::vector<size_t>& input_sizes)
const override;
73 const std::vector<void*>& inputs,
74 const std::vector<void*>& outputs,
75 const std::vector<size_t>& sizes
79 std::string getName()
const override {
return "ADM"; }
80 size_t getNumInputs()
const override {
return 1; }
81 size_t getNumOutputs()
const override {
return 1; }
83 std::vector<size_t> estimateOutputSizes(
84 const std::vector<size_t>& input_sizes
87 std::unordered_map<std::string, size_t>
88 getActualOutputSizesByName()
const override {
89 return {{
"output", actual_output_size_}};
92 size_t getActualOutputSize(
int index)
const override {
93 return (index == 0) ? actual_output_size_ : 0;
97 uint16_t getStageTypeId()
const override {
102 uint8_t getInputDataType(
size_t )
const override {
103 return dtype_ == ADMDtype::U16
104 ?
static_cast<uint8_t
>(DataType::UINT16)
105 : static_cast<uint8_t>(
DataType::UINT32);
109 uint8_t getOutputDataType(
size_t )
const override {
114 size_t serializeHeader(
115 size_t , uint8_t* buf,
size_t max_size
117 if (max_size < 12)
return 0;
118 buf[0] =
static_cast<uint8_t
>(dtype_);
119 buf[1] = buf[2] = buf[3] = 0;
120 std::memcpy(buf + 4, &num_elements_,
sizeof(uint64_t));
124 void deserializeHeader(
const uint8_t* buf,
size_t size)
override {
125 if (size >= 1) dtype_ =
static_cast<ADMDtype>(buf[0]);
126 if (size >= 12) std::memcpy(&num_elements_, buf + 4,
sizeof(uint64_t));
129 size_t getMaxHeaderSize(
size_t )
const override {
return 12; }
131 void saveState()
override {
132 saved_dtype_ = dtype_;
133 saved_num_elements_ = num_elements_;
134 saved_output_size_ = actual_output_size_;
137 void restoreState()
override {
138 dtype_ = saved_dtype_;
139 num_elements_ = saved_num_elements_;
140 actual_output_size_ = saved_output_size_;
145 bool is_inverse_ =
false;
146 uint64_t num_elements_ = 0;
147 size_t actual_output_size_ = 0;
150 size_t cap_elements_ = 0;
153 int* d_signal_length_ =
nullptr;
154 int* d_output_lengths_ =
nullptr;
155 void* d_centers_ =
nullptr;
156 uint32_t* d_block_flags_ =
nullptr;
157 uint8_t* d_codes_ =
nullptr;
158 uint8_t* d_concat_signals_ =
nullptr;
159 uint8_t* d_bit_signals_ =
nullptr;
160 int* d_loc_offset_ =
nullptr;
161 int* d_prefix_state_ =
nullptr;
162 unsigned int* d_overflow_flag_ =
nullptr;
165 ADMDtype saved_dtype_ = ADMDtype::U16;
166 uint64_t saved_num_elements_ = 0;
167 size_t saved_output_size_ = 0;
170 void initScratch(
size_t num_elements, MemoryPool* pool);
173 int maxSignalBytes()
const;
175 size_t centerElemBytes()
const;
Definition fzm_format.h:25
ADMDtype
Definition adm_stage.h:40
@ ADM
Adaptive Data Mapping transform (MANS)
DataType
Element data type identifiers used in buffer and stage descriptors.
Definition fzm_format.h:104
@ UNKNOWN
Byte-transparent stages: skip type checking at finalize()
Base class interface for all compression stages.