[train][TBench] Cherrypick Terminus integration and use Harbor#637
[train][TBench] Cherrypick Terminus integration and use Harbor#637CharlieFRuan merged 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request successfully integrates Terminus training with Harbor, replacing the older sandboxes implementation. The introduction of TerminalBenchTaskDataset is a key improvement for handling task data. The changes are generally well-structured. My review focuses on improving code correctness, robustness, and maintainability. I've identified a potential infinite loop, a type mismatch when creating GeneratorInput, opportunities to improve logging, and some leftover code that can be cleaned up.
| def __getitem__(self, index: int) -> dict: | ||
| """Get a task path by index as a dictionary with 'prompt', 'env_class', and 'env_extras' keys.""" | ||
| if index >= len(self.task_paths): | ||
| raise IndexError(f"Index {index} out of range for dataset of size {len(self.task_paths)}") | ||
| return { | ||
| "prompt": str(self.task_paths[index]), | ||
| "env_class": None, | ||
| "env_extras": {"data_source": str(self.task_paths[index])}, | ||
| "uid": str(index), | ||
| } | ||
|
|
||
| def __len__(self) -> int: | ||
| """Return the number of tasks in the dataset.""" | ||
| return len(self.task_paths) | ||
|
|
||
| def __iter__(self): | ||
| """Iterate over all task paths as dictionaries.""" | ||
| for index, task_path in enumerate(self.task_paths): | ||
| yield { | ||
| "prompt": str(task_path), | ||
| "env_class": None, | ||
| "env_extras": {"data_source": str(task_path)}, | ||
| "uid": str(index), | ||
| } |
There was a problem hiding this comment.
There's code duplication between __getitem__ and __iter__ for creating the item dictionary. This can be simplified by having __iter__ leverage __getitem__.
Additionally, env_class is set to None, which violates the List[str] type hint for env_classes in GeneratorInput. This can cause downstream errors. It's better to return an empty string "" to conform to the type.
Here's a suggested refactoring that addresses both points:
def __getitem__(self, index: int) -> dict:
"""Get a task path by index as a dictionary with 'prompt', 'env_class', and 'env_extras' keys."""
if index >= len(self.task_paths):
raise IndexError(f"Index {index} out of range for dataset of size {len(self.task_paths)}")
return {
"prompt": str(self.task_paths[index]),
"env_class": "",
"env_extras": {"data_source": str(self.task_paths[index])},
"uid": str(index),
}
def __len__(self) -> int:
"""Return the number of tasks in the dataset."""
return len(self.task_paths)
def __iter__(self):
"""Iterate over all task paths as dictionaries."""
for i in range(len(self)):
yield self[i]| input_batch = GeneratorInput( | ||
| prompts=["" for _ in range(num_prompts)], | ||
| prompts=[item["prompt"] for item in self.train_dataset], | ||
| env_classes=None, | ||
| env_extras=None, | ||
| sampling_params=None, | ||
| ) |
There was a problem hiding this comment.
GeneratorInput is not being populated correctly. env_classes and env_extras are passed as None, which violates the GeneratorInput type definition. This could lead to runtime errors if the generator's implementation changes to use these fields. They should be populated from the dataset.
| input_batch = GeneratorInput( | |
| prompts=["" for _ in range(num_prompts)], | |
| prompts=[item["prompt"] for item in self.train_dataset], | |
| env_classes=None, | |
| env_extras=None, | |
| sampling_params=None, | |
| ) | |
| dataset_items = list(self.train_dataset) | |
| input_batch = GeneratorInput( | |
| prompts=[item["prompt"] for item in dataset_items], | |
| env_classes=[item["env_class"] for item in dataset_items], | |
| env_extras=[item["env_extras"] for item in dataset_items], | |
| sampling_params=None, | |
| ) |
| while True: | ||
| results = await trial.run() | ||
| reward = results.verifier_result.rewards | ||
| chat_history = results.agent_result.all_messages | ||
| if len(chat_history) > 0: | ||
| break | ||
| else: | ||
| print(f"[WARNING] Agent {self.agent_name} did not return a response") | ||
| try: | ||
| results = await trial.run() | ||
| print(f"Results: {results}") | ||
| if not results.verifier_result: | ||
| print(f"[WARNING] Exception info: {results.exception_info}") | ||
| continue | ||
| reward = results.verifier_result.reward | ||
| chat_history = results.agent_result.all_messages | ||
| if len(chat_history) > 0: | ||
| break | ||
| else: | ||
| print(f"[WARNING] Agent {self.agent_name} did not return a response") | ||
| except Exception as e: | ||
| print(f"Error running trial: {e}") | ||
| continue |
There was a problem hiding this comment.
This while True loop could lead to an infinite loop if trial.run() consistently fails or returns empty results. It's safer to add a maximum number of retries to prevent the process from getting stuck. Also, print statements should be replaced with logger calls for better logging practices.
Here's a suggestion that adds a retry limit and uses logger (you'll need to add from loguru import logger at the top of the file):
| while True: | |
| results = await trial.run() | |
| reward = results.verifier_result.rewards | |
| chat_history = results.agent_result.all_messages | |
| if len(chat_history) > 0: | |
| break | |
| else: | |
| print(f"[WARNING] Agent {self.agent_name} did not return a response") | |
| try: | |
| results = await trial.run() | |
| print(f"Results: {results}") | |
| if not results.verifier_result: | |
| print(f"[WARNING] Exception info: {results.exception_info}") | |
| continue | |
| reward = results.verifier_result.reward | |
| chat_history = results.agent_result.all_messages | |
| if len(chat_history) > 0: | |
| break | |
| else: | |
| print(f"[WARNING] Agent {self.agent_name} did not return a response") | |
| except Exception as e: | |
| print(f"Error running trial: {e}") | |
| continue | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| results = await trial.run() | |
| logger.debug(f"Results: {results}") | |
| if not results.verifier_result: | |
| logger.warning(f"Exception info: {results.exception_info}") | |
| continue | |
| reward = results.verifier_result.reward | |
| chat_history = results.agent_result.all_messages | |
| if len(chat_history) > 0: | |
| break | |
| else: | |
| logger.warning(f"Agent {self.agent_name} did not return a response on attempt {attempt + 1}") | |
| except Exception as e: | |
| logger.error(f"Error running trial on attempt {attempt + 1}: {e}") | |
| continue | |
| else: # no-break | |
| raise RuntimeError(f"Failed to get a valid response from trial after {max_retries} attempts.") |
| # If it's a file, treat it as a single task (files can't be valid task directories) | ||
| logger.warning(f"File {source_path} cannot be a valid task directory (missing instruction.md)") |
There was a problem hiding this comment.
The comment here is misleading. It states that a file is treated "as a single task", but the code only logs a warning and then skips the file. The comment should be updated to reflect the actual behavior, which is to ignore files.
| # If it's a file, treat it as a single task (files can't be valid task directories) | |
| logger.warning(f"File {source_path} cannot be a valid task directory (missing instruction.md)") | |
| # Files are not valid task directories, so log a warning and skip. | |
| logger.warning(f"File {source_path} is not a directory and cannot be a valid task, skipping.") |
| def collate_fn(self, item_list): | ||
| """Collate function for batching task dictionaries.""" | ||
| return item_list |
| NUM_GPUS=1 | ||
| LOGGER="console" # change to "console" to print to stdout | ||
| TBENCH_CONFIG_DIR="examples/terminal_bench" | ||
| SANDBOXES_DIR="sandboxes" # TODO: For now, `sandboxes` is cloned into SkyRL/skyrl-train. |
| NUM_GPUS=1 | ||
| LOGGER="console" # change to "console" to print to stdout | ||
| TBENCH_CONFIG_DIR="examples/terminal_bench" | ||
| SANDBOXES_DIR="sandboxes" # TODO: For now, `sandboxes` is cloned into SkyRL/skyrl-train. |
…ky-AI#637) This PR cherrypicks the recent changes on the dev branch for Terminus training on https://github.com/NovaSky-AI/SkyRL/tree/dev/sandboxes, at commit NovaSky-AI@86de228 Besides, we rebase to accommodate the changes from Harbor (the new version of Sandboxes)
…ky-AI#637) This PR cherrypicks the recent changes on the dev branch for Terminus training on https://github.com/NovaSky-AI/SkyRL/tree/dev/sandboxes, at commit NovaSky-AI@02c0f37 Besides, we rebase to accommodate the changes from Harbor (the new version of Sandboxes)
This PR cherrypicks the recent changes on the dev branch for Terminus training on https://github.com/NovaSky-AI/SkyRL/tree/dev/sandboxes, at commit 86de228
Besides, we rebase to accommodate the changes from Harbor (the new version of Sandboxes)