-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
135 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
# Reformat with black, isort, docformatter, etc. | ||
e40ac28317c61ea90345d3499986957b0e1c9134 | ||
# Reformat with pyupgrade | ||
795d1b03c6d7a9272018413564a8f02eef6fdec6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
..._bench/mlos_bench/storage/sql/alembic/versions/f83fb8ae7fc4_add_trial_runner_id_column.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# | ||
""" | ||
Add trial_runner_id column. | ||
Revision ID: f83fb8ae7fc4 | ||
Revises: d2a708351ba8 | ||
Create Date: 2025-01-03 21:25:48.848196+00:00 | ||
""" | ||
# pylint: disable=no-member | ||
|
||
from collections.abc import Sequence | ||
|
||
import sqlalchemy as sa | ||
from alembic import op | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "f83fb8ae7fc4" | ||
down_revision: str | None = "d2a708351ba8" | ||
branch_labels: str | Sequence[str] | None = None | ||
depends_on: str | Sequence[str] | None = None | ||
|
||
|
||
def upgrade() -> None: | ||
"""The schema upgrade script for this revision.""" | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.add_column("trial", sa.Column("trial_runner_id", sa.Integer(), nullable=True, default=None)) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
"""The schema downgrade script for this revision.""" | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_column("trial", "trial_runner_id") | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
mlos_bench/mlos_bench/tests/storage/test_storage_schemas.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# | ||
"""Test sql schemas for mlos_bench storage.""" | ||
|
||
from alembic.migration import MigrationContext | ||
from sqlalchemy import inspect | ||
|
||
from mlos_bench.storage.base_experiment_data import ExperimentData | ||
from mlos_bench.storage.sql.storage import SqlStorage | ||
|
||
# NOTE: This value is hardcoded to the latest revision in the alembic versions directory. | ||
# It could also be obtained programmatically using the "alembic heads" command or heads() API. | ||
# See Also: schema.py for an example of programmatic alembic config access. | ||
CURRENT_ALEMBIC_HEAD = "f83fb8ae7fc4" | ||
|
||
|
||
def test_storage_schemas(storage: SqlStorage) -> None: | ||
"""Test storage schema creation.""" | ||
eng = storage._engine # pylint: disable=protected-access | ||
with eng.connect() as conn: # pylint: disable=protected-access | ||
inspector = inspect(conn) | ||
# Make sure the "trial_runner_id" column exists. | ||
# (i.e., the latest schema has been applied) | ||
assert any( | ||
column["name"] == "trial_runner_id" for column in inspect(conn).get_columns("trial") | ||
) | ||
# Make sure the "alembic_version" table exists and is appropriately stamped. | ||
assert inspector.has_table("alembic_version") | ||
context = MigrationContext.configure(conn) | ||
current_rev = context.get_current_revision() | ||
assert ( | ||
current_rev == CURRENT_ALEMBIC_HEAD | ||
), f"Expected {CURRENT_ALEMBIC_HEAD}, got {current_rev}" | ||
|
||
|
||
# Note: this is a temporary test. It will be removed and replaced with a more | ||
# properly integrated test in #702. | ||
def test_trial_runner_id_default(storage: SqlStorage, exp_data: ExperimentData) -> None: | ||
"""Test that the new trial_runner_id column defaults to None.""" | ||
assert exp_data.trials | ||
eng = storage._engine # pylint: disable=protected-access | ||
schema = storage._schema # pylint: disable=protected-access | ||
with eng.connect() as conn: | ||
trials = conn.execute( | ||
schema.trial_result.select().with_only_columns( | ||
schema.trial.c.trial_runner_id, | ||
) | ||
) | ||
# trial_runner_id is not currently fully implemented | ||
trial_row = trials.fetchone() | ||
assert trial_row | ||
assert trial_row.trial_runner_id is None |