Skip to content

Split View: Airflow로 ML 학습 파이프라인 오케스트레이션하기

|

Airflow로 ML 학습 파이프라인 오케스트레이션하기

1. ML 파이프라인에서 Airflow의 역할

ML 프로젝트에서 모델 학습은 단순히 model.fit()을 호출하는 것으로 끝나지 않는다. 데이터 수집, 전처리, Feature Engineering, 모델 학습, 평가, 모델 레지스트리 등록, 그리고 배포까지 일련의 단계를 재현 가능하고 자동화된 방식으로 실행해야 한다. 이때 Workflow Orchestration 도구가 필요하며, Apache Airflow는 이 영역에서 가장 성숙한 오픈소스 프로젝트 중 하나다.

Airflow vs 다른 오케스트레이션 도구

ML 파이프라인 오케스트레이션 도구로는 Airflow 외에도 Kubeflow Pipelines, Prefect, Dagster 등이 있다. 각각의 특성을 비교해보자.

Apache Airflow는 범용 워크플로우 오케스트레이터로, ML에 특화된 도구는 아니지만 풍부한 Operator 생태계, 강력한 스케줄링 기능, 그리고 Kubernetes 위에서의 확장성 덕분에 ML 파이프라인에서도 널리 사용된다. Airflow 3.x에서는 DAG versioning, event-driven scheduling, multi-language support 등이 추가되면서 더욱 강력해졌다.

Kubeflow Pipelines는 Kubernetes 네이티브 ML 파이프라인 도구로, ML 워크플로우에 특화되어 있다. GPU 스케줄링, 실험 추적, 모델 서빙까지 통합적으로 제공하지만, Kubernetes에 대한 깊은 이해가 전제되어야 하고 Kubernetes 외의 환경에서는 사용이 어렵다.

Prefect는 "Airflow, but nicer"라는 평가를 받는 도구로, Python-native한 인터페이스와 빠른 셋업이 장점이다. Dynamic flow, hybrid execution, 실시간 SLA alerting 등을 제공하며, 특히 소규모 팀이나 PoC 단계에서 빠르게 적용하기 좋다.

Dagster는 asset-centric 접근 방식을 취하는 현대적인 오케스트레이터다. 데이터 파이프라인을 "실행 단계"가 아닌 "생산하는 데이터"를 중심으로 설계하는 철학을 가지고 있어, 데이터 lineage 추적에 강점이 있다.

특성AirflowKubeflow PipelinesPrefectDagster
ML 특화범용ML 특화범용Data-centric
학습 곡선중간높음낮음중간
Kubernetes 의존성선택적필수선택적선택적
커뮤니티/생태계매우 큼성장 중성장 중
GPU 스케줄링KubernetesPodOperator네이티브Kubernetes 연동Kubernetes 연동
스케줄링매우 강력제한적강력강력

Airflow를 ML 파이프라인 오케스트레이터로 선택하는 가장 큰 이유는 도구 비종속성(tool-agnostic) 이다. Airflow는 MLflow, SageMaker, Vertex AI 등 다양한 ML 도구와 연동 가능하며, ML 이외의 데이터 파이프라인(ETL, 데이터 품질 체크 등)과 동일한 인프라에서 관리할 수 있다.

2. Airflow 핵심 개념 복습: DAG, Operator, Task, XCom

Apache Airflow의 공식 문서에 따르면, 핵심 개념은 다음과 같다.

DAG (Directed Acyclic Graph)

DAG는 실행하고자 하는 모든 Task의 집합이며, Task 간의 관계와 의존성을 반영하는 방식으로 구성된다. DAG는 Python 스크립트로 정의되며, 코드 자체가 DAG의 구조(Task와 그 의존성)를 표현한다. DAG run은 DAG의 인스턴스화(instantiation)이며, 특정 execution_date에 대해 실행되는 Task Instance들을 포함한다.

from airflow import DAG
from datetime import datetime

with DAG(
    dag_id='ml_training_pipeline',
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
    tags=['ml', 'training'],
) as dag:
    # Task 정의
    pass

Operator

Operator는 미리 정의된 Task의 템플릿이다. 공식 문서에 따르면, "DAG가 워크플로우를 어떻게 실행할지 기술한다면, Operator는 Task가 실제로 무엇을 수행하는지 결정한다." 모든 Operator는 BaseOperator를 상속받으며, PythonOperator, BashOperator, KubernetesPodOperator 등이 대표적이다.

Task

Task는 Operator의 인스턴스이다. DAG 내에서 실행 가능한 단위이며, Task Instance는 특정 DAG run에서 Task의 구체적 실행을 나타낸다.

XCom (Cross-Communication)

XCom은 Task 간 데이터를 주고받는 메커니즘이다. 공식 문서에 따르면, "기본적으로 Task는 완전히 격리되어 있으며 서로 다른 머신에서 실행될 수 있기 때문에" XCom이 필요하다. XCom은 key, task_id, dag_id로 식별된다.

# Push
task_instance.xcom_push(key="model_accuracy", value=0.95)

# Pull
accuracy = task_instance.xcom_pull(key="model_accuracy", task_ids="evaluate_model")

중요한 점은 XCom은 소량의 데이터 전달용으로 설계되었다는 것이다. DataFrame과 같은 대용량 데이터를 XCom으로 전달하면 안 되며, 그런 경우에는 Object Storage Backend를 사용해야 한다.

3. KubernetesPodOperator로 GPU 학습 Job 실행

ML 학습 파이프라인에서 KubernetesPodOperator는 핵심적인 역할을 한다. 공식 문서에 따르면, KubernetesPodOperator는 "Kubernetes object spec 정의의 대체물"로, DAG context 내에서 Airflow 스케줄러에 의해 실행될 수 있다. 즉, Pod 스펙을 위한 별도의 YAML/JSON 파일을 작성할 필요 없이, Python 코드로 직접 Kubernetes Pod을 정의하고 실행할 수 있다.

GPU 기반 ML 학습에서 KubernetesPodOperator가 특히 유용한 이유는 다음과 같다.

  • Task 레벨 리소스 설정: 각 Task마다 CPU, Memory, GPU를 독립적으로 할당할 수 있다
  • 커스텀 의존성 지원: PyPI에 없는 패키지나 특정 CUDA 버전이 필요한 경우, 커스텀 Docker 이미지를 지정할 수 있다
  • 언어 무관성: Python뿐 아니라 어떤 언어로 작성된 학습 코드든 컨테이너로 감싸서 실행할 수 있다
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s

train_model = KubernetesPodOperator(
    task_id='train_model',
    name='gpu-training-job',
    namespace='ml-training',
    image='my-registry/ml-trainer:latest',
    cmds=['python'],
    arguments=['train.py', '--epochs', '100', '--batch-size', '64'],
    container_resources=k8s.V1ResourceRequirements(
        requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
        limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
    ),
    get_logs=True,
    is_delete_operator_pod=True,
    do_xcom_push=True,
)

do_xcom_push=True로 설정하면, Pod 내에서 /airflow/xcom/return.json 파일에 JSON을 작성하여 XCom으로 데이터를 반환할 수 있다. 학습 결과 메트릭을 다음 Task로 전달할 때 유용하다.

4. KubernetesPodOperator 주요 파라미터 분석

공식 문서를 기반으로 KubernetesPodOperator의 주요 파라미터를 ML 학습 관점에서 분석한다.

container_resources

GPU 학습에서 가장 중요한 파라미터다. kubernetes.client.models.V1ResourceRequirements를 사용하여 CPU, Memory, GPU 리소스를 요청(requests)하고 제한(limits)할 수 있다.

container_resources = k8s.V1ResourceRequirements(
    requests={
        'cpu': '4',
        'memory': '16Gi',
        'nvidia.com/gpu': '1'
    },
    limits={
        'cpu': '8',
        'memory': '32Gi',
        'nvidia.com/gpu': '1'
    },
)

GPU의 경우 nvidia.com/gpu 리소스 키를 사용하며, 일반적으로 requests와 limits를 동일하게 설정한다. GPU는 분할 할당이 불가능하기 때문이다(MIG를 사용하지 않는 한).

tolerations

GPU 노드는 보통 taint가 설정되어 일반 워크로드가 스케줄링되지 않도록 관리된다. tolerations를 통해 GPU 노드에 Pod이 스케줄링될 수 있도록 허용한다.

tolerations = [
    k8s.V1Toleration(
        key='nvidia.com/gpu',
        operator='Exists',
        effect='NoSchedule',
    ),
]

node_selector

특정 노드 그룹에서만 학습 Pod이 실행되도록 지정할 수 있다. 예를 들어, A100 GPU가 장착된 노드에서만 실행하고 싶다면 다음과 같이 설정한다.

node_selector = {
    'gpu-type': 'a100',
    'node-pool': 'ml-training',
}

affinity

더 세밀한 스케줄링 제어가 필요할 때 V1Affinity를 사용한다. Node affinity, Pod affinity/anti-affinity를 설정하여 특정 zone이나 노드에 Pod을 배치할 수 있다.

affinity = k8s.V1Affinity(
    node_affinity=k8s.V1NodeAffinity(
        required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(
            node_selector_terms=[
                k8s.V1NodeSelectorTerm(
                    match_expressions=[
                        k8s.V1NodeSelectorRequirement(
                            key='gpu-type',
                            operator='In',
                            values=['a100', 'h100'],
                        )
                    ]
                )
            ]
        )
    )
)

volumes / volume_mounts

학습 데이터셋이나 체크포인트를 PVC(PersistentVolumeClaim)로 마운트할 때 사용한다.

volume = k8s.V1Volume(
    name='training-data',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
        claim_name='ml-dataset-pvc'
    ),
)

volume_mount = k8s.V1VolumeMount(
    name='training-data',
    mount_path='/data',
    read_only=True,
)

기타 주요 파라미터

파라미터설명ML 학습 활용
image_pull_secretsPrivate Registry 인증사내 ML 이미지 접근
env_vars환경 변수 목록CUDA_VISIBLE_DEVICES, WANDB_API_KEY 등
secretsSecret 볼륨/환경 변수API Key, 인증 정보 관리
is_delete_operator_pod완료 후 Pod 삭제리소스 정리(True 권장)
get_logs로그를 Airflow UI에 표시학습 로그 모니터링
deferrable비동기 실행 모드장시간 학습 시 Worker 효율화
on_finish_action완료 후 동작delete_pod으로 자동 정리

공식 문서에서는 type safety를 위해 편의 클래스(convenience classes) 대신 kubernetes.client.models의 네이티브 객체를 사용하도록 권장한다.

5. DAG 설계 패턴: 데이터 전처리 -> 학습 -> 평가 -> 배포

ML 학습 파이프라인의 일반적인 DAG 구조는 다음과 같은 단계를 포함한다.

데이터 검증 -> 전처리 -> Feature Engineering -> 학습 -> 평가 -> [조건부] 배포

각 단계를 Airflow Task로 매핑하면 다음과 같은 설계 패턴이 된다.

from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from datetime import datetime

with DAG(
    dag_id='ml_training_pipeline',
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
) as dag:

    validate_data = KubernetesPodOperator(
        task_id='validate_data',
        image='ml-tools:latest',
        cmds=['python', 'validate.py'],
        # CPU만 필요
    )

    preprocess = KubernetesPodOperator(
        task_id='preprocess_data',
        image='ml-tools:latest',
        cmds=['python', 'preprocess.py'],
        # 대용량 메모리 필요
    )

    train = KubernetesPodOperator(
        task_id='train_model',
        image='ml-trainer:latest',
        cmds=['python', 'train.py'],
        # GPU 필요
        do_xcom_push=True,
    )

    evaluate = KubernetesPodOperator(
        task_id='evaluate_model',
        image='ml-trainer:latest',
        cmds=['python', 'evaluate.py'],
        do_xcom_push=True,
    )

    def check_model_quality(**context):
        metrics = context['ti'].xcom_pull(task_ids='evaluate_model')
        if metrics['accuracy'] > 0.90:
            return 'deploy_model'
        return 'notify_failure'

    branch = BranchPythonOperator(
        task_id='check_quality',
        python_callable=check_model_quality,
    )

    deploy = KubernetesPodOperator(
        task_id='deploy_model',
        image='ml-deployer:latest',
        cmds=['python', 'deploy.py'],
    )

    validate_data >> preprocess >> train >> evaluate >> branch
    branch >> [deploy]

이 설계 패턴의 핵심은 조건부 배포다. 평가 단계에서 모델 성능 메트릭을 XCom으로 전달하고, BranchPythonOperator를 통해 품질 기준을 충족하는 경우에만 배포 Task가 실행된다. 이는 품질이 낮은 모델이 프로덕션에 배포되는 것을 방지한다.

TaskGroup을 활용한 단계 구분

복잡한 파이프라인에서는 TaskGroup을 활용하여 시각적으로 단계를 구분할 수 있다. 공식 문서에 따르면, "TaskGroup은 Graph View에서 Task를 계층적 그룹으로 조직화하는데 사용된다. 반복 패턴을 만들고 시각적 복잡성을 줄이는 데 유용하다."

from airflow.utils.task_group import TaskGroup

with TaskGroup('data_preparation') as data_prep:
    validate_data >> preprocess >> feature_engineering

with TaskGroup('model_training') as training:
    train >> evaluate

data_prep >> training >> branch

6. XCom으로 태스크 간 메트릭/아티팩트 전달

공식 문서에 따르면, XCom은 xcom_pushxcom_pull 메서드를 통해 명시적으로 저장소에 "push" 및 "pull" 된다. ML 파이프라인에서 XCom은 다음과 같은 용도로 활용된다.

학습 메트릭 전달

# train.py (KubernetesPodOperator 내부에서 실행)
import json

metrics = {
    'accuracy': 0.9534,
    'f1_score': 0.9421,
    'loss': 0.0312,
    'model_path': 's3://ml-models/experiment-42/model.pt',
    'run_id': 'exp-42-20260301',
}

# /airflow/xcom/return.json에 JSON 작성
with open('/airflow/xcom/return.json', 'w') as f:
    json.dump(metrics, f)

Auto-Push와 Return Value

많은 Operator와 @task 데코레이터는 do_xcom_push=True(기본값)일 때 반환값을 return_value라는 키의 XCom으로 자동 push한다. 이후 다음과 같이 pull할 수 있다.

value = task_instance.xcom_pull(task_ids='train_model')

Multiple Outputs

딕셔너리를 반환하며 multiple_outputs=True를 설정하면, 각 키가 개별 XCom으로 저장된다.

@task(multiple_outputs=True)
def evaluate_model(**context):
    return {
        'accuracy': 0.95,
        'f1_score': 0.94,
        'model_path': 's3://models/latest.pt',
    }

이후 개별 키로 pull하거나 전체 return_value로 pull할 수 있다.

Custom XCom Backend

기본 XCom Backend인 BaseXCom은 Airflow 데이터베이스에 XCom을 저장한다. 소량의 데이터에는 문제가 없지만, 모델 아티팩트나 대용량 데이터를 다룰 때는 Custom Backend가 필요하다. BaseXCom을 상속받아 serialize_valuedeserialize_value 메서드를 오버라이드하면 된다.

Object Storage Backend를 사용하려면 xcom_backend 설정을 airflow.providers.common.io.xcom.backend.XComObjectStorageBackend로 지정한다. 이를 통해 S3나 GCS에 XCom 데이터를 저장할 수 있다.

[core]
xcom_backend = airflow.providers.common.io.xcom.backend.XComObjectStorageBackend

XCom 사용 시 주의사항

공식 문서에서 강조하는 핵심 사항은 다음과 같다.

  • XCom은 소량의 데이터 전달용으로 설계되었다. DataFrame 등 대용량 데이터를 전달하면 안 된다
  • Task가 실패 후 재시도될 때, 이전 XCom은 자동으로 클리어되어 idempotent 실행을 보장한다
  • XCom 작업은 get_current_context()를 통한 Task Context에서 수행해야 하며, 직접 DB 업데이트는 지원되지 않는다

ML 파이프라인에서의 실용적 가이드라인으로는, 메트릭(accuracy, loss 등)과 경로 정보(model_path, artifact_uri 등)는 XCom으로 전달하고, 실제 모델 파일이나 데이터셋은 S3/GCS 같은 Object Storage에 저장한 뒤 경로만 XCom으로 공유하는 것이 바람직하다.

7. Dynamic Task Mapping으로 하이퍼파라미터 튜닝

Dynamic Task Mapping은 Airflow 2.3에서 도입된 기능으로, 공식 문서에 따르면 "런타임에 현재 데이터를 기반으로 여러 Task를 생성할 수 있게 해주는 메커니즘"이다. DAG 작성자가 사전에 몇 개의 Task가 필요한지 알 필요가 없다.

ML 하이퍼파라미터 튜닝에서 이 기능은 매우 강력하다. 다양한 하이퍼파라미터 조합으로 병렬 학습을 실행할 수 있기 때문이다.

expand()와 partial()

expand()는 매핑할 파라미터를, partial()은 모든 Task에 공통으로 전달할 고정 파라미터를 지정한다.

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

hyperparams = [
    {'lr': '0.001', 'batch_size': '32', 'optimizer': 'adam'},
    {'lr': '0.01', 'batch_size': '64', 'optimizer': 'adam'},
    {'lr': '0.001', 'batch_size': '32', 'optimizer': 'sgd'},
    {'lr': '0.0001', 'batch_size': '128', 'optimizer': 'adamw'},
]

train_tasks = KubernetesPodOperator.partial(
    task_id='hyperparam_training',
    image='ml-trainer:latest',
    namespace='ml-training',
    get_logs=True,
    is_delete_operator_pod=True,
    do_xcom_push=True,
    container_resources=k8s.V1ResourceRequirements(
        requests={'nvidia.com/gpu': '1'},
        limits={'nvidia.com/gpu': '1'},
    ),
).expand(
    arguments=[
        ['train.py', '--lr', hp['lr'], '--batch-size', hp['batch_size'], '--optimizer', hp['optimizer']]
        for hp in hyperparams
    ],
)

이 코드는 4개의 하이퍼파라미터 조합에 대해 자동으로 4개의 병렬 KubernetesPodOperator Task를 생성한다.

Task에서 동적으로 매핑 데이터 생성

더 나아가, 상위 Task에서 동적으로 하이퍼파라미터 조합을 생성할 수도 있다.

@task
def generate_hyperparams():
    """Grid Search 또는 Random Search로 하이퍼파라미터 생성"""
    import itertools

    learning_rates = [0.001, 0.01, 0.0001]
    batch_sizes = [32, 64, 128]
    optimizers = ['adam', 'sgd']

    combinations = list(itertools.product(learning_rates, batch_sizes, optimizers))
    return [
        {'lr': str(lr), 'batch_size': str(bs), 'optimizer': opt}
        for lr, bs, opt in combinations
    ]

hp_list = generate_hyperparams()

공식 문서에 따르면, "Task에서 생성된 매핑은 trigger_rule=TriggerRule.ALWAYS 사용을 금지"한다는 제약이 있다.

Cross-Product Mapping

여러 expand() 파라미터를 지정하면 모든 조합(cross product)이 생성된다.

@task
def train(lr: float, batch_size: int):
    # 학습 로직
    pass

train.expand(lr=[0.001, 0.01], batch_size=[32, 64])
# 2 x 2 = 4개의 Task Instance 생성

Map-Reduce 패턴

매핑된 Task들의 결과를 수집하여 최적 모델을 선택하는 패턴이다.

@task
def select_best_model(results):
    """모든 하이퍼파라미터 조합의 결과 중 최선을 선택"""
    best = max(results, key=lambda x: x['accuracy'])
    return best

training_results = train_task.expand(params=hyperparams)
best = select_best_model(training_results)

공식 문서에 따르면, 수집된 결과는 "lazy proxy" sequence로 반환되며, eager list가 아님에 주의해야 한다.

제약 사항

항목설정기본값
최대 매핑 인스턴스 수[core] max_map_length1024
Task당 병렬 실행 제한max_active_tis_per_dagTask별 설정

매핑 가능한 데이터 타입은 listdict만 가능하며, 다른 타입은 UnmappableXComTypePushed 에러가 발생한다. 또한 공식 문서에 따르면, "필드가 templated로 표시되어 있고 매핑된 경우, 해당 필드는 템플릿 처리되지 않는다."

8. Sensor를 활용한 데이터 도착 감지

ML 파이프라인에서는 학습 데이터가 준비될 때까지 대기해야 하는 경우가 많다. Airflow의 Sensor는 특정 조건이 충족될 때까지 대기하는 특수한 Operator다.

S3KeySensor

S3에 학습 데이터가 업로드될 때까지 대기한다.

from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor

wait_for_data = S3KeySensor(
    task_id='wait_for_training_data',
    bucket_name='ml-datasets',
    bucket_key='daily/{{ ds }}/training_data.parquet',
    aws_conn_id='aws_default',
    poke_interval=300,  # 5분마다 확인
    timeout=3600 * 6,   # 최대 6시간 대기
    mode='reschedule',  # Worker slot 반환
    deferrable=True,    # Triggerer를 사용한 비동기 대기
)

mode='reschedule'은 poke 사이에 Worker slot을 반환하여 리소스 효율성을 높인다. deferrable=True를 설정하면 Triggerer 컴포넌트가 비동기적으로 polling을 처리하여 더욱 효율적인 Worker 활용이 가능하다.

ExternalTaskSensor

다른 DAG의 Task가 완료될 때까지 대기한다. 예를 들어, 데이터 전처리 DAG가 완료된 후에 학습 DAG를 시작하는 패턴이다.

from airflow.sensors.external_task import ExternalTaskSensor

wait_for_preprocessing = ExternalTaskSensor(
    task_id='wait_for_preprocessing',
    external_dag_id='data_preprocessing_pipeline',
    external_task_id='final_validation',
    allowed_states=['success'],
    execution_delta=timedelta(hours=0),
    timeout=3600,
    poke_interval=60,
    mode='reschedule',
)

allowed_states는 허용되는 상태 목록으로, 기본값은 ['success']다. execution_delta는 확인할 이전 실행과의 시간 차이를 지정한다.

실용적 패턴: 데이터 도착 후 학습 시작

wait_for_data >> validate_data >> preprocess >> train >> evaluate >> deploy

이 패턴은 데이터 파이프라인과 ML 파이프라인을 느슨하게 결합(loosely couple)하면서도, 데이터가 준비된 후에만 학습이 시작되도록 보장한다.

9. MLflow와 Airflow 연동 패턴

Airflow와 MLflow의 연동에서 역할 분담은 명확하다. Airflow는 "언제, 어떤 순서로 실행할지"를 관리하고, MLflow는 "실행 과정에서 무엇이 일어났는지, 모델이 어디에 있는지"를 기록한다. 이러한 관심사의 분리(separation of concerns)를 통해, 데이터 엔지니어는 ML 코드를 건드리지 않고 Airflow DAG를 관리하고, 데이터 사이언티스트는 스케줄링 인프라를 걱정하지 않고 모델 개발에 집중할 수 있다.

연동 패턴 1: 학습 Task 내부에서 MLflow 직접 호출

가장 직관적인 패턴이다. KubernetesPodOperator로 실행되는 학습 컨테이너 내부에서 MLflow API를 직접 호출한다.

# train.py (컨테이너 내부에서 실행)
import mlflow
import mlflow.pytorch

mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("daily-training")

with mlflow.start_run() as run:
    # 하이퍼파라미터 로깅
    mlflow.log_params({
        'learning_rate': 0.001,
        'batch_size': 64,
        'epochs': 100,
    })

    # 모델 학습
    model = train_model(...)

    # 메트릭 로깅
    mlflow.log_metrics({
        'accuracy': 0.95,
        'f1_score': 0.94,
    })

    # 모델 등록
    mlflow.pytorch.log_model(model, "model")

    # run_id를 XCom으로 전달
    import json
    with open('/airflow/xcom/return.json', 'w') as f:
        json.dump({'run_id': run.info.run_id}, f)

연동 패턴 2: Airflow에서 MLflow Model Registry 활용

학습 완료 후, 모델 평가 결과에 따라 MLflow Model Registry에 모델을 등록하고 스테이지를 전환한다.

@task
def register_model(run_id: str, accuracy: float):
    import mlflow
    from mlflow.tracking import MlflowClient

    client = MlflowClient("http://mlflow-server:5000")

    # 모델 등록
    model_uri = f"runs:/{run_id}/model"
    mv = mlflow.register_model(model_uri, "production-model")

    # 정확도가 기준을 넘으면 Production 스테이지로 전환
    if accuracy > 0.93:
        client.transition_model_version_stage(
            name="production-model",
            version=mv.version,
            stage="Production",
        )

연동 패턴 3: MLflow에서 학습 이력 조회 후 비교

@task
def compare_with_baseline():
    from mlflow.tracking import MlflowClient

    client = MlflowClient("http://mlflow-server:5000")

    # Production 모델의 메트릭 조회
    prod_versions = client.get_latest_versions("production-model", stages=["Production"])
    if prod_versions:
        prod_run = client.get_run(prod_versions[0].run_id)
        baseline_accuracy = float(prod_run.data.metrics['accuracy'])
        return {'baseline_accuracy': baseline_accuracy}
    return {'baseline_accuracy': 0.0}

10. TaskFlow API 활용 (@task 데코레이터)

TaskFlow API는 공식 문서에 따르면 "데코레이터를 사용하여 DAG와 Task를 정의하는 함수형 API"로, Task 간 데이터 전달과 의존성 정의를 크게 단순화한다.

핵심 특성

TaskFlow의 가장 큰 장점은 자동 XCom 관리자동 의존성 계산이다. TaskFlow 함수를 호출하면 실행되는 것이 아니라, 결과를 나타내는 XComArg 객체가 반환된다. 이를 다운스트림 Task의 입력으로 사용하면 의존성이 자동으로 계산된다.

from airflow.sdk import dag, task
from datetime import datetime

@dag(
    schedule='@daily',
    start_date=datetime(2026, 1, 1),
    catchup=False,
)
def ml_training_taskflow():

    @task
    def extract_data():
        """데이터 추출"""
        return {'data_path': 's3://bucket/data/2026-03-01.parquet', 'num_rows': 100000}

    @task(multiple_outputs=True)
    def preprocess(data_info: dict):
        """데이터 전처리"""
        return {
            'train_path': f"{data_info['data_path']}/train.parquet",
            'test_path': f"{data_info['data_path']}/test.parquet",
            'feature_count': 128,
        }

    @task
    def train_model(train_path: str, feature_count: int):
        """모델 학습"""
        # 실제로는 KubernetesPodOperator로 GPU 학습 실행
        return {
            'model_path': 's3://models/latest.pt',
            'accuracy': 0.95,
        }

    @task
    def evaluate_and_deploy(model_info: dict):
        """모델 평가 및 조건부 배포"""
        if model_info['accuracy'] > 0.90:
            return f"Deployed model from {model_info['model_path']}"
        return "Model quality below threshold, skipping deployment"

    # 자동 의존성 연결
    data = extract_data()
    processed = preprocess(data)
    model = train_model(
        train_path=processed['train_path'],
        feature_count=processed['feature_count'],
    )
    evaluate_and_deploy(model)

ml_training_taskflow()

Context 접근

Task는 Airflow context 변수를 keyword argument로 받을 수 있다.

@task
def log_execution_info(task_instance=None, dag_run=None):
    print(f"Run ID: {task_instance.run_id}")
    print(f"DAG Run: {dag_run.dag_id}")

객체 직렬화

TaskFlow는 @dataclass, @attr.define 또는 커스텀 serialize()/deserialize() 메서드를 통한 커스텀 객체 전달을 지원한다. __version__: ClassVar[int] 을 사용하여 버전 관리도 가능하다.

11. 실전 DAG 코드 예시 (전체 ML 파이프라인)

지금까지 다룬 모든 개념을 통합한 실전 ML 학습 파이프라인 DAG 코드다. 데이터 도착 감지부터 하이퍼파라미터 튜닝, 최적 모델 선택, 배포까지 전 과정을 포함한다.

"""
ML Training Pipeline DAG
- S3 데이터 도착 감지
- 데이터 전처리 (KubernetesPodOperator)
- Dynamic Task Mapping을 활용한 하이퍼파라미터 튜닝
- 최적 모델 선택 및 MLflow 등록
- 조건부 배포
"""

from __future__ import annotations

from datetime import datetime, timedelta
from airflow import DAG
from airflow.sdk import task
from airflow.operators.python import BranchPythonOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.utils.task_group import TaskGroup
from kubernetes.client import models as k8s

# -- 공통 설정 --
GPU_RESOURCES = k8s.V1ResourceRequirements(
    requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
    limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
)

GPU_TOLERATIONS = [
    k8s.V1Toleration(
        key='nvidia.com/gpu',
        operator='Exists',
        effect='NoSchedule',
    ),
]

GPU_NODE_SELECTOR = {'gpu-type': 'a100', 'node-pool': 'ml-training'}

DATA_VOLUME = k8s.V1Volume(
    name='shared-data',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
        claim_name='ml-shared-pvc'
    ),
)

DATA_VOLUME_MOUNT = k8s.V1VolumeMount(
    name='shared-data',
    mount_path='/shared',
)

default_args = {
    'owner': 'ml-team',
    'retries': 2,
    'retry_delay': timedelta(minutes=5),
    'execution_timeout': timedelta(hours=4),
}


with DAG(
    dag_id='ml_full_training_pipeline',
    default_args=default_args,
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
    tags=['ml', 'training', 'production'],
    max_active_runs=1,
    doc_md="""
    ## ML Full Training Pipeline
    매일 새로운 데이터로 모델을 재학습하고,
    품질 기준을 충족하면 자동 배포하는 파이프라인.
    """,
) as dag:

    # ===== Stage 1: 데이터 도착 감지 =====
    wait_for_data = S3KeySensor(
        task_id='wait_for_training_data',
        bucket_name='ml-datasets',
        bucket_key='daily/{{ ds }}/features.parquet',
        aws_conn_id='aws_default',
        poke_interval=300,
        timeout=3600 * 6,
        mode='reschedule',
        deferrable=True,
    )

    # ===== Stage 2: 데이터 전처리 =====
    with TaskGroup('data_preparation') as data_prep:

        validate_data = KubernetesPodOperator(
            task_id='validate_data',
            name='data-validation',
            namespace='ml-training',
            image='ml-tools:latest',
            cmds=['python', 'validate.py'],
            arguments=['--date', '{{ ds }}', '--source', 's3://ml-datasets/daily/{{ ds }}/'],
            env_vars=[
                k8s.V1EnvVar(name='DATA_DATE', value='{{ ds }}'),
            ],
            container_resources=k8s.V1ResourceRequirements(
                requests={'cpu': '2', 'memory': '8Gi'},
                limits={'cpu': '4', 'memory': '16Gi'},
            ),
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
        )

        preprocess = KubernetesPodOperator(
            task_id='preprocess_data',
            name='data-preprocessing',
            namespace='ml-training',
            image='ml-tools:latest',
            cmds=['python', 'preprocess.py'],
            arguments=[
                '--input', 's3://ml-datasets/daily/{{ ds }}/',
                '--output', 's3://ml-processed/{{ ds }}/',
            ],
            container_resources=k8s.V1ResourceRequirements(
                requests={'cpu': '4', 'memory': '32Gi'},
                limits={'cpu': '8', 'memory': '64Gi'},
            ),
            volumes=[DATA_VOLUME],
            volume_mounts=[DATA_VOLUME_MOUNT],
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
        )

        validate_data >> preprocess

    # ===== Stage 3: 하이퍼파라미터 생성 및 병렬 학습 =====
    @task
    def generate_hyperparams():
        """학습할 하이퍼파라미터 조합 생성"""
        import itertools
        learning_rates = [0.001, 0.0001]
        batch_sizes = [32, 64]
        optimizers = ['adam', 'adamw']

        combos = list(itertools.product(learning_rates, batch_sizes, optimizers))
        return [
            [
                'train.py',
                '--lr', str(lr),
                '--batch-size', str(bs),
                '--optimizer', opt,
                '--data-path', 's3://ml-processed/{{ ds }}/',
                '--experiment-name', 'daily-training-{{ ds }}',
            ]
            for lr, bs, opt in combos
        ]

    hp_args = generate_hyperparams()

    with TaskGroup('hyperparameter_tuning') as hp_tuning:

        train_tasks = KubernetesPodOperator.partial(
            task_id='train_with_hp',
            name='gpu-hp-training',
            namespace='ml-training',
            image='ml-trainer:latest',
            cmds=['python'],
            container_resources=GPU_RESOURCES,
            tolerations=GPU_TOLERATIONS,
            node_selector=GPU_NODE_SELECTOR,
            volumes=[DATA_VOLUME],
            volume_mounts=[DATA_VOLUME_MOUNT],
            env_vars=[
                k8s.V1EnvVar(name='MLFLOW_TRACKING_URI', value='http://mlflow:5000'),
                k8s.V1EnvVar(name='CUDA_VISIBLE_DEVICES', value='0'),
            ],
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
            startup_timeout_seconds=600,
        ).expand(arguments=hp_args)

    # ===== Stage 4: 최적 모델 선택 =====
    @task
    def select_best_model(training_results):
        """모든 학습 결과 중 최적 모델 선택"""
        valid_results = [r for r in training_results if r is not None]
        if not valid_results:
            raise ValueError("No valid training results found")

        best = max(valid_results, key=lambda x: x.get('accuracy', 0))
        print(f"Best model: accuracy={best['accuracy']}, run_id={best['run_id']}")
        return best

    best_model = select_best_model(train_tasks)

    # ===== Stage 5: 모델 평가 =====
    @task(multiple_outputs=True)
    def evaluate_model(model_info: dict):
        """최적 모델에 대한 상세 평가"""
        import mlflow
        from mlflow.tracking import MlflowClient

        client = MlflowClient("http://mlflow:5000")

        # 현재 Production 모델과 비교
        try:
            prod_versions = client.get_latest_versions(
                "production-model", stages=["Production"]
            )
            if prod_versions:
                prod_run = client.get_run(prod_versions[0].run_id)
                baseline_accuracy = float(prod_run.data.metrics.get('accuracy', 0))
            else:
                baseline_accuracy = 0.0
        except Exception:
            baseline_accuracy = 0.0

        new_accuracy = model_info['accuracy']
        improvement = new_accuracy - baseline_accuracy

        return {
            'should_deploy': improvement > 0.005 and new_accuracy > 0.90,
            'new_accuracy': new_accuracy,
            'baseline_accuracy': baseline_accuracy,
            'improvement': improvement,
            'run_id': model_info['run_id'],
            'model_path': model_info.get('model_path', ''),
        }

    eval_result = evaluate_model(best_model)

    # ===== Stage 6: 조건부 배포 =====
    def decide_deployment(**context):
        should_deploy = context['ti'].xcom_pull(
            task_ids='evaluate_model', key='should_deploy'
        )
        if should_deploy:
            return 'deploy_model'
        return 'skip_deployment'

    deployment_branch = BranchPythonOperator(
        task_id='deployment_decision',
        python_callable=decide_deployment,
    )

    deploy = KubernetesPodOperator(
        task_id='deploy_model',
        name='model-deployment',
        namespace='ml-serving',
        image='ml-deployer:latest',
        cmds=['python', 'deploy.py'],
        arguments=[
            '--run-id', "{{ ti.xcom_pull(task_ids='evaluate_model', key='run_id') }}",
            '--model-name', 'production-model',
        ],
        container_resources=k8s.V1ResourceRequirements(
            requests={'cpu': '2', 'memory': '4Gi'},
            limits={'cpu': '4', 'memory': '8Gi'},
        ),
        get_logs=True,
        is_delete_operator_pod=True,
    )

    @task(task_id='skip_deployment')
    def skip_deployment():
        print("Model quality does not meet deployment criteria. Skipping.")

    skip = skip_deployment()

    # ===== Stage 7: 알림 =====
    @task(trigger_rule='none_failed_min_one_success')
    def send_notification(**context):
        """학습 결과 알림 전송"""
        eval_data = context['ti'].xcom_pull(task_ids='evaluate_model', key='return_value')
        print(f"Pipeline completed. Accuracy: {eval_data.get('new_accuracy')}")
        print(f"Improvement: {eval_data.get('improvement')}")
        # Slack / Email 알림 로직

    notification = send_notification()

    # ===== DAG 의존성 정의 =====
    wait_for_data >> data_prep >> hp_args >> hp_tuning
    hp_tuning >> best_model >> eval_result >> deployment_branch
    deployment_branch >> [deploy, skip] >> notification

이 DAG는 다음과 같은 Airflow 기능들을 통합적으로 활용한다.

  • S3KeySensor: 데이터 도착 대기 (deferrable 모드)
  • KubernetesPodOperator: GPU 기반 학습 Job 실행
  • Dynamic Task Mapping: expand()를 통한 하이퍼파라미터 병렬 학습
  • TaskFlow API: @task 데코레이터를 통한 Python 함수 기반 Task 정의
  • XCom: 학습 메트릭과 모델 경로 전달
  • TaskGroup: 시각적 단계 구분
  • BranchPythonOperator: 모델 품질 기반 조건부 배포

실제 프로덕션에서는 에러 핸들링, SLA 설정, 알림 통합 등을 추가하여 안정성을 높여야 한다.

12. References

본 글에서 분석한 내용은 다음 Apache Airflow 공식 문서와 관련 자료를 기반으로 작성되었다.

Apache Airflow 공식 문서

관련 자료

Orchestrating ML Training Pipelines with Airflow

1. The Role of Airflow in ML Pipelines

In ML projects, model training does not end simply by calling model.fit(). A series of steps -- data collection, preprocessing, Feature Engineering, model training, evaluation, model registry registration, and deployment -- must be executed in a reproducible and automated manner. This is where a Workflow Orchestration tool is needed, and Apache Airflow is one of the most mature open-source projects in this domain.

Airflow vs Other Orchestration Tools

Besides Airflow, other ML pipeline orchestration tools include Kubeflow Pipelines, Prefect, and Dagster. Let us compare their characteristics.

Apache Airflow is a general-purpose workflow orchestrator. While it is not specialized for ML, it is widely used in ML pipelines thanks to its rich Operator ecosystem, powerful scheduling capabilities, and scalability on Kubernetes. With Airflow 3.x, features such as DAG versioning, event-driven scheduling, and multi-language support have been added, making it even more powerful.

Kubeflow Pipelines is a Kubernetes-native ML pipeline tool specialized for ML workflows. It provides integrated GPU scheduling, experiment tracking, and model serving, but requires deep understanding of Kubernetes and is difficult to use outside Kubernetes environments.

Prefect is often described as "Airflow, but nicer," with its Python-native interface and quick setup being key advantages. It offers dynamic flow, hybrid execution, and real-time SLA alerting, making it particularly suitable for small teams or PoC stages where rapid adoption is needed.

Dagster is a modern orchestrator that takes an asset-centric approach. With a philosophy of designing data pipelines centered around "the data being produced" rather than "execution steps," it has strengths in data lineage tracking.

CharacteristicAirflowKubeflow PipelinesPrefectDagster
ML SpecializationGeneral-purposeML-specializedGeneral-purposeData-centric
Learning CurveMediumHighLowMedium
Kubernetes DependencyOptionalRequiredOptionalOptional
Community/EcosystemVery largeLargeGrowingGrowing
GPU SchedulingKubernetesPodOperatorNativeKubernetes integrationKubernetes integration
SchedulingVery powerfulLimitedPowerfulPowerful

The biggest reason for choosing Airflow as an ML pipeline orchestrator is its tool-agnostic nature. Airflow can integrate with various ML tools such as MLflow, SageMaker, and Vertex AI, and can manage ML and non-ML data pipelines (ETL, data quality checks, etc.) on the same infrastructure.

2. Reviewing Airflow Core Concepts: DAG, Operator, Task, XCom

According to the official Apache Airflow documentation, the core concepts are as follows.

DAG (Directed Acyclic Graph)

A DAG is a collection of all Tasks you want to run, organized in a way that reflects their relationships and dependencies. DAGs are defined as Python scripts, and the code itself expresses the structure of the DAG (Tasks and their dependencies). A DAG run is an instantiation of a DAG, containing Task Instances that execute for a specific execution_date.

from airflow import DAG
from datetime import datetime

with DAG(
    dag_id='ml_training_pipeline',
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
    tags=['ml', 'training'],
) as dag:
    # Task definitions
    pass

Operator

An Operator is a pre-defined template for a Task. According to the official documentation, "If a DAG describes how to run a workflow, Operators determine what actually gets done by a Task." All Operators inherit from BaseOperator, with PythonOperator, BashOperator, and KubernetesPodOperator being representative examples.

Task

A Task is an instance of an Operator. It is an executable unit within a DAG, and a Task Instance represents a specific execution of a Task in a particular DAG run.

XCom (Cross-Communication)

XCom is a mechanism for passing data between Tasks. According to the official documentation, "by default, Tasks are entirely isolated and may run on completely different machines," which is why XCom is needed. XCom is identified by key, task_id, and dag_id.

# Push
task_instance.xcom_push(key="model_accuracy", value=0.95)

# Pull
accuracy = task_instance.xcom_pull(key="model_accuracy", task_ids="evaluate_model")

An important point is that XCom is designed for transferring small amounts of data. You should not pass large data such as DataFrames through XCom; in such cases, you should use an Object Storage Backend.

3. Running GPU Training Jobs with KubernetesPodOperator

In ML training pipelines, KubernetesPodOperator plays a central role. According to the official documentation, KubernetesPodOperator is "a stand-in replacement for Kubernetes object spec definitions" that can be executed by the Airflow scheduler within the DAG context. In other words, you can define and run Kubernetes Pods directly in Python code without writing separate YAML/JSON files for Pod specs.

The reasons KubernetesPodOperator is particularly useful for GPU-based ML training are as follows:

  • Task-level resource configuration: CPU, Memory, and GPU can be independently allocated for each Task
  • Custom dependency support: When packages not on PyPI or a specific CUDA version is needed, a custom Docker image can be specified
  • Language agnosticism: Training code written in any language, not just Python, can be wrapped in a container and executed
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s

train_model = KubernetesPodOperator(
    task_id='train_model',
    name='gpu-training-job',
    namespace='ml-training',
    image='my-registry/ml-trainer:latest',
    cmds=['python'],
    arguments=['train.py', '--epochs', '100', '--batch-size', '64'],
    container_resources=k8s.V1ResourceRequirements(
        requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
        limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
    ),
    get_logs=True,
    is_delete_operator_pod=True,
    do_xcom_push=True,
)

By setting do_xcom_push=True, you can return data as XCom by writing JSON to the /airflow/xcom/return.json file inside the Pod. This is useful for passing training result metrics to the next Task.

4. Analysis of Key KubernetesPodOperator Parameters

This section analyzes the key parameters of KubernetesPodOperator from an ML training perspective, based on the official documentation.

container_resources

This is the most important parameter for GPU training. Using kubernetes.client.models.V1ResourceRequirements, you can request and limit CPU, Memory, and GPU resources.

container_resources = k8s.V1ResourceRequirements(
    requests={
        'cpu': '4',
        'memory': '16Gi',
        'nvidia.com/gpu': '1'
    },
    limits={
        'cpu': '8',
        'memory': '32Gi',
        'nvidia.com/gpu': '1'
    },
)

For GPUs, the nvidia.com/gpu resource key is used, and requests and limits are typically set to the same value. This is because GPUs cannot be fractionally allocated (unless MIG is used).

tolerations

GPU nodes are usually configured with taints to prevent general workloads from being scheduled. tolerations allow Pods to be scheduled on GPU nodes.

tolerations = [
    k8s.V1Toleration(
        key='nvidia.com/gpu',
        operator='Exists',
        effect='NoSchedule',
    ),
]

node_selector

You can specify that training Pods should only run on specific node groups. For example, to run only on nodes equipped with A100 GPUs:

node_selector = {
    'gpu-type': 'a100',
    'node-pool': 'ml-training',
}

affinity

When more fine-grained scheduling control is needed, use V1Affinity. You can set Node affinity and Pod affinity/anti-affinity to place Pods in specific zones or nodes.

affinity = k8s.V1Affinity(
    node_affinity=k8s.V1NodeAffinity(
        required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(
            node_selector_terms=[
                k8s.V1NodeSelectorTerm(
                    match_expressions=[
                        k8s.V1NodeSelectorRequirement(
                            key='gpu-type',
                            operator='In',
                            values=['a100', 'h100'],
                        )
                    ]
                )
            ]
        )
    )
)

volumes / volume_mounts

Used for mounting training datasets or checkpoints via PVC (PersistentVolumeClaim).

volume = k8s.V1Volume(
    name='training-data',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
        claim_name='ml-dataset-pvc'
    ),
)

volume_mount = k8s.V1VolumeMount(
    name='training-data',
    mount_path='/data',
    read_only=True,
)

Other Key Parameters

ParameterDescriptionML Training Usage
image_pull_secretsPrivate Registry authAccessing internal ML images
env_varsEnvironment variable listCUDA_VISIBLE_DEVICES, WANDB_API_KEY, etc.
secretsSecret volumes/env varsAPI Key, credential management
is_delete_operator_podDelete Pod after completionResource cleanup (True recommended)
get_logsDisplay logs in Airflow UITraining log monitoring
deferrableAsync execution modeWorker efficiency for long training
on_finish_actionPost-completion actionAuto cleanup with delete_pod

The official documentation recommends using native objects from kubernetes.client.models instead of convenience classes for type safety.

5. DAG Design Pattern: Data Preprocessing to Training to Evaluation to Deployment

The general DAG structure for an ML training pipeline includes the following stages:

Data Validation -> Preprocessing -> Feature Engineering -> Training -> Evaluation -> [Conditional] Deployment

Mapping each stage to Airflow Tasks results in the following design pattern:

from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from datetime import datetime

with DAG(
    dag_id='ml_training_pipeline',
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
) as dag:

    validate_data = KubernetesPodOperator(
        task_id='validate_data',
        image='ml-tools:latest',
        cmds=['python', 'validate.py'],
        # Only CPU needed
    )

    preprocess = KubernetesPodOperator(
        task_id='preprocess_data',
        image='ml-tools:latest',
        cmds=['python', 'preprocess.py'],
        # Large memory needed
    )

    train = KubernetesPodOperator(
        task_id='train_model',
        image='ml-trainer:latest',
        cmds=['python', 'train.py'],
        # GPU needed
        do_xcom_push=True,
    )

    evaluate = KubernetesPodOperator(
        task_id='evaluate_model',
        image='ml-trainer:latest',
        cmds=['python', 'evaluate.py'],
        do_xcom_push=True,
    )

    def check_model_quality(**context):
        metrics = context['ti'].xcom_pull(task_ids='evaluate_model')
        if metrics['accuracy'] > 0.90:
            return 'deploy_model'
        return 'notify_failure'

    branch = BranchPythonOperator(
        task_id='check_quality',
        python_callable=check_model_quality,
    )

    deploy = KubernetesPodOperator(
        task_id='deploy_model',
        image='ml-deployer:latest',
        cmds=['python', 'deploy.py'],
    )

    validate_data >> preprocess >> train >> evaluate >> branch
    branch >> [deploy]

The key to this design pattern is conditional deployment. The evaluation stage passes model performance metrics via XCom, and through BranchPythonOperator, the deployment Task is executed only when quality criteria are met. This prevents low-quality models from being deployed to production.

Stage Separation Using TaskGroup

In complex pipelines, TaskGroup can be used to visually separate stages. According to the official documentation, "TaskGroup is used to organize Tasks into hierarchical groups in the Graph View. It is useful for creating repeating patterns and reducing visual complexity."

from airflow.utils.task_group import TaskGroup

with TaskGroup('data_preparation') as data_prep:
    validate_data >> preprocess >> feature_engineering

with TaskGroup('model_training') as training:
    train >> evaluate

data_prep >> training >> branch

6. Passing Metrics/Artifacts Between Tasks with XCom

According to the official documentation, XCom is explicitly "pushed" to and "pulled" from storage via the xcom_push and xcom_pull methods. In ML pipelines, XCom is used for the following purposes:

Passing Training Metrics

# train.py (executed inside KubernetesPodOperator)
import json

metrics = {
    'accuracy': 0.9534,
    'f1_score': 0.9421,
    'loss': 0.0312,
    'model_path': 's3://ml-models/experiment-42/model.pt',
    'run_id': 'exp-42-20260301',
}

# Write JSON to /airflow/xcom/return.json
with open('/airflow/xcom/return.json', 'w') as f:
    json.dump(metrics, f)

Auto-Push and Return Value

Many Operators and the @task decorator automatically push the return value as an XCom with the key return_value when do_xcom_push=True (default). You can then pull it as follows:

value = task_instance.xcom_pull(task_ids='train_model')

Multiple Outputs

When returning a dictionary with multiple_outputs=True, each key is stored as an individual XCom.

@task(multiple_outputs=True)
def evaluate_model(**context):
    return {
        'accuracy': 0.95,
        'f1_score': 0.94,
        'model_path': 's3://models/latest.pt',
    }

You can then pull by individual key or pull the entire return_value.

Custom XCom Backend

The default XCom Backend, BaseXCom, stores XCom in the Airflow database. While this is fine for small amounts of data, a Custom Backend is needed when dealing with model artifacts or large data. You can override the serialize_value and deserialize_value methods by inheriting from BaseXCom.

To use the Object Storage Backend, set the xcom_backend configuration to airflow.providers.common.io.xcom.backend.XComObjectStorageBackend. This allows storing XCom data in S3 or GCS.

[core]
xcom_backend = airflow.providers.common.io.xcom.backend.XComObjectStorageBackend

XCom Usage Precautions

The key points emphasized in the official documentation are:

  • XCom is designed for transferring small amounts of data. Do not pass large data such as DataFrames
  • When a Task is retried after failure, previous XCom is automatically cleared to ensure idempotent execution
  • XCom operations must be performed within the Task Context via get_current_context(), and direct DB updates are not supported

As a practical guideline for ML pipelines, metrics (accuracy, loss, etc.) and path information (model_path, artifact_uri, etc.) should be passed via XCom, while actual model files and datasets should be stored in Object Storage like S3/GCS, with only the paths shared via XCom.

7. Hyperparameter Tuning with Dynamic Task Mapping

Dynamic Task Mapping was introduced in Airflow 2.3 and, according to the official documentation, is "a mechanism that allows creating multiple Tasks at runtime based on current data." The DAG author does not need to know in advance how many Tasks are needed.

This feature is very powerful for ML hyperparameter tuning because it enables parallel training with various hyperparameter combinations.

expand() and partial()

expand() specifies the parameters to map, while partial() specifies fixed parameters common to all Tasks.

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

hyperparams = [
    {'lr': '0.001', 'batch_size': '32', 'optimizer': 'adam'},
    {'lr': '0.01', 'batch_size': '64', 'optimizer': 'adam'},
    {'lr': '0.001', 'batch_size': '32', 'optimizer': 'sgd'},
    {'lr': '0.0001', 'batch_size': '128', 'optimizer': 'adamw'},
]

train_tasks = KubernetesPodOperator.partial(
    task_id='hyperparam_training',
    image='ml-trainer:latest',
    namespace='ml-training',
    get_logs=True,
    is_delete_operator_pod=True,
    do_xcom_push=True,
    container_resources=k8s.V1ResourceRequirements(
        requests={'nvidia.com/gpu': '1'},
        limits={'nvidia.com/gpu': '1'},
    ),
).expand(
    arguments=[
        ['train.py', '--lr', hp['lr'], '--batch-size', hp['batch_size'], '--optimizer', hp['optimizer']]
        for hp in hyperparams
    ],
)

This code automatically creates 4 parallel KubernetesPodOperator Tasks for 4 hyperparameter combinations.

Dynamically Generating Mapping Data from Tasks

Furthermore, you can dynamically generate hyperparameter combinations from upstream Tasks.

@task
def generate_hyperparams():
    """Generate hyperparameters via Grid Search or Random Search"""
    import itertools

    learning_rates = [0.001, 0.01, 0.0001]
    batch_sizes = [32, 64, 128]
    optimizers = ['adam', 'sgd']

    combinations = list(itertools.product(learning_rates, batch_sizes, optimizers))
    return [
        {'lr': str(lr), 'batch_size': str(bs), 'optimizer': opt}
        for lr, bs, opt in combinations
    ]

hp_list = generate_hyperparams()

According to the official documentation, "mappings generated from a Task prohibit the use of trigger_rule=TriggerRule.ALWAYS."

Cross-Product Mapping

Specifying multiple expand() parameters generates all combinations (cross product).

@task
def train(lr: float, batch_size: int):
    # Training logic
    pass

train.expand(lr=[0.001, 0.01], batch_size=[32, 64])
# 2 x 2 = 4 Task Instances created

Map-Reduce Pattern

A pattern for collecting results from mapped Tasks and selecting the optimal model.

@task
def select_best_model(results):
    """Select the best among all hyperparameter combination results"""
    best = max(results, key=lambda x: x['accuracy'])
    return best

training_results = train_task.expand(params=hyperparams)
best = select_best_model(training_results)

According to the official documentation, the collected results are returned as a "lazy proxy" sequence, not an eager list.

Constraints

ItemSettingDefault
Max mapped instances[core] max_map_length1024
Parallel execution per Taskmax_active_tis_per_dagPer-task setting

Only list and dict types can be mapped; other types will raise an UnmappableXComTypePushed error. Additionally, according to the official documentation, "if a field is marked as templated and is mapped, it will not be template-rendered."

8. Detecting Data Arrival with Sensors

In ML pipelines, it is often necessary to wait until training data is ready. Airflow's Sensors are special Operators that wait until a specific condition is met.

S3KeySensor

Waits until training data is uploaded to S3.

from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor

wait_for_data = S3KeySensor(
    task_id='wait_for_training_data',
    bucket_name='ml-datasets',
    bucket_key='daily/{{ ds }}/training_data.parquet',
    aws_conn_id='aws_default',
    poke_interval=300,  # Check every 5 minutes
    timeout=3600 * 6,   # Wait up to 6 hours
    mode='reschedule',  # Return Worker slot
    deferrable=True,    # Async wait using Triggerer
)

mode='reschedule' returns the Worker slot between pokes to improve resource efficiency. Setting deferrable=True allows the Triggerer component to handle polling asynchronously for even more efficient Worker utilization.

ExternalTaskSensor

Waits until a Task from another DAG completes. For example, a pattern where the training DAG starts after the data preprocessing DAG completes.

from airflow.sensors.external_task import ExternalTaskSensor

wait_for_preprocessing = ExternalTaskSensor(
    task_id='wait_for_preprocessing',
    external_dag_id='data_preprocessing_pipeline',
    external_task_id='final_validation',
    allowed_states=['success'],
    execution_delta=timedelta(hours=0),
    timeout=3600,
    poke_interval=60,
    mode='reschedule',
)

allowed_states is a list of permitted states, with ['success'] as the default. execution_delta specifies the time difference from the previous execution to check.

Practical Pattern: Starting Training After Data Arrival

wait_for_data >> validate_data >> preprocess >> train >> evaluate >> deploy

This pattern loosely couples the data pipeline and ML pipeline while ensuring that training only starts after data is ready.

9. MLflow and Airflow Integration Patterns

The role division between Airflow and MLflow integration is clear. Airflow manages "when and in what order to execute," while MLflow records "what happened during execution and where the model is." Through this separation of concerns, data engineers can manage Airflow DAGs without touching ML code, and data scientists can focus on model development without worrying about scheduling infrastructure.

Integration Pattern 1: Direct MLflow Calls Inside Training Tasks

This is the most intuitive pattern. MLflow APIs are called directly inside training containers executed by KubernetesPodOperator.

# train.py (executed inside container)
import mlflow
import mlflow.pytorch

mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("daily-training")

with mlflow.start_run() as run:
    # Hyperparameter logging
    mlflow.log_params({
        'learning_rate': 0.001,
        'batch_size': 64,
        'epochs': 100,
    })

    # Model training
    model = train_model(...)

    # Metric logging
    mlflow.log_metrics({
        'accuracy': 0.95,
        'f1_score': 0.94,
    })

    # Model registration
    mlflow.pytorch.log_model(model, "model")

    # Pass run_id via XCom
    import json
    with open('/airflow/xcom/return.json', 'w') as f:
        json.dump({'run_id': run.info.run_id}, f)

Integration Pattern 2: Utilizing MLflow Model Registry from Airflow

After training is complete, the model is registered in the MLflow Model Registry based on evaluation results and the stage is transitioned.

@task
def register_model(run_id: str, accuracy: float):
    import mlflow
    from mlflow.tracking import MlflowClient

    client = MlflowClient("http://mlflow-server:5000")

    # Register model
    model_uri = f"runs:/{run_id}/model"
    mv = mlflow.register_model(model_uri, "production-model")

    # Transition to Production stage if accuracy exceeds threshold
    if accuracy > 0.93:
        client.transition_model_version_stage(
            name="production-model",
            version=mv.version,
            stage="Production",
        )

Integration Pattern 3: Querying and Comparing Training History from MLflow

@task
def compare_with_baseline():
    from mlflow.tracking import MlflowClient

    client = MlflowClient("http://mlflow-server:5000")

    # Query Production model metrics
    prod_versions = client.get_latest_versions("production-model", stages=["Production"])
    if prod_versions:
        prod_run = client.get_run(prod_versions[0].run_id)
        baseline_accuracy = float(prod_run.data.metrics['accuracy'])
        return {'baseline_accuracy': baseline_accuracy}
    return {'baseline_accuracy': 0.0}

10. Utilizing the TaskFlow API (@task Decorator)

According to the official documentation, the TaskFlow API is "a functional API that defines DAGs and Tasks using decorators," greatly simplifying data transfer and dependency definition between Tasks.

Key Characteristics

The biggest advantage of TaskFlow is automatic XCom management and automatic dependency calculation. When a TaskFlow function is called, it is not executed immediately; instead, an XComArg object representing the result is returned. Using this as input to downstream Tasks causes dependencies to be calculated automatically.

from airflow.sdk import dag, task
from datetime import datetime

@dag(
    schedule='@daily',
    start_date=datetime(2026, 1, 1),
    catchup=False,
)
def ml_training_taskflow():

    @task
    def extract_data():
        """Data extraction"""
        return {'data_path': 's3://bucket/data/2026-03-01.parquet', 'num_rows': 100000}

    @task(multiple_outputs=True)
    def preprocess(data_info: dict):
        """Data preprocessing"""
        return {
            'train_path': f"{data_info['data_path']}/train.parquet",
            'test_path': f"{data_info['data_path']}/test.parquet",
            'feature_count': 128,
        }

    @task
    def train_model(train_path: str, feature_count: int):
        """Model training"""
        # In practice, GPU training would be run via KubernetesPodOperator
        return {
            'model_path': 's3://models/latest.pt',
            'accuracy': 0.95,
        }

    @task
    def evaluate_and_deploy(model_info: dict):
        """Model evaluation and conditional deployment"""
        if model_info['accuracy'] > 0.90:
            return f"Deployed model from {model_info['model_path']}"
        return "Model quality below threshold, skipping deployment"

    # Automatic dependency wiring
    data = extract_data()
    processed = preprocess(data)
    model = train_model(
        train_path=processed['train_path'],
        feature_count=processed['feature_count'],
    )
    evaluate_and_deploy(model)

ml_training_taskflow()

Context Access

Tasks can receive Airflow context variables as keyword arguments.

@task
def log_execution_info(task_instance=None, dag_run=None):
    print(f"Run ID: {task_instance.run_id}")
    print(f"DAG Run: {dag_run.dag_id}")

Object Serialization

TaskFlow supports custom object passing via @dataclass, @attr.define, or custom serialize()/deserialize() methods. Version management is also possible using __version__: ClassVar[int].

11. Production DAG Code Example (Full ML Pipeline)

The following is a production ML training pipeline DAG code that integrates all the concepts covered so far. It includes the entire process from data arrival detection to hyperparameter tuning, optimal model selection, and deployment.

"""
ML Training Pipeline DAG
- S3 data arrival detection
- Data preprocessing (KubernetesPodOperator)
- Hyperparameter tuning with Dynamic Task Mapping
- Optimal model selection and MLflow registration
- Conditional deployment
"""

from __future__ import annotations

from datetime import datetime, timedelta
from airflow import DAG
from airflow.sdk import task
from airflow.operators.python import BranchPythonOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.utils.task_group import TaskGroup
from kubernetes.client import models as k8s

# -- Common Configuration --
GPU_RESOURCES = k8s.V1ResourceRequirements(
    requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
    limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
)

GPU_TOLERATIONS = [
    k8s.V1Toleration(
        key='nvidia.com/gpu',
        operator='Exists',
        effect='NoSchedule',
    ),
]

GPU_NODE_SELECTOR = {'gpu-type': 'a100', 'node-pool': 'ml-training'}

DATA_VOLUME = k8s.V1Volume(
    name='shared-data',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
        claim_name='ml-shared-pvc'
    ),
)

DATA_VOLUME_MOUNT = k8s.V1VolumeMount(
    name='shared-data',
    mount_path='/shared',
)

default_args = {
    'owner': 'ml-team',
    'retries': 2,
    'retry_delay': timedelta(minutes=5),
    'execution_timeout': timedelta(hours=4),
}


with DAG(
    dag_id='ml_full_training_pipeline',
    default_args=default_args,
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
    tags=['ml', 'training', 'production'],
    max_active_runs=1,
    doc_md="""
    ## ML Full Training Pipeline
    A pipeline that retrains the model daily with new data
    and automatically deploys it if quality criteria are met.
    """,
) as dag:

    # ===== Stage 1: Data Arrival Detection =====
    wait_for_data = S3KeySensor(
        task_id='wait_for_training_data',
        bucket_name='ml-datasets',
        bucket_key='daily/{{ ds }}/features.parquet',
        aws_conn_id='aws_default',
        poke_interval=300,
        timeout=3600 * 6,
        mode='reschedule',
        deferrable=True,
    )

    # ===== Stage 2: Data Preprocessing =====
    with TaskGroup('data_preparation') as data_prep:

        validate_data = KubernetesPodOperator(
            task_id='validate_data',
            name='data-validation',
            namespace='ml-training',
            image='ml-tools:latest',
            cmds=['python', 'validate.py'],
            arguments=['--date', '{{ ds }}', '--source', 's3://ml-datasets/daily/{{ ds }}/'],
            env_vars=[
                k8s.V1EnvVar(name='DATA_DATE', value='{{ ds }}'),
            ],
            container_resources=k8s.V1ResourceRequirements(
                requests={'cpu': '2', 'memory': '8Gi'},
                limits={'cpu': '4', 'memory': '16Gi'},
            ),
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
        )

        preprocess = KubernetesPodOperator(
            task_id='preprocess_data',
            name='data-preprocessing',
            namespace='ml-training',
            image='ml-tools:latest',
            cmds=['python', 'preprocess.py'],
            arguments=[
                '--input', 's3://ml-datasets/daily/{{ ds }}/',
                '--output', 's3://ml-processed/{{ ds }}/',
            ],
            container_resources=k8s.V1ResourceRequirements(
                requests={'cpu': '4', 'memory': '32Gi'},
                limits={'cpu': '8', 'memory': '64Gi'},
            ),
            volumes=[DATA_VOLUME],
            volume_mounts=[DATA_VOLUME_MOUNT],
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
        )

        validate_data >> preprocess

    # ===== Stage 3: Hyperparameter Generation and Parallel Training =====
    @task
    def generate_hyperparams():
        """Generate hyperparameter combinations for training"""
        import itertools
        learning_rates = [0.001, 0.0001]
        batch_sizes = [32, 64]
        optimizers = ['adam', 'adamw']

        combos = list(itertools.product(learning_rates, batch_sizes, optimizers))
        return [
            [
                'train.py',
                '--lr', str(lr),
                '--batch-size', str(bs),
                '--optimizer', opt,
                '--data-path', 's3://ml-processed/{{ ds }}/',
                '--experiment-name', 'daily-training-{{ ds }}',
            ]
            for lr, bs, opt in combos
        ]

    hp_args = generate_hyperparams()

    with TaskGroup('hyperparameter_tuning') as hp_tuning:

        train_tasks = KubernetesPodOperator.partial(
            task_id='train_with_hp',
            name='gpu-hp-training',
            namespace='ml-training',
            image='ml-trainer:latest',
            cmds=['python'],
            container_resources=GPU_RESOURCES,
            tolerations=GPU_TOLERATIONS,
            node_selector=GPU_NODE_SELECTOR,
            volumes=[DATA_VOLUME],
            volume_mounts=[DATA_VOLUME_MOUNT],
            env_vars=[
                k8s.V1EnvVar(name='MLFLOW_TRACKING_URI', value='http://mlflow:5000'),
                k8s.V1EnvVar(name='CUDA_VISIBLE_DEVICES', value='0'),
            ],
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
            startup_timeout_seconds=600,
        ).expand(arguments=hp_args)

    # ===== Stage 4: Optimal Model Selection =====
    @task
    def select_best_model(training_results):
        """Select the optimal model from all training results"""
        valid_results = [r for r in training_results if r is not None]
        if not valid_results:
            raise ValueError("No valid training results found")

        best = max(valid_results, key=lambda x: x.get('accuracy', 0))
        print(f"Best model: accuracy={best['accuracy']}, run_id={best['run_id']}")
        return best

    best_model = select_best_model(train_tasks)

    # ===== Stage 5: Model Evaluation =====
    @task(multiple_outputs=True)
    def evaluate_model(model_info: dict):
        """Detailed evaluation of the optimal model"""
        import mlflow
        from mlflow.tracking import MlflowClient

        client = MlflowClient("http://mlflow:5000")

        # Compare with current Production model
        try:
            prod_versions = client.get_latest_versions(
                "production-model", stages=["Production"]
            )
            if prod_versions:
                prod_run = client.get_run(prod_versions[0].run_id)
                baseline_accuracy = float(prod_run.data.metrics.get('accuracy', 0))
            else:
                baseline_accuracy = 0.0
        except Exception:
            baseline_accuracy = 0.0

        new_accuracy = model_info['accuracy']
        improvement = new_accuracy - baseline_accuracy

        return {
            'should_deploy': improvement > 0.005 and new_accuracy > 0.90,
            'new_accuracy': new_accuracy,
            'baseline_accuracy': baseline_accuracy,
            'improvement': improvement,
            'run_id': model_info['run_id'],
            'model_path': model_info.get('model_path', ''),
        }

    eval_result = evaluate_model(best_model)

    # ===== Stage 6: Conditional Deployment =====
    def decide_deployment(**context):
        should_deploy = context['ti'].xcom_pull(
            task_ids='evaluate_model', key='should_deploy'
        )
        if should_deploy:
            return 'deploy_model'
        return 'skip_deployment'

    deployment_branch = BranchPythonOperator(
        task_id='deployment_decision',
        python_callable=decide_deployment,
    )

    deploy = KubernetesPodOperator(
        task_id='deploy_model',
        name='model-deployment',
        namespace='ml-serving',
        image='ml-deployer:latest',
        cmds=['python', 'deploy.py'],
        arguments=[
            '--run-id', "{{ ti.xcom_pull(task_ids='evaluate_model', key='run_id') }}",
            '--model-name', 'production-model',
        ],
        container_resources=k8s.V1ResourceRequirements(
            requests={'cpu': '2', 'memory': '4Gi'},
            limits={'cpu': '4', 'memory': '8Gi'},
        ),
        get_logs=True,
        is_delete_operator_pod=True,
    )

    @task(task_id='skip_deployment')
    def skip_deployment():
        print("Model quality does not meet deployment criteria. Skipping.")

    skip = skip_deployment()

    # ===== Stage 7: Notification =====
    @task(trigger_rule='none_failed_min_one_success')
    def send_notification(**context):
        """Send training result notification"""
        eval_data = context['ti'].xcom_pull(task_ids='evaluate_model', key='return_value')
        print(f"Pipeline completed. Accuracy: {eval_data.get('new_accuracy')}")
        print(f"Improvement: {eval_data.get('improvement')}")
        # Slack / Email notification logic

    notification = send_notification()

    # ===== DAG Dependency Definition =====
    wait_for_data >> data_prep >> hp_args >> hp_tuning
    hp_tuning >> best_model >> eval_result >> deployment_branch
    deployment_branch >> [deploy, skip] >> notification

This DAG integrates the following Airflow features:

  • S3KeySensor: Data arrival waiting (deferrable mode)
  • KubernetesPodOperator: GPU-based training Job execution
  • Dynamic Task Mapping: Parallel hyperparameter training via expand()
  • TaskFlow API: Python function-based Task definition via @task decorator
  • XCom: Training metric and model path passing
  • TaskGroup: Visual stage separation
  • BranchPythonOperator: Conditional deployment based on model quality

In actual production, error handling, SLA settings, and notification integration should be added to improve reliability.

12. References

The content analyzed in this article is based on the following Apache Airflow official documentation and related resources.

Apache Airflow Official Documentation

Quiz

Q1: What is the main topic covered in "Orchestrating ML Training Pipelines with Airflow"?

Analyzing methods for automating ML training pipelines using KubernetesPodOperator, Dynamic Task Mapping, and more, based on the official Apache Airflow documentation.

Q2: What is The Role of Airflow in ML Pipelines? In ML projects, model training does not end simply by calling model.fit(). A series of steps -- data collection, preprocessing, Feature Engineering, model training, evaluation, model registry registration, and deployment -- must be executed in a reproducible and automated manner....

Q3: Explain the core concept of Reviewing Airflow Core Concepts: DAG, Operator, Task, XCom.

According to the official Apache Airflow documentation, the core concepts are as follows. DAG (Directed Acyclic Graph) A DAG is a collection of all Tasks you want to run, organized in a way that reflects their relationships and dependencies.

Q4: What are the key aspects of Running GPU Training Jobs with KubernetesPodOperator?

In ML training pipelines, KubernetesPodOperator plays a central role. According to the official documentation, KubernetesPodOperator is "a stand-in replacement for Kubernetes object spec definitions" that can be executed by the Airflow scheduler within the DAG context.

Q5: How does Analysis of Key KubernetesPodOperator Parameters work? This section analyzes the key parameters of KubernetesPodOperator from an ML training perspective, based on the official documentation. container_resources This is the most important parameter for GPU training.