Skip to content
Published on

Orchestrating ML Training Pipelines with Airflow

Authors
  • Name
    Twitter

1. The Role of Airflow in ML Pipelines

In ML projects, model training does not end simply by calling model.fit(). A series of steps -- data collection, preprocessing, Feature Engineering, model training, evaluation, model registry registration, and deployment -- must be executed in a reproducible and automated manner. This is where a Workflow Orchestration tool is needed, and Apache Airflow is one of the most mature open-source projects in this domain.

Airflow vs Other Orchestration Tools

Besides Airflow, other ML pipeline orchestration tools include Kubeflow Pipelines, Prefect, and Dagster. Let us compare their characteristics.

Apache Airflow is a general-purpose workflow orchestrator. While it is not specialized for ML, it is widely used in ML pipelines thanks to its rich Operator ecosystem, powerful scheduling capabilities, and scalability on Kubernetes. With Airflow 3.x, features such as DAG versioning, event-driven scheduling, and multi-language support have been added, making it even more powerful.

Kubeflow Pipelines is a Kubernetes-native ML pipeline tool specialized for ML workflows. It provides integrated GPU scheduling, experiment tracking, and model serving, but requires deep understanding of Kubernetes and is difficult to use outside Kubernetes environments.

Prefect is often described as "Airflow, but nicer," with its Python-native interface and quick setup being key advantages. It offers dynamic flow, hybrid execution, and real-time SLA alerting, making it particularly suitable for small teams or PoC stages where rapid adoption is needed.

Dagster is a modern orchestrator that takes an asset-centric approach. With a philosophy of designing data pipelines centered around "the data being produced" rather than "execution steps," it has strengths in data lineage tracking.

CharacteristicAirflowKubeflow PipelinesPrefectDagster
ML SpecializationGeneral-purposeML-specializedGeneral-purposeData-centric
Learning CurveMediumHighLowMedium
Kubernetes DependencyOptionalRequiredOptionalOptional
Community/EcosystemVery largeLargeGrowingGrowing
GPU SchedulingKubernetesPodOperatorNativeKubernetes integrationKubernetes integration
SchedulingVery powerfulLimitedPowerfulPowerful

The biggest reason for choosing Airflow as an ML pipeline orchestrator is its tool-agnostic nature. Airflow can integrate with various ML tools such as MLflow, SageMaker, and Vertex AI, and can manage ML and non-ML data pipelines (ETL, data quality checks, etc.) on the same infrastructure.

2. Reviewing Airflow Core Concepts: DAG, Operator, Task, XCom

According to the official Apache Airflow documentation, the core concepts are as follows.

DAG (Directed Acyclic Graph)

A DAG is a collection of all Tasks you want to run, organized in a way that reflects their relationships and dependencies. DAGs are defined as Python scripts, and the code itself expresses the structure of the DAG (Tasks and their dependencies). A DAG run is an instantiation of a DAG, containing Task Instances that execute for a specific execution_date.

from airflow import DAG
from datetime import datetime

with DAG(
    dag_id='ml_training_pipeline',
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
    tags=['ml', 'training'],
) as dag:
    # Task definitions
    pass

Operator

An Operator is a pre-defined template for a Task. According to the official documentation, "If a DAG describes how to run a workflow, Operators determine what actually gets done by a Task." All Operators inherit from BaseOperator, with PythonOperator, BashOperator, and KubernetesPodOperator being representative examples.

Task

A Task is an instance of an Operator. It is an executable unit within a DAG, and a Task Instance represents a specific execution of a Task in a particular DAG run.

XCom (Cross-Communication)

XCom is a mechanism for passing data between Tasks. According to the official documentation, "by default, Tasks are entirely isolated and may run on completely different machines," which is why XCom is needed. XCom is identified by key, task_id, and dag_id.

# Push
task_instance.xcom_push(key="model_accuracy", value=0.95)

# Pull
accuracy = task_instance.xcom_pull(key="model_accuracy", task_ids="evaluate_model")

An important point is that XCom is designed for transferring small amounts of data. You should not pass large data such as DataFrames through XCom; in such cases, you should use an Object Storage Backend.

3. Running GPU Training Jobs with KubernetesPodOperator

In ML training pipelines, KubernetesPodOperator plays a central role. According to the official documentation, KubernetesPodOperator is "a stand-in replacement for Kubernetes object spec definitions" that can be executed by the Airflow scheduler within the DAG context. In other words, you can define and run Kubernetes Pods directly in Python code without writing separate YAML/JSON files for Pod specs.

The reasons KubernetesPodOperator is particularly useful for GPU-based ML training are as follows:

  • Task-level resource configuration: CPU, Memory, and GPU can be independently allocated for each Task
  • Custom dependency support: When packages not on PyPI or a specific CUDA version is needed, a custom Docker image can be specified
  • Language agnosticism: Training code written in any language, not just Python, can be wrapped in a container and executed
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s

train_model = KubernetesPodOperator(
    task_id='train_model',
    name='gpu-training-job',
    namespace='ml-training',
    image='my-registry/ml-trainer:latest',
    cmds=['python'],
    arguments=['train.py', '--epochs', '100', '--batch-size', '64'],
    container_resources=k8s.V1ResourceRequirements(
        requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
        limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
    ),
    get_logs=True,
    is_delete_operator_pod=True,
    do_xcom_push=True,
)

By setting do_xcom_push=True, you can return data as XCom by writing JSON to the /airflow/xcom/return.json file inside the Pod. This is useful for passing training result metrics to the next Task.

4. Analysis of Key KubernetesPodOperator Parameters

This section analyzes the key parameters of KubernetesPodOperator from an ML training perspective, based on the official documentation.

container_resources

This is the most important parameter for GPU training. Using kubernetes.client.models.V1ResourceRequirements, you can request and limit CPU, Memory, and GPU resources.

container_resources = k8s.V1ResourceRequirements(
    requests={
        'cpu': '4',
        'memory': '16Gi',
        'nvidia.com/gpu': '1'
    },
    limits={
        'cpu': '8',
        'memory': '32Gi',
        'nvidia.com/gpu': '1'
    },
)

For GPUs, the nvidia.com/gpu resource key is used, and requests and limits are typically set to the same value. This is because GPUs cannot be fractionally allocated (unless MIG is used).

tolerations

GPU nodes are usually configured with taints to prevent general workloads from being scheduled. tolerations allow Pods to be scheduled on GPU nodes.

tolerations = [
    k8s.V1Toleration(
        key='nvidia.com/gpu',
        operator='Exists',
        effect='NoSchedule',
    ),
]

node_selector

You can specify that training Pods should only run on specific node groups. For example, to run only on nodes equipped with A100 GPUs:

node_selector = {
    'gpu-type': 'a100',
    'node-pool': 'ml-training',
}

affinity

When more fine-grained scheduling control is needed, use V1Affinity. You can set Node affinity and Pod affinity/anti-affinity to place Pods in specific zones or nodes.

affinity = k8s.V1Affinity(
    node_affinity=k8s.V1NodeAffinity(
        required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(
            node_selector_terms=[
                k8s.V1NodeSelectorTerm(
                    match_expressions=[
                        k8s.V1NodeSelectorRequirement(
                            key='gpu-type',
                            operator='In',
                            values=['a100', 'h100'],
                        )
                    ]
                )
            ]
        )
    )
)

volumes / volume_mounts

Used for mounting training datasets or checkpoints via PVC (PersistentVolumeClaim).

volume = k8s.V1Volume(
    name='training-data',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
        claim_name='ml-dataset-pvc'
    ),
)

volume_mount = k8s.V1VolumeMount(
    name='training-data',
    mount_path='/data',
    read_only=True,
)

Other Key Parameters

ParameterDescriptionML Training Usage
image_pull_secretsPrivate Registry authAccessing internal ML images
env_varsEnvironment variable listCUDA_VISIBLE_DEVICES, WANDB_API_KEY, etc.
secretsSecret volumes/env varsAPI Key, credential management
is_delete_operator_podDelete Pod after completionResource cleanup (True recommended)
get_logsDisplay logs in Airflow UITraining log monitoring
deferrableAsync execution modeWorker efficiency for long training
on_finish_actionPost-completion actionAuto cleanup with delete_pod

The official documentation recommends using native objects from kubernetes.client.models instead of convenience classes for type safety.

5. DAG Design Pattern: Data Preprocessing to Training to Evaluation to Deployment

The general DAG structure for an ML training pipeline includes the following stages:

Data Validation -> Preprocessing -> Feature Engineering -> Training -> Evaluation -> [Conditional] Deployment

Mapping each stage to Airflow Tasks results in the following design pattern:

from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from datetime import datetime

with DAG(
    dag_id='ml_training_pipeline',
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
) as dag:

    validate_data = KubernetesPodOperator(
        task_id='validate_data',
        image='ml-tools:latest',
        cmds=['python', 'validate.py'],
        # Only CPU needed
    )

    preprocess = KubernetesPodOperator(
        task_id='preprocess_data',
        image='ml-tools:latest',
        cmds=['python', 'preprocess.py'],
        # Large memory needed
    )

    train = KubernetesPodOperator(
        task_id='train_model',
        image='ml-trainer:latest',
        cmds=['python', 'train.py'],
        # GPU needed
        do_xcom_push=True,
    )

    evaluate = KubernetesPodOperator(
        task_id='evaluate_model',
        image='ml-trainer:latest',
        cmds=['python', 'evaluate.py'],
        do_xcom_push=True,
    )

    def check_model_quality(**context):
        metrics = context['ti'].xcom_pull(task_ids='evaluate_model')
        if metrics['accuracy'] > 0.90:
            return 'deploy_model'
        return 'notify_failure'

    branch = BranchPythonOperator(
        task_id='check_quality',
        python_callable=check_model_quality,
    )

    deploy = KubernetesPodOperator(
        task_id='deploy_model',
        image='ml-deployer:latest',
        cmds=['python', 'deploy.py'],
    )

    validate_data >> preprocess >> train >> evaluate >> branch
    branch >> [deploy]

The key to this design pattern is conditional deployment. The evaluation stage passes model performance metrics via XCom, and through BranchPythonOperator, the deployment Task is executed only when quality criteria are met. This prevents low-quality models from being deployed to production.

Stage Separation Using TaskGroup

In complex pipelines, TaskGroup can be used to visually separate stages. According to the official documentation, "TaskGroup is used to organize Tasks into hierarchical groups in the Graph View. It is useful for creating repeating patterns and reducing visual complexity."

from airflow.utils.task_group import TaskGroup

with TaskGroup('data_preparation') as data_prep:
    validate_data >> preprocess >> feature_engineering

with TaskGroup('model_training') as training:
    train >> evaluate

data_prep >> training >> branch

6. Passing Metrics/Artifacts Between Tasks with XCom

According to the official documentation, XCom is explicitly "pushed" to and "pulled" from storage via the xcom_push and xcom_pull methods. In ML pipelines, XCom is used for the following purposes:

Passing Training Metrics

# train.py (executed inside KubernetesPodOperator)
import json

metrics = {
    'accuracy': 0.9534,
    'f1_score': 0.9421,
    'loss': 0.0312,
    'model_path': 's3://ml-models/experiment-42/model.pt',
    'run_id': 'exp-42-20260301',
}

# Write JSON to /airflow/xcom/return.json
with open('/airflow/xcom/return.json', 'w') as f:
    json.dump(metrics, f)

Auto-Push and Return Value

Many Operators and the @task decorator automatically push the return value as an XCom with the key return_value when do_xcom_push=True (default). You can then pull it as follows:

value = task_instance.xcom_pull(task_ids='train_model')

Multiple Outputs

When returning a dictionary with multiple_outputs=True, each key is stored as an individual XCom.

@task(multiple_outputs=True)
def evaluate_model(**context):
    return {
        'accuracy': 0.95,
        'f1_score': 0.94,
        'model_path': 's3://models/latest.pt',
    }

You can then pull by individual key or pull the entire return_value.

Custom XCom Backend

The default XCom Backend, BaseXCom, stores XCom in the Airflow database. While this is fine for small amounts of data, a Custom Backend is needed when dealing with model artifacts or large data. You can override the serialize_value and deserialize_value methods by inheriting from BaseXCom.

To use the Object Storage Backend, set the xcom_backend configuration to airflow.providers.common.io.xcom.backend.XComObjectStorageBackend. This allows storing XCom data in S3 or GCS.

[core]
xcom_backend = airflow.providers.common.io.xcom.backend.XComObjectStorageBackend

XCom Usage Precautions

The key points emphasized in the official documentation are:

  • XCom is designed for transferring small amounts of data. Do not pass large data such as DataFrames
  • When a Task is retried after failure, previous XCom is automatically cleared to ensure idempotent execution
  • XCom operations must be performed within the Task Context via get_current_context(), and direct DB updates are not supported

As a practical guideline for ML pipelines, metrics (accuracy, loss, etc.) and path information (model_path, artifact_uri, etc.) should be passed via XCom, while actual model files and datasets should be stored in Object Storage like S3/GCS, with only the paths shared via XCom.

7. Hyperparameter Tuning with Dynamic Task Mapping

Dynamic Task Mapping was introduced in Airflow 2.3 and, according to the official documentation, is "a mechanism that allows creating multiple Tasks at runtime based on current data." The DAG author does not need to know in advance how many Tasks are needed.

This feature is very powerful for ML hyperparameter tuning because it enables parallel training with various hyperparameter combinations.

expand() and partial()

expand() specifies the parameters to map, while partial() specifies fixed parameters common to all Tasks.

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

hyperparams = [
    {'lr': '0.001', 'batch_size': '32', 'optimizer': 'adam'},
    {'lr': '0.01', 'batch_size': '64', 'optimizer': 'adam'},
    {'lr': '0.001', 'batch_size': '32', 'optimizer': 'sgd'},
    {'lr': '0.0001', 'batch_size': '128', 'optimizer': 'adamw'},
]

train_tasks = KubernetesPodOperator.partial(
    task_id='hyperparam_training',
    image='ml-trainer:latest',
    namespace='ml-training',
    get_logs=True,
    is_delete_operator_pod=True,
    do_xcom_push=True,
    container_resources=k8s.V1ResourceRequirements(
        requests={'nvidia.com/gpu': '1'},
        limits={'nvidia.com/gpu': '1'},
    ),
).expand(
    arguments=[
        ['train.py', '--lr', hp['lr'], '--batch-size', hp['batch_size'], '--optimizer', hp['optimizer']]
        for hp in hyperparams
    ],
)

This code automatically creates 4 parallel KubernetesPodOperator Tasks for 4 hyperparameter combinations.

Dynamically Generating Mapping Data from Tasks

Furthermore, you can dynamically generate hyperparameter combinations from upstream Tasks.

@task
def generate_hyperparams():
    """Generate hyperparameters via Grid Search or Random Search"""
    import itertools

    learning_rates = [0.001, 0.01, 0.0001]
    batch_sizes = [32, 64, 128]
    optimizers = ['adam', 'sgd']

    combinations = list(itertools.product(learning_rates, batch_sizes, optimizers))
    return [
        {'lr': str(lr), 'batch_size': str(bs), 'optimizer': opt}
        for lr, bs, opt in combinations
    ]

hp_list = generate_hyperparams()

According to the official documentation, "mappings generated from a Task prohibit the use of trigger_rule=TriggerRule.ALWAYS."

Cross-Product Mapping

Specifying multiple expand() parameters generates all combinations (cross product).

@task
def train(lr: float, batch_size: int):
    # Training logic
    pass

train.expand(lr=[0.001, 0.01], batch_size=[32, 64])
# 2 x 2 = 4 Task Instances created

Map-Reduce Pattern

A pattern for collecting results from mapped Tasks and selecting the optimal model.

@task
def select_best_model(results):
    """Select the best among all hyperparameter combination results"""
    best = max(results, key=lambda x: x['accuracy'])
    return best

training_results = train_task.expand(params=hyperparams)
best = select_best_model(training_results)

According to the official documentation, the collected results are returned as a "lazy proxy" sequence, not an eager list.

Constraints

ItemSettingDefault
Max mapped instances[core] max_map_length1024
Parallel execution per Taskmax_active_tis_per_dagPer-task setting

Only list and dict types can be mapped; other types will raise an UnmappableXComTypePushed error. Additionally, according to the official documentation, "if a field is marked as templated and is mapped, it will not be template-rendered."

8. Detecting Data Arrival with Sensors

In ML pipelines, it is often necessary to wait until training data is ready. Airflow's Sensors are special Operators that wait until a specific condition is met.

S3KeySensor

Waits until training data is uploaded to S3.

from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor

wait_for_data = S3KeySensor(
    task_id='wait_for_training_data',
    bucket_name='ml-datasets',
    bucket_key='daily/{{ ds }}/training_data.parquet',
    aws_conn_id='aws_default',
    poke_interval=300,  # Check every 5 minutes
    timeout=3600 * 6,   # Wait up to 6 hours
    mode='reschedule',  # Return Worker slot
    deferrable=True,    # Async wait using Triggerer
)

mode='reschedule' returns the Worker slot between pokes to improve resource efficiency. Setting deferrable=True allows the Triggerer component to handle polling asynchronously for even more efficient Worker utilization.

ExternalTaskSensor

Waits until a Task from another DAG completes. For example, a pattern where the training DAG starts after the data preprocessing DAG completes.

from airflow.sensors.external_task import ExternalTaskSensor

wait_for_preprocessing = ExternalTaskSensor(
    task_id='wait_for_preprocessing',
    external_dag_id='data_preprocessing_pipeline',
    external_task_id='final_validation',
    allowed_states=['success'],
    execution_delta=timedelta(hours=0),
    timeout=3600,
    poke_interval=60,
    mode='reschedule',
)

allowed_states is a list of permitted states, with ['success'] as the default. execution_delta specifies the time difference from the previous execution to check.

Practical Pattern: Starting Training After Data Arrival

wait_for_data >> validate_data >> preprocess >> train >> evaluate >> deploy

This pattern loosely couples the data pipeline and ML pipeline while ensuring that training only starts after data is ready.

9. MLflow and Airflow Integration Patterns

The role division between Airflow and MLflow integration is clear. Airflow manages "when and in what order to execute," while MLflow records "what happened during execution and where the model is." Through this separation of concerns, data engineers can manage Airflow DAGs without touching ML code, and data scientists can focus on model development without worrying about scheduling infrastructure.

Integration Pattern 1: Direct MLflow Calls Inside Training Tasks

This is the most intuitive pattern. MLflow APIs are called directly inside training containers executed by KubernetesPodOperator.

# train.py (executed inside container)
import mlflow
import mlflow.pytorch

mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("daily-training")

with mlflow.start_run() as run:
    # Hyperparameter logging
    mlflow.log_params({
        'learning_rate': 0.001,
        'batch_size': 64,
        'epochs': 100,
    })

    # Model training
    model = train_model(...)

    # Metric logging
    mlflow.log_metrics({
        'accuracy': 0.95,
        'f1_score': 0.94,
    })

    # Model registration
    mlflow.pytorch.log_model(model, "model")

    # Pass run_id via XCom
    import json
    with open('/airflow/xcom/return.json', 'w') as f:
        json.dump({'run_id': run.info.run_id}, f)

Integration Pattern 2: Utilizing MLflow Model Registry from Airflow

After training is complete, the model is registered in the MLflow Model Registry based on evaluation results and the stage is transitioned.

@task
def register_model(run_id: str, accuracy: float):
    import mlflow
    from mlflow.tracking import MlflowClient

    client = MlflowClient("http://mlflow-server:5000")

    # Register model
    model_uri = f"runs:/{run_id}/model"
    mv = mlflow.register_model(model_uri, "production-model")

    # Transition to Production stage if accuracy exceeds threshold
    if accuracy > 0.93:
        client.transition_model_version_stage(
            name="production-model",
            version=mv.version,
            stage="Production",
        )

Integration Pattern 3: Querying and Comparing Training History from MLflow

@task
def compare_with_baseline():
    from mlflow.tracking import MlflowClient

    client = MlflowClient("http://mlflow-server:5000")

    # Query Production model metrics
    prod_versions = client.get_latest_versions("production-model", stages=["Production"])
    if prod_versions:
        prod_run = client.get_run(prod_versions[0].run_id)
        baseline_accuracy = float(prod_run.data.metrics['accuracy'])
        return {'baseline_accuracy': baseline_accuracy}
    return {'baseline_accuracy': 0.0}

10. Utilizing the TaskFlow API (@task Decorator)

According to the official documentation, the TaskFlow API is "a functional API that defines DAGs and Tasks using decorators," greatly simplifying data transfer and dependency definition between Tasks.

Key Characteristics

The biggest advantage of TaskFlow is automatic XCom management and automatic dependency calculation. When a TaskFlow function is called, it is not executed immediately; instead, an XComArg object representing the result is returned. Using this as input to downstream Tasks causes dependencies to be calculated automatically.

from airflow.sdk import dag, task
from datetime import datetime

@dag(
    schedule='@daily',
    start_date=datetime(2026, 1, 1),
    catchup=False,
)
def ml_training_taskflow():

    @task
    def extract_data():
        """Data extraction"""
        return {'data_path': 's3://bucket/data/2026-03-01.parquet', 'num_rows': 100000}

    @task(multiple_outputs=True)
    def preprocess(data_info: dict):
        """Data preprocessing"""
        return {
            'train_path': f"{data_info['data_path']}/train.parquet",
            'test_path': f"{data_info['data_path']}/test.parquet",
            'feature_count': 128,
        }

    @task
    def train_model(train_path: str, feature_count: int):
        """Model training"""
        # In practice, GPU training would be run via KubernetesPodOperator
        return {
            'model_path': 's3://models/latest.pt',
            'accuracy': 0.95,
        }

    @task
    def evaluate_and_deploy(model_info: dict):
        """Model evaluation and conditional deployment"""
        if model_info['accuracy'] > 0.90:
            return f"Deployed model from {model_info['model_path']}"
        return "Model quality below threshold, skipping deployment"

    # Automatic dependency wiring
    data = extract_data()
    processed = preprocess(data)
    model = train_model(
        train_path=processed['train_path'],
        feature_count=processed['feature_count'],
    )
    evaluate_and_deploy(model)

ml_training_taskflow()

Context Access

Tasks can receive Airflow context variables as keyword arguments.

@task
def log_execution_info(task_instance=None, dag_run=None):
    print(f"Run ID: {task_instance.run_id}")
    print(f"DAG Run: {dag_run.dag_id}")

Object Serialization

TaskFlow supports custom object passing via @dataclass, @attr.define, or custom serialize()/deserialize() methods. Version management is also possible using __version__: ClassVar[int].

11. Production DAG Code Example (Full ML Pipeline)

The following is a production ML training pipeline DAG code that integrates all the concepts covered so far. It includes the entire process from data arrival detection to hyperparameter tuning, optimal model selection, and deployment.

"""
ML Training Pipeline DAG
- S3 data arrival detection
- Data preprocessing (KubernetesPodOperator)
- Hyperparameter tuning with Dynamic Task Mapping
- Optimal model selection and MLflow registration
- Conditional deployment
"""

from __future__ import annotations

from datetime import datetime, timedelta
from airflow import DAG
from airflow.sdk import task
from airflow.operators.python import BranchPythonOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.utils.task_group import TaskGroup
from kubernetes.client import models as k8s

# -- Common Configuration --
GPU_RESOURCES = k8s.V1ResourceRequirements(
    requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
    limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
)

GPU_TOLERATIONS = [
    k8s.V1Toleration(
        key='nvidia.com/gpu',
        operator='Exists',
        effect='NoSchedule',
    ),
]

GPU_NODE_SELECTOR = {'gpu-type': 'a100', 'node-pool': 'ml-training'}

DATA_VOLUME = k8s.V1Volume(
    name='shared-data',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
        claim_name='ml-shared-pvc'
    ),
)

DATA_VOLUME_MOUNT = k8s.V1VolumeMount(
    name='shared-data',
    mount_path='/shared',
)

default_args = {
    'owner': 'ml-team',
    'retries': 2,
    'retry_delay': timedelta(minutes=5),
    'execution_timeout': timedelta(hours=4),
}


with DAG(
    dag_id='ml_full_training_pipeline',
    default_args=default_args,
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
    tags=['ml', 'training', 'production'],
    max_active_runs=1,
    doc_md="""
    ## ML Full Training Pipeline
    A pipeline that retrains the model daily with new data
    and automatically deploys it if quality criteria are met.
    """,
) as dag:

    # ===== Stage 1: Data Arrival Detection =====
    wait_for_data = S3KeySensor(
        task_id='wait_for_training_data',
        bucket_name='ml-datasets',
        bucket_key='daily/{{ ds }}/features.parquet',
        aws_conn_id='aws_default',
        poke_interval=300,
        timeout=3600 * 6,
        mode='reschedule',
        deferrable=True,
    )

    # ===== Stage 2: Data Preprocessing =====
    with TaskGroup('data_preparation') as data_prep:

        validate_data = KubernetesPodOperator(
            task_id='validate_data',
            name='data-validation',
            namespace='ml-training',
            image='ml-tools:latest',
            cmds=['python', 'validate.py'],
            arguments=['--date', '{{ ds }}', '--source', 's3://ml-datasets/daily/{{ ds }}/'],
            env_vars=[
                k8s.V1EnvVar(name='DATA_DATE', value='{{ ds }}'),
            ],
            container_resources=k8s.V1ResourceRequirements(
                requests={'cpu': '2', 'memory': '8Gi'},
                limits={'cpu': '4', 'memory': '16Gi'},
            ),
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
        )

        preprocess = KubernetesPodOperator(
            task_id='preprocess_data',
            name='data-preprocessing',
            namespace='ml-training',
            image='ml-tools:latest',
            cmds=['python', 'preprocess.py'],
            arguments=[
                '--input', 's3://ml-datasets/daily/{{ ds }}/',
                '--output', 's3://ml-processed/{{ ds }}/',
            ],
            container_resources=k8s.V1ResourceRequirements(
                requests={'cpu': '4', 'memory': '32Gi'},
                limits={'cpu': '8', 'memory': '64Gi'},
            ),
            volumes=[DATA_VOLUME],
            volume_mounts=[DATA_VOLUME_MOUNT],
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
        )

        validate_data >> preprocess

    # ===== Stage 3: Hyperparameter Generation and Parallel Training =====
    @task
    def generate_hyperparams():
        """Generate hyperparameter combinations for training"""
        import itertools
        learning_rates = [0.001, 0.0001]
        batch_sizes = [32, 64]
        optimizers = ['adam', 'adamw']

        combos = list(itertools.product(learning_rates, batch_sizes, optimizers))
        return [
            [
                'train.py',
                '--lr', str(lr),
                '--batch-size', str(bs),
                '--optimizer', opt,
                '--data-path', 's3://ml-processed/{{ ds }}/',
                '--experiment-name', 'daily-training-{{ ds }}',
            ]
            for lr, bs, opt in combos
        ]

    hp_args = generate_hyperparams()

    with TaskGroup('hyperparameter_tuning') as hp_tuning:

        train_tasks = KubernetesPodOperator.partial(
            task_id='train_with_hp',
            name='gpu-hp-training',
            namespace='ml-training',
            image='ml-trainer:latest',
            cmds=['python'],
            container_resources=GPU_RESOURCES,
            tolerations=GPU_TOLERATIONS,
            node_selector=GPU_NODE_SELECTOR,
            volumes=[DATA_VOLUME],
            volume_mounts=[DATA_VOLUME_MOUNT],
            env_vars=[
                k8s.V1EnvVar(name='MLFLOW_TRACKING_URI', value='http://mlflow:5000'),
                k8s.V1EnvVar(name='CUDA_VISIBLE_DEVICES', value='0'),
            ],
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
            startup_timeout_seconds=600,
        ).expand(arguments=hp_args)

    # ===== Stage 4: Optimal Model Selection =====
    @task
    def select_best_model(training_results):
        """Select the optimal model from all training results"""
        valid_results = [r for r in training_results if r is not None]
        if not valid_results:
            raise ValueError("No valid training results found")

        best = max(valid_results, key=lambda x: x.get('accuracy', 0))
        print(f"Best model: accuracy={best['accuracy']}, run_id={best['run_id']}")
        return best

    best_model = select_best_model(train_tasks)

    # ===== Stage 5: Model Evaluation =====
    @task(multiple_outputs=True)
    def evaluate_model(model_info: dict):
        """Detailed evaluation of the optimal model"""
        import mlflow
        from mlflow.tracking import MlflowClient

        client = MlflowClient("http://mlflow:5000")

        # Compare with current Production model
        try:
            prod_versions = client.get_latest_versions(
                "production-model", stages=["Production"]
            )
            if prod_versions:
                prod_run = client.get_run(prod_versions[0].run_id)
                baseline_accuracy = float(prod_run.data.metrics.get('accuracy', 0))
            else:
                baseline_accuracy = 0.0
        except Exception:
            baseline_accuracy = 0.0

        new_accuracy = model_info['accuracy']
        improvement = new_accuracy - baseline_accuracy

        return {
            'should_deploy': improvement > 0.005 and new_accuracy > 0.90,
            'new_accuracy': new_accuracy,
            'baseline_accuracy': baseline_accuracy,
            'improvement': improvement,
            'run_id': model_info['run_id'],
            'model_path': model_info.get('model_path', ''),
        }

    eval_result = evaluate_model(best_model)

    # ===== Stage 6: Conditional Deployment =====
    def decide_deployment(**context):
        should_deploy = context['ti'].xcom_pull(
            task_ids='evaluate_model', key='should_deploy'
        )
        if should_deploy:
            return 'deploy_model'
        return 'skip_deployment'

    deployment_branch = BranchPythonOperator(
        task_id='deployment_decision',
        python_callable=decide_deployment,
    )

    deploy = KubernetesPodOperator(
        task_id='deploy_model',
        name='model-deployment',
        namespace='ml-serving',
        image='ml-deployer:latest',
        cmds=['python', 'deploy.py'],
        arguments=[
            '--run-id', "{{ ti.xcom_pull(task_ids='evaluate_model', key='run_id') }}",
            '--model-name', 'production-model',
        ],
        container_resources=k8s.V1ResourceRequirements(
            requests={'cpu': '2', 'memory': '4Gi'},
            limits={'cpu': '4', 'memory': '8Gi'},
        ),
        get_logs=True,
        is_delete_operator_pod=True,
    )

    @task(task_id='skip_deployment')
    def skip_deployment():
        print("Model quality does not meet deployment criteria. Skipping.")

    skip = skip_deployment()

    # ===== Stage 7: Notification =====
    @task(trigger_rule='none_failed_min_one_success')
    def send_notification(**context):
        """Send training result notification"""
        eval_data = context['ti'].xcom_pull(task_ids='evaluate_model', key='return_value')
        print(f"Pipeline completed. Accuracy: {eval_data.get('new_accuracy')}")
        print(f"Improvement: {eval_data.get('improvement')}")
        # Slack / Email notification logic

    notification = send_notification()

    # ===== DAG Dependency Definition =====
    wait_for_data >> data_prep >> hp_args >> hp_tuning
    hp_tuning >> best_model >> eval_result >> deployment_branch
    deployment_branch >> [deploy, skip] >> notification

This DAG integrates the following Airflow features:

  • S3KeySensor: Data arrival waiting (deferrable mode)
  • KubernetesPodOperator: GPU-based training Job execution
  • Dynamic Task Mapping: Parallel hyperparameter training via expand()
  • TaskFlow API: Python function-based Task definition via @task decorator
  • XCom: Training metric and model path passing
  • TaskGroup: Visual stage separation
  • BranchPythonOperator: Conditional deployment based on model quality

In actual production, error handling, SLA settings, and notification integration should be added to improve reliability.

12. References

The content analyzed in this article is based on the following Apache Airflow official documentation and related resources.

Apache Airflow Official Documentation