Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,18 @@ class ConstraintCritic : public CriticFunction
float getMaxVelConstraint() {return max_vel_;}
float getMinVelConstraint() {return min_vel_;}

float getMaxVelXConstraint() {return vx_max_;}
float getMinVelXConstraint() {return vx_min_;}
float getMaxVelYConstraint() {return vy_max_;}

protected:
unsigned int power_{0};
float weight_{0};
float min_vel_;
float max_vel_;
float vx_max_{0};
float vx_min_{0};
float vy_max_{0};
float min_vel_{0};
float max_vel_{0};
};

} // namespace mppi::critics
Expand Down
24 changes: 13 additions & 11 deletions nav2_mppi_controller/src/critics/constraint_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ void ConstraintCritic::initialize()
getParentParam(vx_min, "vx_min", -0.35f);

const float min_sgn = vx_min > 0.0f ? 1.0f : -1.0f;
vx_max_ = vx_max;
vx_min_ = vx_min;
vy_max_ = vy_max;
max_vel_ = sqrtf(vx_max * vx_max + vy_max * vy_max);
min_vel_ = min_sgn * sqrtf(vx_min * vx_min + vy_max * vy_max);
}
Expand All @@ -58,21 +61,20 @@ void ConstraintCritic::score(CriticData & data)
}

// Omnidirectional motion model
// Axis wise violation check
auto omni = dynamic_cast<OmniMotionModel *>(data.motion_model.get());
if (omni != nullptr) {
auto & vx = data.state.vx;
unsigned int n_rows = data.state.vx.rows();
unsigned int n_cols = data.state.vx.cols();
Eigen::ArrayXXf sgn(n_rows, n_cols);
sgn = vx.unaryExpr([](const float x) {return copysignf(1.0f, x);});

auto vel_total = sgn * (data.state.vx.square() + data.state.vy.square()).sqrt();
if (power_ > 1u) {
data.costs += ((((vel_total - max_vel_).max(0.0f) + (min_vel_ - vel_total).
max(0.0f)) * data.model_dt).rowwise().sum().eval() * weight_).pow(power_).eval();
auto & vy = data.state.vy;
auto vx_violation = (vx - vx_max_).max(0.0f) + (vx_min_ - vx).max(0.0f);
auto vy_violation = (vy.abs() - vy_max_).max(0.0f);
auto violation = ((vx_violation + vy_violation) * data.model_dt).rowwise().sum().eval();
auto weighted_violation = violation * weight_;

if(power_ > 1u) {
data.costs += weighted_violation.pow(power_);
} else {
data.costs += ((((vel_total - max_vel_).max(0.0f) + (min_vel_ - vel_total).
max(0.0f)) * data.model_dt).rowwise().sum().eval() * weight_).eval();
data.costs += weighted_violation;
}
return;
}
Expand Down
62 changes: 62 additions & 0 deletions nav2_mppi_controller/test/critics_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,68 @@ TEST(CriticTests, ConstraintsCritic)
EXPECT_GT(costs.sum(), 0);
// 4.0 weight * 0.1 model_dt * (0.2 - 0.4/2.5) * 30 timesteps = 0.48
EXPECT_NEAR(costs(1), 0.48, 0.01);
costs.setZero();

// Now with Holonomic
node->set_parameter(rclcpp::Parameter("mppi.vy_max", 0.3));
critic = ConstraintCritic();
critic.on_configure(node, "mppi", "critic", costmap_ros, &param_handler);
EXPECT_NEAR(critic.getMaxVelYConstraint(), 0.3, 1e-6);

data.motion_model = std::make_shared<OmniMotionModel>();

// reset state
state.vx.setConstant(0.0f);
state.vy.setConstant(0.0f);
state.wz.setConstant(0.0f);

// vx violation check
state.vx.row(999).setConstant(0.60f);
state.vy.setConstant(0.0f);
critic.score(data);
EXPECT_GT(costs.sum(), 0);
// 4.0 weight * 0.1 model_dt * 0.1 error introduced * 30 timesteps = 1.2
EXPECT_NEAR(costs(999), 1.2, 0.01);
costs.setZero();

// vy violation check
state.vx.setConstant(0.0f);
state.vy.row(999).setConstant(0.50f);
critic.score(data);
EXPECT_GT(costs.sum(), 0);
// 4.0 weight * 0.1 model_dt * 0.2 error introduced * 30 timesteps = 2.4
EXPECT_NEAR(costs(999), 2.4, 0.01);
costs.setZero();

// combined check
state.vx.row(999).setConstant(0.6f);
state.vy.row(999).setConstant(-0.5f);
critic.score(data);
EXPECT_GT(costs.sum(), 0);
// vx-violation 4.0 weight * 0.1 model_dt * 0.1 error introduced * 30 timesteps = 1.2
// vy-violation 4.0 weight * 0.1 model_dt * 0.2 error introduced * 30 timesteps = 2.4
// total-violation = 1.2 + 2.4
EXPECT_NEAR(costs(999), 3.6, 0.01);
costs.setZero();

// power > 1u
node->declare_parameter("mppi.critic.cost_power", 2);
node->set_parameter(rclcpp::Parameter("mppi.critic.cost_power", 2));
critic = ConstraintCritic();
critic.on_configure(node, "mppi", "critic", costmap_ros, &param_handler);

// reset state
state.vx.setConstant(0.0f);
state.vy.setConstant(0.0f);
state.wz.setConstant(0.0f);

// vx violation check (no violation so zero cost)
state.vx.row(999).setConstant(0.20f);
state.vy.setConstant(0.0f);
critic.score(data);
EXPECT_NEAR(costs.sum(), 0.0, 1e-6);
EXPECT_NEAR(costs(999), 0.0, 1e-6);
costs.setZero();
}

TEST(CriticTests, ObstacleCriticMisalignedParams) {
Expand Down
Loading