Manage your ML models with Weights and Biases and Airflow

Info

This page has not yet been updated for Airflow 3. The concepts shown are relevant, but some code may need to be updated. If you run any examples, take care to update import statements and watch for any other breaking changes.

Weights and Biases (W&B) is a machine learning platform for model management that includes features like experiment tracking, dataset versioning, and model performance evaluation and visualization. Using W&B with Airflow gives you a powerful ML orchestration stack with first-class features for building, training, and managing your models.

In this tutorial, you’ll learn how to create an Airflow DAG that completes feature engineering, model training, and predictions with the Astro Python SDK and scikit-learn, and registers the model with W&B for evaluation and visualization.

Info

This tutorial was developed in partnership with Weights and Biases. For resources on implementing other use cases with W&B, see Tutorials.

Time to complete

This tutorial takes approximately one hour to complete.

Assumed knowledge

To get the most out of this tutorial, you should be familiar with:

Prerequisites

Quickstart

If you have a Github account, you can get started quickly by cloning the demo repository. For more detailed instructions for setting up the project, start with Step 1.

  1. Clone the demo repository:

    1git clone https://github.com/astronomer/airflow-wandb-demo
    2cd airflow-wandb-demo
  2. Update the .env file with your WANDB_API_KEY.

  3. Start Airflow by running:

    1astro dev start
  4. Continue with Step 7 below.

Step 1: Configure your Astro project

Use the Astro CLI to create and run an Airflow project locally.

  1. Create a new Astro project:

    1$ mkdir astro-wandb-tutorial && cd astro-wandb-tutorial
    2$ astro dev init
  2. Add the following line to the requirements.txt file of your Astro project:

    astro-sdk-python[postgres]==1.5.3
    wandb==0.14.0
    pandas==1.5.3
    numpy==1.24.2
    scikit-learn==1.2.2

    This installs the packages needed to transform the data and run feature engineering, model training, and predictions.

Step 2: Prepare the data

This tutorial will create a model that classifies churn risk based on customer data.

  1. Create a subfolder called data in your Astro project include folder.
  2. Download the demo CSV files from this GitHub directory.
  3. Save the downloaded CSV files in the include/data folder. You should have 5 files in total.

Step 3: Create your SQL transformation scripts

Before feature engineering and training, the data needs to be transformed. This tutorial uses the Astro Python SDK transform_file function to complete several transformations using SQL.

  1. Create a file in your include folder called customer_churn_month.sql and copy the following code into the file.

    1with subscription_periods as (
    2 select subscription_id,
    3 customer_id,
    4 cast(start_date as date) as start_date,
    5 cast(end_date as date) as end_date,
    6 monthly_amount
    7 from {{subscription_periods}}
    8),
    9
    10months as (
    11 select cast(date_month as date) as date_month from {{util_months}}
    12),
    13
    14customers as (
    15 select
    16 customer_id,
    17 date_trunc('month', min(start_date)) as date_month_start,
    18 date_trunc('month', max(end_date)) as date_month_end
    19 from subscription_periods
    20 group by 1
    21),
    22
    23customer_months as (
    24 select
    25 customers.customer_id,
    26 months.date_month
    27 from customers
    28 inner join months
    29 on months.date_month >= customers.date_month_start
    30 and months.date_month < customers.date_month_end
    31),
    32
    33joined as (
    34 select
    35 customer_months.date_month,
    36 customer_months.customer_id,
    37 coalesce(subscription_periods.monthly_amount, 0) as mrr
    38 from customer_months
    39 left join subscription_periods
    40 on customer_months.customer_id = subscription_periods.customer_id
    41 and customer_months.date_month >= subscription_periods.start_date
    42 and (customer_months.date_month < subscription_periods.end_date
    43 or subscription_periods.end_date is null)
    44),
    45
    46customer_revenue_by_month as (
    47 select
    48 date_month,
    49 customer_id,
    50 mrr,
    51 mrr > 0 as is_active,
    52 min(case when mrr > 0 then date_month end) over (
    53 partition by customer_id
    54 ) as first_active_month,
    55
    56 max(case when mrr > 0 then date_month end) over (
    57 partition by customer_id
    58 ) as last_active_month,
    59
    60 case
    61 when min(case when mrr > 0 then date_month end) over (
    62 partition by customer_id
    63 ) = date_month then true
    64 else false end as is_first_month,
    65 case
    66 when max(case when mrr > 0 then date_month end) over (
    67 partition by customer_id
    68 ) = date_month then true
    69 else false end as is_last_month
    70 from joined
    71),
    72
    73joined1 as (
    74
    75 select
    76 date_month + interval '1 month' as date_month,
    77 customer_id,
    78 0::float as mrr,
    79 false as is_active,
    80 first_active_month,
    81 last_active_month,
    82 false as is_first_month,
    83 false as is_last_month
    84
    85 from customer_revenue_by_month
    86
    87 where is_last_month
    88
    89)
    90
    91select * from joined1;
  2. Create another file in your include folder called customers.sql and copy the following code into the file.

    1with
    2customers as (
    3
    4 select *
    5 from {{customers_table}}
    6
    7),
    8
    9orders as (
    10
    11 select *
    12 from {{orders_table}}
    13
    14),
    15
    16payments as (
    17
    18 select *
    19 from {{payments_table}}
    20
    21),
    22
    23customer_orders as (
    24
    25 select
    26 customer_id,
    27 cast(min(order_date) as date) as first_order,
    28 cast(max(order_date) as date) as most_recent_order,
    29 count(order_id) as number_of_orders
    30 from orders
    31
    32 group by customer_id
    33
    34),
    35
    36customer_payments as (
    37
    38 select
    39 orders.customer_id,
    40 sum(amount / 100) as total_amount
    41
    42 from payments
    43
    44 left join orders on payments.order_id = orders.order_id
    45
    46 group by orders.customer_id
    47
    48),
    49
    50final as (
    51
    52 select
    53 customers.customer_id,
    54 customers.first_name,
    55 customers.last_name,
    56 customer_orders.first_order,
    57 customer_orders.most_recent_order,
    58 customer_orders.number_of_orders,
    59 customer_payments.total_amount as customer_lifetime_value
    60
    61 from customers
    62
    63 left join customer_orders on customers.customer_id = customer_orders.customer_id
    64
    65 left join customer_payments on customers.customer_id = customer_payments.customer_id
    66
    67)
    68
    69select
    70*
    71from final

Step 4: Create a W&B API Key

In your W&B account, create an API key that you will use to connect Airflow to W&B. You can create a key by going to the Authorize page or your user settings.

Step 5: Set up your connections and environment variables

You’ll use environment variables to create Airflow connections to Snowflake and W&B, as well as to configure the Astro Python SDK.

  1. Open the .env file in your Astro project and paste the following code.

    WANDB_API_KEY='<your-wandb-api-key>'
    AIRFLOW_CONN_POSTGRES_DEFAULT='postgresql://postgres:postgres@host.docker.internal:5432/postgres?options=-csearch_path%3Dtmp_astro'
  2. Replace <your-wandb-api-key> with the API key you created in Step 4. No changes are needed for the AIRFLOW_CONN_POSTGRES_DEFAULT environment variable.

Step 6: Create your DAG

  1. Create a file in your Astro project dags folder called customer_analytics.py and copy the following code into the file:

    1from datetime import datetime
    2import os
    3
    4from astro import sql as aql
    5from astro.files import File
    6from astro.sql.table import Table
    7from airflow.decorators import dag, task_group
    8
    9import pandas as pd
    10import numpy as np
    11from sklearn.model_selection import train_test_split
    12from sklearn.ensemble import RandomForestClassifier
    13import tempfile
    14import pickle
    15from pathlib import Path
    16
    17import wandb
    18from wandb.sklearn import plot_precision_recall, plot_feature_importances
    19from wandb.sklearn import plot_class_proportions, plot_learning_curve, plot_roc
    20
    21_POSTGRES_CONN = "postgres_default"
    22wandb_project = "demo"
    23wandb_team = "astro-demos"
    24local_data_dir = "include/data"
    25sources = ["subscription_periods", "util_months", "customers", "orders", "payments"]
    26
    27
    28@dag(schedule=None, start_date=datetime(2023, 1, 1), catchup=False)
    29def customer_analytics():
    30 @task_group()
    31 def extract_and_load(sources: list) -> dict:
    32 for source in sources:
    33 aql.load_file(
    34 task_id=f"load_{source}",
    35 input_file=File(f"{local_data_dir}/{source}.csv"),
    36 output_table=Table(
    37 name=f"STG_{source.upper()}", conn_id=_POSTGRES_CONN
    38 ),
    39 if_exists="replace",
    40 )
    41
    42 @task_group()
    43 def transform():
    44 aql.transform_file(
    45 task_id="transform_churn",
    46 file_path=f"{Path(__file__).parent.as_posix()}/../include/customer_churn_month.sql",
    47 parameters={
    48 "subscription_periods": Table(
    49 name="STG_SUBSCRIPTION_PERIODS", conn_id=_POSTGRES_CONN
    50 ),
    51 "util_months": Table(name="STG_UTIL_MONTHS", conn_id=_POSTGRES_CONN),
    52 },
    53 op_kwargs={
    54 "output_table": Table(
    55 name="CUSTOMER_CHURN_MONTH", conn_id=_POSTGRES_CONN
    56 )
    57 },
    58 )
    59
    60 aql.transform_file(
    61 task_id="transform_customers",
    62 file_path=f"{Path(__file__).parent.as_posix()}/../include/customers.sql",
    63 parameters={
    64 "customers_table": Table(name="STG_CUSTOMERS", conn_id=_POSTGRES_CONN),
    65 "orders_table": Table(name="STG_ORDERS", conn_id=_POSTGRES_CONN),
    66 "payments_table": Table(name="STG_PAYMENTS", conn_id=_POSTGRES_CONN),
    67 },
    68 op_kwargs={"output_table": Table(name="CUSTOMERS", conn_id=_POSTGRES_CONN)},
    69 )
    70
    71 @aql.dataframe()
    72 def features(customer_df: pd.DataFrame, churned_df: pd.DataFrame) -> pd.DataFrame:
    73 customer_df["customer_id"] = customer_df["customer_id"].apply(str)
    74 customer_df.set_index("customer_id", inplace=True)
    75
    76 churned_df["customer_id"] = churned_df["customer_id"].apply(str)
    77 churned_df.set_index("customer_id", inplace=True)
    78 churned_df["is_active"] = churned_df["is_active"].astype(int).replace(0, 1)
    79
    80 df = (
    81 customer_df[["number_of_orders", "customer_lifetime_value"]]
    82 .join(churned_df[["is_active"]], how="left")
    83 .fillna(0)
    84 .reset_index()
    85 ) # inplace=True)
    86
    87 return df
    88
    89 @aql.dataframe()
    90 def train(df: pd.DataFrame) -> dict:
    91 features = ["number_of_orders", "customer_lifetime_value"]
    92 target = ["is_active"]
    93
    94 test_size = 0.3
    95 X_train, X_test, y_train, y_test = train_test_split(
    96 df[features], df[target], test_size=test_size, random_state=1883
    97 )
    98 X_train = np.array(X_train.values.tolist())
    99 y_train = np.array(y_train.values.tolist()).reshape(
    100 len(y_train),
    101 )
    102 y_train = y_train.reshape(
    103 len(y_train),
    104 )
    105 X_test = np.array(X_test.values.tolist())
    106 y_test = np.array(y_test.values.tolist())
    107 y_test = y_test.reshape(
    108 len(y_test),
    109 )
    110
    111 model = RandomForestClassifier()
    112 _ = model.fit(X_train, y_train)
    113 model_params = model.get_params()
    114
    115 y_pred = model.predict(X_test)
    116 y_probas = model.predict_proba(X_test)
    117 importances = model.feature_importances_
    118 indices = np.argsort(importances)[::-1]
    119
    120 wandb.login()
    121 run = wandb.init(
    122 project=wandb_project,
    123 config=model_params,
    124 entity=wandb_team,
    125 group="wandb-demo",
    126 name="jaffle_churn",
    127 dir="include",
    128 mode="online",
    129 )
    130
    131 wandb.config.update(
    132 {"test_size": test_size, "train_len": len(X_train), "test_len": len(X_test)}
    133 )
    134 plot_class_proportions(y_train, y_test, ["not_churned", "churned"])
    135 plot_learning_curve(model, X_train, y_train)
    136 plot_roc(y_test, y_probas, ["not_churned", "churned"])
    137 plot_precision_recall(y_test, y_probas, ["not_churned", "churned"])
    138 plot_feature_importances(model)
    139
    140 model_artifact_name = "churn_classifier"
    141
    142 with tempfile.NamedTemporaryFile(delete=False) as tf:
    143 pickle.dump(model, tf)
    144 tf.close()
    145 artifact = wandb.Artifact(model_artifact_name, type="model")
    146 artifact.add_file(local_path=tf.name, name=model_artifact_name)
    147 wandb.log_artifact(artifact)
    148 os.remove(tf.name)
    149
    150 wandb.finish()
    151
    152 return {"run_id": run.id, "artifact_name": model_artifact_name}
    153
    154 @aql.dataframe()
    155 def predict(model_info: dict, customer_df: pd.DataFrame) -> pd.DataFrame:
    156 wandb.login()
    157 run = wandb.init(
    158 project=wandb_project,
    159 entity=wandb_team,
    160 group="wandb-demo",
    161 name="jaffle_churn",
    162 dir="include",
    163 resume="must",
    164 id=model_info["run_id"],
    165 )
    166
    167 customer_df.fillna(0, inplace=True)
    168
    169 features = ["number_of_orders", "customer_lifetime_value"]
    170
    171 artifact = run.use_artifact(
    172 f"{model_info['artifact_name']}:latest", type="model"
    173 )
    174
    175 with tempfile.TemporaryDirectory() as td:
    176 with open(artifact.file(td), "rb") as mf:
    177 model = pickle.load(mf)
    178 customer_df["PRED"] = model.predict_proba(
    179 np.array(customer_df[features].values.tolist())
    180 )[:, 0]
    181
    182 wandb.finish()
    183
    184 customer_df.reset_index(inplace=True)
    185
    186 return customer_df
    187
    188 _extract_and_load = extract_and_load(sources)
    189
    190 _transformed = transform()
    191
    192 _features = features(
    193 customer_df=Table(name="customers", conn_id=_POSTGRES_CONN),
    194 churned_df=Table(name="customer_churn_month", conn_id=_POSTGRES_CONN),
    195 )
    196
    197 _model_info = train(df=_features)
    198
    199 _predict_churn = predict(
    200 model_info=_model_info,
    201 customer_df=Table(name="customers", conn_id=_POSTGRES_CONN),
    202 output_table=Table(name=f"pred_churn", conn_id=_POSTGRES_CONN),
    203 )
    204
    205 _extract_and_load >> _transformed >> _features
    206
    207
    208customer_analytics()

    This DAG completes the following steps:

    • The extract_and_load task group contains one task for each CSV in your include/data folder that uses the Astro Python SDK load_file function to load the data to Postgres.
    • The transform task group contains two tasks that transform the data using the Astro Python SDK transform_file function and the SQL scripts in your include folder.
    • The features task is a Python function implemented with the Astro Python SDK @dataframe decorator that uses Pandas to create the features needed for the model.
    • The train task is a Python function implemented with the Astro Python SDK @dataframe decorator that uses scikit-learn to train a Random Forest classifier model and push the results to W&B.
    • The predict task pulls the model from W&B in order to make predictions and stores them in postgres.
  2. Run the following command to start your project in a local environment:

    1astro dev start

Step 7: Run your DAG and view results

  1. Open the (Airflow UI)[http://localhost:8080], unpause the customer_analytics DAG, and trigger the DAG.

  2. The logs in the train and predict tasks will contain a link to your W&B project which shows plotted results from the training and prediction.

    wandb task logs

    Go to one of the links to view the results in W&B.

    wandb results