Skip to content

feat(auth): Pluggable AuthService with abstract base class#10702

Merged
HimavarshaVS merged 188 commits intomainfrom
pluggable-auth-service
Feb 6, 2026
Merged

feat(auth): Pluggable AuthService with abstract base class#10702
HimavarshaVS merged 188 commits intomainfrom
pluggable-auth-service

Conversation

@ogabrielluiz
Copy link
Contributor

@ogabrielluiz ogabrielluiz commented Nov 24, 2025

Summary

Implements a pluggable AuthService that enables alternative authentication implementations (e.g., OIDC) to be registered via the pluggable services system while maintaining full backward compatibility.

Key Changes

New Abstract Base Class (auth/base.py)

  • AuthServiceBase defines the contract for all authentication implementations
  • 20+ abstract methods covering authentication, token management, and user validation
  • Clear separation between required interface and implementation details

Refactored AuthService (auth/service.py)

  • Now extends AuthServiceBase
  • Contains all authentication logic previously spread across utils.py
  • JWT-based implementation remains the default

Thin Delegation Layer (auth/utils.py)

  • All functions now delegate to get_auth_service()
  • Maintains backward compatibility for existing code
  • Simplified function signatures (removed settings_service parameter)

Service Integration

  • Added get_auth_service() to deps.py
  • Updated AUTH_SERVICE in ServiceType enum
  • Updated mcp_encryption.py to use new API

How to Use a Custom Auth Implementation

# my_auth/oidc_service.py
class OIDCAuthService(AuthServiceBase):
    async def get_current_user(self, token, query_param, header_param, db):
        # OIDC token validation
        ...
# lfx.toml
[services]
auth_service = "my_auth.oidc_service:OIDCAuthService"

Test Coverage

  • 61 unit tests for AuthService methods
  • Token creation/validation
  • User authentication flows
  • Password hashing/verification
  • API key encryption/decryption
  • Pluggable service delegation
uv run pytest src/backend/tests/unit/services/auth/ -v

Migration Notes

  • encrypt_api_key(key) and decrypt_api_key(key) no longer accept settings_service
  • All existing imports from langflow.services.auth.utils continue to work

Summary by CodeRabbit

Release Notes

  • Bug Fixes

    • Added validation to reject null passwords during password reset and update operations.
    • Improved API key encryption and decryption handling for more consistent behavior.
  • Chores

    • Refactored authentication system architecture for enhanced flexibility and maintainability.
    • Updated webhook authentication flow for improved reliability.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 24, 2025

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • 🔍 Trigger a full review

Walkthrough

This PR refactors the authentication system from utility-function-based to a pluggable service-based architecture. It introduces an abstract base class (AuthServiceBase) defining the auth contract, implements a concrete service (AuthService), and refactors utility functions to delegate to the service. API key encryption/decryption operations are simplified by removing settings_service dependencies.

Changes

Cohort / File(s) Summary
Auth Service Base Architecture
src/backend/base/langflow/services/auth/base.py
New abstract base class defining comprehensive authentication interface with 25+ methods covering token management, user validation, API key encryption, webhook authorization, and MCP-specific flows.
Auth Service Implementation
src/backend/base/langflow/services/auth/service.py
Concrete implementation extending AuthServiceBase with full auth logic integration (token creation/validation, password hashing, Fernet-based key encryption, API key security, webhook flows, and error handling).
Auth Utils Refactoring
src/backend/base/langflow/services/auth/utils.py
Refactored to delegate all authentication functions to _auth_service(), removing ~510 lines of inline implementations (JWT handling, Fernet operations, webhook logic) in favor of service-layer abstraction.
Auth Service Factory & Deps
src/backend/base/langflow/services/auth/factory.py
src/backend/base/langflow/services/deps.py
Factory updated with type hints; new get_auth_service() accessor added for dependency resolution and service retrieval.
API Key Encryption Simplification
src/backend/base/langflow/services/auth/mcp_encryption.py
src/backend/base/langflow/api/v1/api_key.py
src/backend/base/langflow/api/v1/store.py
src/backend/base/langflow/services/variable/kubernetes.py
src/backend/base/langflow/services/variable/service.py
Removed settings_service parameter from all encrypt_api_key() and decrypt_api_key() calls; encryption logic now self-contained in auth service.
User API Improvements
src/backend/base/langflow/api/v1/users.py
Added explicit None guards for password fields in patch_user() and reset_password() to enforce password validation earlier.
CLI & Main Entry Point
src/backend/base/langflow/__main__.py
Updated token auth flow to use get_current_user_from_access_token instead of JWT-based approach; aligned exception types and import paths (check_key now from API key CRUD).
Pluggable Services Infrastructure
src/lfx/src/lfx/services/schema.py
src/lfx/src/lfx/services/manager.py
src/lfx/PLUGGABLE_SERVICES.md
Added AUTH_SERVICE to service type enum; introduced circular dependency detection and thread-safe plugin discovery locking in service manager.
Storage Service Base
src/lfx/src/lfx/services/storage/service.py
Changed teardown() method from abstract requirement to optional override with no-op default.
Unit Tests: Auth Service
src/backend/tests/unit/services/auth/test_auth_service.py
New comprehensive test suite covering token creation/validation, password hashing, API key encryption, user authentication flows, MCP variants, and error cases (444 lines).
Unit Tests: Pluggable Auth & Encryption
src/backend/tests/unit/services/auth/test_pluggable_auth.py
src/backend/tests/unit/services/auth/test_mcp_encryption.py
New pluggable service tests with DummyAuthService; refactored encryption tests to use real AuthService instead of mocked settings.
Integration Tests Updated
src/backend/tests/unit/test_cli.py
src/backend/tests/unit/test_setup_superuser.py
src/backend/tests/unit/test_webhook.py
src/lfx/tests/unit/services/test_decorator_registration.py
src/lfx/tests/unit/services/test_edge_cases.py
Updated mock paths to align with refactored auth service; webhook tests refactored to interact with real get_auth_service() instead of patching settings; service registration tests updated to include mock session service.
Configuration Baseline
.secrets.baseline
Updated line number references and timestamp due to code shifts.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Client
    participant FastAPI
    participant AuthUtils as auth_utils
    participant AuthService
    participant DB

    Client->>FastAPI: POST /users/login
    FastAPI->>AuthUtils: authenticate_user(username, password)
    AuthUtils->>AuthService: _auth_service().authenticate_user()
    AuthService->>DB: fetch user by username
    DB-->>AuthService: user record
    AuthService->>AuthService: verify_password()
    AuthService->>DB: update last_login_at
    DB-->>AuthService: ✓
    AuthService->>AuthService: create_user_tokens()
    AuthService->>AuthService: create_token(access claims)
    AuthService->>AuthService: create_token(refresh claims)
    AuthService-->>AuthUtils: {access_token, refresh_token}
    AuthUtils-->>FastAPI: token dict
    FastAPI-->>Client: 200 OK {access_token, refresh_token}
Loading
sequenceDiagram
    autonumber
    participant Client
    participant FastAPI
    participant AuthUtils as auth_utils
    participant AuthService
    participant DB

    Client->>FastAPI: GET /flow/{id} + Authorization: Bearer TOKEN
    FastAPI->>AuthUtils: get_current_user(token, query_param, header_param, db)
    AuthUtils->>AuthService: _auth_service().get_current_user()
    alt Token exists
        AuthService->>AuthService: get_current_user_from_access_token()
        AuthService->>AuthService: decode & validate JWT
        AuthService->>DB: fetch user by id from token
        DB-->>AuthService: user record
    else API Key exists
        AuthService->>AuthService: _api_key_security_impl()
        AuthService->>AuthService: decrypt_api_key()
        AuthService->>DB: lookup user by api key
        DB-->>AuthService: user record
    end
    AuthService-->>AuthUtils: user (User | UserRead)
    AuthUtils-->>FastAPI: user
    FastAPI->>FastAPI: `@depends` authorization complete
    FastAPI-->>Client: 200 OK {flow data}
Loading
sequenceDiagram
    autonumber
    participant User as User Code
    participant DepFactory as Dependency Factory
    participant ServiceManager
    participant AuthServiceFactory
    participant AuthService

    User->>DepFactory: get_auth_service()
    DepFactory->>ServiceManager: get(ServiceType.AUTH_SERVICE)
    alt Service cached
        ServiceManager-->>DepFactory: AuthService instance
    else First time
        ServiceManager->>AuthServiceFactory: create(settings_service)
        AuthServiceFactory->>AuthService: __init__()
        AuthService->>AuthService: initialize encryption keys
        AuthService-->>AuthServiceFactory: instance
        AuthServiceFactory-->>ServiceManager: instance
        ServiceManager->>ServiceManager: cache & register
        ServiceManager-->>DepFactory: instance
    end
    DepFactory-->>User: AuthService(AuthServiceBase)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • mpawlow
  • HimavarshaVS
  • jordanrfrazier
🚥 Pre-merge checks | ✅ 4 | ❌ 3
❌ Failed checks (3 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.34% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Quality And Coverage ⚠️ Warning Tests use @pytest.mark.anyio instead of project standard @pytest.mark.asyncio, creating inconsistency with 67+ files using asyncio. Replace all @pytest.mark.anyio decorators with @pytest.mark.asyncio throughout test files to align with project standards and pytest-asyncio configuration.
Test File Naming And Structure ⚠️ Warning Test files use incorrect @pytest.mark.anyio instead of @pytest.mark.asyncio, violating project conventions configured with pytest-asyncio. Replace @pytest.mark.anyio with @pytest.mark.asyncio in test_pluggable_auth.py to align with project's pytest-asyncio configuration used across 67+ test files.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(auth): Pluggable AuthService with abstract base class' clearly summarizes the main change: introducing a pluggable authentication service with an abstract base class. This is the primary feature delivered by this PR.
Test Coverage For New Implementations ✅ Passed PR includes comprehensive test coverage with new test_auth_service.py (444 lines, 61+ tests), test_pluggable_auth.py (71 lines), and updates to existing integration tests covering token management, user auth flows, encryption, password hashing, and error handling.
Excessive Mock Usage Warning ✅ Passed Test files demonstrate balanced mock usage: AsyncMocks target I/O dependencies (DB, external services) while real AuthService instances and business logic remain unmocked, with avg 1-2 mocks per test.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch pluggable-auth-service

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request implements a pluggable authentication service architecture that enables alternative authentication implementations (e.g., OIDC) to be registered while maintaining full backward compatibility with existing JWT-based authentication.

Key Changes:

  • Introduced AuthServiceBase abstract class defining the contract for all authentication implementations with 19 abstract methods
  • Refactored AuthService to extend AuthServiceBase and consolidated authentication logic from utils.py
  • Converted auth/utils.py into a thin delegation layer that maintains backward compatibility
  • Added AUTH_SERVICE to the ServiceType enum and integrated with the pluggable services system

Reviewed changes

Copilot reviewed 14 out of 15 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
src/backend/base/langflow/services/auth/base.py New abstract base class defining the authentication service contract with 19 abstract methods covering authentication, token management, user validation, and API key operations
src/backend/base/langflow/services/auth/service.py Refactored to extend AuthServiceBase, consolidating all authentication logic (500+ lines) including JWT operations, API key validation, and user management
src/backend/base/langflow/services/auth/utils.py Converted to delegation layer with functions forwarding to get_auth_service(), maintaining backward compatibility while removing direct implementation code
src/backend/base/langflow/services/auth/factory.py Added type annotation for create() method to return AuthService
src/backend/base/langflow/services/auth/mcp_encryption.py Updated to call encryption methods through the auth service instead of passing settings_service parameter
src/backend/base/langflow/services/deps.py Added get_auth_service() function to retrieve the authentication service instance
src/lfx/src/lfx/services/schema.py Added AUTH_SERVICE enum value for pluggable service registration
src/lfx/PLUGGABLE_SERVICES.md Updated documentation to include auth_service in the list of pluggable services
src/backend/base/langflow/__main__.py Updated to use renamed function get_current_user_from_access_token (was get_current_user_by_jwt)
src/backend/tests/unit/test_cli.py Updated mock patch to use renamed authentication function
src/backend/tests/unit/services/auth/test_pluggable_auth.py New tests verifying pluggable service delegation for authentication methods
src/backend/tests/unit/services/auth/test_mcp_encryption.py Updated to use AuthService directly instead of mocking settings service, with proper secret annotations
src/backend/tests/unit/services/auth/test_auth_service.py New comprehensive test suite with 61 unit tests covering token creation, validation, authentication flows, and password operations
.secrets.baseline Updated line numbers to reflect code reorganization

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +163 to +165
token: str | Coroutine,
db: AsyncSession,
) -> User:
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature of get_current_user_from_access_token is inconsistent between the utils layer and the implementation:

  • In utils.py:60, it accepts token: str | Coroutine | None
  • In service.py:163, it only accepts token: str | Coroutine (no None)

While internal callers (like get_current_user and get_current_user_for_websocket) check for None before calling this method, the public API exposed through utils allows None to be passed. If a caller uses the utils function directly with None, it would cause issues.

Either:

  1. Update the service implementation to handle None explicitly (and raise an appropriate error), or
  2. Update the utils signature to match the implementation (remove None), or
  3. Document that the utils function should not be called with None

The first option is recommended for robustness.

Suggested change
token: str | Coroutine,
db: AsyncSession,
) -> User:
token: str | Coroutine | None,
db: AsyncSession,
) -> User:
if token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication token.",
headers={"WWW-Authenticate": "Bearer"},
)

Copilot uses AI. Check for mistakes.
Comment on lines +120 to +121
def get_user_id_from_token(token: str):
return _auth_service().get_user_id_from_token(token)
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint. Add return type annotation:

from uuid import UUID

def get_user_id_from_token(token: str) -> UUID:
    return _auth_service().get_user_id_from_token(token)

Note: The UUID import should be added at the top of the file if not already present.

Copilot uses AI. Check for mistakes.
data={"sub": str(user_id), "type": "access"},
expires_delta=access_token_expires,
)
async def create_refresh_token(refresh_token: str, db: AsyncSession):
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint. Add return type annotation:

async def create_refresh_token(refresh_token: str, db: AsyncSession) -> dict:
    return await _auth_service().create_refresh_token(refresh_token, db)
Suggested change
async def create_refresh_token(refresh_token: str, db: AsyncSession):
async def create_refresh_token(refresh_token: str, db: AsyncSession) -> dict:

Copilot uses AI. Check for mistakes.
"token_type": "bearer",
}

def encrypt_api_key(api_key: str):
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint. Add return type annotation:

def encrypt_api_key(api_key: str) -> str:
    return _auth_service().encrypt_api_key(api_key)
Suggested change
def encrypt_api_key(api_key: str):
def encrypt_api_key(api_key: str) -> str:

Copilot uses AI. Check for mistakes.
)
return fernet.decrypt(encrypted_api_key).decode()
return ""
def decrypt_api_key(encrypted_api_key: str):
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint. Add return type annotation:

def decrypt_api_key(encrypted_api_key: str) -> str:
    return _auth_service().decrypt_api_key(encrypted_api_key)
Suggested change
def decrypt_api_key(encrypted_api_key: str):
def decrypt_api_key(encrypted_api_key: str) -> str:

Copilot uses AI. Check for mistakes.
Comment on lines +100 to +101
def create_token(data: dict, expires_delta):
return _auth_service().create_token(data, expires_delta)
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type hints for function parameters and return type. Add type annotations:

from datetime import timedelta

def create_token(data: dict, expires_delta: timedelta) -> str:
    return _auth_service().create_token(data, expires_delta)

Note: The timedelta import should be added at the top of the file.

Copilot uses AI. Check for mistakes.
Comment on lines +116 to +117
def create_user_api_key(user_id):
return _auth_service().create_user_api_key(user_id)
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type hints for function parameter and return type. Add type annotations:

from uuid import UUID

def create_user_api_key(user_id: UUID) -> dict:
    return _auth_service().create_user_api_key(user_id)

Note: The UUID import should be added at the top of the file if not already present.

Copilot uses AI. Check for mistakes.
Comment on lines +124 to +125
async def create_user_tokens(user_id, db: AsyncSession, *, update_last_login: bool = False):
return await _auth_service().create_user_tokens(user_id, db, update_last_login=update_last_login)
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type hints for function parameters and return type. Add type annotations:

from uuid import UUID

async def create_user_tokens(user_id: UUID, db: AsyncSession, *, update_last_login: bool = False) -> dict:
    return await _auth_service().create_user_tokens(user_id, db, update_last_login=update_last_login)

Note: The UUID import should be added at the top of the file if not already present.

Copilot uses AI. Check for mistakes.
# Update: last_login_at
if update_last_login:
await update_user_last_login_at(user_id, db)
async def authenticate_user(username: str, password: str, db: AsyncSession):
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint. Add return type annotation:

async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None:
    return await _auth_service().authenticate_user(username, password, db)
Suggested change
async def authenticate_user(username: str, password: str, db: AsyncSession):
async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None:

Copilot uses AI. Check for mistakes.
Comment on lines +112 to +113
async def create_user_longterm_token(db: AsyncSession):
return await _auth_service().create_user_longterm_token(db)
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint. Add return type annotation:

from uuid import UUID

async def create_user_longterm_token(db: AsyncSession) -> tuple[UUID, dict]:
    return await _auth_service().create_user_longterm_token(db)

Note: The UUID import should be added at the top of the file if not already present.

Copilot uses AI. Check for mistakes.
@ogabrielluiz
Copy link
Contributor Author

@copilot open a new pull request to apply changes based on the comments in this thread

Copy link
Contributor

Copilot AI commented Nov 24, 2025

@ogabrielluiz I've opened a new pull request, #10710, to work on those changes. Once the pull request is ready, I'll request review from you.

@github-actions github-actions bot added enhancement New feature or request and removed enhancement New feature or request labels Nov 25, 2025
@github-actions

This comment has been minimized.

@github-actions github-actions bot added enhancement New feature or request and removed enhancement New feature or request labels Nov 27, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Nov 27, 2025

Build successful! ✅
Deploying docs draft.
Deploy successful! View draft

Comment on lines +18 to +49

def convert_value(v):
if v is None:
return "Not available"
if isinstance(v, str):
v_stripped = v.strip().lower()
if v_stripped in {"null", "nan", "infinity", "-infinity"}:
return "Not available"
if isinstance(v, float):
try:
if math.isnan(v):
return "Not available"
except Exception as e: # noqa: BLE001
logger.aexception(f"Error converting value {v} to float: {e}")

if hasattr(v, "isnat") and getattr(v, "isnat", False):
return "Not available"
return v

not_avail = "Not available"
required_fields_set = set(required_fields) if required_fields else set()
result = []
for d in data:
if not isinstance(d, dict):
result.append(d)
continue
new_dict = {k: convert_value(v) for k, v in d.items()}
missing = required_fields_set - new_dict.keys()
if missing:
for k in missing:
new_dict[k] = not_avail
result.append(new_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 46% (0.46x) speedup for replace_none_and_null_with_empty_str in src/backend/base/langflow/agentic/mcp/support.py

⏱️ Runtime : 1.39 milliseconds 950 microseconds (best of 178 runs)

📝 Explanation and details

The optimized code achieves a 46% speedup by eliminating function call overhead and improving data access patterns. Here are the key optimizations:

What was optimized:

  1. Inlined the convert_value function - The biggest performance gain comes from removing the nested function calls within the dictionary processing loop. The original code called convert_value() for every dictionary value, creating significant function call overhead.

  2. Eliminated dictionary comprehension - Replaced {k: convert_value(v) for k, v in d.items()} with an explicit loop that builds the dictionary incrementally. This avoids closure overhead and allows better control flow with continue statements.

  3. Precomputed constants - Moved null_strings set creation outside the loop and cached result.append as append_result to avoid repeated method lookups.

  4. Conditional required fields processing - Only processes missing required fields when required_fields_set is non-empty, avoiding unnecessary set operations.

Why it's faster:

  • Function call elimination: The original code made 4,200+ calls to convert_value() in the profiler results. Inlining this removes all that overhead.
  • Better branching: Using explicit if/elif/continue statements allows the CPU to exit early from value processing, rather than always executing the full function.
  • Reduced method lookups: Caching result.append and d.items() eliminates repeated attribute access.

Impact on workloads:
Based on the test results, this optimization is particularly effective for:

  • Large-scale data processing (500+ dictionaries with multiple fields)
  • High None/null value density workloads where many values need replacement
  • Scenarios with many required fields that need backfilling

The optimization maintains identical behavior and error handling while significantly reducing the computational overhead of processing dictionary values, making it especially valuable for data cleaning pipelines or ETL operations.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 43 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import math

# imports
import pytest
from langflow.agentic.mcp.support import replace_none_and_null_with_empty_str

# unit tests

# -------------------- Basic Test Cases --------------------

def test_none_and_null_string_replacement():
    # Test that None and 'null' (case-insensitive) are replaced
    data = [
        {"a": None, "b": "null", "c": "NuLl", "d": "something"},
        {"a": "NULL", "b": "not null", "c": None, "d": 123}
    ]
    expected = [
        {"a": "Not available", "b": "Not available", "c": "Not available", "d": "something"},
        {"a": "Not available", "b": "not null", "c": "Not available", "d": 123}
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_nan_and_infinity_string_replacement():
    # Test that 'NaN', 'Infinity', '-Infinity' strings are replaced
    data = [
        {"x": "NaN", "y": "Infinity", "z": "-Infinity"},
        {"x": "nan", "y": " infinity ", "z": "  -infinity "}
    ]
    expected = [
        {"x": "Not available", "y": "Not available", "z": "Not available"},
        {"x": "Not available", "y": "Not available", "z": "Not available"}
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_float_nan_and_regular_numbers():
    # Test that float('nan') is replaced, but regular floats are not
    data = [
        {"x": float('nan'), "y": 3.14, "z": 0.0},
        {"x": -1.23, "y": 42.0}
    ]
    expected = [
        {"x": "Not available", "y": 3.14, "z": 0.0},
        {"x": -1.23, "y": 42.0}
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_no_replacement_needed():
    # Test that normal values are untouched
    data = [{"a": 1, "b": "hello", "c": False}]
    expected = [{"a": 1, "b": "hello", "c": False}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

# -------------------- Edge Test Cases --------------------

def test_required_fields_missing():
    # Test that missing required fields are added with "Not available"
    data = [{"a": 1}, {"b": 2}]
    required_fields = ["a", "b", "c"]
    expected = [
        {"a": 1, "b": "Not available", "c": "Not available"},
        {"a": "Not available", "b": 2, "c": "Not available"}
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields)

def test_required_fields_present():
    # Test that required fields are not overwritten if present
    data = [{"a": None, "b": 2, "c": 3}]
    required_fields = ["a", "b", "c"]
    expected = [{"a": "Not available", "b": 2, "c": 3}]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields)

def test_empty_data_list():
    # Test empty input list
    codeflash_output = replace_none_and_null_with_empty_str([])

def test_non_dict_elements():
    # Test that non-dict elements are returned as-is
    data = [{"a": None}, 123, "hello", None]
    expected = [{"a": "Not available"}, 123, "hello", None]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_dict_with_non_string_keys():
    # Test dicts with non-string keys
    data = [{1: None, 2.5: "null", (1,2): "no"}]
    expected = [{1: "Not available", 2.5: "Not available", (1,2): "no"}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_dict_with_nested_dict():
    # Test that nested dicts are not recursively processed
    nested = {"x": None}
    data = [{"a": nested}]
    expected = [{"a": nested}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_custom_object_with_isnat():
    # Test object with isnat attribute True/False
    class Dummy:
        def __init__(self, isnat):
            self.isnat = isnat
    data = [{"x": Dummy(True)}, {"x": Dummy(False)}]
    codeflash_output = replace_none_and_null_with_empty_str(data); result = codeflash_output

def test_required_fields_empty_list():
    # Test required_fields as empty list (should not add any fields)
    data = [{"a": None}]
    expected = [{"a": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields=[])

def test_required_fields_none():
    # Test required_fields as None (should not add any fields)
    data = [{"a": None}]
    expected = [{"a": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields=None)

def test_string_with_spaces():
    # Test that strings with spaces are stripped before checking for null/nan/infinity
    data = [{"a": "  null  ", "b": "  NaN  ", "c": "  Infinity  ", "d": "  -Infinity  "}]
    expected = [{"a": "Not available", "b": "Not available", "c": "Not available", "d": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_case_insensitive_matching():
    # Test that matching is case-insensitive
    data = [{"a": "NULL", "b": "nAn", "c": "INFINITY", "d": "-INFINITY"}]
    expected = [{"a": "Not available", "b": "Not available", "c": "Not available", "d": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

# -------------------- Large Scale Test Cases --------------------

def test_large_list_of_dicts():
    # Test with a large list of dicts
    n = 500
    data = [{"a": None if i % 2 == 0 else i, "b": "null" if i % 3 == 0 else i, "c": float('nan') if i % 5 == 0 else i} for i in range(n)]
    codeflash_output = replace_none_and_null_with_empty_str(data); result = codeflash_output
    for i, d in enumerate(result):
        # "a" should be "Not available" if i even, else i
        if i % 2 == 0:
            pass
        else:
            pass
        # "b" should be "Not available" if i divisible by 3, else i
        if i % 3 == 0:
            pass
        else:
            pass
        # "c" should be "Not available" if i divisible by 5, else i
        if i % 5 == 0:
            pass
        else:
            pass

def test_large_required_fields():
    # Test with many required fields
    n = 100
    data = [{} for _ in range(10)]
    required_fields = [f"f{i}" for i in range(n)]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields); result = codeflash_output
    for d in result:
        for f in required_fields:
            pass

def test_large_dict():
    # Test with a single dict with many keys
    n = 500
    data = [{str(i): None if i % 2 == 0 else "ok" for i in range(n)}]
    codeflash_output = replace_none_and_null_with_empty_str(data); result = codeflash_output
    for i in range(n):
        key = str(i)
        if i % 2 == 0:
            pass
        else:
            pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import math

# imports
import pytest
from langflow.agentic.mcp.support import replace_none_and_null_with_empty_str

# unit tests

# 1. BASIC TEST CASES

def test_basic_none_and_null_replacement():
    # Test replacing None and 'null' in various cases
    data = [
        {"a": None, "b": "null", "c": "NULL", "d": "Null", "e": "NuLl"},
        {"a": 1, "b": "hello", "c": 0, "d": False, "e": ""},
    ]
    expected = [
        {"a": "Not available", "b": "Not available", "c": "Not available", "d": "Not available", "e": "Not available"},
        {"a": 1, "b": "hello", "c": 0, "d": False, "e": ""},
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_basic_nan_and_infinity_strings():
    # Test replacing 'NaN', 'Infinity', '-Infinity' strings
    data = [
        {"a": "NaN", "b": "Infinity", "c": "-Infinity", "d": "nan", "e": " infinity "},
        {"a": "foo", "b": "bar", "c": "baz", "d": "qux", "e": "quux"},
    ]
    expected = [
        {"a": "Not available", "b": "Not available", "c": "Not available", "d": "Not available", "e": "Not available"},
        {"a": "foo", "b": "bar", "c": "baz", "d": "qux", "e": "quux"},
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_basic_float_nan():
    # Test replacing float('nan')
    data = [
        {"a": float('nan'), "b": 1.0, "c": 0.0, "d": -1.0},
    ]
    expected = [
        {"a": "Not available", "b": 1.0, "c": 0.0, "d": -1.0},
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_basic_no_replacements():
    # Test when nothing should be replaced
    data = [
        {"a": 1, "b": "test", "c": 2.5},
    ]
    expected = [
        {"a": 1, "b": "test", "c": 2.5},
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_basic_required_fields():
    # Test required_fields adds missing keys with "Not available"
    data = [
        {"a": 1, "b": 2},
        {"b": 3},
    ]
    required_fields = ["a", "b", "c"]
    expected = [
        {"a": 1, "b": 2, "c": "Not available"},
        {"b": 3, "a": "Not available", "c": "Not available"},
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields)

# 2. EDGE TEST CASES

def test_edge_empty_list():
    # Test empty input list
    data = []
    expected = []
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_empty_dict():
    # Test list with empty dict
    data = [{}]
    expected = [{}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_non_dict_input():
    # Test list containing non-dict elements
    data = [{"a": None}, 42, "string", None, [1,2,3]]
    expected = [{"a": "Not available"}, 42, "string", None, [1,2,3]]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_strip_spaces():
    # Test that strings with leading/trailing whitespace are handled
    data = [
        {"a": " null ", "b": "   NaN", "c": "Infinity   ", "d": " -Infinity "},
    ]
    expected = [
        {"a": "Not available", "b": "Not available", "c": "Not available", "d": "Not available"},
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_nested_dicts():
    # Test that nested dicts are not recursively processed
    data = [
        {"a": {"b": None, "c": "null"}, "d": None}
    ]
    expected = [
        {"a": {"b": None, "c": "null"}, "d": "Not available"}
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_isnat_object():
    # Test object with isnat attribute set to True
    class FakeNat:
        isnat = True
    data = [{"a": FakeNat()}]
    expected = [{"a": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_isnat_false():
    # Test object with isnat attribute set to False
    class FakeNotNat:
        isnat = False
    data = [{"a": FakeNotNat()}]
    expected = [{"a": FakeNotNat()}]
    codeflash_output = replace_none_and_null_with_empty_str(data); result = codeflash_output

def test_edge_required_fields_none():
    # Test required_fields is None
    data = [{"a": None}]
    expected = [{"a": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data, None)

def test_edge_required_fields_empty():
    # Test required_fields is empty list
    data = [{"a": None}]
    expected = [{"a": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data, [])

def test_edge_missing_all_required_fields():
    # Test all required fields missing
    data = [{}]
    required_fields = ["a", "b"]
    expected = [{"a": "Not available", "b": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields)

def test_edge_required_field_already_not_available():
    # Test required_fields where key is present but value is None
    data = [{"a": None}]
    required_fields = ["a"]
    expected = [{"a": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields)

def test_edge_nan_string_with_spaces():
    # Test ' nan ' with spaces is replaced
    data = [{"a": "  nan  "}]
    expected = [{"a": "Not available"}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_zero_and_false():
    # Test that 0 and False are not replaced
    data = [{"a": 0, "b": False}]
    expected = [{"a": 0, "b": False}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_multiple_types():
    # Test a mix of types
    data = [{"a": None, "b": "null", "c": float('nan'), "d": 0, "e": "", "f": False}]
    expected = [{"a": "Not available", "b": "Not available", "c": "Not available", "d": 0, "e": "", "f": False}]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_edge_dict_with_extra_keys():
    # Test dict with extra keys not in required_fields
    data = [{"a": 1, "b": 2, "c": 3}]
    required_fields = ["a", "b"]
    expected = [{"a": 1, "b": 2, "c": 3}]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields)

# 3. LARGE SCALE TEST CASES

def test_large_scale_many_dicts():
    # Test with a large list of dicts
    data = [{"a": None if i % 2 == 0 else i, "b": "null" if i % 3 == 0 else str(i)} for i in range(500)]
    expected = [
        {
            "a": "Not available" if i % 2 == 0 else i,
            "b": "Not available" if i % 3 == 0 else str(i)
        }
        for i in range(500)
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_large_scale_many_fields():
    # Test with dicts with many fields
    fields = [f"f{i}" for i in range(100)]
    data = [
        {field: None if i % 2 == 0 else "null" if i % 3 == 0 else i for i, field in enumerate(fields)}
        for _ in range(10)
    ]
    expected = [
        {field: "Not available" if i % 2 == 0 or i % 3 == 0 else i for i, field in enumerate(fields)}
        for _ in range(10)
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_large_scale_required_fields():
    # Test with required_fields on large data
    data = [{"a": None}, {"b": 2}]
    required_fields = [f"f{i}" for i in range(50)] + ["a", "b"]
    expected = [
        dict({"a": "Not available"}, **{k: "Not available" for k in required_fields if k != "a"}),
        dict({"b": 2, "a": "Not available"}, **{k: "Not available" for k in required_fields if k not in ("a", "b")}),
    ]
    codeflash_output = replace_none_and_null_with_empty_str(data, required_fields); result = codeflash_output
    # Check all required fields present and correct values
    for res, exp in zip(result, expected):
        for k in required_fields:
            pass

def test_large_scale_non_dicts():
    # Test with a large list of non-dict elements
    data = [i for i in range(200)] + ["null", None, float('nan')]
    expected = [i for i in range(200)] + ["null", None, float('nan')]
    codeflash_output = replace_none_and_null_with_empty_str(data)

def test_large_scale_mixed():
    # Mix of dicts and non-dicts
    data = [{"a": None}] * 100 + [None] * 100
    expected = [{"a": "Not available"}] * 100 + [None] * 100
    codeflash_output = replace_none_and_null_with_empty_str(data)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr10702-2025-11-27T21.16.33

Click to see suggested changes
Suggested change
def convert_value(v):
if v is None:
return "Not available"
if isinstance(v, str):
v_stripped = v.strip().lower()
if v_stripped in {"null", "nan", "infinity", "-infinity"}:
return "Not available"
if isinstance(v, float):
try:
if math.isnan(v):
return "Not available"
except Exception as e: # noqa: BLE001
logger.aexception(f"Error converting value {v} to float: {e}")
if hasattr(v, "isnat") and getattr(v, "isnat", False):
return "Not available"
return v
not_avail = "Not available"
required_fields_set = set(required_fields) if required_fields else set()
result = []
for d in data:
if not isinstance(d, dict):
result.append(d)
continue
new_dict = {k: convert_value(v) for k, v in d.items()}
missing = required_fields_set - new_dict.keys()
if missing:
for k in missing:
new_dict[k] = not_avail
result.append(new_dict)
not_avail = "Not available"
# Precompute for micro-optimization purposes
null_strings = {"null", "nan", "infinity", "-infinity"}
if required_fields:
required_fields_set = set(required_fields)
else:
required_fields_set = set()
# Inline "convert_value" to eliminate function calling overhead in the dict comprehension
result = []
append_result = result.append
for d in data:
if not isinstance(d, dict):
append_result(d)
continue
# Avoid repeated method lookups in loop, and hoist locals
d_items = d.items()
new_dict = {}
for k, v in d_items:
# Inline of convert_value(v)
if v is None:
new_dict[k] = not_avail
continue
if isinstance(v, str):
v_stripped = v.strip().lower()
if v_stripped in null_strings:
new_dict[k] = not_avail
continue
elif isinstance(v, float):
# math.isnan is cheap, don't try-except unless necessary for rare types
# (logger only called on non-float, non-string values)
try:
if math.isnan(v):
new_dict[k] = not_avail
continue
except Exception as e: # noqa: BLE001
logger.aexception(f"Error converting value {v} to float: {e}")
# is_nat detection, done last for performance
elif hasattr(v, "isnat"):
try:
if v.isnat:
new_dict[k] = not_avail
continue
except Exception:
pass # Ignore attribute access errors, fallback
new_dict[k] = v
# Only check for missing if required_fields are used
if required_fields_set:
missing = required_fields_set - new_dict.keys()
if missing:
for k in missing:
new_dict[k] = not_avail
append_result(new_dict)

Comment on lines +166 to +173
# Check for type changes
if self._has_type_change(call) and phase != MigrationPhase.CONTRACT:
violations.append(
Violation("DIRECT_TYPE_CHANGE", "Type changes should use expand-contract pattern", call.lineno)
)

# Check for nullable changes
if self._changes_nullable_to_false(call) and phase != MigrationPhase.CONTRACT:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 26% (0.26x) speedup for MigrationValidator._check_alter_column in src/backend/base/langflow/alembic/migration_validator.py

⏱️ Runtime : 475 microseconds 377 microseconds (best of 133 runs)

📝 Explanation and details

The optimization achieves a 25% speedup by eliminating redundant keyword iterations and reducing method call overhead in the _check_alter_column method.

Key optimizations applied:

  1. Single-pass keyword processing: Instead of calling _has_type_change() and _changes_nullable_to_false() separately (which each iterate over call.keywords), the optimized version combines both checks into a single loop. This reduces keyword iteration from 2 passes to 1 pass.

  2. Early termination: Added break conditions when both has_type_change and nullable_to_false are found, avoiding unnecessary iterations through remaining keywords.

  3. Reduced method call overhead: Eliminated two method calls per invocation by inlining the logic, removing function call overhead and attribute lookups.

  4. Precomputed phase comparison: The is_not_contract variable avoids repeating the phase != MigrationPhase.CONTRACT comparison twice.

Performance impact analysis:

  • The line profiler shows the original code spent 63.1% of time in _changes_nullable_to_false() and 26.5% in _has_type_change() - nearly 90% of execution time was in these helper methods
  • The optimized version eliminates this overhead by processing keywords once in the main method
  • For calls with many keywords (like the 500-keyword test cases), this optimization becomes increasingly beneficial as it scales linearly instead of quadratically

Test case benefits:

  • Small keyword lists: Modest gains from reduced method call overhead
  • Large keyword lists (500+ keywords): Significant speedup from single-pass processing and early termination
  • Mixed workloads: Consistent 25% improvement across various keyword patterns

The optimization maintains identical behavior and API compatibility while substantially improving performance for AST processing workloads, which typically involve many small method calls over structured data.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 169 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import ast

# imports
import pytest
from langflow.alembic.migration_validator import MigrationValidator

# --- Helper Classes and Enums for Testing ---

class MigrationPhase:
    """Enum-like class for migration phases."""
    EXPAND = "EXPAND"
    CONTRACT = "CONTRACT"
    OTHER = "OTHER"

class Violation:
    """Simple Violation class for test purposes."""
    def __init__(self, code, message, lineno):
        self.code = code
        self.message = message
        self.lineno = lineno

    def __eq__(self, other):
        return (
            isinstance(other, Violation)
            and self.code == other.code
            and self.message == other.message
            and self.lineno == other.lineno
        )

    def __repr__(self):
        return f"Violation({self.code!r}, {self.message!r}, {self.lineno!r})"

# --- Unit Tests ---

# Helper to create ast.Call objects for tests
def make_call(keywords, lineno=1):
    return ast.Call(
        func=ast.Name(id="alter_column", ctx=ast.Load()),
        args=[],
        keywords=keywords,
        lineno=lineno
    )

def make_keyword(arg, value):
    return ast.keyword(arg=arg, value=value)

# 1. Basic Test Cases

def test_no_type_or_nullable_change_returns_no_violations():
    # No relevant keywords
    call = make_call([])
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_type_change_in_expand_phase_reports_violation():
    # type_ keyword present, not CONTRACT phase
    call = make_call([make_keyword("type_", ast.Constant(value="Integer"))], lineno=10)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_type_change_in_contract_phase_no_violation():
    # type_ keyword present, CONTRACT phase
    call = make_call([make_keyword("type_", ast.Constant(value="String"))], lineno=20)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_nullable_change_to_false_in_expand_phase_reports_violation():
    # nullable=False present, not CONTRACT phase
    call = make_call([make_keyword("nullable", ast.Constant(value=False))], lineno=30)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_nullable_change_to_false_in_contract_phase_no_violation():
    # nullable=False present, CONTRACT phase
    call = make_call([make_keyword("nullable", ast.Constant(value=False))], lineno=40)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_nullable_change_to_true_no_violation():
    # nullable=True present, any phase
    call = make_call([make_keyword("nullable", ast.Constant(value=True))], lineno=50)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_type_and_nullable_change_both_violations():
    # Both type_ and nullable=False present, not CONTRACT phase
    call = make_call([
        make_keyword("type_", ast.Constant(value="String")),
        make_keyword("nullable", ast.Constant(value=False))
    ], lineno=60)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

# 2. Edge Test Cases

def test_nullable_keyword_not_constant_no_violation():
    # nullable keyword present but value is not ast.Constant
    call = make_call([make_keyword("nullable", ast.Name(id="SOME_VAR", ctx=ast.Load()))], lineno=70)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_type_keyword_with_unusual_name_no_violation():
    # type keyword with different name should not trigger violation
    call = make_call([make_keyword("typex", ast.Constant(value="String"))], lineno=80)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_multiple_irrelevant_keywords_no_violation():
    # Only irrelevant keywords
    call = make_call([
        make_keyword("default", ast.Constant(value=42)),
        make_keyword("server_default", ast.Constant(value=None))
    ], lineno=90)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_nullable_false_and_type_change_in_other_phase():
    # Both violations in a phase that's neither EXPAND nor CONTRACT
    call = make_call([
        make_keyword("type_", ast.Constant(value="Float")),
        make_keyword("nullable", ast.Constant(value=False))
    ], lineno=100)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.OTHER); violations = codeflash_output

def test_no_keywords_returns_no_violations():
    # No keywords at all
    call = make_call([], lineno=110)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_nullable_false_and_type_in_contract_phase_no_violation():
    # Both changes in CONTRACT phase
    call = make_call([
        make_keyword("type_", ast.Constant(value="Float")),
        make_keyword("nullable", ast.Constant(value=False))
    ], lineno=120)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_nullable_false_with_multiple_keywords():
    # nullable=False among other irrelevant keywords
    call = make_call([
        make_keyword("default", ast.Constant(value="abc")),
        make_keyword("nullable", ast.Constant(value=False)),
        make_keyword("comment", ast.Constant(value="desc"))
    ], lineno=130)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

# 3. Large Scale Test Cases

def test_many_irrelevant_keywords_large_scale():
    # Large number of irrelevant keywords
    keywords = [make_keyword(f"irrelevant_{i}", ast.Constant(value=i)) for i in range(500)]
    call = make_call(keywords, lineno=140)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_many_type_changes_large_scale():
    # Many type_ keywords, should trigger only one violation
    keywords = [make_keyword("type_", ast.Constant(value=f"Type{i}")) for i in range(500)]
    call = make_call(keywords, lineno=150)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_many_nullable_false_large_scale():
    # Many nullable=False keywords, should trigger only one violation
    keywords = [make_keyword("nullable", ast.Constant(value=False)) for _ in range(500)]
    call = make_call(keywords, lineno=160)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_many_type_and_nullable_false_large_scale():
    # Many type_ and nullable=False keywords, should trigger both violations
    keywords = (
        [make_keyword("type_", ast.Constant(value=f"Type{i}")) for i in range(250)] +
        [make_keyword("nullable", ast.Constant(value=False)) for i in range(250)]
    )
    call = make_call(keywords, lineno=170)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_large_scale_contract_phase_no_violation():
    # Large number of type_ and nullable=False keywords, CONTRACT phase
    keywords = (
        [make_keyword("type_", ast.Constant(value=f"Type{i}")) for i in range(500)] +
        [make_keyword("nullable", ast.Constant(value=False)) for i in range(500)]
    )
    call = make_call(keywords, lineno=180)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import ast

# imports
import pytest
from langflow.alembic.migration_validator import MigrationValidator


# Define MigrationPhase and Violation for test environment
class MigrationPhase:
    EXPAND = "EXPAND"
    CONTRACT = "CONTRACT"
    CLEANUP = "CLEANUP"

# Helper function to create ast.Call objects for testing
def make_call(keywords, lineno=1):
    # keywords: list of (arg, value) tuples
    ast_keywords = []
    for arg, value in keywords:
        if isinstance(value, bool) or isinstance(value, int) or isinstance(value, str):
            ast_value = ast.Constant(value=value)
        else:
            ast_value = value  # For more complex AST nodes if needed
        ast_keywords.append(ast.keyword(arg=arg, value=ast_value))
    return ast.Call(func=ast.Name(id="alter_column", ctx=ast.Load()), args=[], keywords=ast_keywords, lineno=lineno)

# unit tests

# --- Basic Test Cases ---

def test_no_violation_with_no_type_or_nullable_change():
    # No type or nullable change, no violation expected
    call = make_call([("name", "foo")], lineno=10)
    validator = MigrationValidator()
    for phase in [MigrationPhase.EXPAND, MigrationPhase.CONTRACT, MigrationPhase.CLEANUP]:
        codeflash_output = validator._check_alter_column(call, phase); violations = codeflash_output

def test_type_change_in_expand_phase():
    # Type change in EXPAND phase should trigger DIRECT_TYPE_CHANGE violation
    call = make_call([("type_", "Integer")], lineno=5)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_type_change_in_contract_phase():
    # Type change in CONTRACT phase should not trigger violation
    call = make_call([("type_", "String")], lineno=6)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_nullable_false_in_expand_phase():
    # Making nullable=False in EXPAND should trigger BREAKING_ADD_COLUMN violation
    call = make_call([("nullable", False)], lineno=7)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_nullable_false_in_contract_phase():
    # Making nullable=False in CONTRACT phase should not trigger violation
    call = make_call([("nullable", False)], lineno=8)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output


def test_nullable_true_in_expand_phase():
    # Setting nullable=True should not trigger violation
    call = make_call([("nullable", True)], lineno=11)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_type_change_with_type_keyword():
    # Use "type" instead of "type_"
    call = make_call([("type", "Boolean")], lineno=12)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_nullable_false_with_non_constant_value():
    # nullable set to non-ast.Constant value (should not be flagged)
    call = ast.Call(
        func=ast.Name(id="alter_column", ctx=ast.Load()),
        args=[],
        keywords=[ast.keyword(arg="nullable", value=ast.Name(id="False", ctx=ast.Load()))],
        lineno=13
    )
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_multiple_keywords_with_irrelevant_args():
    # Irrelevant keywords should not trigger violations
    call = make_call([("server_default", "foo"), ("comment", "bar")], lineno=14)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output

def test_empty_keywords():
    # No keywords at all
    call = make_call([], lineno=15)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output


def test_type_and_nullable_false_in_contract_phase():
    # Both type change and nullable=False in CONTRACT phase should not trigger any violation
    call = make_call([("type_", "Float"), ("nullable", False)], lineno=17)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_nullable_false_and_type_change_on_different_lines():
    # Test that line numbers are correct for violations
    call = make_call([("type_", "Float"), ("nullable", False)], lineno=123)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output
    for v in violations:
        pass

# --- Large Scale Test Cases ---

def test_many_calls_with_varied_keywords():
    # Test scalability with many calls, each with varied keywords
    validator = MigrationValidator()
    violations_total = 0
    for i in range(100):  # 100 calls, varied
        if i % 3 == 0:
            call = make_call([("type_", "String")], lineno=i)
            phase = MigrationPhase.EXPAND
        elif i % 3 == 1:
            call = make_call([("nullable", False)], lineno=i)
            phase = MigrationPhase.EXPAND
        else:
            call = make_call([("type_", "String"), ("nullable", False)], lineno=i)
            phase = MigrationPhase.CONTRACT
        codeflash_output = validator._check_alter_column(call, phase); violations = codeflash_output
        # Only first two cases should yield violations
        if i % 3 == 0:
            violations_total += len(violations)
        elif i % 3 == 1:
            violations_total += len(violations)
        else:
            pass

def test_large_keywords_list_with_irrelevant_args():
    # Large keyword list, none relevant
    keywords = [(f"irrelevant_{i}", i) for i in range(500)]
    call = make_call(keywords, lineno=200)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.EXPAND); violations = codeflash_output


def test_large_scale_contract_phase_no_violation():
    # Large keyword list, with type_ and nullable=False, but CONTRACT phase
    keywords = [(f"irrelevant_{i}", i) for i in range(498)]
    keywords += [("type_", "Decimal"), ("nullable", False)]
    call = make_call(keywords, lineno=400)
    validator = MigrationValidator()
    codeflash_output = validator._check_alter_column(call, MigrationPhase.CONTRACT); violations = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr10702-2025-11-27T22.15.22

Click to see suggested changes
Suggested change
# Check for type changes
if self._has_type_change(call) and phase != MigrationPhase.CONTRACT:
violations.append(
Violation("DIRECT_TYPE_CHANGE", "Type changes should use expand-contract pattern", call.lineno)
)
# Check for nullable changes
if self._changes_nullable_to_false(call) and phase != MigrationPhase.CONTRACT:
# Inline the _has_type_change and _changes_nullable_to_false checks for performance
# Extract keywords once for reuse
keywords = call.keywords
# Fast path: Precompute for phase != CONTRACT (avoid attribute fetch unless necessary)
is_not_contract = phase != MigrationPhase.CONTRACT
has_type_change = False
nullable_to_false = False
# Avoid repeated loop over call.keywords by doing both checks at once
for keyword in keywords:
arg = keyword.arg
if not has_type_change and arg in ("type_", "type"):
has_type_change = True
# Early out if both found
if nullable_to_false:
break
if not nullable_to_false and arg == "nullable" and isinstance(keyword.value, ast.Constant):
if keyword.value.value is False:
nullable_to_false = True
# Early out if both found
if has_type_change:
break
if has_type_change and is_not_contract:
violations.append(
Violation("DIRECT_TYPE_CHANGE", "Type changes should use expand-contract pattern", call.lineno)
)
if nullable_to_false and is_not_contract:

Comment on lines +184 to +196
violations = []

if phase != MigrationPhase.CONTRACT:
violations.append(
Violation(
"IMMEDIATE_DROP",
f"Column drops only allowed in CONTRACT phase (current: {phase.value})",
call.lineno,
)
)

return violations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 37% (0.37x) speedup for MigrationValidator._check_drop_column in src/backend/base/langflow/alembic/migration_validator.py

⏱️ Runtime : 4.17 milliseconds 3.04 milliseconds (best of 103 runs)

📝 Explanation and details

The optimized code achieves a 37% speedup through two key micro-optimizations that reduce redundant operations:

Key Optimizations:

  1. Cached attribute access: phase_value = phase.value eliminates repeated enum attribute lookups. The line profiler shows this single access takes 44% of total time but is only done once, versus the original code accessing phase.value multiple times during violation creation.

  2. Early return pattern: Instead of always creating an empty violations = [] list and conditionally appending to it, the optimized version only creates the list when violations actually occur. This avoids unnecessary list creation and mutation in the success case.

Performance Impact Analysis:

The line profiler reveals the optimization is most effective when violations are generated (non-CONTRACT phases). In the optimized version:

  • Fewer function calls are executed when violations occur (3,517 vs 6,029 hits)
  • The string formatting with phase_value is slightly more efficient than phase.value
  • Memory allocation is reduced by avoiding the empty list creation in success cases

Test Case Performance:

  • CONTRACT phase tests (no violations): Benefit from avoiding empty list creation
  • Non-CONTRACT phase tests: Benefit from both cached attribute access and direct list return
  • Large-scale tests (1000+ calls): The cumulative effect of these micro-optimizations becomes significant

This optimization is particularly valuable in validation workflows where the function may be called frequently during migration analysis, as even small per-call improvements compound significantly at scale.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 6057 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import ast
from enum import Enum

# imports
import pytest
from langflow.alembic.migration_validator import MigrationValidator


# Supporting classes and enums for the test environment
class MigrationPhase(Enum):
    EXPAND = "EXPAND"
    CONTRACT = "CONTRACT"
    MIGRATION = "MIGRATION"
    UNKNOWN = "UNKNOWN"

class Violation:
    def __init__(self, code, message, lineno):
        self.code = code
        self.message = message
        self.lineno = lineno

    def __eq__(self, other):
        return (
            isinstance(other, Violation)
            and self.code == other.code
            and self.message == other.message
            and self.lineno == other.lineno
        )

    def __repr__(self):
        return f"Violation({self.code!r}, {self.message!r}, {self.lineno!r})"
from langflow.alembic.migration_validator import MigrationValidator

# unit tests

# Helper to create a dummy ast.Call node with a given line number
def make_ast_call(lineno=1):
    # ast.Call(func, args, keywords, lineno=...)
    # func can be ast.Name(id='drop_column')
    return ast.Call(
        func=ast.Name(id='drop_column', ctx=ast.Load()),
        args=[],
        keywords=[],
        lineno=lineno
    )

# ------------------------------
# Basic Test Cases
# ------------------------------

def test_drop_column_contract_phase_no_violation():
    """Should NOT raise violation when dropping column in CONTRACT phase."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=10)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_drop_column_expand_phase_violation():
    """Should raise violation when dropping column in EXPAND phase."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=20)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.EXPAND); violations = codeflash_output
    v = violations[0]

def test_drop_column_migration_phase_violation():
    """Should raise violation when dropping column in MIGRATION phase."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=30)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.MIGRATION); violations = codeflash_output
    v = violations[0]

def test_drop_column_unknown_phase_violation():
    """Should raise violation when dropping column in UNKNOWN phase."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=40)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.UNKNOWN); violations = codeflash_output
    v = violations[0]

# ------------------------------
# Edge Test Cases
# ------------------------------

def test_drop_column_contract_phase_edge_linenos():
    """Test edge line numbers in CONTRACT phase (should not raise violation)."""
    validator = MigrationValidator()
    # Test minimum line number
    codeflash_output = validator._check_drop_column(make_ast_call(lineno=1), MigrationPhase.CONTRACT); violations = codeflash_output

    # Test high line number
    codeflash_output = validator._check_drop_column(make_ast_call(lineno=999), MigrationPhase.CONTRACT); violations = codeflash_output

def test_drop_column_non_contract_edge_linenos():
    """Test edge line numbers in non-CONTRACT phases (should raise violation)."""
    validator = MigrationValidator()
    # Minimum line number
    codeflash_output = validator._check_drop_column(make_ast_call(lineno=1), MigrationPhase.EXPAND); violations = codeflash_output

    # High line number
    codeflash_output = validator._check_drop_column(make_ast_call(lineno=999), MigrationPhase.MIGRATION); violations = codeflash_output

def test_drop_column_phase_enum_variations():
    """Test with unexpected phase values (simulate misuse of MigrationPhase)."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=55)
    # Simulate an object that has a 'value' attribute but is not a MigrationPhase
    class FakePhase:
        value = "FAKE"
    codeflash_output = validator._check_drop_column(call, FakePhase()); violations = codeflash_output

def test_drop_column_phase_none():
    """Test with None as phase (should raise AttributeError)."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=60)
    with pytest.raises(AttributeError):
        validator._check_drop_column(call, None)

def test_drop_column_call_missing_lineno():
    """Test with ast.Call missing lineno attribute (should raise AttributeError)."""
    validator = MigrationValidator()
    call = ast.Call(
        func=ast.Name(id='drop_column', ctx=ast.Load()),
        args=[],
        keywords=[]
        # No lineno provided
    )
    # Only fails if phase != CONTRACT
    with pytest.raises(AttributeError):
        validator._check_drop_column(call, MigrationPhase.EXPAND)

def test_drop_column_call_non_int_lineno():
    """Test with ast.Call lineno as non-integer (should still work if attribute present)."""
    validator = MigrationValidator()
    call = make_ast_call(lineno="not_an_int")
    codeflash_output = validator._check_drop_column(call, MigrationPhase.EXPAND); violations = codeflash_output

# ------------------------------
# Large Scale Test Cases
# ------------------------------

def test_drop_column_many_calls_contract_phase():
    """Test scalability: many calls in CONTRACT phase (should not raise violations)."""
    validator = MigrationValidator()
    for lineno in range(1, 1001):  # 1000 calls
        call = make_ast_call(lineno=lineno)
        codeflash_output = validator._check_drop_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_drop_column_many_calls_non_contract_phase():
    """Test scalability: many calls in non-CONTRACT phase (should raise violations)."""
    validator = MigrationValidator()
    for lineno in range(1, 1001):  # 1000 calls
        call = make_ast_call(lineno=lineno)
        codeflash_output = validator._check_drop_column(call, MigrationPhase.EXPAND); violations = codeflash_output
        v = violations[0]

def test_drop_column_mixed_phases_large_scale():
    """Test scalability: calls in mixed phases."""
    validator = MigrationValidator()
    for lineno in range(1, 1001):
        call = make_ast_call(lineno=lineno)
        phase = MigrationPhase.CONTRACT if lineno % 2 == 0 else MigrationPhase.EXPAND
        codeflash_output = validator._check_drop_column(call, phase); violations = codeflash_output
        if phase == MigrationPhase.CONTRACT:
            pass
        else:
            pass

# ------------------------------
# Mutation Testing Guards
# ------------------------------

def test_mutation_guard_contract_vs_non_contract():
    """Mutation guard: ensure only CONTRACT phase is allowed, all others raise violation."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=123)
    # CONTRACT phase: no violation
    codeflash_output = validator._check_drop_column(call, MigrationPhase.CONTRACT)
    # All other phases: violation
    for phase in [MigrationPhase.EXPAND, MigrationPhase.MIGRATION, MigrationPhase.UNKNOWN]:
        codeflash_output = validator._check_drop_column(call, phase); violations = codeflash_output

def test_mutation_guard_violation_message_contains_phase():
    """Mutation guard: violation message should include the actual phase value."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=321)
    for phase in [MigrationPhase.EXPAND, MigrationPhase.MIGRATION, MigrationPhase.UNKNOWN]:
        codeflash_output = validator._check_drop_column(call, phase); violations = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import ast
from enum import Enum

# imports
import pytest  # used for our unit tests
from langflow.alembic.migration_validator import MigrationValidator

# --- Supporting Classes, Enums, and Types (minimal viable implementation for tests) ---

class MigrationPhase(Enum):
    EXPAND = "EXPAND"
    CONTRACT = "CONTRACT"
    MERGE = "MERGE"

class Violation:
    def __init__(self, code, message, lineno):
        self.code = code
        self.message = message
        self.lineno = lineno

    def __eq__(self, other):
        return (
            isinstance(other, Violation)
            and self.code == other.code
            and self.message == other.message
            and self.lineno == other.lineno
        )

    def __repr__(self):
        return f"Violation({self.code!r}, {self.message!r}, {self.lineno!r})"
from langflow.alembic.migration_validator import MigrationValidator

# --- Unit Tests ---

# Helper to create a dummy ast.Call node
def make_ast_call(lineno=1):
    # ast.Call signature: ast.Call(func, args, keywords, lineno, col_offset, ...)
    # We'll use minimal viable fields for our tests
    node = ast.Call(
        func=ast.Name(id="drop_column", ctx=ast.Load()),
        args=[],
        keywords=[],
    )
    node.lineno = lineno
    node.col_offset = 0
    return node

# ------------------ BASIC TEST CASES ------------------

def test_drop_column_contract_phase_no_violation():
    """Should NOT raise violation when dropping column in CONTRACT phase."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=10)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_drop_column_expand_phase_violation():
    """Should raise IMMEDIATE_DROP violation when dropping column in EXPAND phase."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=20)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.EXPAND); violations = codeflash_output
    v = violations[0]

def test_drop_column_merge_phase_violation():
    """Should raise IMMEDIATE_DROP violation when dropping column in MERGE phase."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=30)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.MERGE); violations = codeflash_output
    v = violations[0]

# ------------------ EDGE TEST CASES ------------------

def test_drop_column_with_unusual_lineno():
    """Should correctly report violation with edge-case lineno values."""
    validator = MigrationValidator()
    # Edge case: lineno = 0 (typically invalid, but should be handled)
    call = make_ast_call(lineno=0)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.EXPAND); violations = codeflash_output

    # Edge case: very high lineno
    call = make_ast_call(lineno=999)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.MERGE); violations = codeflash_output

def test_drop_column_with_missing_lineno_attribute():
    """Should raise AttributeError if ast.Call node lacks lineno."""
    validator = MigrationValidator()
    call = ast.Call(
        func=ast.Name(id="drop_column", ctx=ast.Load()),
        args=[],
        keywords=[],
    )
    # Deliberately do NOT set 'lineno'
    with pytest.raises(AttributeError):
        validator._check_drop_column(call, MigrationPhase.EXPAND)

def test_drop_column_with_non_ast_call_object():
    """Should raise AttributeError if input is not an ast.Call with lineno."""
    validator = MigrationValidator()
    class DummyCall:
        pass
    call = DummyCall()
    with pytest.raises(AttributeError):
        validator._check_drop_column(call, MigrationPhase.EXPAND)

def test_drop_column_with_invalid_phase_type():
    """Should raise AttributeError if phase lacks .value."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=42)
    # Use a string instead of MigrationPhase
    with pytest.raises(AttributeError):
        validator._check_drop_column(call, "CONTRACT")

def test_drop_column_contract_phase_case_insensitive():
    """Should not allow string phases even if value matches."""
    validator = MigrationValidator()
    call = make_ast_call(lineno=55)
    # Use a string 'CONTRACT' instead of MigrationPhase.CONTRACT
    with pytest.raises(AttributeError):
        validator._check_drop_column(call, "CONTRACT")

def test_drop_column_contract_phase_with_strict_mode_false():
    """Should behave identically regardless of strict_mode setting."""
    validator = MigrationValidator(strict_mode=False)
    call = make_ast_call(lineno=60)
    codeflash_output = validator._check_drop_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

# ------------------ LARGE SCALE TEST CASES ------------------

def test_many_drop_column_calls_contract_phase():
    """Should handle large number of drop_column calls in CONTRACT phase efficiently."""
    validator = MigrationValidator()
    calls = [make_ast_call(lineno=i) for i in range(1, 501)]  # 500 calls
    # All should return empty violations
    for call in calls:
        codeflash_output = validator._check_drop_column(call, MigrationPhase.CONTRACT); violations = codeflash_output

def test_many_drop_column_calls_expand_phase():
    """Should handle large number of drop_column calls in EXPAND phase efficiently."""
    validator = MigrationValidator()
    calls = [make_ast_call(lineno=i) for i in range(1, 501)]  # 500 calls
    for i, call in enumerate(calls, 1):
        codeflash_output = validator._check_drop_column(call, MigrationPhase.EXPAND); violations = codeflash_output
        v = violations[0]

def test_many_drop_column_calls_mixed_phases():
    """Should correctly handle a mix of phases in bulk."""
    validator = MigrationValidator()
    calls = [make_ast_call(lineno=i) for i in range(1, 1001)]  # 1000 calls
    # Alternate phases
    for i, call in enumerate(calls, 1):
        phase = MigrationPhase.CONTRACT if i % 2 == 0 else MigrationPhase.EXPAND
        codeflash_output = validator._check_drop_column(call, phase); violations = codeflash_output
        if phase == MigrationPhase.CONTRACT:
            pass
        else:
            pass

def test_drop_column_performance_large_scale():
    """Performance: Should not be slow for 1000 calls."""
    import time
    validator = MigrationValidator()
    calls = [make_ast_call(lineno=i) for i in range(1, 1001)]
    start = time.time()
    for call in calls:
        validator._check_drop_column(call, MigrationPhase.EXPAND)
    elapsed = time.time() - start
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr10702-2025-11-27T22.23.16

Suggested change
violations = []
if phase != MigrationPhase.CONTRACT:
violations.append(
Violation(
"IMMEDIATE_DROP",
f"Column drops only allowed in CONTRACT phase (current: {phase.value})",
call.lineno,
)
)
return violations
# Avoid repeated attribute access for phase.value
phase_value = phase.value
if phase_value != "CONTRACT":
# Only generate violations if necessary
return [
Violation(
"IMMEDIATE_DROP",
f"Column drops only allowed in CONTRACT phase (current: {phase_value})",
call.lineno,
)
]
return []

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community Pull Request from an external contributor enhancement New feature or request lgtm This PR has been approved by a maintainer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants