Skip to content

Commit 78dc399

Browse files
committed
Support for multiple images as input
1 parent 96a0f8a commit 78dc399

File tree

6 files changed

+125
-4
lines changed

6 files changed

+125
-4
lines changed

src/models/extra_inputs.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,14 @@ void ExtraInputs::Add(const std::vector<ExtraInput>& extra_inputs, const std::ve
6060
registrar_.Add();
6161
}
6262

63+
bool ExtraInputs::Replace(const std::string& name, const std::shared_ptr<Tensor>& tensor) {
64+
for (size_t i = 0; i < state_.input_names_.size(); ++i) {
65+
if (name == std::string(state_.input_names_[i])) {
66+
state_.inputs_[i] = tensor->ort_tensor_.get();
67+
return true;
68+
}
69+
}
70+
return false;
71+
}
72+
6373
} // namespace Generators

src/models/extra_inputs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct PresetExtraInputs {
1717
struct ExtraInputs {
1818
ExtraInputs(State& state);
1919
void Add(const std::vector<ExtraInput>& extra_inputs, const std::vector<std::string>& required_input_names = {});
20-
20+
bool Replace(const std::string& name, const std::shared_ptr<Tensor>& tensor);
2121
private:
2222
State& state_;
2323
const Model& model_{state_.model_};

src/models/multi_decoder_pipeline_modal.cpp

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,94 @@ void VisionPipelineState::SetExtraInputs(const std::vector<ExtraInput>& extra_in
118118
image_features_ = std::make_unique<MultiModalFeatures>(*this, MultiModalFeatures::Mode::Output, // model output
119119
model_.config_->model.vision.outputs.image_features,
120120
num_images_, num_image_tokens_);
121-
image_features_->Add();
121+
for (const auto& ei : extra_inputs) {
122+
if (ei.name == "pixel_values") {
123+
pixel_values_tensor_ = ei.tensor;
124+
break;
125+
}
126+
}
122127
extra_inputs_.Add(extra_inputs, model_.vision_session_->GetInputNames());
123128
}
124129

125-
DeviceSpan<float> VisionPipelineState::Run(int current_length, DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> next_indices) {
126-
State::Run(*model_.vision_session_);
130+
// Create a [1, C, H, W] tensor and copy the i-th image from a [N, C, H, W] tensor.
131+
static std::shared_ptr<Tensor> MakeSingleImagePixelValues(const std::shared_ptr<Tensor>& full,
132+
int64_t index,
133+
DeviceInterface* device) {
134+
if (!full || !full->GetOrtTensor()) {
135+
throw std::runtime_error("MakeSingleImagePixelValues: source tensor is null");
136+
}
137+
const auto full_shape = full->GetShape(); // expected [N, C, H, W]
138+
if (full_shape.size() != 4) {
139+
throw std::runtime_error("MakeSingleImagePixelValues: expected [N, C, H, W] shape");
140+
}
141+
const int64_t N = full_shape[0];
142+
const int64_t C = full_shape[1];
143+
const int64_t H = full_shape[2];
144+
const int64_t W = full_shape[3];
145+
if (index < 0 || index >= N) {
146+
throw std::runtime_error("MakeSingleImagePixelValues: index out of range");
147+
}
148+
149+
// Destination shape [1, C, H, W]
150+
std::vector<int64_t> dst_shape = {1, C, H, W};
151+
152+
auto dst = std::make_shared<Tensor>(device, full->GetType());
153+
dst->CreateTensor(dst_shape, /*make_static=*/false);
154+
155+
// Compute byte ranges and copy
156+
const size_t elem_size = Ort::SizeOf(full->GetType());
157+
const size_t per_image_bytes = static_cast<size_t>(C) * static_cast<size_t>(H) * static_cast<size_t>(W) * elem_size;
158+
const size_t offset_bytes = static_cast<size_t>(index) * per_image_bytes;
159+
160+
auto src_bytes = full->GetByteSpan();
161+
auto dst_bytes = dst->GetByteSpan();
162+
163+
if (offset_bytes + per_image_bytes > src_bytes.size() || per_image_bytes > dst_bytes.size()) {
164+
throw std::runtime_error("MakeSingleImagePixelValues: copy bounds exceeded");
165+
}
166+
167+
dst_bytes.CopyFrom(src_bytes.subspan(offset_bytes, per_image_bytes));
168+
return dst;
169+
}
170+
171+
DeviceSpan<float> VisionPipelineState::Run(int current_length,
172+
DeviceSpan<int32_t>& next_tokens,
173+
DeviceSpan<int32_t> next_indices) {
174+
if (!model_.vision_session_ || !image_features_ || !pixel_values_tensor_) {
175+
return {};
176+
}
177+
178+
const int64_t total_images = num_images_;
179+
const size_t bytes_per_image = image_features_->BytesPerImage();
180+
181+
// Flat destination bytes of the global features buffer
182+
auto dst_all_bytes = image_features_->AsByteSpan();
183+
184+
// Bind a single-image output features object once and reuse across runs
185+
std::unique_ptr<MultiModalFeatures> run_features =
186+
std::make_unique<MultiModalFeatures>(*this,
187+
MultiModalFeatures::Mode::Output,
188+
model_.config_->model.vision.outputs.image_features,
189+
/*batch_size=*/1,
190+
/*num_feature_tokens=*/num_image_tokens_);
191+
run_features->Add();
192+
193+
for (int64_t i = 0; i < total_images; ++i) {
194+
auto pixel_values_i = MakeSingleImagePixelValues(pixel_values_tensor_, i, model_.p_device_);
195+
extra_inputs_.Replace("pixel_values", pixel_values_i);
196+
197+
State::Run(*model_.vision_session_);
198+
199+
auto src_bytes = run_features->AsByteSpan();
200+
201+
const size_t dst_offset = static_cast<size_t>(i) * bytes_per_image;
202+
if (dst_offset + bytes_per_image <= dst_all_bytes.size() && bytes_per_image <= src_bytes.size()) {
203+
dst_all_bytes.subspan(dst_offset, bytes_per_image).CopyFrom(src_bytes.subspan(0, bytes_per_image));
204+
} else {
205+
throw std::runtime_error("VisionPipelineState::Run: features copy out of bounds");
206+
}
207+
}
208+
127209
return {};
128210
}
129211

@@ -648,6 +730,7 @@ DeviceSpan<float> MultiModalDecoderPipelineState::Run(int current_length, Device
648730
if (num_audio_tokens_ > 0 && speech_state_) {
649731
speech_state_->Run(current_length, next_tokens, next_indices);
650732
}
733+
vision_state_->image_features_->Add();
651734
if (vision_state_) embedding_state_->image_features_->ReuseFeaturesBuffer(*vision_state_->image_features_);
652735
if (speech_state_) embedding_state_->audio_features_->ReuseFeaturesBuffer(*speech_state_->audio_features_);
653736
embedding_state_->inputs_embeds_.ReuseEmbeddingsBuffer(decoder_pipeline_state_->full_inputs_embeds_);

src/models/multi_decoder_pipeline_modal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct VisionPipelineState : State {
4343
const MultiModalPipelineLanguageModel& model_;
4444
int64_t num_image_tokens_;
4545
int64_t num_images_{};
46+
std::shared_ptr<Tensor> pixel_values_tensor_;
4647
ExtraInputs extra_inputs_{*this}; // Model inputs
4748
std::unique_ptr<MultiModalFeatures> image_features_;
4849
};

src/models/multi_modal_features.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,28 @@ void MultiModalFeatures::ReuseFeaturesBuffer(MultiModalFeatures& other) {
7777
state_.inputs_[index_] = other.state_.outputs_[other.index_];
7878
}
7979

80+
DeviceSpan<uint8_t> MultiModalFeatures::AsByteSpan() {
81+
if (!features_) {
82+
throw std::runtime_error("MultiModalFeatures: features_ not allocated");
83+
}
84+
return ByteWrapTensor(*model_.p_device_, *features_);
85+
}
86+
87+
size_t MultiModalFeatures::BytesPerImage() const {
88+
// Shape can be [B, T, H] or [T, H]. Compute T*H and multiply by element size.
89+
if (shape_.empty()) return 0;
90+
int64_t tokens = 0;
91+
int64_t hidden = 0;
92+
if (shape_.size() == 3) {
93+
tokens = shape_[1];
94+
hidden = shape_[2];
95+
} else if (shape_.size() == 2) {
96+
tokens = shape_[0];
97+
hidden = shape_[1];
98+
} else {
99+
throw std::runtime_error("MultiModalFeatures: unexpected features shape rank");
100+
}
101+
return static_cast<size_t>(tokens) * static_cast<size_t>(hidden) * Ort::SizeOf(type_);
102+
}
103+
80104
} // namespace Generators

src/models/multi_modal_features.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ struct MultiModalFeatures {
1919
void Update(bool is_prompt);
2020
void ReuseFeaturesBuffer(MultiModalFeatures& other);
2121

22+
DeviceSpan<uint8_t> AsByteSpan();
23+
size_t BytesPerImage() const;
24+
2225
auto& GetShape() const { return shape_; }
2326
size_t GetIndex() { return index_; }
2427
OrtValue* Get() { return features_.get(); }

0 commit comments

Comments
 (0)