- Authors
- Name
- 1. The Role of Airflow in ML Pipelines
- 2. Reviewing Airflow Core Concepts: DAG, Operator, Task, XCom
- 3. Running GPU Training Jobs with KubernetesPodOperator
- 4. Analysis of Key KubernetesPodOperator Parameters
- 5. DAG Design Pattern: Data Preprocessing to Training to Evaluation to Deployment
- 6. Passing Metrics/Artifacts Between Tasks with XCom
- 7. Hyperparameter Tuning with Dynamic Task Mapping
- 8. Detecting Data Arrival with Sensors
- 9. MLflow and Airflow Integration Patterns
- 10. Utilizing the TaskFlow API (@task Decorator)
- 11. Production DAG Code Example (Full ML Pipeline)
- 12. References
1. The Role of Airflow in ML Pipelines
In ML projects, model training does not end simply by calling model.fit(). A series of steps -- data collection, preprocessing, Feature Engineering, model training, evaluation, model registry registration, and deployment -- must be executed in a reproducible and automated manner. This is where a Workflow Orchestration tool is needed, and Apache Airflow is one of the most mature open-source projects in this domain.
Airflow vs Other Orchestration Tools
Besides Airflow, other ML pipeline orchestration tools include Kubeflow Pipelines, Prefect, and Dagster. Let us compare their characteristics.
Apache Airflow is a general-purpose workflow orchestrator. While it is not specialized for ML, it is widely used in ML pipelines thanks to its rich Operator ecosystem, powerful scheduling capabilities, and scalability on Kubernetes. With Airflow 3.x, features such as DAG versioning, event-driven scheduling, and multi-language support have been added, making it even more powerful.
Kubeflow Pipelines is a Kubernetes-native ML pipeline tool specialized for ML workflows. It provides integrated GPU scheduling, experiment tracking, and model serving, but requires deep understanding of Kubernetes and is difficult to use outside Kubernetes environments.
Prefect is often described as "Airflow, but nicer," with its Python-native interface and quick setup being key advantages. It offers dynamic flow, hybrid execution, and real-time SLA alerting, making it particularly suitable for small teams or PoC stages where rapid adoption is needed.
Dagster is a modern orchestrator that takes an asset-centric approach. With a philosophy of designing data pipelines centered around "the data being produced" rather than "execution steps," it has strengths in data lineage tracking.
| Characteristic | Airflow | Kubeflow Pipelines | Prefect | Dagster |
|---|---|---|---|---|
| ML Specialization | General-purpose | ML-specialized | General-purpose | Data-centric |
| Learning Curve | Medium | High | Low | Medium |
| Kubernetes Dependency | Optional | Required | Optional | Optional |
| Community/Ecosystem | Very large | Large | Growing | Growing |
| GPU Scheduling | KubernetesPodOperator | Native | Kubernetes integration | Kubernetes integration |
| Scheduling | Very powerful | Limited | Powerful | Powerful |
The biggest reason for choosing Airflow as an ML pipeline orchestrator is its tool-agnostic nature. Airflow can integrate with various ML tools such as MLflow, SageMaker, and Vertex AI, and can manage ML and non-ML data pipelines (ETL, data quality checks, etc.) on the same infrastructure.
2. Reviewing Airflow Core Concepts: DAG, Operator, Task, XCom
According to the official Apache Airflow documentation, the core concepts are as follows.
DAG (Directed Acyclic Graph)
A DAG is a collection of all Tasks you want to run, organized in a way that reflects their relationships and dependencies. DAGs are defined as Python scripts, and the code itself expresses the structure of the DAG (Tasks and their dependencies). A DAG run is an instantiation of a DAG, containing Task Instances that execute for a specific execution_date.
from airflow import DAG
from datetime import datetime
with DAG(
dag_id='ml_training_pipeline',
start_date=datetime(2026, 1, 1),
schedule_interval='@daily',
catchup=False,
tags=['ml', 'training'],
) as dag:
# Task definitions
pass
Operator
An Operator is a pre-defined template for a Task. According to the official documentation, "If a DAG describes how to run a workflow, Operators determine what actually gets done by a Task." All Operators inherit from BaseOperator, with PythonOperator, BashOperator, and KubernetesPodOperator being representative examples.
Task
A Task is an instance of an Operator. It is an executable unit within a DAG, and a Task Instance represents a specific execution of a Task in a particular DAG run.
XCom (Cross-Communication)
XCom is a mechanism for passing data between Tasks. According to the official documentation, "by default, Tasks are entirely isolated and may run on completely different machines," which is why XCom is needed. XCom is identified by key, task_id, and dag_id.
# Push
task_instance.xcom_push(key="model_accuracy", value=0.95)
# Pull
accuracy = task_instance.xcom_pull(key="model_accuracy", task_ids="evaluate_model")
An important point is that XCom is designed for transferring small amounts of data. You should not pass large data such as DataFrames through XCom; in such cases, you should use an Object Storage Backend.
3. Running GPU Training Jobs with KubernetesPodOperator
In ML training pipelines, KubernetesPodOperator plays a central role. According to the official documentation, KubernetesPodOperator is "a stand-in replacement for Kubernetes object spec definitions" that can be executed by the Airflow scheduler within the DAG context. In other words, you can define and run Kubernetes Pods directly in Python code without writing separate YAML/JSON files for Pod specs.
The reasons KubernetesPodOperator is particularly useful for GPU-based ML training are as follows:
- Task-level resource configuration: CPU, Memory, and GPU can be independently allocated for each Task
- Custom dependency support: When packages not on PyPI or a specific CUDA version is needed, a custom Docker image can be specified
- Language agnosticism: Training code written in any language, not just Python, can be wrapped in a container and executed
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s
train_model = KubernetesPodOperator(
task_id='train_model',
name='gpu-training-job',
namespace='ml-training',
image='my-registry/ml-trainer:latest',
cmds=['python'],
arguments=['train.py', '--epochs', '100', '--batch-size', '64'],
container_resources=k8s.V1ResourceRequirements(
requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
),
get_logs=True,
is_delete_operator_pod=True,
do_xcom_push=True,
)
By setting do_xcom_push=True, you can return data as XCom by writing JSON to the /airflow/xcom/return.json file inside the Pod. This is useful for passing training result metrics to the next Task.
4. Analysis of Key KubernetesPodOperator Parameters
This section analyzes the key parameters of KubernetesPodOperator from an ML training perspective, based on the official documentation.
container_resources
This is the most important parameter for GPU training. Using kubernetes.client.models.V1ResourceRequirements, you can request and limit CPU, Memory, and GPU resources.
container_resources = k8s.V1ResourceRequirements(
requests={
'cpu': '4',
'memory': '16Gi',
'nvidia.com/gpu': '1'
},
limits={
'cpu': '8',
'memory': '32Gi',
'nvidia.com/gpu': '1'
},
)
For GPUs, the nvidia.com/gpu resource key is used, and requests and limits are typically set to the same value. This is because GPUs cannot be fractionally allocated (unless MIG is used).
tolerations
GPU nodes are usually configured with taints to prevent general workloads from being scheduled. tolerations allow Pods to be scheduled on GPU nodes.
tolerations = [
k8s.V1Toleration(
key='nvidia.com/gpu',
operator='Exists',
effect='NoSchedule',
),
]
node_selector
You can specify that training Pods should only run on specific node groups. For example, to run only on nodes equipped with A100 GPUs:
node_selector = {
'gpu-type': 'a100',
'node-pool': 'ml-training',
}
affinity
When more fine-grained scheduling control is needed, use V1Affinity. You can set Node affinity and Pod affinity/anti-affinity to place Pods in specific zones or nodes.
affinity = k8s.V1Affinity(
node_affinity=k8s.V1NodeAffinity(
required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(
node_selector_terms=[
k8s.V1NodeSelectorTerm(
match_expressions=[
k8s.V1NodeSelectorRequirement(
key='gpu-type',
operator='In',
values=['a100', 'h100'],
)
]
)
]
)
)
)
volumes / volume_mounts
Used for mounting training datasets or checkpoints via PVC (PersistentVolumeClaim).
volume = k8s.V1Volume(
name='training-data',
persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
claim_name='ml-dataset-pvc'
),
)
volume_mount = k8s.V1VolumeMount(
name='training-data',
mount_path='/data',
read_only=True,
)
Other Key Parameters
| Parameter | Description | ML Training Usage |
|---|---|---|
image_pull_secrets | Private Registry auth | Accessing internal ML images |
env_vars | Environment variable list | CUDA_VISIBLE_DEVICES, WANDB_API_KEY, etc. |
secrets | Secret volumes/env vars | API Key, credential management |
is_delete_operator_pod | Delete Pod after completion | Resource cleanup (True recommended) |
get_logs | Display logs in Airflow UI | Training log monitoring |
deferrable | Async execution mode | Worker efficiency for long training |
on_finish_action | Post-completion action | Auto cleanup with delete_pod |
The official documentation recommends using native objects from kubernetes.client.models instead of convenience classes for type safety.
5. DAG Design Pattern: Data Preprocessing to Training to Evaluation to Deployment
The general DAG structure for an ML training pipeline includes the following stages:
Data Validation -> Preprocessing -> Feature Engineering -> Training -> Evaluation -> [Conditional] Deployment
Mapping each stage to Airflow Tasks results in the following design pattern:
from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from datetime import datetime
with DAG(
dag_id='ml_training_pipeline',
start_date=datetime(2026, 1, 1),
schedule_interval='@daily',
catchup=False,
) as dag:
validate_data = KubernetesPodOperator(
task_id='validate_data',
image='ml-tools:latest',
cmds=['python', 'validate.py'],
# Only CPU needed
)
preprocess = KubernetesPodOperator(
task_id='preprocess_data',
image='ml-tools:latest',
cmds=['python', 'preprocess.py'],
# Large memory needed
)
train = KubernetesPodOperator(
task_id='train_model',
image='ml-trainer:latest',
cmds=['python', 'train.py'],
# GPU needed
do_xcom_push=True,
)
evaluate = KubernetesPodOperator(
task_id='evaluate_model',
image='ml-trainer:latest',
cmds=['python', 'evaluate.py'],
do_xcom_push=True,
)
def check_model_quality(**context):
metrics = context['ti'].xcom_pull(task_ids='evaluate_model')
if metrics['accuracy'] > 0.90:
return 'deploy_model'
return 'notify_failure'
branch = BranchPythonOperator(
task_id='check_quality',
python_callable=check_model_quality,
)
deploy = KubernetesPodOperator(
task_id='deploy_model',
image='ml-deployer:latest',
cmds=['python', 'deploy.py'],
)
validate_data >> preprocess >> train >> evaluate >> branch
branch >> [deploy]
The key to this design pattern is conditional deployment. The evaluation stage passes model performance metrics via XCom, and through BranchPythonOperator, the deployment Task is executed only when quality criteria are met. This prevents low-quality models from being deployed to production.
Stage Separation Using TaskGroup
In complex pipelines, TaskGroup can be used to visually separate stages. According to the official documentation, "TaskGroup is used to organize Tasks into hierarchical groups in the Graph View. It is useful for creating repeating patterns and reducing visual complexity."
from airflow.utils.task_group import TaskGroup
with TaskGroup('data_preparation') as data_prep:
validate_data >> preprocess >> feature_engineering
with TaskGroup('model_training') as training:
train >> evaluate
data_prep >> training >> branch
6. Passing Metrics/Artifacts Between Tasks with XCom
According to the official documentation, XCom is explicitly "pushed" to and "pulled" from storage via the xcom_push and xcom_pull methods. In ML pipelines, XCom is used for the following purposes:
Passing Training Metrics
# train.py (executed inside KubernetesPodOperator)
import json
metrics = {
'accuracy': 0.9534,
'f1_score': 0.9421,
'loss': 0.0312,
'model_path': 's3://ml-models/experiment-42/model.pt',
'run_id': 'exp-42-20260301',
}
# Write JSON to /airflow/xcom/return.json
with open('/airflow/xcom/return.json', 'w') as f:
json.dump(metrics, f)
Auto-Push and Return Value
Many Operators and the @task decorator automatically push the return value as an XCom with the key return_value when do_xcom_push=True (default). You can then pull it as follows:
value = task_instance.xcom_pull(task_ids='train_model')
Multiple Outputs
When returning a dictionary with multiple_outputs=True, each key is stored as an individual XCom.
@task(multiple_outputs=True)
def evaluate_model(**context):
return {
'accuracy': 0.95,
'f1_score': 0.94,
'model_path': 's3://models/latest.pt',
}
You can then pull by individual key or pull the entire return_value.
Custom XCom Backend
The default XCom Backend, BaseXCom, stores XCom in the Airflow database. While this is fine for small amounts of data, a Custom Backend is needed when dealing with model artifacts or large data. You can override the serialize_value and deserialize_value methods by inheriting from BaseXCom.
To use the Object Storage Backend, set the xcom_backend configuration to airflow.providers.common.io.xcom.backend.XComObjectStorageBackend. This allows storing XCom data in S3 or GCS.
[core]
xcom_backend = airflow.providers.common.io.xcom.backend.XComObjectStorageBackend
XCom Usage Precautions
The key points emphasized in the official documentation are:
- XCom is designed for transferring small amounts of data. Do not pass large data such as DataFrames
- When a Task is retried after failure, previous XCom is automatically cleared to ensure idempotent execution
- XCom operations must be performed within the Task Context via
get_current_context(), and direct DB updates are not supported
As a practical guideline for ML pipelines, metrics (accuracy, loss, etc.) and path information (model_path, artifact_uri, etc.) should be passed via XCom, while actual model files and datasets should be stored in Object Storage like S3/GCS, with only the paths shared via XCom.
7. Hyperparameter Tuning with Dynamic Task Mapping
Dynamic Task Mapping was introduced in Airflow 2.3 and, according to the official documentation, is "a mechanism that allows creating multiple Tasks at runtime based on current data." The DAG author does not need to know in advance how many Tasks are needed.
This feature is very powerful for ML hyperparameter tuning because it enables parallel training with various hyperparameter combinations.
expand() and partial()
expand() specifies the parameters to map, while partial() specifies fixed parameters common to all Tasks.
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
hyperparams = [
{'lr': '0.001', 'batch_size': '32', 'optimizer': 'adam'},
{'lr': '0.01', 'batch_size': '64', 'optimizer': 'adam'},
{'lr': '0.001', 'batch_size': '32', 'optimizer': 'sgd'},
{'lr': '0.0001', 'batch_size': '128', 'optimizer': 'adamw'},
]
train_tasks = KubernetesPodOperator.partial(
task_id='hyperparam_training',
image='ml-trainer:latest',
namespace='ml-training',
get_logs=True,
is_delete_operator_pod=True,
do_xcom_push=True,
container_resources=k8s.V1ResourceRequirements(
requests={'nvidia.com/gpu': '1'},
limits={'nvidia.com/gpu': '1'},
),
).expand(
arguments=[
['train.py', '--lr', hp['lr'], '--batch-size', hp['batch_size'], '--optimizer', hp['optimizer']]
for hp in hyperparams
],
)
This code automatically creates 4 parallel KubernetesPodOperator Tasks for 4 hyperparameter combinations.
Dynamically Generating Mapping Data from Tasks
Furthermore, you can dynamically generate hyperparameter combinations from upstream Tasks.
@task
def generate_hyperparams():
"""Generate hyperparameters via Grid Search or Random Search"""
import itertools
learning_rates = [0.001, 0.01, 0.0001]
batch_sizes = [32, 64, 128]
optimizers = ['adam', 'sgd']
combinations = list(itertools.product(learning_rates, batch_sizes, optimizers))
return [
{'lr': str(lr), 'batch_size': str(bs), 'optimizer': opt}
for lr, bs, opt in combinations
]
hp_list = generate_hyperparams()
According to the official documentation, "mappings generated from a Task prohibit the use of trigger_rule=TriggerRule.ALWAYS."
Cross-Product Mapping
Specifying multiple expand() parameters generates all combinations (cross product).
@task
def train(lr: float, batch_size: int):
# Training logic
pass
train.expand(lr=[0.001, 0.01], batch_size=[32, 64])
# 2 x 2 = 4 Task Instances created
Map-Reduce Pattern
A pattern for collecting results from mapped Tasks and selecting the optimal model.
@task
def select_best_model(results):
"""Select the best among all hyperparameter combination results"""
best = max(results, key=lambda x: x['accuracy'])
return best
training_results = train_task.expand(params=hyperparams)
best = select_best_model(training_results)
According to the official documentation, the collected results are returned as a "lazy proxy" sequence, not an eager list.
Constraints
| Item | Setting | Default |
|---|---|---|
| Max mapped instances | [core] max_map_length | 1024 |
| Parallel execution per Task | max_active_tis_per_dag | Per-task setting |
Only list and dict types can be mapped; other types will raise an UnmappableXComTypePushed error. Additionally, according to the official documentation, "if a field is marked as templated and is mapped, it will not be template-rendered."
8. Detecting Data Arrival with Sensors
In ML pipelines, it is often necessary to wait until training data is ready. Airflow's Sensors are special Operators that wait until a specific condition is met.
S3KeySensor
Waits until training data is uploaded to S3.
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
wait_for_data = S3KeySensor(
task_id='wait_for_training_data',
bucket_name='ml-datasets',
bucket_key='daily/{{ ds }}/training_data.parquet',
aws_conn_id='aws_default',
poke_interval=300, # Check every 5 minutes
timeout=3600 * 6, # Wait up to 6 hours
mode='reschedule', # Return Worker slot
deferrable=True, # Async wait using Triggerer
)
mode='reschedule' returns the Worker slot between pokes to improve resource efficiency. Setting deferrable=True allows the Triggerer component to handle polling asynchronously for even more efficient Worker utilization.
ExternalTaskSensor
Waits until a Task from another DAG completes. For example, a pattern where the training DAG starts after the data preprocessing DAG completes.
from airflow.sensors.external_task import ExternalTaskSensor
wait_for_preprocessing = ExternalTaskSensor(
task_id='wait_for_preprocessing',
external_dag_id='data_preprocessing_pipeline',
external_task_id='final_validation',
allowed_states=['success'],
execution_delta=timedelta(hours=0),
timeout=3600,
poke_interval=60,
mode='reschedule',
)
allowed_states is a list of permitted states, with ['success'] as the default. execution_delta specifies the time difference from the previous execution to check.
Practical Pattern: Starting Training After Data Arrival
wait_for_data >> validate_data >> preprocess >> train >> evaluate >> deploy
This pattern loosely couples the data pipeline and ML pipeline while ensuring that training only starts after data is ready.
9. MLflow and Airflow Integration Patterns
The role division between Airflow and MLflow integration is clear. Airflow manages "when and in what order to execute," while MLflow records "what happened during execution and where the model is." Through this separation of concerns, data engineers can manage Airflow DAGs without touching ML code, and data scientists can focus on model development without worrying about scheduling infrastructure.
Integration Pattern 1: Direct MLflow Calls Inside Training Tasks
This is the most intuitive pattern. MLflow APIs are called directly inside training containers executed by KubernetesPodOperator.
# train.py (executed inside container)
import mlflow
import mlflow.pytorch
mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("daily-training")
with mlflow.start_run() as run:
# Hyperparameter logging
mlflow.log_params({
'learning_rate': 0.001,
'batch_size': 64,
'epochs': 100,
})
# Model training
model = train_model(...)
# Metric logging
mlflow.log_metrics({
'accuracy': 0.95,
'f1_score': 0.94,
})
# Model registration
mlflow.pytorch.log_model(model, "model")
# Pass run_id via XCom
import json
with open('/airflow/xcom/return.json', 'w') as f:
json.dump({'run_id': run.info.run_id}, f)
Integration Pattern 2: Utilizing MLflow Model Registry from Airflow
After training is complete, the model is registered in the MLflow Model Registry based on evaluation results and the stage is transitioned.
@task
def register_model(run_id: str, accuracy: float):
import mlflow
from mlflow.tracking import MlflowClient
client = MlflowClient("http://mlflow-server:5000")
# Register model
model_uri = f"runs:/{run_id}/model"
mv = mlflow.register_model(model_uri, "production-model")
# Transition to Production stage if accuracy exceeds threshold
if accuracy > 0.93:
client.transition_model_version_stage(
name="production-model",
version=mv.version,
stage="Production",
)
Integration Pattern 3: Querying and Comparing Training History from MLflow
@task
def compare_with_baseline():
from mlflow.tracking import MlflowClient
client = MlflowClient("http://mlflow-server:5000")
# Query Production model metrics
prod_versions = client.get_latest_versions("production-model", stages=["Production"])
if prod_versions:
prod_run = client.get_run(prod_versions[0].run_id)
baseline_accuracy = float(prod_run.data.metrics['accuracy'])
return {'baseline_accuracy': baseline_accuracy}
return {'baseline_accuracy': 0.0}
10. Utilizing the TaskFlow API (@task Decorator)
According to the official documentation, the TaskFlow API is "a functional API that defines DAGs and Tasks using decorators," greatly simplifying data transfer and dependency definition between Tasks.
Key Characteristics
The biggest advantage of TaskFlow is automatic XCom management and automatic dependency calculation. When a TaskFlow function is called, it is not executed immediately; instead, an XComArg object representing the result is returned. Using this as input to downstream Tasks causes dependencies to be calculated automatically.
from airflow.sdk import dag, task
from datetime import datetime
@dag(
schedule='@daily',
start_date=datetime(2026, 1, 1),
catchup=False,
)
def ml_training_taskflow():
@task
def extract_data():
"""Data extraction"""
return {'data_path': 's3://bucket/data/2026-03-01.parquet', 'num_rows': 100000}
@task(multiple_outputs=True)
def preprocess(data_info: dict):
"""Data preprocessing"""
return {
'train_path': f"{data_info['data_path']}/train.parquet",
'test_path': f"{data_info['data_path']}/test.parquet",
'feature_count': 128,
}
@task
def train_model(train_path: str, feature_count: int):
"""Model training"""
# In practice, GPU training would be run via KubernetesPodOperator
return {
'model_path': 's3://models/latest.pt',
'accuracy': 0.95,
}
@task
def evaluate_and_deploy(model_info: dict):
"""Model evaluation and conditional deployment"""
if model_info['accuracy'] > 0.90:
return f"Deployed model from {model_info['model_path']}"
return "Model quality below threshold, skipping deployment"
# Automatic dependency wiring
data = extract_data()
processed = preprocess(data)
model = train_model(
train_path=processed['train_path'],
feature_count=processed['feature_count'],
)
evaluate_and_deploy(model)
ml_training_taskflow()
Context Access
Tasks can receive Airflow context variables as keyword arguments.
@task
def log_execution_info(task_instance=None, dag_run=None):
print(f"Run ID: {task_instance.run_id}")
print(f"DAG Run: {dag_run.dag_id}")
Object Serialization
TaskFlow supports custom object passing via @dataclass, @attr.define, or custom serialize()/deserialize() methods. Version management is also possible using __version__: ClassVar[int].
11. Production DAG Code Example (Full ML Pipeline)
The following is a production ML training pipeline DAG code that integrates all the concepts covered so far. It includes the entire process from data arrival detection to hyperparameter tuning, optimal model selection, and deployment.
"""
ML Training Pipeline DAG
- S3 data arrival detection
- Data preprocessing (KubernetesPodOperator)
- Hyperparameter tuning with Dynamic Task Mapping
- Optimal model selection and MLflow registration
- Conditional deployment
"""
from __future__ import annotations
from datetime import datetime, timedelta
from airflow import DAG
from airflow.sdk import task
from airflow.operators.python import BranchPythonOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.utils.task_group import TaskGroup
from kubernetes.client import models as k8s
# -- Common Configuration --
GPU_RESOURCES = k8s.V1ResourceRequirements(
requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
)
GPU_TOLERATIONS = [
k8s.V1Toleration(
key='nvidia.com/gpu',
operator='Exists',
effect='NoSchedule',
),
]
GPU_NODE_SELECTOR = {'gpu-type': 'a100', 'node-pool': 'ml-training'}
DATA_VOLUME = k8s.V1Volume(
name='shared-data',
persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
claim_name='ml-shared-pvc'
),
)
DATA_VOLUME_MOUNT = k8s.V1VolumeMount(
name='shared-data',
mount_path='/shared',
)
default_args = {
'owner': 'ml-team',
'retries': 2,
'retry_delay': timedelta(minutes=5),
'execution_timeout': timedelta(hours=4),
}
with DAG(
dag_id='ml_full_training_pipeline',
default_args=default_args,
start_date=datetime(2026, 1, 1),
schedule_interval='@daily',
catchup=False,
tags=['ml', 'training', 'production'],
max_active_runs=1,
doc_md="""
## ML Full Training Pipeline
A pipeline that retrains the model daily with new data
and automatically deploys it if quality criteria are met.
""",
) as dag:
# ===== Stage 1: Data Arrival Detection =====
wait_for_data = S3KeySensor(
task_id='wait_for_training_data',
bucket_name='ml-datasets',
bucket_key='daily/{{ ds }}/features.parquet',
aws_conn_id='aws_default',
poke_interval=300,
timeout=3600 * 6,
mode='reschedule',
deferrable=True,
)
# ===== Stage 2: Data Preprocessing =====
with TaskGroup('data_preparation') as data_prep:
validate_data = KubernetesPodOperator(
task_id='validate_data',
name='data-validation',
namespace='ml-training',
image='ml-tools:latest',
cmds=['python', 'validate.py'],
arguments=['--date', '{{ ds }}', '--source', 's3://ml-datasets/daily/{{ ds }}/'],
env_vars=[
k8s.V1EnvVar(name='DATA_DATE', value='{{ ds }}'),
],
container_resources=k8s.V1ResourceRequirements(
requests={'cpu': '2', 'memory': '8Gi'},
limits={'cpu': '4', 'memory': '16Gi'},
),
get_logs=True,
is_delete_operator_pod=True,
do_xcom_push=True,
)
preprocess = KubernetesPodOperator(
task_id='preprocess_data',
name='data-preprocessing',
namespace='ml-training',
image='ml-tools:latest',
cmds=['python', 'preprocess.py'],
arguments=[
'--input', 's3://ml-datasets/daily/{{ ds }}/',
'--output', 's3://ml-processed/{{ ds }}/',
],
container_resources=k8s.V1ResourceRequirements(
requests={'cpu': '4', 'memory': '32Gi'},
limits={'cpu': '8', 'memory': '64Gi'},
),
volumes=[DATA_VOLUME],
volume_mounts=[DATA_VOLUME_MOUNT],
get_logs=True,
is_delete_operator_pod=True,
do_xcom_push=True,
)
validate_data >> preprocess
# ===== Stage 3: Hyperparameter Generation and Parallel Training =====
@task
def generate_hyperparams():
"""Generate hyperparameter combinations for training"""
import itertools
learning_rates = [0.001, 0.0001]
batch_sizes = [32, 64]
optimizers = ['adam', 'adamw']
combos = list(itertools.product(learning_rates, batch_sizes, optimizers))
return [
[
'train.py',
'--lr', str(lr),
'--batch-size', str(bs),
'--optimizer', opt,
'--data-path', 's3://ml-processed/{{ ds }}/',
'--experiment-name', 'daily-training-{{ ds }}',
]
for lr, bs, opt in combos
]
hp_args = generate_hyperparams()
with TaskGroup('hyperparameter_tuning') as hp_tuning:
train_tasks = KubernetesPodOperator.partial(
task_id='train_with_hp',
name='gpu-hp-training',
namespace='ml-training',
image='ml-trainer:latest',
cmds=['python'],
container_resources=GPU_RESOURCES,
tolerations=GPU_TOLERATIONS,
node_selector=GPU_NODE_SELECTOR,
volumes=[DATA_VOLUME],
volume_mounts=[DATA_VOLUME_MOUNT],
env_vars=[
k8s.V1EnvVar(name='MLFLOW_TRACKING_URI', value='http://mlflow:5000'),
k8s.V1EnvVar(name='CUDA_VISIBLE_DEVICES', value='0'),
],
get_logs=True,
is_delete_operator_pod=True,
do_xcom_push=True,
startup_timeout_seconds=600,
).expand(arguments=hp_args)
# ===== Stage 4: Optimal Model Selection =====
@task
def select_best_model(training_results):
"""Select the optimal model from all training results"""
valid_results = [r for r in training_results if r is not None]
if not valid_results:
raise ValueError("No valid training results found")
best = max(valid_results, key=lambda x: x.get('accuracy', 0))
print(f"Best model: accuracy={best['accuracy']}, run_id={best['run_id']}")
return best
best_model = select_best_model(train_tasks)
# ===== Stage 5: Model Evaluation =====
@task(multiple_outputs=True)
def evaluate_model(model_info: dict):
"""Detailed evaluation of the optimal model"""
import mlflow
from mlflow.tracking import MlflowClient
client = MlflowClient("http://mlflow:5000")
# Compare with current Production model
try:
prod_versions = client.get_latest_versions(
"production-model", stages=["Production"]
)
if prod_versions:
prod_run = client.get_run(prod_versions[0].run_id)
baseline_accuracy = float(prod_run.data.metrics.get('accuracy', 0))
else:
baseline_accuracy = 0.0
except Exception:
baseline_accuracy = 0.0
new_accuracy = model_info['accuracy']
improvement = new_accuracy - baseline_accuracy
return {
'should_deploy': improvement > 0.005 and new_accuracy > 0.90,
'new_accuracy': new_accuracy,
'baseline_accuracy': baseline_accuracy,
'improvement': improvement,
'run_id': model_info['run_id'],
'model_path': model_info.get('model_path', ''),
}
eval_result = evaluate_model(best_model)
# ===== Stage 6: Conditional Deployment =====
def decide_deployment(**context):
should_deploy = context['ti'].xcom_pull(
task_ids='evaluate_model', key='should_deploy'
)
if should_deploy:
return 'deploy_model'
return 'skip_deployment'
deployment_branch = BranchPythonOperator(
task_id='deployment_decision',
python_callable=decide_deployment,
)
deploy = KubernetesPodOperator(
task_id='deploy_model',
name='model-deployment',
namespace='ml-serving',
image='ml-deployer:latest',
cmds=['python', 'deploy.py'],
arguments=[
'--run-id', "{{ ti.xcom_pull(task_ids='evaluate_model', key='run_id') }}",
'--model-name', 'production-model',
],
container_resources=k8s.V1ResourceRequirements(
requests={'cpu': '2', 'memory': '4Gi'},
limits={'cpu': '4', 'memory': '8Gi'},
),
get_logs=True,
is_delete_operator_pod=True,
)
@task(task_id='skip_deployment')
def skip_deployment():
print("Model quality does not meet deployment criteria. Skipping.")
skip = skip_deployment()
# ===== Stage 7: Notification =====
@task(trigger_rule='none_failed_min_one_success')
def send_notification(**context):
"""Send training result notification"""
eval_data = context['ti'].xcom_pull(task_ids='evaluate_model', key='return_value')
print(f"Pipeline completed. Accuracy: {eval_data.get('new_accuracy')}")
print(f"Improvement: {eval_data.get('improvement')}")
# Slack / Email notification logic
notification = send_notification()
# ===== DAG Dependency Definition =====
wait_for_data >> data_prep >> hp_args >> hp_tuning
hp_tuning >> best_model >> eval_result >> deployment_branch
deployment_branch >> [deploy, skip] >> notification
This DAG integrates the following Airflow features:
- S3KeySensor: Data arrival waiting (deferrable mode)
- KubernetesPodOperator: GPU-based training Job execution
- Dynamic Task Mapping: Parallel hyperparameter training via
expand() - TaskFlow API: Python function-based Task definition via
@taskdecorator - XCom: Training metric and model path passing
- TaskGroup: Visual stage separation
- BranchPythonOperator: Conditional deployment based on model quality
In actual production, error handling, SLA settings, and notification integration should be added to improve reliability.
12. References
The content analyzed in this article is based on the following Apache Airflow official documentation and related resources.
Apache Airflow Official Documentation
- Airflow Core Concepts - Overview
- Airflow Core Concepts - Operators
- Airflow Core Concepts - XComs
- Airflow Core Concepts - TaskFlow
- Airflow Tutorial - TaskFlow API
- Dynamic Task Mapping
- KubernetesPodOperator
- KubernetesPodOperator API Reference
- S3KeySensor (Amazon Provider)
- Object Storage XCom Backend
- Airflow MLOps Use Cases