Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Add support to get Variables in task SDK to author tasks #45458

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

amoghrajesh
Copy link
Contributor

@amoghrajesh amoghrajesh commented Jan 7, 2025

closes: #45449

Intent

With AIP 72 coming in and for extending the task sdk to be able to write "complete" dags, we need to be able to interact with Airflow Variables: https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/variables.html

Historically, this was done like this:

from airflow.models import Variable

# Normal call style
foo = Variable.get("foo")

# Auto-deserializes a JSON value
bar = Variable.get("bar", deserialize_json=True)

# Returns the value of default_var (None) if the variable is not set
baz = Variable.get("baz", default_var=None)

Either at the task level or at the DAG level. Note that "airflow.models" is used - which is what we are removing for Airflow 3 so that user code doesn't directly interact with DB models, preventing any potential hazard to the Airflow metadata DB. Instead, some user facing interfaces will be exposed to interact with Airflow entities so that we can provide a better DAG writing experience as well as be secure and reduce any risks.

The aim here is to be able to write dags with from airflow.sdk import Variable

Key changes in the PR

  • In definitions/variable.py user facing interface, a "get" method has been introduced to fetch variables.
  • This "get" method piggybacks on _get_variable(key) which was introduced in AIP-72: Allow retrieving Variable from Task Context #45431
  • In _get_variable, we perform a hypothesis check, we try to import "airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS", if import is possible, that means we are in execution context of a task, but if it fails, we are in execution context of a dag, so we attempt to import "airflow.dag_processing.processor import COMMS_DECODER" instead.
  • There is effectively a supervisor in DAG processing too, since we extend on the same machinery for dag processing as introduced in Swap Dag Parsing to use the TaskSDK machinery. #44972. We rely on this machinery and play around with the _handle_requests to pass around the VariableResult to the dag processing process.
  • In the dag processor's handle_requests, we interact with the DB directly without using execution API and the client, this is the core package and can freely interact with DB. The advantage this brings in is that we won't need an exeuction API server when testing DAG level stuff. Like getting variables at DAG level.

Testing

Variable.get at dag level

DAG:

from __future__ import annotations

from airflow.models.baseoperator import BaseOperator
from airflow import DAG
from airflow.sdk import Variable

value = Variable.get(key="my_var")

class CustomOperator(BaseOperator):
    def execute(self, context):
        print(f"Variable defined at top level of dag has value: {value}")


with DAG("example_get_variable_using_task_sdk", schedule=None, catchup=False) as dag:
    CustomOperator(task_id="print_top_level_variable")

Variable:
image

When variable is present:
image

When variable isn't present (scheduler doesn't crash)

C[2025-01-08T14:46:40.831+0000] {scheduler_job_runner.py:244} INFO - Exiting gracefully upon receiving signal 2
[2025-01-08 14:46:40 +0000] [449] [INFO] Handling signal: int
2025-01-08 14:46:40 [error    ] Variable: my_var does not exist [supervisor]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/airflow/airflow/dag_processing/processor.py:252 in _handle_request                          │
│                                                                                                  │
│   249 │   │   │   try:                                                                           │
│   250 │   │   │   │   value = Variable.get(key)                                                  │
│   251 │   │   │   except KeyError:                                                               │
│ ❱ 252 │   │   │   │   log.exception("Variable: %s does not exist", key)                          │
│   253 │   │   │   │   raise                                                                      │
│   254 │   │   │   var_result = VariableResult.from_variable_response(VariableResponse(key=key,   │
│   255 │   │   │   resp = var_result.model_dump_json(exclude_unset=True).encode()                 │
│                                                                                                  │
│ ╭─────────────────────────────────────────── locals ───────────────────────────────────────────╮ │
│ │  key = 'my_var'                                                                              │ │
│ │  log = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None,               │ │
│ │        context_class=None, initial_values={'logger_name': 'supervisor'},                     │ │
│ │        logger_factory_args=())>                                                              │ │
│ │  msg = GetVariable(key='my_var', type='GetVariable')                                         │ │
│ │ self = <DagFileProcessorProcess id=UUID('01944661-c7a1-7733-a5e5-54e427b6db66') pid=7181>    │ │
│ ╰──────────────────────────────────────────────────────────────────────────────────────────────╯ │
│                                                                                                  │
│ /opt/airflow/airflow/models/variable.py:144 in get                                               │
│                                                                                                  │
│   141 │   │   │   if default_var is not cls.__NO_DEFAULT_SENTINEL:                               │
│   142 │   │   │   │   return default_var                                                         │
│   143 │   │   │   else:                                                                          │
│ ❱ 144 │   │   │   │   raise KeyError(f"Variable {key} does not exist")                           │
│   145 │   │   else:                                                                              │
│   146 │   │   │   if deserialize_json:                                                           │
│   147 │   │   │   │   obj = json.loads(var_val)                                                  │
│                                                                                                  │
│ ╭─────────────────────── locals ───────────────────────╮                                         │
│ │      default_var = <object object at 0xffff7fe99af0> │                                         │
│ │ deserialize_json = False                             │                                         │
│ │              key = 'my_var'                          │                                         │
│ │          var_val = None                              │                                         │
│ ╰──────────────────────────────────────────────────────╯                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'Variable my_var does not exist'

Process ForkProcess-12:
Traceback (most recent call last):
  File "/usr/local/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/airflow/airflow/dag_processing/manager.py", line 192, in _run_processor_manager
    processor_manager.run()
  File "/opt/airflow/airflow/dag_processing/manager.py", line 404, in run
    return self._run_parsing_loop()
  File "/opt/airflow/airflow/dag_processing/manager.py", line 492, in _run_parsing_loop
    self._service_processor_sockets(timeout=poll_time)
  File "/opt/airflow/airflow/dag_processing/manager.py", line 550, in _service_processor_sockets
    need_more = socket_handler(key.fileobj)
  File "/opt/airflow/task_sdk/src/airflow/sdk/execution_time/supervisor.py", line 789, in cb
    gen.send(line)
  File "/opt/airflow/airflow/dag_processing/processor.py", line 243, in handle_requests
    def _handle_request(self, msg: ToParent, log: FilteringBoundLogger) -> None:  # type: ignore[override]
  File "/opt/airflow/airflow/dag_processing/processor.py", line 252, in _handle_request
    log.exception("Variable: %s does not exist", key)
  File "/opt/airflow/airflow/models/variable.py", line 144, in get
    raise KeyError(f"Variable {key} does not exist")
KeyError: 'Variable my_var does not exist'
[2025-01-08 14:46:40 +0000] [450] [INFO] Worker exiting (pid: 450)
[2025-01-08 14:46:40 +0000] [451] [INFO] Worker exiting (pid: 451)
[2025-01-08 14:46:41 +0000] [449] [INFO] Shutting down: Master    

Variable.get at task level

DAG:

from __future__ import annotations

from airflow.models.baseoperator import BaseOperator
from airflow import DAG
from airflow.sdk import Variable



class CustomOperator(BaseOperator):
    def execute(self, context):
        value = Variable.get(key="my_var")
        print(f"Variable defined at task level has value: {value}")


with DAG("example_get_variable_using_task_sdk", schedule=None, catchup=False) as dag:
    CustomOperator(task_id="print_top_level_variable")

Variable present:
image

Variable not present:
image


^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named {pr_number}.significant.rst or {issue_number}.significant.rst, in newsfragments.

@amoghrajesh
Copy link
Contributor Author

This allows me to test something like:

from __future__ import annotations

from airflow.models.baseoperator import BaseOperator
from airflow.sdk import dag
from airflow.sdk import Variable


class CustomOperator(BaseOperator):
    def execute(self, context):
        value = Variable.get(key="my_var")
        print(f"The variable value is: {value}")


@dag()
def var_from_defn():
    CustomOperator(task_id="hello")


var_from_defn()

Advantage is that now this can be used at task level as well at DAG parsing level.

The PR is pre mature, will add edge cases etc once we are OK with the general direction.

@amoghrajesh
Copy link
Contributor Author

Some ideas:

  1. _get_variable and _get_connection are really nice utilities and it will be nice if we can use those in definitions as well as in context, but they shouldn't be housed there at all. They should probably move to something slightly upper level - that can be used both by SDK as well as the execution_time. Something common to both.
  2. Or we just duplicate that code in the Variable class - but simpler, but that could soon become spaghetti code imo.

Copy link
Member

@ashb ashb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to support accessing Variable at the top level of dag files at parse time too (not sure if the impl needs to change, likely not, but we should test it)

@amoghrajesh
Copy link
Contributor Author

Tested it out, and yes that won't be possible because of us depending on execution time. So implementation will have to change:

Broken DAG: [/files/dags/get-variable-from-sdk.py]
Traceback (most recent call last):
  File "/opt/airflow/task_sdk/src/airflow/sdk/definitions/variable.py", line 48, in get
    return _get_variable(key).value
  File "/opt/airflow/task_sdk/src/airflow/sdk/execution_time/context.py", line 76, in _get_variable
    from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
ImportError: cannot import name 'SUPERVISOR_COMMS' from 'airflow.sdk.execution_time.task_runner' (/opt/airflow/task_sdk/src/airflow/sdk/execution_time/task_runner.py)

@amoghrajesh
Copy link
Contributor Author

amoghrajesh commented Jan 8, 2025

New update:

Why don't we just use the SDK client instead? We don't really have a need to rely on supervisor here as variables can be retrieved at the top level too. We should also be able to add some level of control at the API level to return / reject API requests as forbidden.

When we integrate the token mechanism, we can generate one long running token for such arbitrary requests.

Testing:

  1. Variable.get at dag level
from __future__ import annotations

from airflow.models.baseoperator import BaseOperator
from airflow.sdk import dag
from airflow.sdk import Variable

value = Variable.get(key="my_var")

class CustomOperator(BaseOperator):
    def execute(self, context):
        print(f"The variable from top level dag is: {value}")


@dag()
def var_from_defn():
    CustomOperator(task_id="hello")


var_from_defn()

image

  1. Variable.get inside task
from __future__ import annotations

from airflow.models.baseoperator import BaseOperator
from airflow.sdk import dag
from airflow.sdk import Variable


class CustomOperator(BaseOperator):
    def execute(self, context):
        value = Variable.get(key="my_var")
        print(f"The variable from top level dag is: {value}")


@dag()
def var_from_defn():
    CustomOperator(task_id="hello")


var_from_defn()

image

@amoghrajesh amoghrajesh force-pushed the AIP72-get-variable-in-sdk branch from 21aa3e9 to f87beb8 Compare January 8, 2025 14:31
@amoghrajesh amoghrajesh requested a review from ashb January 8, 2025 14:32
@amoghrajesh amoghrajesh marked this pull request as ready for review January 8, 2025 14:59
Comment on lines 80 to 82
except ImportError:
# If not, hypothesis is false and this request is from dag level.
from airflow.dag_processing.processor import COMMS_DECODER as COMMS # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if this one fails too, raise it?

Comment on lines -237 to -238
# GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable
super()._handle_request(msg, log)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wont really need this, because for cases of variables, connecitons, we will have to interact with the DB model directly. If we go to super(). _handle_request, it brings the SDK API client into picture, which shouldn't be needed for DAG level stuff

Comment on lines -237 to -238
# GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable
super()._handle_request(msg, log)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wont really need this, because for cases of variables, connecitons, we will have to interact with the DB model directly. If we go to super(). _handle_request, it brings the SDK API client into picture, which shouldn't be needed for DAG level stuff

@amoghrajesh
Copy link
Contributor Author

Interesting that I cannot reproduce the failures locally.

Comment on lines 77 to 82
try:
# We check the hypothesis if the request for variable came from task.
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS as COMMS # type: ignore
except ImportError:
# If not, hypothesis is false and this request is from dag level.
from airflow.dag_processing.processor import COMMS_DECODER as COMMS # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too happy with this one. Wondering if we can do anything better here..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The challenge here is to know if the context is dag or task level. I don't seem to find a clear distinction to point out at and use

Comment on lines -69 to -86
def _get_variable(key: str, deserialize_json: bool) -> Variable:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.variable`
# A reason to not move it to `airflow.sdk.execution_time.comms` is that it
# will make that module depend on Task SDK, which is not ideal because we intend to
# keep Task SDK as a separate package than execution time mods.
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

if TYPE_CHECKING:
assert isinstance(msg, VariableResult)
return _convert_variable_result_to_variable(msg, deserialize_json)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the right time to move these helpers to airflow.sdk.execution_time.variable. We might be running into a circular import otherwise

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AIP-72: Add support to get Variables in task sdk outside of context
2 participants