Skip to content

Commit ed7d837

Browse files
committed
Refactor model option handling and validation logic for improved performance and clarity
1 parent 05a4472 commit ed7d837

File tree

17 files changed

+439
-244
lines changed

17 files changed

+439
-244
lines changed

src/agent/runloop/model_picker/dynamic_models/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ impl DynamicModelRegistry {
8686
registry
8787
}
8888

89-
pub(super) fn indexes_for(&self, provider: Provider) -> Vec<usize> {
89+
pub(super) fn indexes_for(&self, provider: Provider) -> &[usize] {
9090
self.provider_models
9191
.get(&provider)
92-
.cloned()
93-
.unwrap_or_default()
92+
.map(Vec::as_slice)
93+
.unwrap_or(&[])
9494
}
9595

9696
pub(super) fn detail(&self, index: usize) -> Option<&SelectionDetail> {

src/agent/runloop/model_picker/interaction.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use vtcode_core::config::types::ReasoningEffortLevel;
66
use vtcode_tui::ui::interactive_list::{SelectionEntry, run_interactive_selection};
77

88
use super::dynamic_models::DynamicModelRegistry;
9-
use super::options::{ModelOption, picker_provider_order};
9+
use super::options::{ModelOption, option_indexes_for_provider, picker_provider_order};
1010
use super::rendering::{
1111
CUSTOM_PROVIDER_SUBTITLE, CUSTOM_PROVIDER_TITLE, KEEP_CURRENT_DESCRIPTION,
1212
dynamic_model_subtitle, static_model_subtitle,
@@ -49,11 +49,10 @@ pub(super) fn select_model_with_ratatui_list(
4949

5050
let mut choices = Vec::new();
5151
for provider in picker_provider_order() {
52-
let provider_models: Vec<&ModelOption> = options
53-
.iter()
54-
.filter(|option| option.provider == provider)
55-
.collect();
56-
for option in &provider_models {
52+
for option_index in option_indexes_for_provider(provider) {
53+
let Some(option) = options.get(*option_index) else {
54+
continue;
55+
};
5756
let description = format!(
5857
"{} • {}",
5958
provider.label(),
@@ -87,7 +86,7 @@ pub(super) fn select_model_with_ratatui_list(
8786
}
8887
} else {
8988
for entry_index in dynamic_indexes {
90-
if let Some(detail) = dynamic_models.detail(entry_index) {
89+
if let Some(detail) = dynamic_models.detail(*entry_index) {
9190
let description = format!(
9291
"{} • {}",
9392
provider.label(),

src/agent/runloop/model_picker/mod.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use interaction::{
1414
ModelSelectionListOutcome, select_model_with_ratatui_list, select_reasoning_with_ratatui,
1515
select_service_tier_with_ratatui,
1616
};
17-
use options::{MODEL_OPTIONS, ModelOption};
17+
use options::{MODEL_OPTIONS, ModelOption, find_option_index};
1818
use rendering::{
1919
CLOSE_THEME_MESSAGE, prompt_api_key_plain, prompt_custom_model_entry, prompt_reasoning_plain,
2020
prompt_service_tier_plain, render_reasoning_inline, render_service_tier_inline,
@@ -425,20 +425,17 @@ impl ModelPickerState {
425425
return None;
426426
}
427427

428-
if let Some((index, _)) = self.options.iter().enumerate().find(|(_, option)| {
429-
option.provider.to_string() == provider_key && option.id.eq_ignore_ascii_case(model_key)
430-
}) {
431-
return Some(InlineListSelection::Model(index));
432-
}
433-
434428
let Ok(provider) = Provider::from_str(provider_key.as_str()) else {
435429
return None;
436430
};
431+
if let Some(index) = find_option_index(provider, model_key) {
432+
return Some(InlineListSelection::Model(index));
433+
}
437434
for entry_index in self.dynamic_models.indexes_for(provider) {
438-
if let Some(detail) = self.dynamic_models.detail(entry_index)
435+
if let Some(detail) = self.dynamic_models.detail(*entry_index)
439436
&& detail.model_id.eq_ignore_ascii_case(model_key)
440437
{
441-
return Some(InlineListSelection::DynamicModel(entry_index));
438+
return Some(InlineListSelection::DynamicModel(*entry_index));
442439
}
443440
}
444441

src/agent/runloop/model_picker/options.rs

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use hashbrown::HashMap;
12
use once_cell::sync::Lazy;
23

34
use vtcode_core::config::models::{ModelId, Provider};
@@ -13,24 +14,63 @@ pub(super) struct ModelOption {
1314
pub(super) reasoning_alternative: Option<ModelId>,
1415
}
1516

17+
static EMPTY_OPTION_INDEXES: [usize; 0] = [];
18+
1619
pub(super) static MODEL_OPTIONS: Lazy<Vec<ModelOption>> = Lazy::new(|| {
17-
let mut options = Vec::new();
18-
for provider in Provider::all_providers() {
19-
for model in ModelId::models_for_provider(provider) {
20-
options.push(ModelOption {
21-
model,
22-
provider,
23-
id: model.as_str(),
24-
display: model.display_name(),
25-
description: model.description(),
26-
supports_reasoning: model.supports_reasoning_effort(),
27-
reasoning_alternative: model.non_reasoning_variant(),
28-
});
29-
}
20+
let models = ModelId::all_models();
21+
let mut options = Vec::with_capacity(models.len());
22+
for model in models {
23+
let provider = model.provider();
24+
options.push(ModelOption {
25+
model,
26+
provider,
27+
id: model.as_str(),
28+
display: model.display_name(),
29+
description: model.description(),
30+
supports_reasoning: model.supports_reasoning_effort(),
31+
reasoning_alternative: model.non_reasoning_variant(),
32+
});
3033
}
3134
options
3235
});
3336

37+
static MODEL_OPTION_INDEXES_BY_PROVIDER: Lazy<HashMap<Provider, Box<[usize]>>> = Lazy::new(|| {
38+
let mut indexes: HashMap<Provider, Vec<usize>> = HashMap::new();
39+
for (index, option) in MODEL_OPTIONS.iter().enumerate() {
40+
indexes.entry(option.provider).or_default().push(index);
41+
}
42+
43+
indexes
44+
.into_iter()
45+
.map(|(provider, provider_indexes)| (provider, provider_indexes.into_boxed_slice()))
46+
.collect()
47+
});
48+
49+
static MODEL_OPTION_INDEX_BY_PROVIDER_MODEL: Lazy<HashMap<Provider, HashMap<String, usize>>> =
50+
Lazy::new(|| {
51+
let mut index = HashMap::new();
52+
for (option_index, option) in MODEL_OPTIONS.iter().enumerate() {
53+
index
54+
.entry(option.provider)
55+
.or_insert_with(HashMap::new)
56+
.insert(option.id.to_ascii_lowercase(), option_index);
57+
}
58+
index
59+
});
60+
61+
pub(super) fn option_indexes_for_provider(provider: Provider) -> &'static [usize] {
62+
MODEL_OPTION_INDEXES_BY_PROVIDER
63+
.get(&provider)
64+
.map(Box::as_ref)
65+
.unwrap_or(&EMPTY_OPTION_INDEXES)
66+
}
67+
68+
pub(super) fn find_option_index(provider: Provider, model_id: &str) -> Option<usize> {
69+
MODEL_OPTION_INDEX_BY_PROVIDER_MODEL
70+
.get(&provider)
71+
.and_then(|provider_models| provider_models.get(&model_id.to_ascii_lowercase()).copied())
72+
}
73+
3474
pub(super) fn picker_provider_order() -> Vec<Provider> {
3575
let mut providers: Vec<Provider> = Provider::all_providers()
3676
.into_iter()

src/agent/runloop/model_picker/rendering.rs

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use hashbrown::HashMap;
2-
31
use anyhow::Result;
42

53
use vtcode_core::config::constants::ui;
@@ -10,7 +8,7 @@ use vtcode_core::ui::{InlineListItem, InlineListSearchConfig, InlineListSelectio
108
use vtcode_core::utils::ansi::{AnsiRenderer, MessageStyle};
119

1210
use super::dynamic_models::DynamicModelRegistry;
13-
use super::options::{ModelOption, picker_provider_order};
11+
use super::options::{ModelOption, option_indexes_for_provider, picker_provider_order};
1412

1513
mod prompts;
1614
pub(super) use prompts::{
@@ -236,20 +234,23 @@ pub(super) fn render_step_one_inline(
236234
) -> Result<()> {
237235
let mut items = Vec::new();
238236
for provider in picker_provider_order() {
239-
let provider_models: Vec<(usize, &ModelOption)> = options
240-
.iter()
241-
.enumerate()
242-
.filter(|(_, candidate)| candidate.provider == provider)
243-
.collect();
237+
let provider_model_indexes = option_indexes_for_provider(provider);
244238
let dynamic_indexes = dynamic_models.indexes_for(provider);
245239
let has_error = dynamic_models.error_for(provider).is_some();
246240
let has_warning = dynamic_models.warning_for(provider).is_some();
247241

248-
if provider_models.is_empty() && dynamic_indexes.is_empty() && !has_error && !has_warning {
242+
if provider_model_indexes.is_empty()
243+
&& dynamic_indexes.is_empty()
244+
&& !has_error
245+
&& !has_warning
246+
{
249247
continue;
250248
}
251249

252-
for (idx, option) in &provider_models {
250+
for idx in provider_model_indexes {
251+
let Some(option) = options.get(*idx) else {
252+
continue;
253+
};
253254
items.push(InlineListItem {
254255
title: option.display.to_string(),
255256
subtitle: Some(static_model_subtitle(
@@ -271,7 +272,7 @@ pub(super) fn render_step_one_inline(
271272
}
272273

273274
if provider.is_dynamic() {
274-
for entry_index in &dynamic_indexes {
275+
for entry_index in dynamic_indexes {
275276
if let Some(detail) = dynamic_models.detail(*entry_index) {
276277
let extra_terms = {
277278
let mut terms = Vec::new();
@@ -329,7 +330,7 @@ pub(super) fn render_step_one_inline(
329330
search_value: Some(format!("{} setup", provider.label().to_ascii_lowercase())),
330331
});
331332
}
332-
} else if provider == Provider::HuggingFace && provider_models.is_empty() {
333+
} else if provider == Provider::HuggingFace && provider_model_indexes.is_empty() {
333334
items.push(InlineListItem {
334335
title: "Custom Hugging Face model".to_string(),
335336
subtitle: Some(
@@ -399,28 +400,25 @@ pub(super) fn render_step_one_plain(
399400
"Type 'refresh' to re-query LM Studio and Ollama servers.",
400401
)?;
401402

402-
let mut grouped: HashMap<Provider, Vec<&ModelOption>> = HashMap::new();
403-
for option in options {
404-
grouped.entry(option.provider).or_default().push(option);
405-
}
406-
407403
let mut first_section = true;
408404
for provider in picker_provider_order() {
405+
let provider_model_indexes = option_indexes_for_provider(provider);
409406
if provider.is_local() {
410407
if !first_section {
411408
renderer.line(MessageStyle::Info, &provider_group_divider_line())?;
412409
}
413410
first_section = false;
414411
renderer.line(MessageStyle::Info, &format!("[{}]", provider.label()))?;
415-
if let Some(list) = grouped.get(&provider) {
416-
for option in list {
417-
renderer.line(MessageStyle::Info, &format!(" {}", option.display))?;
418-
renderer.line(
419-
MessageStyle::Info,
420-
&format!(" {}", static_model_subtitle(option, "", "")),
421-
)?;
422-
renderer.line(MessageStyle::Info, &format!(" {}", option.description))?;
423-
}
412+
for option_index in provider_model_indexes {
413+
let Some(option) = options.get(*option_index) else {
414+
continue;
415+
};
416+
renderer.line(MessageStyle::Info, &format!(" {}", option.display))?;
417+
renderer.line(
418+
MessageStyle::Info,
419+
&format!(" {}", static_model_subtitle(option, "", "")),
420+
)?;
421+
renderer.line(MessageStyle::Info, &format!(" {}", option.description))?;
424422
}
425423

426424
if let Some(warning) = dynamic_models.warning_for(provider) {
@@ -443,7 +441,7 @@ pub(super) fn render_step_one_plain(
443441
}
444442
} else {
445443
for entry_index in dynamic_indexes {
446-
if let Some(detail) = dynamic_models.detail(entry_index) {
444+
if let Some(detail) = dynamic_models.detail(*entry_index) {
447445
renderer
448446
.line(MessageStyle::Info, &format!(" {}", detail.model_display))?;
449447
renderer.line(
@@ -476,26 +474,30 @@ pub(super) fn render_step_one_plain(
476474
MessageStyle::Info,
477475
" Docs: https://huggingface.co/docs/inference-providers",
478476
)?;
479-
if let Some(list) = grouped.get(&provider) {
480-
for option in list {
481-
renderer.line(MessageStyle::Info, &format!(" {}", option.display))?;
482-
renderer.line(
483-
MessageStyle::Info,
484-
&format!(" {}", static_model_subtitle(option, "", "")),
485-
)?;
486-
renderer.line(MessageStyle::Info, &format!(" {}", option.description))?;
487-
}
477+
for option_index in provider_model_indexes {
478+
let Some(option) = options.get(*option_index) else {
479+
continue;
480+
};
481+
renderer.line(MessageStyle::Info, &format!(" {}", option.display))?;
482+
renderer.line(
483+
MessageStyle::Info,
484+
&format!(" {}", static_model_subtitle(option, "", "")),
485+
)?;
486+
renderer.line(MessageStyle::Info, &format!(" {}", option.description))?;
488487
}
489488
} else {
490-
let Some(list) = grouped.get(&provider) else {
489+
if provider_model_indexes.is_empty() {
491490
continue;
492-
};
491+
}
493492
if !first_section {
494493
renderer.line(MessageStyle::Info, &provider_group_divider_line())?;
495494
}
496495
first_section = false;
497496
renderer.line(MessageStyle::Info, &format!("[{}]", provider.label()))?;
498-
for option in list {
497+
for option_index in provider_model_indexes {
498+
let Some(option) = options.get(*option_index) else {
499+
continue;
500+
};
499501
renderer.line(MessageStyle::Info, &format!(" {}", option.display))?;
500502
renderer.line(
501503
MessageStyle::Info,

src/agent/runloop/model_picker/selection.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use vtcode_core::config::constants::reasoning;
66
use vtcode_core::config::models::{ModelId, Provider};
77
use vtcode_core::config::types::ReasoningEffortLevel;
88

9-
use super::options::ModelOption;
9+
use super::options::{ModelOption, find_option_index};
1010

1111
#[derive(Clone)]
1212
pub(super) struct SelectionDetail {
@@ -91,11 +91,9 @@ pub(super) fn parse_model_selection(
9191
let provider_lower = provider_token.to_ascii_lowercase();
9292
let provider_enum = Provider::from_str(&provider_lower).ok();
9393

94-
if let Some(option) = options
95-
.iter()
96-
.find(|candidate| candidate.id.eq_ignore_ascii_case(model_token.trim()))
97-
&& let Some(provider) = provider_enum
98-
&& provider == option.provider
94+
if let Some(provider) = provider_enum
95+
&& let Some(option_index) = find_option_index(provider, model_token.trim())
96+
&& let Some(option) = options.get(option_index)
9997
{
10098
return Ok(selection_from_option(option));
10199
}

src/agent/runloop/model_picker/tests.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use tempfile::tempdir;
55
use vtcode_config::OpenAIServiceTier;
66
use vtcode_core::config::models::ModelId;
77

8+
use self::options::{find_option_index, option_indexes_for_provider};
9+
810
fn has_model(options: &[ModelOption], model: ModelId) -> bool {
911
let id = model.as_str();
1012
let provider = model.provider();
@@ -185,6 +187,20 @@ fn preferred_model_selection_matches_current_static_model() {
185187
assert_eq!(option.id, model_id);
186188
}
187189

190+
#[test]
191+
fn static_picker_indexes_resolve_provider_models() {
192+
let openai_indexes = option_indexes_for_provider(Provider::OpenAI);
193+
assert!(!openai_indexes.is_empty());
194+
195+
let gpt54_index = find_option_index(Provider::OpenAI, "GPT-5.4")
196+
.expect("gpt-5.4 should be indexed case-insensitively");
197+
let option = MODEL_OPTIONS
198+
.get(gpt54_index)
199+
.expect("indexed option should exist");
200+
assert_eq!(option.id, "gpt-5.4");
201+
assert_eq!(option.provider, Provider::OpenAI);
202+
}
203+
188204
#[test]
189205
fn preferred_model_selection_returns_none_for_unknown_model() {
190206
let picker = base_picker_state("anthropic", "does-not-exist");

0 commit comments

Comments
 (0)