Dynamic Tasks in Airflow

Overview

With the release of Airflow 2.3, users can write DAGs that dynamically generate parallel tasks at runtime. This feature, known as dynamic task mapping, is a paradigm shift for DAG design in Airflow.

Prior to Airflow 2.3, tasks could only be generated dynamically at the time that the DAG was parsed, meaning you had to change your DAG code if you needed to adjust tasks based on some external factor. With dynamic task mapping, you can easily write DAGs that create tasks based on your current runtime environment.

In this guide, we’ll explain the concept of dynamic task mapping and provide example implementations for common use cases.

Dynamic Task Concepts

Airflow’s dynamic task mapping feature is built off of the MapReduce programming model. The map procedure takes a set of inputs and creates a single task for each one. The reduce procedure, which is optional, allows a task to operate on the collected output of a mapped task. In practice, this means that your DAG can create an arbitrary number of parallel tasks at runtime based on some input parameter(s) (the “map”), and then if needed, have a single task downstream of your parallel mapped tasks that depends on their output (the “reduce”).

Airflow tasks have two new functions available to implement the “map” portion of dynamic task mapping. For the task you want to map, all operator parameters must be passed through one of these two functions.

  • expand(): This function passes the parameter or parameters that you want to map on. A separate parallel task will be created for each input.
  • partial(): This function passes any parameters that remain constant across all mapped tasks which are generated by expand().

For example, the following task uses both of these functions to dynamically generate 3 task runs:

@task
    def add(x: int, y: int):
        return x + y

    added_values = add.partial(y=10).expand(x=[1, 2, 3])

This expand function creates three mapped add tasks, one for each entry in the x input list. The partial function specifies a value for y that remains constant in each task.

There are a couple of things to keep in mind when working with mapped tasks:

  • You can use the results of an upstream task as the input to a mapped task (in fact, this is one of the most powerful possibilities of this feature). The upstream task must return a value in a dict or list form. If you’re using traditional operators (ie. not decorated tasks), the mapping values must be stored in XCom.
  • You can map over multiple parameters. This will result in a cross product with one task for each combination of parameters.
  • You can use the results of a mapped task as input to a downstream mapped task.
  • You can have a mapped task that results in no task instances (e.g. if your upstream task that generates the mapping values returns an empty list). In this case, the mapped task will be marked skipped, and downstream tasks will be run according to the trigger rules you set (by default, downstream tasks will also be skipped).
  • Some parameters are not mappable. For example, task_id, pool, and many BaseOperator arguments are not mappable.

For more high level examples of how to apply dynamic task mapping functions in different cases, check out the Apache Airflow documentation.

The Airflow UI gives us observability for mapped tasks in both the Graph View and the Grid View.

In the Graph View, any mapped tasks will be indicated by a set of brackets [ ] following the task ID. The number in the brackets will update for each DAG run to reflect how many mapped instances were created.

Mapped Graph

Clicking on the mapped task, we have a new Mapped Instances drop down where we can select a specific mapped task run to perform actions on.

Mapped Actions

Selecting one of the mapped instances provides links to other views like you would see for any other Airflow task: Instance Details, Rendered, Log, XCom, etc.

Mapped Views

Similarly, the Grid View shows task details and history for each mapped task. All mapped tasks will be combined into one row on the grid (shown as load_files_to_snowflake [ ] in the following example). Clicking into that task will provide details on each individual mapped instance.

Mapped Grid

Example Implementations

In this section we’ll show how dynamic task mapping can be implemented for two classic use cases: ELT and ML Ops. The first implementation will use traditional Airflow operators, and the second will use decorated functions and the TaskFlow API.

ELT

For our first example, we’ll implement one of the most common use cases for dynamic tasks: processing files in S3. In this scenario, we will use an ELT framework to extract data from files in S3, load the data into Snowflake, and transform the data using Snowflake’s built-in compute. We assume that files will be dropped daily, but we don’t know how many will arrive each day. We’ll leverage dynamic task mapping to create a unique task for each file at runtime. This gives us the benefit of atomicity, better observability, and easier recovery from failures.

Note: Code for this example can be found in this repo.

The DAG below completes the following steps:

  1. Use a decorated Python operator to get the current list of files from S3. The S3 prefix passed to this function is parameterized with ds_nodash so it pulls files only for the execution date of the DAG run (e.g. for a DAG run on April 12th, we would assume the files landed in a folder named 20220412/).
  2. Using the results of the first task, map an S3ToSnowflakeOperator for each file.
  3. Move the daily folder of processed files into a processed/ folder while,
  4. Simultaneously (with Step 3), run a Snowflake query that transforms the data. The query is located in a separate SQL file in our include/ directory.
  5. Delete the folder of daily files now that it has been moved to processed/ for record keeping.
from airflow import DAG
from airflow.decorators import task
from airflow.providers.snowflake.transfers.s3_to_snowflake import S3ToSnowflakeOperator
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.s3_copy_object import S3CopyObjectOperator
from airflow.providers.amazon.aws.operators.s3_delete_objects import S3DeleteObjectsOperator

from datetime import datetime

@task
def get_s3_files(current_prefix):
    s3_hook = S3Hook(aws_conn_id='s3')
    current_files = s3_hook.list_keys(bucket_name='my-bucket', prefix=current_prefix + "/", start_after_key=current_prefix + "/")
    return [[file] for file in current_files]


with DAG(dag_id='mapping_elt', 
        start_date=datetime(2022, 4, 2),
        catchup=False,
        template_searchpath='/usr/local/airflow/include',
        schedule_interval='@daily') as dag:

    copy_to_snowflake = S3ToSnowflakeOperator.partial(
        task_id='load_files_to_snowflake', 
        stage='MY_STAGE',
        table='COMBINED_HOMES',
        schema='MYSCHEMA',
        file_format="(type = 'CSV',field_delimiter = ',', skip_header=1)",
        snowflake_conn_id='snowflake').expand(s3_keys=get_s3_files(current_prefix="{{ ds_nodash }}"))

    move_s3 = S3CopyObjectOperator(
        task_id='move_files_to_processed',
        aws_conn_id='s3',
        source_bucket_name='my-bucket',
        source_bucket_key="{{ ds_nodash }}"+"/",
        dest_bucket_name='my-bucket',
        dest_bucket_key="processed/"+"{{ ds_nodash }}"+"/"
    )

    delete_landing_files = S3DeleteObjectsOperator(
        task_id='delete_landing_files',
        aws_conn_id='s3',
        bucket='my-bucket',
        prefix="{{ ds_nodash }}"+"/"
    )

    transform_in_snowflake = SnowflakeOperator(
        task_id='run_transformation_query',
        sql='/transformation_query.sql',
        snowflake_conn_id='snowflake'
    )

    copy_to_snowflake >> [move_s3, transform_in_snowflake]
    move_s3 >> delete_landing_files

The Graph View of the DAG looks like this:

ELT Graph

When dynamically mapping tasks, make note of the format needed for the parameter you are mapping on. In the example above, we write our own Python function to get the S3 keys because the S3toSnowflakeOperator requires each s3_key parameter to be in a list format, and the s3_hook.list_keys function returns a single list with all keys. By writing our own simple function, we can turn the hook results into a list of lists that can be used by the downstream operator.

ML Ops

Dynamic tasks can also be very useful for productionizing machine learning pipelines. ML Ops often includes some sort of dynamic component. The following use cases are common:

  • Training different models: You have a reusable pipeline that you use to create a separate DAG for each data source. For each DAG, you point your generic pipeline to your data source, and to a list of models you want to experiment with in parallel. That list of models might change periodically, but by leveraging dynamic task mapping, you can always have a single task per model without any user intervention when the models change. You also maintain all of the history for any models you have trained in the past, even if they are no longer included in your list.
  • Hyperparameter training a model: You have a single model that you want to hyperparameter tune before publishing results from the best set of parameters. With dynamic task mapping, you can grab your parameters from any external system at runtime, giving you full flexibility and history.
  • Creating a different model for each customer: You have a model that you need to train separately for each individual customer. Your customer list changes frequently, and you need to retain the history of any previous customer models. This can be a tricky use case to implement with dynamic DAGs, because the history of any removed DAGs is not retained in the Airflow UI, and performance issues can arise if the customer list is long. With dynamic tasks, you can maintain a single DAG that updates as needed based on the current list of customers at runtime.

In the example DAG below, we implement the first of these use cases. We also highlight how dynamic task mapping is simple to implement with decorated tasks.

from airflow.decorators import task, dag, task_group
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook

from datetime import datetime

import logging
import mlflow

import pandas as pd

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.linear_model import LogisticRegression
import lightgbm as lgb

import include.metrics as metrics
from include.grid_configs import models, params


mlflow.set_tracking_uri('http://host.docker.internal:5000')
try:
   # Creating an experiment 
   mlflow.create_experiment('census_prediction')
except:
   pass
# Setting the environment with the created experiment
mlflow.set_experiment('census_prediction')

mlflow.sklearn.autolog()
mlflow.lightgbm.autolog()

@dag(
   start_date=datetime(2022, 1, 1),
   schedule_interval=None,
   catchup=False
)
def mlflow_multimodel_example():

   @task
   def load_data():
       """Pull Census data from Public BigQuery and save as Pandas dataframe in GCS bucket with XCom"""

       bq = BigQueryHook()
       sql = """
       SELECT * FROM `bigquery-public-data.ml_datasets.census_adult_income`
       """

       return bq.get_pandas_df(sql=sql, dialect='standard')


   @task
   def preprocessing(df: pd.DataFrame):
       """Clean Data and prepare for feature engineering
       
       Returns pandas dataframe via Xcom to GCS bucket.

       Keyword arguments:
       df -- Raw data pulled from BigQuery to be processed. 
       """

       df.dropna(inplace=True)
       df.drop_duplicates(inplace=True)

       # Clean Categorical Variables (strings)
       cols = df.columns
       for col in cols:
           if df.dtypes[col]=='object':
               df[col] =df[col].apply(lambda x: x.rstrip().lstrip())


       # Rename up '?' values as 'Unknown'
       df['workclass'] = df['workclass'].apply(lambda x: 'Unknown' if x == '?' else x)
       df['occupation'] = df['occupation'].apply(lambda x: 'Unknown' if x == '?' else x)
       df['native_country'] = df['native_country'].apply(lambda x: 'Unknown' if x == '?' else x)


       # Drop Extra/Unused Columns
       df.drop(columns=['education_num', 'relationship', 'functional_weight'], inplace=True)

       return df


   @task
   def feature_engineering(df: pd.DataFrame):
       """Feature engineering step
       
       Returns pandas dataframe via XCom to GCS bucket.

       Keyword arguments:
       df -- data from previous step pulled from BigQuery to be processed. 
       """
       # Onehot encoding 
       df = pd.get_dummies(df, prefix='workclass', columns=['workclass'])
       df = pd.get_dummies(df, prefix='education', columns=['education'])
       df = pd.get_dummies(df, prefix='occupation', columns=['occupation'])
       df = pd.get_dummies(df, prefix='race', columns=['race'])
       df = pd.get_dummies(df, prefix='sex', columns=['sex'])
       df = pd.get_dummies(df, prefix='income_bracket', columns=['income_bracket'])
       df = pd.get_dummies(df, prefix='native_country', columns=['native_country'])

       # Bin Ages
       df['age_bins'] = pd.cut(x=df['age'], bins=[16,29,39,49,59,100], labels=[1, 2, 3, 4, 5])

       # Dependent Variable
       df['never_married'] = df['marital_status'].apply(lambda x: 1 if x == 'Never-married' else 0) 

       # Drop redundant column
       df.drop(columns=['income_bracket_<=50K', 'marital_status', 'age'], inplace=True)

       return df

   @task
   def get_models():
       """
       Returns list of models to train from by reading a file in the include/ directory.
       We assume this file has two parameters for each model entry: model, and params
       """
       return [models]

   @task()
   def train(df: pd.DataFrame, model_type=models, model=models[model], grid_params=models[params], **kwargs):
       """Train and validate model using a grid search for the optimal parameter values and a five fold cross validation.

       Returns accuracy score via XCom to GCS bucket.

       Keyword arguments:
       df -- data from previous step pulled from BigQuery to be processed. 
       """
       y = df['never_married']
       X = df.drop(columns=['never_married'])

       X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=55, stratify=y)

       grid_search = GridSearchCV(model, param_grid=grid_params, verbose=1, cv=5, n_jobs=-1)

       with mlflow.start_run(run_name=f'{model_type}_{kwargs["run_id"]}'):

           logging.info('Performing Gridsearch')
           grid_search.fit(X_train, y_train)

           logging.info(f'Best Parameters\n{grid_search.best_params_}')
           best_params = grid_search.best_params_

           if model_type == 'lgbm':

               train_set = lgb.Dataset(X_train, label=y_train)
               test_set = lgb.Dataset(X_test, label=y_test)

               best_params['metric'] = ['auc', 'binary_logloss']

               logging.info('Training model with best parameters')
               clf = lgb.train(
                   train_set=train_set,
                   valid_sets=[train_set, test_set],
                   valid_names=['train', 'validation'],
                   params=best_params,
                   early_stopping_rounds=5
               )

           else:
               logging.info('Training model with best parameters')
               clf = LogisticRegression(penalty=best_params['penalty'], C=best_params['C'], solver=best_params['solver']).fit(X_train, y_train)

           y_pred_class = metrics.test(clf, X_test)

           # Log Classification Report, Confusion Matrix, and ROC Curve
           metrics.log_all_eval_metrics(y_test, y_pred_class)

   df = load_data()
   clean_data = preprocessing(df)
   features = feature_engineering(clean_data)
   train_modes = train.partial(features).expand(get_models())
   
dag = mlflow_multimodel_example()

Note that in this example, our model information that we map on is pulled from a grid_configs file in our include/ directory, which looks like this:

from numpy.random.mtrand import seed
from sklearn.linear_model import LogisticRegression
import lightgbm as lgb


models = {
    'lgbm': {
        'model': lgb.LGBMClassifier(objective='binary', metric=['auc', 'binary_logloss'], seed=55, boosting_type='gbdt'),
        'params': {
            'learning_rate': [0.01, .05, .1], 
            'n_estimators': [50, 100, 150],
            'num_leaves': [31, 40, 80],
            'max_depth': [16, 24, 31, 40]
        }
    },
    'log_reg': {
        'model': LogisticRegression(max_iter=500), boosting_type='gbdt'),
        'params': {
            'penalty': ['l1','l2','elasticnet'],
            'C': [0.001, 0.01, 0.1, 1, 10, 100],
            'solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga']
        }
    }  
}

The model information could come from any external system as well, with an update to the get_models() task. The only requirement is that the resulting map input is a dict or list.

Also note that for decorated tasks like in the DAG above, the mapping parameters are automatically passed by calling the proper functions, leveraging the TaskFlow API to avoid explicit calling of XCom.

Start building your next-generation data platform with Astro.

Get Started