Skip to content

필사 모드: Orchestrating ML Training Pipelines with Airflow

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

1. The Role of Airflow in ML Pipelines

In ML projects, model training does not end simply by calling `model.fit()`. A series of steps -- data collection, preprocessing, Feature Engineering, model training, evaluation, model registry registration, and deployment -- must be executed in a **reproducible and automated manner**. This is where a Workflow Orchestration tool is needed, and Apache Airflow is one of the most mature open-source projects in this domain.

Airflow vs Other Orchestration Tools

Besides Airflow, other ML pipeline orchestration tools include Kubeflow Pipelines, Prefect, and Dagster. Let us compare their characteristics.

**Apache Airflow** is a general-purpose workflow orchestrator. While it is not specialized for ML, it is widely used in ML pipelines thanks to its rich Operator ecosystem, powerful scheduling capabilities, and scalability on Kubernetes. With Airflow 3.x, features such as DAG versioning, event-driven scheduling, and multi-language support have been added, making it even more powerful.

**Kubeflow Pipelines** is a Kubernetes-native ML pipeline tool specialized for ML workflows. It provides integrated GPU scheduling, experiment tracking, and model serving, but requires deep understanding of Kubernetes and is difficult to use outside Kubernetes environments.

**Prefect** is often described as "Airflow, but nicer," with its Python-native interface and quick setup being key advantages. It offers dynamic flow, hybrid execution, and real-time SLA alerting, making it particularly suitable for small teams or PoC stages where rapid adoption is needed.

**Dagster** is a modern orchestrator that takes an asset-centric approach. With a philosophy of designing data pipelines centered around "the data being produced" rather than "execution steps," it has strengths in data lineage tracking.

| Characteristic | Airflow | Kubeflow Pipelines | Prefect | Dagster |

| --------------------- | --------------------- | ------------------ | ---------------------- | ---------------------- |

| ML Specialization | General-purpose | ML-specialized | General-purpose | Data-centric |

| Learning Curve | Medium | High | Low | Medium |

| Kubernetes Dependency | Optional | Required | Optional | Optional |

| Community/Ecosystem | Very large | Large | Growing | Growing |

| GPU Scheduling | KubernetesPodOperator | Native | Kubernetes integration | Kubernetes integration |

| Scheduling | Very powerful | Limited | Powerful | Powerful |

The biggest reason for choosing Airflow as an ML pipeline orchestrator is its **tool-agnostic** nature. Airflow can integrate with various ML tools such as MLflow, SageMaker, and Vertex AI, and can manage ML and non-ML data pipelines (ETL, data quality checks, etc.) on the same infrastructure.

2. Reviewing Airflow Core Concepts: DAG, Operator, Task, XCom

According to the official Apache Airflow documentation, the core concepts are as follows.

DAG (Directed Acyclic Graph)

A DAG is a collection of all Tasks you want to run, organized in a way that reflects their relationships and dependencies. DAGs are defined as Python scripts, and the code itself expresses the structure of the DAG (Tasks and their dependencies). A DAG run is an instantiation of a DAG, containing Task Instances that execute for a specific `execution_date`.

from airflow import DAG

from datetime import datetime

with DAG(

dag_id='ml_training_pipeline',

start_date=datetime(2026, 1, 1),

schedule_interval='@daily',

catchup=False,

tags=['ml', 'training'],

) as dag:

Task definitions

pass

Operator

An Operator is a pre-defined template for a Task. According to the official documentation, "If a DAG describes how to run a workflow, Operators determine what actually gets done by a Task." All Operators inherit from `BaseOperator`, with PythonOperator, BashOperator, and KubernetesPodOperator being representative examples.

Task

A Task is an instance of an Operator. It is an executable unit within a DAG, and a Task Instance represents a specific execution of a Task in a particular DAG run.

XCom (Cross-Communication)

XCom is a mechanism for passing data between Tasks. According to the official documentation, "by default, Tasks are entirely isolated and may run on completely different machines," which is why XCom is needed. XCom is identified by `key`, `task_id`, and `dag_id`.

Push

task_instance.xcom_push(key="model_accuracy", value=0.95)

Pull

accuracy = task_instance.xcom_pull(key="model_accuracy", task_ids="evaluate_model")

An important point is that XCom is designed for **transferring small amounts of data**. You should not pass large data such as DataFrames through XCom; in such cases, you should use an Object Storage Backend.

3. Running GPU Training Jobs with KubernetesPodOperator

In ML training pipelines, KubernetesPodOperator plays a central role. According to the official documentation, KubernetesPodOperator is "a stand-in replacement for Kubernetes object spec definitions" that can be executed by the Airflow scheduler within the DAG context. In other words, you can define and run Kubernetes Pods directly in Python code without writing separate YAML/JSON files for Pod specs.

The reasons KubernetesPodOperator is particularly useful for GPU-based ML training are as follows:

- **Task-level resource configuration**: CPU, Memory, and GPU can be independently allocated for each Task

- **Custom dependency support**: When packages not on PyPI or a specific CUDA version is needed, a custom Docker image can be specified

- **Language agnosticism**: Training code written in any language, not just Python, can be wrapped in a container and executed

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

from kubernetes.client import models as k8s

train_model = KubernetesPodOperator(

task_id='train_model',

name='gpu-training-job',

namespace='ml-training',

image='my-registry/ml-trainer:latest',

cmds=['python'],

arguments=['train.py', '--epochs', '100', '--batch-size', '64'],

container_resources=k8s.V1ResourceRequirements(

requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},

limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},

),

get_logs=True,

is_delete_operator_pod=True,

do_xcom_push=True,

)

By setting `do_xcom_push=True`, you can return data as XCom by writing JSON to the `/airflow/xcom/return.json` file inside the Pod. This is useful for passing training result metrics to the next Task.

4. Analysis of Key KubernetesPodOperator Parameters

This section analyzes the key parameters of KubernetesPodOperator from an ML training perspective, based on the official documentation.

container_resources

This is the most important parameter for GPU training. Using `kubernetes.client.models.V1ResourceRequirements`, you can request and limit CPU, Memory, and GPU resources.

container_resources = k8s.V1ResourceRequirements(

requests={

'cpu': '4',

'memory': '16Gi',

'nvidia.com/gpu': '1'

},

limits={

'cpu': '8',

'memory': '32Gi',

'nvidia.com/gpu': '1'

},

)

For GPUs, the `nvidia.com/gpu` resource key is used, and requests and limits are typically set to the same value. This is because GPUs cannot be fractionally allocated (unless MIG is used).

tolerations

GPU nodes are usually configured with taints to prevent general workloads from being scheduled. `tolerations` allow Pods to be scheduled on GPU nodes.

tolerations = [

k8s.V1Toleration(

key='nvidia.com/gpu',

operator='Exists',

effect='NoSchedule',

),

]

node_selector

You can specify that training Pods should only run on specific node groups. For example, to run only on nodes equipped with A100 GPUs:

node_selector = {

'gpu-type': 'a100',

'node-pool': 'ml-training',

}

affinity

When more fine-grained scheduling control is needed, use `V1Affinity`. You can set Node affinity and Pod affinity/anti-affinity to place Pods in specific zones or nodes.

affinity = k8s.V1Affinity(

node_affinity=k8s.V1NodeAffinity(

required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(

node_selector_terms=[

k8s.V1NodeSelectorTerm(

match_expressions=[

k8s.V1NodeSelectorRequirement(

key='gpu-type',

operator='In',

values=['a100', 'h100'],

)

]

)

]

)

)

)

volumes / volume_mounts

Used for mounting training datasets or checkpoints via PVC (PersistentVolumeClaim).

volume = k8s.V1Volume(

name='training-data',

persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(

claim_name='ml-dataset-pvc'

),

)

volume_mount = k8s.V1VolumeMount(

name='training-data',

mount_path='/data',

read_only=True,

)

Other Key Parameters

| Parameter | Description | ML Training Usage |

| ------------------------ | --------------------------- | ----------------------------------------- |

| `image_pull_secrets` | Private Registry auth | Accessing internal ML images |

| `env_vars` | Environment variable list | CUDA_VISIBLE_DEVICES, WANDB_API_KEY, etc. |

| `secrets` | Secret volumes/env vars | API Key, credential management |

| `is_delete_operator_pod` | Delete Pod after completion | Resource cleanup (True recommended) |

| `get_logs` | Display logs in Airflow UI | Training log monitoring |

| `deferrable` | Async execution mode | Worker efficiency for long training |

| `on_finish_action` | Post-completion action | Auto cleanup with `delete_pod` |

The official documentation recommends using native objects from `kubernetes.client.models` instead of convenience classes for type safety.

5. DAG Design Pattern: Data Preprocessing to Training to Evaluation to Deployment

The general DAG structure for an ML training pipeline includes the following stages:

Data Validation -> Preprocessing -> Feature Engineering -> Training -> Evaluation -> [Conditional] Deployment

Mapping each stage to Airflow Tasks results in the following design pattern:

from airflow import DAG

from airflow.operators.python import BranchPythonOperator

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

from datetime import datetime

with DAG(

dag_id='ml_training_pipeline',

start_date=datetime(2026, 1, 1),

schedule_interval='@daily',

catchup=False,

) as dag:

validate_data = KubernetesPodOperator(

task_id='validate_data',

image='ml-tools:latest',

cmds=['python', 'validate.py'],

Only CPU needed

)

preprocess = KubernetesPodOperator(

task_id='preprocess_data',

image='ml-tools:latest',

cmds=['python', 'preprocess.py'],

Large memory needed

)

train = KubernetesPodOperator(

task_id='train_model',

image='ml-trainer:latest',

cmds=['python', 'train.py'],

GPU needed

do_xcom_push=True,

)

evaluate = KubernetesPodOperator(

task_id='evaluate_model',

image='ml-trainer:latest',

cmds=['python', 'evaluate.py'],

do_xcom_push=True,

)

def check_model_quality(**context):

metrics = context['ti'].xcom_pull(task_ids='evaluate_model')

if metrics['accuracy'] > 0.90:

return 'deploy_model'

return 'notify_failure'

branch = BranchPythonOperator(

task_id='check_quality',

python_callable=check_model_quality,

)

deploy = KubernetesPodOperator(

task_id='deploy_model',

image='ml-deployer:latest',

cmds=['python', 'deploy.py'],

)

validate_data >> preprocess >> train >> evaluate >> branch

branch >> [deploy]

The key to this design pattern is **conditional deployment**. The evaluation stage passes model performance metrics via XCom, and through `BranchPythonOperator`, the deployment Task is executed only when quality criteria are met. This prevents low-quality models from being deployed to production.

Stage Separation Using TaskGroup

In complex pipelines, TaskGroup can be used to visually separate stages. According to the official documentation, "TaskGroup is used to organize Tasks into hierarchical groups in the Graph View. It is useful for creating repeating patterns and reducing visual complexity."

from airflow.utils.task_group import TaskGroup

with TaskGroup('data_preparation') as data_prep:

validate_data >> preprocess >> feature_engineering

with TaskGroup('model_training') as training:

train >> evaluate

data_prep >> training >> branch

6. Passing Metrics/Artifacts Between Tasks with XCom

According to the official documentation, XCom is explicitly "pushed" to and "pulled" from storage via the `xcom_push` and `xcom_pull` methods. In ML pipelines, XCom is used for the following purposes:

Passing Training Metrics

train.py (executed inside KubernetesPodOperator)

metrics = {

'accuracy': 0.9534,

'f1_score': 0.9421,

'loss': 0.0312,

'model_path': 's3://ml-models/experiment-42/model.pt',

'run_id': 'exp-42-20260301',

}

Write JSON to /airflow/xcom/return.json

with open('/airflow/xcom/return.json', 'w') as f:

json.dump(metrics, f)

Auto-Push and Return Value

Many Operators and the `@task` decorator automatically push the return value as an XCom with the key `return_value` when `do_xcom_push=True` (default). You can then pull it as follows:

value = task_instance.xcom_pull(task_ids='train_model')

Multiple Outputs

When returning a dictionary with `multiple_outputs=True`, each key is stored as an individual XCom.

@task(multiple_outputs=True)

def evaluate_model(**context):

return {

'accuracy': 0.95,

'f1_score': 0.94,

'model_path': 's3://models/latest.pt',

}

You can then pull by individual key or pull the entire `return_value`.

Custom XCom Backend

The default XCom Backend, `BaseXCom`, stores XCom in the Airflow database. While this is fine for small amounts of data, a Custom Backend is needed when dealing with model artifacts or large data. You can override the `serialize_value` and `deserialize_value` methods by inheriting from `BaseXCom`.

To use the Object Storage Backend, set the `xcom_backend` configuration to `airflow.providers.common.io.xcom.backend.XComObjectStorageBackend`. This allows storing XCom data in S3 or GCS.

[core]

xcom_backend = airflow.providers.common.io.xcom.backend.XComObjectStorageBackend

XCom Usage Precautions

The key points emphasized in the official documentation are:

- XCom is designed for **transferring small amounts of data**. Do not pass large data such as DataFrames

- When a Task is retried after failure, previous XCom is automatically cleared to ensure idempotent execution

- XCom operations must be performed within the Task Context via `get_current_context()`, and direct DB updates are not supported

As a practical guideline for ML pipelines, metrics (accuracy, loss, etc.) and path information (model_path, artifact_uri, etc.) should be passed via XCom, while actual model files and datasets should be stored in Object Storage like S3/GCS, with only the paths shared via XCom.

7. Hyperparameter Tuning with Dynamic Task Mapping

Dynamic Task Mapping was introduced in Airflow 2.3 and, according to the official documentation, is "a mechanism that allows creating multiple Tasks at runtime based on current data." The DAG author does not need to know in advance how many Tasks are needed.

This feature is very powerful for ML hyperparameter tuning because it enables parallel training with various hyperparameter combinations.

expand() and partial()

`expand()` specifies the parameters to map, while `partial()` specifies fixed parameters common to all Tasks.

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

hyperparams = [

{'lr': '0.001', 'batch_size': '32', 'optimizer': 'adam'},

{'lr': '0.01', 'batch_size': '64', 'optimizer': 'adam'},

{'lr': '0.001', 'batch_size': '32', 'optimizer': 'sgd'},

{'lr': '0.0001', 'batch_size': '128', 'optimizer': 'adamw'},

]

train_tasks = KubernetesPodOperator.partial(

task_id='hyperparam_training',

image='ml-trainer:latest',

namespace='ml-training',

get_logs=True,

is_delete_operator_pod=True,

do_xcom_push=True,

container_resources=k8s.V1ResourceRequirements(

requests={'nvidia.com/gpu': '1'},

limits={'nvidia.com/gpu': '1'},

),

).expand(

arguments=[

['train.py', '--lr', hp['lr'], '--batch-size', hp['batch_size'], '--optimizer', hp['optimizer']]

for hp in hyperparams

],

)

This code automatically creates 4 parallel KubernetesPodOperator Tasks for 4 hyperparameter combinations.

Dynamically Generating Mapping Data from Tasks

Furthermore, you can dynamically generate hyperparameter combinations from upstream Tasks.

@task

def generate_hyperparams():

"""Generate hyperparameters via Grid Search or Random Search"""

learning_rates = [0.001, 0.01, 0.0001]

batch_sizes = [32, 64, 128]

optimizers = ['adam', 'sgd']

combinations = list(itertools.product(learning_rates, batch_sizes, optimizers))

return [

{'lr': str(lr), 'batch_size': str(bs), 'optimizer': opt}

for lr, bs, opt in combinations

]

hp_list = generate_hyperparams()

According to the official documentation, "mappings generated from a Task prohibit the use of `trigger_rule=TriggerRule.ALWAYS`."

Cross-Product Mapping

Specifying multiple `expand()` parameters generates all combinations (cross product).

@task

def train(lr: float, batch_size: int):

Training logic

pass

train.expand(lr=[0.001, 0.01], batch_size=[32, 64])

2 x 2 = 4 Task Instances created

Map-Reduce Pattern

A pattern for collecting results from mapped Tasks and selecting the optimal model.

@task

def select_best_model(results):

"""Select the best among all hyperparameter combination results"""

best = max(results, key=lambda x: x['accuracy'])

return best

training_results = train_task.expand(params=hyperparams)

best = select_best_model(training_results)

According to the official documentation, the collected results are returned as a "lazy proxy" sequence, not an eager list.

Constraints

| Item | Setting | Default |

| --------------------------- | ------------------------ | ---------------- |

| Max mapped instances | `[core] max_map_length` | 1024 |

| Parallel execution per Task | `max_active_tis_per_dag` | Per-task setting |

Only **list** and **dict** types can be mapped; other types will raise an `UnmappableXComTypePushed` error. Additionally, according to the official documentation, "if a field is marked as templated and is mapped, it will not be template-rendered."

8. Detecting Data Arrival with Sensors

In ML pipelines, it is often necessary to wait until training data is ready. Airflow's Sensors are special Operators that wait until a specific condition is met.

S3KeySensor

Waits until training data is uploaded to S3.

from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor

wait_for_data = S3KeySensor(

task_id='wait_for_training_data',

bucket_name='ml-datasets',

bucket_key='daily/{{ ds }}/training_data.parquet',

aws_conn_id='aws_default',

poke_interval=300, # Check every 5 minutes

timeout=3600 * 6, # Wait up to 6 hours

mode='reschedule', # Return Worker slot

deferrable=True, # Async wait using Triggerer

)

`mode='reschedule'` returns the Worker slot between pokes to improve resource efficiency. Setting `deferrable=True` allows the Triggerer component to handle polling asynchronously for even more efficient Worker utilization.

ExternalTaskSensor

Waits until a Task from another DAG completes. For example, a pattern where the training DAG starts after the data preprocessing DAG completes.

from airflow.sensors.external_task import ExternalTaskSensor

wait_for_preprocessing = ExternalTaskSensor(

task_id='wait_for_preprocessing',

external_dag_id='data_preprocessing_pipeline',

external_task_id='final_validation',

allowed_states=['success'],

execution_delta=timedelta(hours=0),

timeout=3600,

poke_interval=60,

mode='reschedule',

)

`allowed_states` is a list of permitted states, with `['success']` as the default. `execution_delta` specifies the time difference from the previous execution to check.

Practical Pattern: Starting Training After Data Arrival

wait_for_data >> validate_data >> preprocess >> train >> evaluate >> deploy

This pattern loosely couples the data pipeline and ML pipeline while ensuring that training only starts after data is ready.

9. MLflow and Airflow Integration Patterns

The role division between Airflow and MLflow integration is clear. **Airflow manages "when and in what order to execute," while MLflow records "what happened during execution and where the model is."** Through this separation of concerns, data engineers can manage Airflow DAGs without touching ML code, and data scientists can focus on model development without worrying about scheduling infrastructure.

Integration Pattern 1: Direct MLflow Calls Inside Training Tasks

This is the most intuitive pattern. MLflow APIs are called directly inside training containers executed by KubernetesPodOperator.

train.py (executed inside container)

mlflow.set_tracking_uri("http://mlflow-server:5000")

mlflow.set_experiment("daily-training")

with mlflow.start_run() as run:

Hyperparameter logging

mlflow.log_params({

'learning_rate': 0.001,

'batch_size': 64,

'epochs': 100,

})

Model training

model = train_model(...)

Metric logging

mlflow.log_metrics({

'accuracy': 0.95,

'f1_score': 0.94,

})

Model registration

mlflow.pytorch.log_model(model, "model")

Pass run_id via XCom

with open('/airflow/xcom/return.json', 'w') as f:

json.dump({'run_id': run.info.run_id}, f)

Integration Pattern 2: Utilizing MLflow Model Registry from Airflow

After training is complete, the model is registered in the MLflow Model Registry based on evaluation results and the stage is transitioned.

@task

def register_model(run_id: str, accuracy: float):

from mlflow.tracking import MlflowClient

client = MlflowClient("http://mlflow-server:5000")

Register model

model_uri = f"runs:/{run_id}/model"

mv = mlflow.register_model(model_uri, "production-model")

Transition to Production stage if accuracy exceeds threshold

if accuracy > 0.93:

client.transition_model_version_stage(

name="production-model",

version=mv.version,

stage="Production",

)

Integration Pattern 3: Querying and Comparing Training History from MLflow

@task

def compare_with_baseline():

from mlflow.tracking import MlflowClient

client = MlflowClient("http://mlflow-server:5000")

Query Production model metrics

prod_versions = client.get_latest_versions("production-model", stages=["Production"])

if prod_versions:

prod_run = client.get_run(prod_versions[0].run_id)

baseline_accuracy = float(prod_run.data.metrics['accuracy'])

return {'baseline_accuracy': baseline_accuracy}

return {'baseline_accuracy': 0.0}

10. Utilizing the TaskFlow API (@task Decorator)

According to the official documentation, the TaskFlow API is "a functional API that defines DAGs and Tasks using decorators," greatly simplifying data transfer and dependency definition between Tasks.

Key Characteristics

The biggest advantage of TaskFlow is **automatic XCom management** and **automatic dependency calculation**. When a TaskFlow function is called, it is not executed immediately; instead, an `XComArg` object representing the result is returned. Using this as input to downstream Tasks causes dependencies to be calculated automatically.

from airflow.sdk import dag, task

from datetime import datetime

@dag(

schedule='@daily',

start_date=datetime(2026, 1, 1),

catchup=False,

)

def ml_training_taskflow():

@task

def extract_data():

"""Data extraction"""

return {'data_path': 's3://bucket/data/2026-03-01.parquet', 'num_rows': 100000}

@task(multiple_outputs=True)

def preprocess(data_info: dict):

"""Data preprocessing"""

return {

'train_path': f"{data_info['data_path']}/train.parquet",

'test_path': f"{data_info['data_path']}/test.parquet",

'feature_count': 128,

}

@task

def train_model(train_path: str, feature_count: int):

"""Model training"""

In practice, GPU training would be run via KubernetesPodOperator

return {

'model_path': 's3://models/latest.pt',

'accuracy': 0.95,

}

@task

def evaluate_and_deploy(model_info: dict):

"""Model evaluation and conditional deployment"""

if model_info['accuracy'] > 0.90:

return f"Deployed model from {model_info['model_path']}"

return "Model quality below threshold, skipping deployment"

Automatic dependency wiring

data = extract_data()

processed = preprocess(data)

model = train_model(

train_path=processed['train_path'],

feature_count=processed['feature_count'],

)

evaluate_and_deploy(model)

ml_training_taskflow()

Context Access

Tasks can receive Airflow context variables as keyword arguments.

@task

def log_execution_info(task_instance=None, dag_run=None):

print(f"Run ID: {task_instance.run_id}")

print(f"DAG Run: {dag_run.dag_id}")

Object Serialization

TaskFlow supports custom object passing via `@dataclass`, `@attr.define`, or custom `serialize()`/`deserialize()` methods. Version management is also possible using `__version__: ClassVar[int]`.

11. Production DAG Code Example (Full ML Pipeline)

The following is a production ML training pipeline DAG code that integrates all the concepts covered so far. It includes the entire process from data arrival detection to hyperparameter tuning, optimal model selection, and deployment.

"""

ML Training Pipeline DAG

- S3 data arrival detection

- Data preprocessing (KubernetesPodOperator)

- Hyperparameter tuning with Dynamic Task Mapping

- Optimal model selection and MLflow registration

- Conditional deployment

"""

from __future__ import annotations

from datetime import datetime, timedelta

from airflow import DAG

from airflow.sdk import task

from airflow.operators.python import BranchPythonOperator

from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

from airflow.utils.task_group import TaskGroup

from kubernetes.client import models as k8s

-- Common Configuration --

GPU_RESOURCES = k8s.V1ResourceRequirements(

requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},

limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},

)

GPU_TOLERATIONS = [

k8s.V1Toleration(

key='nvidia.com/gpu',

operator='Exists',

effect='NoSchedule',

),

]

GPU_NODE_SELECTOR = {'gpu-type': 'a100', 'node-pool': 'ml-training'}

DATA_VOLUME = k8s.V1Volume(

name='shared-data',

persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(

claim_name='ml-shared-pvc'

),

)

DATA_VOLUME_MOUNT = k8s.V1VolumeMount(

name='shared-data',

mount_path='/shared',

)

default_args = {

'owner': 'ml-team',

'retries': 2,

'retry_delay': timedelta(minutes=5),

'execution_timeout': timedelta(hours=4),

}

with DAG(

dag_id='ml_full_training_pipeline',

default_args=default_args,

start_date=datetime(2026, 1, 1),

schedule_interval='@daily',

catchup=False,

tags=['ml', 'training', 'production'],

max_active_runs=1,

doc_md="""

ML Full Training Pipeline

A pipeline that retrains the model daily with new data

and automatically deploys it if quality criteria are met.

""",

) as dag:

===== Stage 1: Data Arrival Detection =====

wait_for_data = S3KeySensor(

task_id='wait_for_training_data',

bucket_name='ml-datasets',

bucket_key='daily/{{ ds }}/features.parquet',

aws_conn_id='aws_default',

poke_interval=300,

timeout=3600 * 6,

mode='reschedule',

deferrable=True,

)

===== Stage 2: Data Preprocessing =====

with TaskGroup('data_preparation') as data_prep:

validate_data = KubernetesPodOperator(

task_id='validate_data',

name='data-validation',

namespace='ml-training',

image='ml-tools:latest',

cmds=['python', 'validate.py'],

arguments=['--date', '{{ ds }}', '--source', 's3://ml-datasets/daily/{{ ds }}/'],

env_vars=[

k8s.V1EnvVar(name='DATA_DATE', value='{{ ds }}'),

],

container_resources=k8s.V1ResourceRequirements(

requests={'cpu': '2', 'memory': '8Gi'},

limits={'cpu': '4', 'memory': '16Gi'},

),

get_logs=True,

is_delete_operator_pod=True,

do_xcom_push=True,

)

preprocess = KubernetesPodOperator(

task_id='preprocess_data',

name='data-preprocessing',

namespace='ml-training',

image='ml-tools:latest',

cmds=['python', 'preprocess.py'],

arguments=[

'--input', 's3://ml-datasets/daily/{{ ds }}/',

'--output', 's3://ml-processed/{{ ds }}/',

],

container_resources=k8s.V1ResourceRequirements(

requests={'cpu': '4', 'memory': '32Gi'},

limits={'cpu': '8', 'memory': '64Gi'},

),

volumes=[DATA_VOLUME],

volume_mounts=[DATA_VOLUME_MOUNT],

get_logs=True,

is_delete_operator_pod=True,

do_xcom_push=True,

)

validate_data >> preprocess

===== Stage 3: Hyperparameter Generation and Parallel Training =====

@task

def generate_hyperparams():

"""Generate hyperparameter combinations for training"""

learning_rates = [0.001, 0.0001]

batch_sizes = [32, 64]

optimizers = ['adam', 'adamw']

combos = list(itertools.product(learning_rates, batch_sizes, optimizers))

return [

[

'train.py',

'--lr', str(lr),

'--batch-size', str(bs),

'--optimizer', opt,

'--data-path', 's3://ml-processed/{{ ds }}/',

'--experiment-name', 'daily-training-{{ ds }}',

]

for lr, bs, opt in combos

]

hp_args = generate_hyperparams()

with TaskGroup('hyperparameter_tuning') as hp_tuning:

train_tasks = KubernetesPodOperator.partial(

task_id='train_with_hp',

name='gpu-hp-training',

namespace='ml-training',

image='ml-trainer:latest',

cmds=['python'],

container_resources=GPU_RESOURCES,

tolerations=GPU_TOLERATIONS,

node_selector=GPU_NODE_SELECTOR,

volumes=[DATA_VOLUME],

volume_mounts=[DATA_VOLUME_MOUNT],

env_vars=[

k8s.V1EnvVar(name='MLFLOW_TRACKING_URI', value='http://mlflow:5000'),

k8s.V1EnvVar(name='CUDA_VISIBLE_DEVICES', value='0'),

],

get_logs=True,

is_delete_operator_pod=True,

do_xcom_push=True,

startup_timeout_seconds=600,

).expand(arguments=hp_args)

===== Stage 4: Optimal Model Selection =====

@task

def select_best_model(training_results):

"""Select the optimal model from all training results"""

valid_results = [r for r in training_results if r is not None]

if not valid_results:

raise ValueError("No valid training results found")

best = max(valid_results, key=lambda x: x.get('accuracy', 0))

print(f"Best model: accuracy={best['accuracy']}, run_id={best['run_id']}")

return best

best_model = select_best_model(train_tasks)

===== Stage 5: Model Evaluation =====

@task(multiple_outputs=True)

def evaluate_model(model_info: dict):

"""Detailed evaluation of the optimal model"""

from mlflow.tracking import MlflowClient

client = MlflowClient("http://mlflow:5000")

Compare with current Production model

try:

prod_versions = client.get_latest_versions(

"production-model", stages=["Production"]

)

if prod_versions:

prod_run = client.get_run(prod_versions[0].run_id)

baseline_accuracy = float(prod_run.data.metrics.get('accuracy', 0))

else:

baseline_accuracy = 0.0

except Exception:

baseline_accuracy = 0.0

new_accuracy = model_info['accuracy']

improvement = new_accuracy - baseline_accuracy

return {

'should_deploy': improvement > 0.005 and new_accuracy > 0.90,

'new_accuracy': new_accuracy,

'baseline_accuracy': baseline_accuracy,

'improvement': improvement,

'run_id': model_info['run_id'],

'model_path': model_info.get('model_path', ''),

}

eval_result = evaluate_model(best_model)

===== Stage 6: Conditional Deployment =====

def decide_deployment(**context):

should_deploy = context['ti'].xcom_pull(

task_ids='evaluate_model', key='should_deploy'

)

if should_deploy:

return 'deploy_model'

return 'skip_deployment'

deployment_branch = BranchPythonOperator(

task_id='deployment_decision',

python_callable=decide_deployment,

)

deploy = KubernetesPodOperator(

task_id='deploy_model',

name='model-deployment',

namespace='ml-serving',

image='ml-deployer:latest',

cmds=['python', 'deploy.py'],

arguments=[

'--run-id', "{{ ti.xcom_pull(task_ids='evaluate_model', key='run_id') }}",

'--model-name', 'production-model',

],

container_resources=k8s.V1ResourceRequirements(

requests={'cpu': '2', 'memory': '4Gi'},

limits={'cpu': '4', 'memory': '8Gi'},

),

get_logs=True,

is_delete_operator_pod=True,

)

@task(task_id='skip_deployment')

def skip_deployment():

print("Model quality does not meet deployment criteria. Skipping.")

skip = skip_deployment()

===== Stage 7: Notification =====

@task(trigger_rule='none_failed_min_one_success')

def send_notification(**context):

"""Send training result notification"""

eval_data = context['ti'].xcom_pull(task_ids='evaluate_model', key='return_value')

print(f"Pipeline completed. Accuracy: {eval_data.get('new_accuracy')}")

print(f"Improvement: {eval_data.get('improvement')}")

Slack / Email notification logic

notification = send_notification()

===== DAG Dependency Definition =====

wait_for_data >> data_prep >> hp_args >> hp_tuning

hp_tuning >> best_model >> eval_result >> deployment_branch

deployment_branch >> [deploy, skip] >> notification

This DAG integrates the following Airflow features:

- **S3KeySensor**: Data arrival waiting (deferrable mode)

- **KubernetesPodOperator**: GPU-based training Job execution

- **Dynamic Task Mapping**: Parallel hyperparameter training via `expand()`

- **TaskFlow API**: Python function-based Task definition via `@task` decorator

- **XCom**: Training metric and model path passing

- **TaskGroup**: Visual stage separation

- **BranchPythonOperator**: Conditional deployment based on model quality

In actual production, error handling, SLA settings, and notification integration should be added to improve reliability.

12. References

The content analyzed in this article is based on the following Apache Airflow official documentation and related resources.

Apache Airflow Official Documentation

- [Airflow Core Concepts - Overview](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/overview.html)

- [Airflow Core Concepts - Operators](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/operators.html)

- [Airflow Core Concepts - XComs](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/xcoms.html)

- [Airflow Core Concepts - TaskFlow](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/taskflow.html)

- [Airflow Tutorial - TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/tutorial/taskflow.html)

- [Dynamic Task Mapping](https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/dynamic-task-mapping.html)

- [KubernetesPodOperator](https://airflow.apache.org/docs/apache-airflow-providers-cncf-kubernetes/stable/operators.html)

- [KubernetesPodOperator API Reference](https://airflow.apache.org/docs/apache-airflow-providers-cncf-kubernetes/stable/_api/airflow/providers/cncf/kubernetes/operators/pod/index.html)

- [S3KeySensor (Amazon Provider)](https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/_api/airflow/providers/amazon/aws/sensors/s3/index.html)

- [Object Storage XCom Backend](https://airflow.apache.org/docs/apache-airflow-providers-common-io/stable/xcom_backend.html)

- [Airflow MLOps Use Cases](https://airflow.apache.org/use-cases/mlops/)

Related Resources

- [Astronomer - Best practices for orchestrating MLOps pipelines with Airflow](https://www.astronomer.io/docs/learn/airflow-mlops)

- [Astronomer - Use the KubernetesPodOperator](https://www.astronomer.io/docs/learn/kubepod-operator)

- [Astronomer - Introduction to the TaskFlow API and Airflow decorators](https://www.astronomer.io/docs/learn/airflow-decorators/)

Quiz

Q1: What is the main topic covered in "Orchestrating ML Training Pipelines with Airflow"?

Analyzing methods for automating ML training pipelines using KubernetesPodOperator, Dynamic Task

Mapping, and more, based on the official Apache Airflow documentation.

In ML projects, model training does not end simply by calling model.fit(). A series of steps --

data collection, preprocessing, Feature Engineering, model training, evaluation, model registry

registration, and deployment -- must be executed in a reproducible and automated manner....

Q3: Explain the core concept of Reviewing Airflow Core Concepts: DAG, Operator, Task, XCom.

According to the official Apache Airflow documentation, the core concepts are as follows. DAG

(Directed Acyclic Graph) A DAG is a collection of all Tasks you want to run, organized in a way

that reflects their relationships and dependencies.

Q4: What are the key aspects of Running GPU Training Jobs with KubernetesPodOperator?

In ML training pipelines, KubernetesPodOperator plays a central role. According to the official

documentation, KubernetesPodOperator is "a stand-in replacement for Kubernetes object spec

definitions" that can be executed by the Airflow scheduler within the DAG context.

This section analyzes the key parameters of KubernetesPodOperator from an ML training perspective,

based on the official documentation. container_resources This is the most important parameter for

GPU training.

현재 단락 (1/662)

In ML projects, model training does not end simply by calling `model.fit()`. A series of steps -- da...

작성 글자: 0원문 글자: 32,438작성 단락: 0/662