| 1 | """Triggered by Astro DAG trigger notifications to investigate a failed DAG run and post the diagnosis to Slack. |
| 2 | |
| 3 | Required settings. Each can be set as an environment variable OR an Airflow |
| 4 | Variable of the same name (set BEFORE enabling the alert): |
| 5 | - ASTRO_ORGANIZATION_ID (your Astro organization ID) |
| 6 | - ASTRO_DEPLOYMENT_ID (deployment to run diagnoses against) |
| 7 | - ASTRO_API_TOKEN (Astro API token with labs/v1 access) |
| 8 | - SLACK_WEBHOOK_URL (Slack incoming webhook URL) |
| 9 | - SLACK_CHANNEL (optional; overrides the webhook's default channel) |
| 10 | |
| 11 | Set these in the Astro UI under Deployment → Environment (Environment Variables |
| 12 | or Airflow Variables). Missing any required value will raise |
| 13 | `ValueError: Missing required setting ...` at task runtime. |
| 14 | |
| 15 | Expected dag_run.conf keys (sent by the Astro Alert trigger): |
| 16 | - alertId (Astro alert ID; surfaced in the Slack message) |
| 17 | - alertType (DAG_FAILURE — other alert types are ignored) |
| 18 | - dagName (target DAG being triggered — this DAG; not used, the failed source DAG comes from `message`) |
| 19 | - message (free-form alert text; the failed source DAG ID is parsed from `"... for DAG <name>"`) |
| 20 | - airflowDagRunId (the failed source DAG's run ID, e.g. `scheduled__2026-06-03T23:40:00+00:00`) |
| 21 | - logFailureSummaries (dict keyed by failed task ID; if exactly one key, it's used as the failed task ID) |
| 22 | |
| 23 | Example payload (DAG_FAILURE alert from `always_failing_dag`): |
| 24 | { |
| 25 | "alertId": "cmpyms67v000m01pxkfaqoe5t", |
| 26 | "alertType": "DAG_FAILURE", |
| 27 | "dagName": "alert_investigation_agent", |
| 28 | "message": "DAG run failed for DAG always_failing_dag", |
| 29 | "airflowDagRunId": "scheduled__2026-06-03T23:40:00+00:00", |
| 30 | "logFailureSummaries": {"fail_on_purpose": "..."} |
| 31 | } |
| 32 | """ |
| 33 | |
| 34 | from __future__ import annotations |
| 35 | |
| 36 | import json |
| 37 | import logging |
| 38 | import os |
| 39 | import re |
| 40 | from typing import Any |
| 41 | |
| 42 | import requests |
| 43 | from airflow.sdk import Variable, dag, task |
| 44 | from pendulum import datetime |
| 45 | |
| 46 | ASTRO_API_BASE = "https://api.astronomer.io/labs/v1" |
| 47 | # This demo investigates DAG-failure alerts. Other alert types are logged |
| 48 | # and skipped. |
| 49 | ALLOWED_ALERT_TYPES = {"DAG_FAILURE"} |
| 50 | |
| 51 | log = logging.getLogger(__name__) |
| 52 | |
| 53 | |
| 54 | def _get_setting(key: str) -> str: |
| 55 | # Resolve a setting from an environment variable or an Airflow Variable of |
| 56 | # the same name. Env var wins if both are set. |
| 57 | env_value = os.environ.get(key, "") |
| 58 | if env_value: |
| 59 | return env_value |
| 60 | |
| 61 | try: |
| 62 | return Variable.get(key) |
| 63 | except Exception: |
| 64 | return "" |
| 65 | |
| 66 | |
| 67 | def _get_required_setting(key: str) -> str: |
| 68 | value = _get_setting(key) |
| 69 | if value: |
| 70 | return value |
| 71 | raise ValueError(f"Missing required setting: '{key}' (set it as an env var or Airflow Variable)") |
| 72 | |
| 73 | |
| 74 | def _extract_source_dag_id(message: str) -> str: |
| 75 | # Astro Alert payloads don't include the failed source DAG ID as a |
| 76 | # top-level key (`dagName` in conf is the *target* DAG being triggered, |
| 77 | # i.e. this one). Parse it out of the DAG_FAILURE alert `message`: |
| 78 | # "DAG run failed for DAG <dag_id>" |
| 79 | match = re.search(r"for DAG\s+['\"]?([a-zA-Z0-9_.-]+)['\"]?", message, re.IGNORECASE) |
| 80 | if match: |
| 81 | return match.group(1) |
| 82 | |
| 83 | raise ValueError( |
| 84 | "Could not determine the failed source DAG ID from the alert message. " |
| 85 | f"Expected a 'for DAG <name>' phrase; got message={message!r}" |
| 86 | ) |
| 87 | |
| 88 | |
| 89 | def _extract_run_id(conf: dict[str, Any]) -> str: |
| 90 | value = conf.get("airflowDagRunId") |
| 91 | if value: |
| 92 | return str(value) |
| 93 | |
| 94 | raise ValueError("Missing required dag_run.conf key: airflowDagRunId") |
| 95 | |
| 96 | |
| 97 | def _extract_task_id(conf: dict[str, Any]) -> str | None: |
| 98 | # DAG_FAILURE payloads include a `logFailureSummaries` dict keyed by the |
| 99 | # failed task ID(s). If exactly one task failed, use it. |
| 100 | summaries = conf.get("logFailureSummaries") |
| 101 | if isinstance(summaries, dict) and len(summaries) == 1: |
| 102 | return next(iter(summaries)) |
| 103 | |
| 104 | return None |
| 105 | |
| 106 | |
| 107 | def _call_investigation_agent( |
| 108 | organization_id: str, |
| 109 | deployment_id: str, |
| 110 | api_token: str, |
| 111 | dag_id: str, |
| 112 | run_id: str, |
| 113 | task_id: str | None, |
| 114 | ) -> dict[str, Any]: |
| 115 | auth_headers = {"Authorization": f"Bearer {api_token}"} |
| 116 | deployment_base = ( |
| 117 | f"{ASTRO_API_BASE}/organizations/{organization_id}" |
| 118 | f"/observability/deployments/{deployment_id}/dag-failure-diagnosis/runs" |
| 119 | ) |
| 120 | |
| 121 | start_response = requests.post( |
| 122 | deployment_base, |
| 123 | headers={**auth_headers, "Content-Type": "application/json"}, |
| 124 | json={ |
| 125 | "dagId": dag_id, |
| 126 | "runId": run_id, |
| 127 | **({"taskId": task_id} if task_id else {}), |
| 128 | }, |
| 129 | timeout=30, |
| 130 | ) |
| 131 | start_response.raise_for_status() |
| 132 | diagnosis_run_id = start_response.json()["runId"] |
| 133 | |
| 134 | response = requests.get( |
| 135 | f"{deployment_base}/{diagnosis_run_id}/events", |
| 136 | headers={**auth_headers, "Accept": "text/event-stream"}, |
| 137 | stream=True, |
| 138 | timeout=(30, 300), |
| 139 | ) |
| 140 | response.raise_for_status() |
| 141 | |
| 142 | event_type: str | None = None |
| 143 | data_lines: list[str] = [] |
| 144 | text_chunks: list[str] = [] |
| 145 | diagnosis: dict[str, Any] | None = None |
| 146 | |
| 147 | def flush_event() -> None: |
| 148 | nonlocal event_type, data_lines, diagnosis |
| 149 | if not event_type: |
| 150 | data_lines = [] |
| 151 | return |
| 152 | |
| 153 | payload = "\n".join(data_lines) |
| 154 | if event_type == "rca_diagnosis" and payload: |
| 155 | diagnosis = json.loads(payload) |
| 156 | elif event_type == "text_delta" and payload: |
| 157 | text_chunks.append(json.loads(payload).get("text", "")) |
| 158 | elif event_type == "error" and payload: |
| 159 | raise RuntimeError(json.loads(payload).get("message", "Investigation Agent returned an error")) |
| 160 | |
| 161 | event_type = None |
| 162 | data_lines = [] |
| 163 | |
| 164 | for raw_line in response.iter_lines(decode_unicode=True): |
| 165 | if raw_line is None: |
| 166 | continue |
| 167 | |
| 168 | line = raw_line.rstrip("\r") |
| 169 | if not line: |
| 170 | flush_event() |
| 171 | if diagnosis: |
| 172 | break |
| 173 | continue |
| 174 | |
| 175 | if line.startswith(":"): |
| 176 | continue |
| 177 | if line.startswith("event:"): |
| 178 | event_type = line.split(":", 1)[1].strip() |
| 179 | continue |
| 180 | if line.startswith("data:"): |
| 181 | data_lines.append(line.split(":", 1)[1].lstrip()) |
| 182 | |
| 183 | flush_event() |
| 184 | |
| 185 | if diagnosis: |
| 186 | return diagnosis |
| 187 | if text_chunks: |
| 188 | return {"title": f"Investigation for {dag_id}", "summary": "".join(text_chunks).strip()} |
| 189 | |
| 190 | raise RuntimeError("Investigation Agent returned no diagnosis payload") |
| 191 | |
| 192 | |
| 193 | def _format_slack_message( |
| 194 | diagnosis: dict[str, Any], |
| 195 | alert_id: str, |
| 196 | alert_type: str, |
| 197 | dag_id: str, |
| 198 | run_id: str, |
| 199 | task_id: str | None, |
| 200 | ) -> dict[str, Any]: |
| 201 | title = diagnosis.get("title") or f"Investigation for {dag_id}" |
| 202 | summary = diagnosis.get("summary") or "No summary returned." |
| 203 | suggested_fix = diagnosis.get("suggested_fix") or "No suggested fix returned." |
| 204 | severity = diagnosis.get("severity") or "UNKNOWN" |
| 205 | priority = diagnosis.get("priority") or "UNKNOWN" |
| 206 | root_cause_type = diagnosis.get("root_cause_type") or "UNKNOWN" |
| 207 | confidence = diagnosis.get("confidence") |
| 208 | |
| 209 | fields = [ |
| 210 | {"type": "mrkdwn", "text": f"*Alert ID*\n`{alert_id}`"}, |
| 211 | {"type": "mrkdwn", "text": f"*Alert Type*\n`{alert_type}`"}, |
| 212 | {"type": "mrkdwn", "text": f"*DAG*\n`{dag_id}`"}, |
| 213 | {"type": "mrkdwn", "text": f"*Run ID*\n`{run_id}`"}, |
| 214 | {"type": "mrkdwn", "text": f"*Severity*\n`{severity}`"}, |
| 215 | {"type": "mrkdwn", "text": f"*Root Cause Type*\n`{root_cause_type}`"}, |
| 216 | ] |
| 217 | if task_id: |
| 218 | fields.append({"type": "mrkdwn", "text": f"*Task ID*\n`{task_id}`"}) |
| 219 | if confidence is not None: |
| 220 | fields.append({"type": "mrkdwn", "text": f"*Confidence*\n`{confidence}`"}) |
| 221 | if priority != "UNKNOWN": |
| 222 | fields.append({"type": "mrkdwn", "text": f"*Priority*\n`{priority}`"}) |
| 223 | |
| 224 | blocks: list[dict[str, Any]] = [ |
| 225 | {"type": "header", "text": {"type": "plain_text", "text": title[:150]}}, |
| 226 | {"type": "section", "fields": fields[:10]}, |
| 227 | {"type": "section", "text": {"type": "mrkdwn", "text": f"*Summary*\n{summary[:2900]}"}}, |
| 228 | {"type": "section", "text": {"type": "mrkdwn", "text": f"*Suggested Fix*\n{suggested_fix[:2900]}"}}, |
| 229 | ] |
| 230 | |
| 231 | text = ( |
| 232 | f"[{severity}] {title}\n" |
| 233 | f"Root cause type: {root_cause_type}\n" |
| 234 | f"Summary: {summary}\n\n" |
| 235 | f"Suggested fix: {suggested_fix}" |
| 236 | ) |
| 237 | |
| 238 | return {"text": text[:4000], "blocks": blocks} |
| 239 | |
| 240 | |
| 241 | def _post_to_slack(webhook_url: str, channel: str | None, payload: dict[str, Any]) -> None: |
| 242 | body = dict(payload) |
| 243 | if channel: |
| 244 | body["channel"] = channel |
| 245 | |
| 246 | response = requests.post(webhook_url, json=body, timeout=30) |
| 247 | response.raise_for_status() |
| 248 | |
| 249 | |
| 250 | @dag( |
| 251 | dag_id="alert_investigation_agent", |
| 252 | start_date=datetime(2025, 1, 1), |
| 253 | schedule=None, |
| 254 | catchup=False, |
| 255 | tags=["alerts", "investigation-agent", "slack"], |
| 256 | default_args={"owner": "Astro", "retries": 0}, |
| 257 | doc_md=__doc__, |
| 258 | ) |
| 259 | def alert_investigation_agent(): |
| 260 | @task |
| 261 | def handle_alert(**context) -> None: |
| 262 | dag_run = context.get("dag_run") |
| 263 | conf = dict(dag_run.conf or {}) if dag_run else {} |
| 264 | |
| 265 | alert_id = str(conf.get("alertId", "unknown")) |
| 266 | alert_type = str(conf.get("alertType", "unknown")) |
| 267 | message = str(conf.get("message", "")) |
| 268 | |
| 269 | if alert_type not in ALLOWED_ALERT_TYPES: |
| 270 | log.info( |
| 271 | "Ignoring alert: alertType %r is not a DAG failure. " |
| 272 | "This demo only investigates DAG-failure alerts.", |
| 273 | alert_type, |
| 274 | ) |
| 275 | return |
| 276 | |
| 277 | dag_id = _extract_source_dag_id(message) |
| 278 | run_id = _extract_run_id(conf) |
| 279 | task_id = _extract_task_id(conf) |
| 280 | |
| 281 | organization_id = _get_required_setting("ASTRO_ORGANIZATION_ID") |
| 282 | deployment_id = _get_required_setting("ASTRO_DEPLOYMENT_ID") |
| 283 | api_token = _get_required_setting("ASTRO_API_TOKEN") |
| 284 | slack_webhook_url = _get_required_setting("SLACK_WEBHOOK_URL") |
| 285 | slack_channel = _get_setting("SLACK_CHANNEL") or None |
| 286 | |
| 287 | diagnosis = _call_investigation_agent( |
| 288 | organization_id=organization_id, |
| 289 | deployment_id=deployment_id, |
| 290 | api_token=api_token, |
| 291 | dag_id=dag_id, |
| 292 | run_id=run_id, |
| 293 | task_id=task_id, |
| 294 | ) |
| 295 | |
| 296 | log.info( |
| 297 | "Diagnosis completed for %s run %s: title=%s severity=%s root_cause_type=%s summary=%s suggested_fix=%s", |
| 298 | dag_id, |
| 299 | run_id, |
| 300 | diagnosis.get("title") or f"Investigation for {dag_id}", |
| 301 | diagnosis.get("severity") or "UNKNOWN", |
| 302 | diagnosis.get("root_cause_type") or "UNKNOWN", |
| 303 | diagnosis.get("summary") or "No summary returned.", |
| 304 | diagnosis.get("suggested_fix") or "No suggested fix returned.", |
| 305 | ) |
| 306 | |
| 307 | slack_payload = _format_slack_message( |
| 308 | diagnosis=diagnosis, |
| 309 | alert_id=alert_id, |
| 310 | alert_type=alert_type, |
| 311 | dag_id=dag_id, |
| 312 | run_id=run_id, |
| 313 | task_id=task_id, |
| 314 | ) |
| 315 | _post_to_slack(slack_webhook_url, slack_channel, slack_payload) |
| 316 | |
| 317 | handle_alert() |
| 318 | |
| 319 | |
| 320 | alert_investigation_agent() |