Skip to content
Published on

AirflowによるML学習パイプラインのオーケストレーション

Authors
  • Name
    Twitter

1. MLパイプラインにおけるAirflowの役割

MLプロジェクトにおいて、モデル学習は単にmodel.fit()を呼び出すだけでは終わらない。データ収集、前処理、Feature Engineering、モデル学習、評価、モデルレジストリ登録、そしてデプロイまで一連のステップを再現可能かつ自動化された方式で実行する必要がある。このときWorkflow Orchestrationツールが必要となり、Apache Airflowはこの領域で最も成熟したオープンソースプロジェクトの一つである。

Airflow vs 他のオーケストレーションツール

MLパイプラインのオーケストレーションツールとしては、Airflow以外にもKubeflow Pipelines、Prefect、Dagsterなどがある。それぞれの特性を比較してみよう。

Apache Airflowは汎用ワークフローオーケストレーターであり、MLに特化したツールではないが、豊富なOperatorエコシステム、強力なスケジューリング機能、そしてKubernetes上での拡張性のおかげでMLパイプラインでも広く使用されている。Airflow 3.xではDAG versioning、event-driven scheduling、multi-language supportなどが追加され、さらに強力になった。

Kubeflow PipelinesはKubernetesネイティブのMLパイプラインツールで、MLワークフローに特化している。GPUスケジューリング、実験追跡、モデルサービングまで統合的に提供するが、Kubernetesへの深い理解が前提となり、Kubernetes以外の環境では使用が難しい。

Prefectは「Airflow, but nicer」という評価を受けるツールで、Python-nativeなインターフェースと素早いセットアップが利点である。Dynamic flow、hybrid execution、リアルタイムSLA alertingなどを提供し、特に小規模チームやPoC段階で素早く適用するのに適している。

Dagsterはasset-centricアプローチを取る現代的なオーケストレーターである。データパイプラインを「実行ステップ」ではなく「生産するデータ」を中心に設計する哲学を持ち、データリネージ追跡に強みがある。

特性AirflowKubeflow PipelinesPrefectDagster
ML特化汎用ML特化汎用Data-centric
学習曲線中程度高い低い中程度
Kubernetes依存性選択的必須選択的選択的
コミュニティ/エコシステム非常に大きい大きい成長中成長中
GPUスケジューリングKubernetesPodOperatorネイティブKubernetes連携Kubernetes連携
スケジューリング非常に強力限定的強力強力

AirflowをMLパイプラインオーケストレーターとして選択する最大の理由はツール非依存性(tool-agnostic) である。AirflowはMLflow、SageMaker、Vertex AIなど様々なMLツールと連携可能であり、ML以外のデータパイプライン(ETL、データ品質チェックなど)と同一のインフラで管理できる。

2. Airflowコア概念の復習:DAG、Operator、Task、XCom

Apache Airflowの公式ドキュメントによると、コア概念は以下の通りである。

DAG (Directed Acyclic Graph)

DAGは実行したい全てのTaskの集合であり、Task間の関係と依存性を反映する形で構成される。DAGはPythonスクリプトで定義され、コード自体がDAGの構造(Taskとその依存性)を表現する。DAG runはDAGのインスタンス化(instantiation)であり、特定のexecution_dateに対して実行されるTask Instanceを含む。

from airflow import DAG
from datetime import datetime

with DAG(
    dag_id='ml_training_pipeline',
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
    tags=['ml', 'training'],
) as dag:
    # Task定義
    pass

Operator

Operatorは事前に定義されたTaskのテンプレートである。公式ドキュメントによると、「DAGがワークフローをどのように実行するかを記述するなら、OperatorはTaskが実際に何を実行するかを決定する。」全てのOperatorはBaseOperatorを継承し、PythonOperator、BashOperator、KubernetesPodOperatorなどが代表的である。

Task

TaskはOperatorのインスタンスである。DAG内で実行可能な単位であり、Task Instanceは特定のDAG runにおけるTaskの具体的な実行を表す。

XCom (Cross-Communication)

XComはTask間でデータをやり取りするメカニズムである。公式ドキュメントによると、「デフォルトではTaskは完全に分離されており、異なるマシンで実行される可能性があるため」XComが必要である。XComはkeytask_iddag_idで識別される。

# Push
task_instance.xcom_push(key="model_accuracy", value=0.95)

# Pull
accuracy = task_instance.xcom_pull(key="model_accuracy", task_ids="evaluate_model")

重要な点は、XComは少量のデータ転送用に設計されていることである。DataFrameのような大容量データをXComで転送してはならず、そのような場合はObject Storage Backendを使用すべきである。

3. KubernetesPodOperatorによるGPU学習Jobの実行

ML学習パイプラインにおいて、KubernetesPodOperatorは中核的な役割を果たす。公式ドキュメントによると、KubernetesPodOperatorは「Kubernetes object spec定義の代替物」であり、DAGコンテキスト内でAirflowスケジューラーによって実行できる。つまり、Podスペックのための別個のYAML/JSONファイルを作成する必要なく、PythonコードでKubernetes Podを直接定義し実行できる。

GPUベースのML学習でKubernetesPodOperatorが特に有用な理由は以下の通りである。

  • Taskレベルのリソース設定:各TaskごとにCPU、Memory、GPUを独立して割り当てられる
  • カスタム依存性サポート:PyPIにないパッケージや特定のCUDAバージョンが必要な場合、カスタムDockerイメージを指定できる
  • 言語非依存性:Pythonだけでなく、どの言語で書かれた学習コードでもコンテナでラップして実行できる
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s

train_model = KubernetesPodOperator(
    task_id='train_model',
    name='gpu-training-job',
    namespace='ml-training',
    image='my-registry/ml-trainer:latest',
    cmds=['python'],
    arguments=['train.py', '--epochs', '100', '--batch-size', '64'],
    container_resources=k8s.V1ResourceRequirements(
        requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
        limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
    ),
    get_logs=True,
    is_delete_operator_pod=True,
    do_xcom_push=True,
)

do_xcom_push=Trueに設定すると、Pod内で/airflow/xcom/return.jsonファイルにJSONを書き込むことでXComとしてデータを返却できる。学習結果のメトリクスを次のTaskに渡す際に便利である。

4. KubernetesPodOperatorの主要パラメータ分析

公式ドキュメントに基づき、KubernetesPodOperatorの主要パラメータをML学習の観点から分析する。

container_resources

GPU学習で最も重要なパラメータである。kubernetes.client.models.V1ResourceRequirementsを使用してCPU、Memory、GPUリソースをリクエスト(requests)し制限(limits)できる。

container_resources = k8s.V1ResourceRequirements(
    requests={
        'cpu': '4',
        'memory': '16Gi',
        'nvidia.com/gpu': '1'
    },
    limits={
        'cpu': '8',
        'memory': '32Gi',
        'nvidia.com/gpu': '1'
    },
)

GPUの場合はnvidia.com/gpuリソースキーを使用し、一般的にrequestsとlimitsを同一に設定する。GPUは分割割り当てが不可能なためである(MIGを使用しない限り)。

tolerations

GPUノードは通常taintが設定されており、一般的なワークロードがスケジューリングされないように管理される。tolerationsを通じてGPUノードにPodがスケジューリングされることを許可する。

tolerations = [
    k8s.V1Toleration(
        key='nvidia.com/gpu',
        operator='Exists',
        effect='NoSchedule',
    ),
]

node_selector

特定のノードグループでのみ学習Podが実行されるように指定できる。例えば、A100 GPUが搭載されたノードでのみ実行したい場合は以下のように設定する。

node_selector = {
    'gpu-type': 'a100',
    'node-pool': 'ml-training',
}

affinity

より細かいスケジューリング制御が必要な場合はV1Affinityを使用する。Node affinity、Pod affinity/anti-affinityを設定して特定のzoneやノードにPodを配置できる。

affinity = k8s.V1Affinity(
    node_affinity=k8s.V1NodeAffinity(
        required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(
            node_selector_terms=[
                k8s.V1NodeSelectorTerm(
                    match_expressions=[
                        k8s.V1NodeSelectorRequirement(
                            key='gpu-type',
                            operator='In',
                            values=['a100', 'h100'],
                        )
                    ]
                )
            ]
        )
    )
)

volumes / volume_mounts

学習データセットやチェックポイントをPVC(PersistentVolumeClaim)でマウントする際に使用する。

volume = k8s.V1Volume(
    name='training-data',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
        claim_name='ml-dataset-pvc'
    ),
)

volume_mount = k8s.V1VolumeMount(
    name='training-data',
    mount_path='/data',
    read_only=True,
)

その他の主要パラメータ

パラメータ説明ML学習での活用
image_pull_secretsPrivate Registry認証社内MLイメージへのアクセス
env_vars環境変数リストCUDA_VISIBLE_DEVICES、WANDB_API_KEYなど
secretsSecretボリューム/環境変数APIキー、認証情報管理
is_delete_operator_pod完了後Pod削除リソース整理(True推奨)
get_logsログをAirflow UIに表示学習ログモニタリング
deferrable非同期実行モード長時間学習時のWorker効率化
on_finish_action完了後の動作delete_podで自動整理

公式ドキュメントでは、型安全性のためにconvenience classesの代わりにkubernetes.client.modelsのネイティブオブジェクトの使用を推奨している。

5. DAG設計パターン:データ前処理 から 学習 から 評価 から デプロイ

ML学習パイプラインの一般的なDAG構造は以下のようなステージを含む。

データ検証 -> 前処理 -> Feature Engineering -> 学習 -> 評価 -> [条件付き] デプロイ

各ステージをAirflow Taskにマッピングすると以下のような設計パターンになる。

from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from datetime import datetime

with DAG(
    dag_id='ml_training_pipeline',
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
) as dag:

    validate_data = KubernetesPodOperator(
        task_id='validate_data',
        image='ml-tools:latest',
        cmds=['python', 'validate.py'],
        # CPUのみ必要
    )

    preprocess = KubernetesPodOperator(
        task_id='preprocess_data',
        image='ml-tools:latest',
        cmds=['python', 'preprocess.py'],
        # 大容量メモリ必要
    )

    train = KubernetesPodOperator(
        task_id='train_model',
        image='ml-trainer:latest',
        cmds=['python', 'train.py'],
        # GPU必要
        do_xcom_push=True,
    )

    evaluate = KubernetesPodOperator(
        task_id='evaluate_model',
        image='ml-trainer:latest',
        cmds=['python', 'evaluate.py'],
        do_xcom_push=True,
    )

    def check_model_quality(**context):
        metrics = context['ti'].xcom_pull(task_ids='evaluate_model')
        if metrics['accuracy'] > 0.90:
            return 'deploy_model'
        return 'notify_failure'

    branch = BranchPythonOperator(
        task_id='check_quality',
        python_callable=check_model_quality,
    )

    deploy = KubernetesPodOperator(
        task_id='deploy_model',
        image='ml-deployer:latest',
        cmds=['python', 'deploy.py'],
    )

    validate_data >> preprocess >> train >> evaluate >> branch
    branch >> [deploy]

この設計パターンの核心は条件付きデプロイである。評価ステージでモデル性能メトリクスをXComで渡し、BranchPythonOperatorを通じて品質基準を満たす場合にのみデプロイTaskが実行される。これにより品質の低いモデルがプロダクションにデプロイされることを防ぐ。

TaskGroupを活用したステージ区分

複雑なパイプラインではTaskGroupを活用して視覚的にステージを区分できる。公式ドキュメントによると、「TaskGroupはGraph ViewでTaskを階層的なグループに組織化するために使用される。繰り返しパターンの作成と視覚的な複雑さの削減に有用である。」

from airflow.utils.task_group import TaskGroup

with TaskGroup('data_preparation') as data_prep:
    validate_data >> preprocess >> feature_engineering

with TaskGroup('model_training') as training:
    train >> evaluate

data_prep >> training >> branch

6. XComによるTask間のメトリクス/アーティファクト転送

公式ドキュメントによると、XComはxcom_pushxcom_pullメソッドを通じて明示的にストレージに「push」および「pull」される。MLパイプラインでXComは以下のような用途で活用される。

学習メトリクスの転送

# train.py(KubernetesPodOperator内部で実行)
import json

metrics = {
    'accuracy': 0.9534,
    'f1_score': 0.9421,
    'loss': 0.0312,
    'model_path': 's3://ml-models/experiment-42/model.pt',
    'run_id': 'exp-42-20260301',
}

# /airflow/xcom/return.jsonにJSON書き込み
with open('/airflow/xcom/return.json', 'w') as f:
    json.dump(metrics, f)

Auto-PushとReturn Value

多くのOperatorと@taskデコレーターはdo_xcom_push=True(デフォルト)の場合、戻り値をreturn_valueというキーのXComとして自動pushする。その後以下のようにpullできる。

value = task_instance.xcom_pull(task_ids='train_model')

Multiple Outputs

ディクショナリを返しmultiple_outputs=Trueを設定すると、各キーが個別のXComとして保存される。

@task(multiple_outputs=True)
def evaluate_model(**context):
    return {
        'accuracy': 0.95,
        'f1_score': 0.94,
        'model_path': 's3://models/latest.pt',
    }

その後、個別キーでpullするか、全体のreturn_valueでpullできる。

Custom XCom Backend

デフォルトのXCom BackendであるBaseXComはAirflowデータベースにXComを保存する。少量のデータには問題ないが、モデルアーティファクトや大容量データを扱う場合はCustom Backendが必要である。BaseXComを継承してserialize_valuedeserialize_valueメソッドをオーバーライドすればよい。

Object Storage Backendを使用するには、xcom_backend設定をairflow.providers.common.io.xcom.backend.XComObjectStorageBackendに指定する。これによりS3やGCSにXComデータを保存できる。

[core]
xcom_backend = airflow.providers.common.io.xcom.backend.XComObjectStorageBackend

XCom使用時の注意事項

公式ドキュメントで強調されている核心事項は以下の通りである。

  • XComは少量のデータ転送用に設計されている。DataFrameなどの大容量データを転送してはならない
  • Taskが失敗後にリトライされる場合、以前のXComは自動的にクリアされ、冪等な実行が保証される
  • XCom操作はget_current_context()を通じたTask Contextで行う必要があり、直接的なDB更新はサポートされていない

MLパイプラインでの実用的なガイドラインとしては、メトリクス(accuracy、lossなど)やパス情報(model_path、artifact_uriなど)はXComで転送し、実際のモデルファイルやデータセットはS3/GCSのようなObject Storageに保存してパスのみをXComで共有するのが望ましい。

7. Dynamic Task Mappingによるハイパーパラメータチューニング

Dynamic Task MappingはAirflow 2.3で導入された機能で、公式ドキュメントによると「ランタイムに現在のデータに基づいて複数のTaskを生成できるメカニズム」である。DAG作成者が事前にいくつのTaskが必要かを知る必要がない。

MLハイパーパラメータチューニングにおいてこの機能は非常に強力である。様々なハイパーパラメータの組み合わせで並列学習を実行できるためである。

expand()とpartial()

expand()はマッピングするパラメータを、partial()は全てのTaskに共通で渡す固定パラメータを指定する。

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

hyperparams = [
    {'lr': '0.001', 'batch_size': '32', 'optimizer': 'adam'},
    {'lr': '0.01', 'batch_size': '64', 'optimizer': 'adam'},
    {'lr': '0.001', 'batch_size': '32', 'optimizer': 'sgd'},
    {'lr': '0.0001', 'batch_size': '128', 'optimizer': 'adamw'},
]

train_tasks = KubernetesPodOperator.partial(
    task_id='hyperparam_training',
    image='ml-trainer:latest',
    namespace='ml-training',
    get_logs=True,
    is_delete_operator_pod=True,
    do_xcom_push=True,
    container_resources=k8s.V1ResourceRequirements(
        requests={'nvidia.com/gpu': '1'},
        limits={'nvidia.com/gpu': '1'},
    ),
).expand(
    arguments=[
        ['train.py', '--lr', hp['lr'], '--batch-size', hp['batch_size'], '--optimizer', hp['optimizer']]
        for hp in hyperparams
    ],
)

このコードは4つのハイパーパラメータの組み合わせに対して自動的に4つの並列KubernetesPodOperator Taskを生成する。

Taskからの動的マッピングデータ生成

さらに、上流Taskから動的にハイパーパラメータの組み合わせを生成することもできる。

@task
def generate_hyperparams():
    """Grid SearchまたはRandom Searchでハイパーパラメータ生成"""
    import itertools

    learning_rates = [0.001, 0.01, 0.0001]
    batch_sizes = [32, 64, 128]
    optimizers = ['adam', 'sgd']

    combinations = list(itertools.product(learning_rates, batch_sizes, optimizers))
    return [
        {'lr': str(lr), 'batch_size': str(bs), 'optimizer': opt}
        for lr, bs, opt in combinations
    ]

hp_list = generate_hyperparams()

公式ドキュメントによると、「Taskから生成されたマッピングはtrigger_rule=TriggerRule.ALWAYSの使用を禁止」するという制約がある。

Cross-Product Mapping

複数のexpand()パラメータを指定すると、全ての組み合わせ(cross product)が生成される。

@task
def train(lr: float, batch_size: int):
    # 学習ロジック
    pass

train.expand(lr=[0.001, 0.01], batch_size=[32, 64])
# 2 x 2 = 4つのTask Instance生成

Map-Reduceパターン

マッピングされたTaskの結果を収集し、最適モデルを選択するパターンである。

@task
def select_best_model(results):
    """全ハイパーパラメータ組み合わせの結果から最適なものを選択"""
    best = max(results, key=lambda x: x['accuracy'])
    return best

training_results = train_task.expand(params=hyperparams)
best = select_best_model(training_results)

公式ドキュメントによると、収集された結果は「lazy proxy」シーケンスとして返され、eager listではないことに注意が必要である。

制約事項

項目設定デフォルト
最大マッピングインスタンス数[core] max_map_length1024
Task当たりの並列実行制限max_active_tis_per_dagTask別設定

マッピング可能なデータ型はlistdictのみであり、他の型ではUnmappableXComTypePushedエラーが発生する。また公式ドキュメントによると、「フィールドがtemplatedとマークされておりマッピングされている場合、そのフィールドはテンプレート処理されない。」

8. Sensorを活用したデータ到着検知

MLパイプラインでは、学習データが準備されるまで待機する必要がある場合が多い。AirflowのSensorは特定の条件が満たされるまで待機する特殊なOperatorである。

S3KeySensor

S3に学習データがアップロードされるまで待機する。

from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor

wait_for_data = S3KeySensor(
    task_id='wait_for_training_data',
    bucket_name='ml-datasets',
    bucket_key='daily/{{ ds }}/training_data.parquet',
    aws_conn_id='aws_default',
    poke_interval=300,  # 5分ごとに確認
    timeout=3600 * 6,   # 最大6時間待機
    mode='reschedule',  # Workerスロット返却
    deferrable=True,    # Triggererによる非同期待機
)

mode='reschedule'はpoke間にWorkerスロットを返却してリソース効率性を高める。deferrable=Trueを設定するとTriggererコンポーネントが非同期的にpollingを処理し、より効率的なWorker活用が可能になる。

ExternalTaskSensor

他のDAGのTaskが完了するまで待機する。例えば、データ前処理DAGが完了した後に学習DAGを開始するパターンである。

from airflow.sensors.external_task import ExternalTaskSensor

wait_for_preprocessing = ExternalTaskSensor(
    task_id='wait_for_preprocessing',
    external_dag_id='data_preprocessing_pipeline',
    external_task_id='final_validation',
    allowed_states=['success'],
    execution_delta=timedelta(hours=0),
    timeout=3600,
    poke_interval=60,
    mode='reschedule',
)

allowed_statesは許可される状態のリストで、デフォルトは['success']である。execution_deltaは確認する前回実行との時間差を指定する。

実用的パターン:データ到着後の学習開始

wait_for_data >> validate_data >> preprocess >> train >> evaluate >> deploy

このパターンはデータパイプラインとMLパイプラインを疎結合(loosely couple)しつつ、データが準備された後にのみ学習が開始されることを保証する。

9. MLflowとAirflowの連携パターン

AirflowとMLflowの連携における役割分担は明確である。Airflowは「いつ、どの順序で実行するか」を管理し、MLflowは「実行過程で何が起きたか、モデルがどこにあるか」を記録する。 この関心事の分離(separation of concerns)により、データエンジニアはMLコードに触れずにAirflow DAGを管理し、データサイエンティストはスケジューリングインフラを心配せずにモデル開発に集中できる。

連携パターン1:学習Task内部でのMLflow直接呼び出し

最も直感的なパターンである。KubernetesPodOperatorで実行される学習コンテナ内部でMLflow APIを直接呼び出す。

# train.py(コンテナ内部で実行)
import mlflow
import mlflow.pytorch

mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("daily-training")

with mlflow.start_run() as run:
    # ハイパーパラメータのロギング
    mlflow.log_params({
        'learning_rate': 0.001,
        'batch_size': 64,
        'epochs': 100,
    })

    # モデル学習
    model = train_model(...)

    # メトリクスのロギング
    mlflow.log_metrics({
        'accuracy': 0.95,
        'f1_score': 0.94,
    })

    # モデル登録
    mlflow.pytorch.log_model(model, "model")

    # run_idをXComで転送
    import json
    with open('/airflow/xcom/return.json', 'w') as f:
        json.dump({'run_id': run.info.run_id}, f)

連携パターン2:AirflowからのMLflow Model Registry活用

学習完了後、モデル評価結果に基づきMLflow Model Registryにモデルを登録しステージを遷移する。

@task
def register_model(run_id: str, accuracy: float):
    import mlflow
    from mlflow.tracking import MlflowClient

    client = MlflowClient("http://mlflow-server:5000")

    # モデル登録
    model_uri = f"runs:/{run_id}/model"
    mv = mlflow.register_model(model_uri, "production-model")

    # 精度が基準を超えたらProductionステージに遷移
    if accuracy > 0.93:
        client.transition_model_version_stage(
            name="production-model",
            version=mv.version,
            stage="Production",
        )

連携パターン3:MLflowからの学習履歴照会と比較

@task
def compare_with_baseline():
    from mlflow.tracking import MlflowClient

    client = MlflowClient("http://mlflow-server:5000")

    # Productionモデルのメトリクス照会
    prod_versions = client.get_latest_versions("production-model", stages=["Production"])
    if prod_versions:
        prod_run = client.get_run(prod_versions[0].run_id)
        baseline_accuracy = float(prod_run.data.metrics['accuracy'])
        return {'baseline_accuracy': baseline_accuracy}
    return {'baseline_accuracy': 0.0}

10. TaskFlow APIの活用(@taskデコレーター)

TaskFlow APIは公式ドキュメントによると「デコレーターを使用してDAGとTaskを定義する関数型API」であり、Task間のデータ転送と依存性定義を大幅に簡素化する。

核心的な特性

TaskFlowの最大の利点は自動XCom管理自動依存性計算である。TaskFlow関数を呼び出すと実行されるのではなく、結果を表すXComArgオブジェクトが返される。これをダウンストリームTaskの入力として使用すると依存性が自動的に計算される。

from airflow.sdk import dag, task
from datetime import datetime

@dag(
    schedule='@daily',
    start_date=datetime(2026, 1, 1),
    catchup=False,
)
def ml_training_taskflow():

    @task
    def extract_data():
        """データ抽出"""
        return {'data_path': 's3://bucket/data/2026-03-01.parquet', 'num_rows': 100000}

    @task(multiple_outputs=True)
    def preprocess(data_info: dict):
        """データ前処理"""
        return {
            'train_path': f"{data_info['data_path']}/train.parquet",
            'test_path': f"{data_info['data_path']}/test.parquet",
            'feature_count': 128,
        }

    @task
    def train_model(train_path: str, feature_count: int):
        """モデル学習"""
        # 実際にはKubernetesPodOperatorでGPU学習を実行
        return {
            'model_path': 's3://models/latest.pt',
            'accuracy': 0.95,
        }

    @task
    def evaluate_and_deploy(model_info: dict):
        """モデル評価および条件付きデプロイ"""
        if model_info['accuracy'] > 0.90:
            return f"Deployed model from {model_info['model_path']}"
        return "Model quality below threshold, skipping deployment"

    # 自動依存性接続
    data = extract_data()
    processed = preprocess(data)
    model = train_model(
        train_path=processed['train_path'],
        feature_count=processed['feature_count'],
    )
    evaluate_and_deploy(model)

ml_training_taskflow()

Contextアクセス

TaskはAirflowのcontext変数をキーワード引数として受け取ることができる。

@task
def log_execution_info(task_instance=None, dag_run=None):
    print(f"Run ID: {task_instance.run_id}")
    print(f"DAG Run: {dag_run.dag_id}")

オブジェクトシリアライゼーション

TaskFlowは@dataclass@attr.defineまたはカスタムserialize()/deserialize()メソッドによるカスタムオブジェクト転送をサポートする。__version__: ClassVar[int]を使用したバージョン管理も可能である。

11. 実践DAGコード例(完全ML パイプライン)

ここまで扱った全ての概念を統合した実践ML学習パイプラインDAGコードである。データ到着検知からハイパーパラメータチューニング、最適モデル選択、デプロイまでの全プロセスを含む。

"""
ML Training Pipeline DAG
- S3データ到着検知
- データ前処理(KubernetesPodOperator)
- Dynamic Task Mappingを活用したハイパーパラメータチューニング
- 最適モデル選択およびMLflow登録
- 条件付きデプロイ
"""

from __future__ import annotations

from datetime import datetime, timedelta
from airflow import DAG
from airflow.sdk import task
from airflow.operators.python import BranchPythonOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.utils.task_group import TaskGroup
from kubernetes.client import models as k8s

# -- 共通設定 --
GPU_RESOURCES = k8s.V1ResourceRequirements(
    requests={'cpu': '4', 'memory': '16Gi', 'nvidia.com/gpu': '1'},
    limits={'cpu': '8', 'memory': '32Gi', 'nvidia.com/gpu': '1'},
)

GPU_TOLERATIONS = [
    k8s.V1Toleration(
        key='nvidia.com/gpu',
        operator='Exists',
        effect='NoSchedule',
    ),
]

GPU_NODE_SELECTOR = {'gpu-type': 'a100', 'node-pool': 'ml-training'}

DATA_VOLUME = k8s.V1Volume(
    name='shared-data',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
        claim_name='ml-shared-pvc'
    ),
)

DATA_VOLUME_MOUNT = k8s.V1VolumeMount(
    name='shared-data',
    mount_path='/shared',
)

default_args = {
    'owner': 'ml-team',
    'retries': 2,
    'retry_delay': timedelta(minutes=5),
    'execution_timeout': timedelta(hours=4),
}


with DAG(
    dag_id='ml_full_training_pipeline',
    default_args=default_args,
    start_date=datetime(2026, 1, 1),
    schedule_interval='@daily',
    catchup=False,
    tags=['ml', 'training', 'production'],
    max_active_runs=1,
    doc_md="""
    ## ML Full Training Pipeline
    毎日新しいデータでモデルを再学習し、
    品質基準を満たした場合に自動デプロイするパイプライン。
    """,
) as dag:

    # ===== Stage 1: データ到着検知 =====
    wait_for_data = S3KeySensor(
        task_id='wait_for_training_data',
        bucket_name='ml-datasets',
        bucket_key='daily/{{ ds }}/features.parquet',
        aws_conn_id='aws_default',
        poke_interval=300,
        timeout=3600 * 6,
        mode='reschedule',
        deferrable=True,
    )

    # ===== Stage 2: データ前処理 =====
    with TaskGroup('data_preparation') as data_prep:

        validate_data = KubernetesPodOperator(
            task_id='validate_data',
            name='data-validation',
            namespace='ml-training',
            image='ml-tools:latest',
            cmds=['python', 'validate.py'],
            arguments=['--date', '{{ ds }}', '--source', 's3://ml-datasets/daily/{{ ds }}/'],
            env_vars=[
                k8s.V1EnvVar(name='DATA_DATE', value='{{ ds }}'),
            ],
            container_resources=k8s.V1ResourceRequirements(
                requests={'cpu': '2', 'memory': '8Gi'},
                limits={'cpu': '4', 'memory': '16Gi'},
            ),
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
        )

        preprocess = KubernetesPodOperator(
            task_id='preprocess_data',
            name='data-preprocessing',
            namespace='ml-training',
            image='ml-tools:latest',
            cmds=['python', 'preprocess.py'],
            arguments=[
                '--input', 's3://ml-datasets/daily/{{ ds }}/',
                '--output', 's3://ml-processed/{{ ds }}/',
            ],
            container_resources=k8s.V1ResourceRequirements(
                requests={'cpu': '4', 'memory': '32Gi'},
                limits={'cpu': '8', 'memory': '64Gi'},
            ),
            volumes=[DATA_VOLUME],
            volume_mounts=[DATA_VOLUME_MOUNT],
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
        )

        validate_data >> preprocess

    # ===== Stage 3: ハイパーパラメータ生成および並列学習 =====
    @task
    def generate_hyperparams():
        """学習するハイパーパラメータの組み合わせを生成"""
        import itertools
        learning_rates = [0.001, 0.0001]
        batch_sizes = [32, 64]
        optimizers = ['adam', 'adamw']

        combos = list(itertools.product(learning_rates, batch_sizes, optimizers))
        return [
            [
                'train.py',
                '--lr', str(lr),
                '--batch-size', str(bs),
                '--optimizer', opt,
                '--data-path', 's3://ml-processed/{{ ds }}/',
                '--experiment-name', 'daily-training-{{ ds }}',
            ]
            for lr, bs, opt in combos
        ]

    hp_args = generate_hyperparams()

    with TaskGroup('hyperparameter_tuning') as hp_tuning:

        train_tasks = KubernetesPodOperator.partial(
            task_id='train_with_hp',
            name='gpu-hp-training',
            namespace='ml-training',
            image='ml-trainer:latest',
            cmds=['python'],
            container_resources=GPU_RESOURCES,
            tolerations=GPU_TOLERATIONS,
            node_selector=GPU_NODE_SELECTOR,
            volumes=[DATA_VOLUME],
            volume_mounts=[DATA_VOLUME_MOUNT],
            env_vars=[
                k8s.V1EnvVar(name='MLFLOW_TRACKING_URI', value='http://mlflow:5000'),
                k8s.V1EnvVar(name='CUDA_VISIBLE_DEVICES', value='0'),
            ],
            get_logs=True,
            is_delete_operator_pod=True,
            do_xcom_push=True,
            startup_timeout_seconds=600,
        ).expand(arguments=hp_args)

    # ===== Stage 4: 最適モデル選択 =====
    @task
    def select_best_model(training_results):
        """全学習結果から最適モデルを選択"""
        valid_results = [r for r in training_results if r is not None]
        if not valid_results:
            raise ValueError("No valid training results found")

        best = max(valid_results, key=lambda x: x.get('accuracy', 0))
        print(f"Best model: accuracy={best['accuracy']}, run_id={best['run_id']}")
        return best

    best_model = select_best_model(train_tasks)

    # ===== Stage 5: モデル評価 =====
    @task(multiple_outputs=True)
    def evaluate_model(model_info: dict):
        """最適モデルの詳細評価"""
        import mlflow
        from mlflow.tracking import MlflowClient

        client = MlflowClient("http://mlflow:5000")

        # 現在のProductionモデルとの比較
        try:
            prod_versions = client.get_latest_versions(
                "production-model", stages=["Production"]
            )
            if prod_versions:
                prod_run = client.get_run(prod_versions[0].run_id)
                baseline_accuracy = float(prod_run.data.metrics.get('accuracy', 0))
            else:
                baseline_accuracy = 0.0
        except Exception:
            baseline_accuracy = 0.0

        new_accuracy = model_info['accuracy']
        improvement = new_accuracy - baseline_accuracy

        return {
            'should_deploy': improvement > 0.005 and new_accuracy > 0.90,
            'new_accuracy': new_accuracy,
            'baseline_accuracy': baseline_accuracy,
            'improvement': improvement,
            'run_id': model_info['run_id'],
            'model_path': model_info.get('model_path', ''),
        }

    eval_result = evaluate_model(best_model)

    # ===== Stage 6: 条件付きデプロイ =====
    def decide_deployment(**context):
        should_deploy = context['ti'].xcom_pull(
            task_ids='evaluate_model', key='should_deploy'
        )
        if should_deploy:
            return 'deploy_model'
        return 'skip_deployment'

    deployment_branch = BranchPythonOperator(
        task_id='deployment_decision',
        python_callable=decide_deployment,
    )

    deploy = KubernetesPodOperator(
        task_id='deploy_model',
        name='model-deployment',
        namespace='ml-serving',
        image='ml-deployer:latest',
        cmds=['python', 'deploy.py'],
        arguments=[
            '--run-id', "{{ ti.xcom_pull(task_ids='evaluate_model', key='run_id') }}",
            '--model-name', 'production-model',
        ],
        container_resources=k8s.V1ResourceRequirements(
            requests={'cpu': '2', 'memory': '4Gi'},
            limits={'cpu': '4', 'memory': '8Gi'},
        ),
        get_logs=True,
        is_delete_operator_pod=True,
    )

    @task(task_id='skip_deployment')
    def skip_deployment():
        print("Model quality does not meet deployment criteria. Skipping.")

    skip = skip_deployment()

    # ===== Stage 7: 通知 =====
    @task(trigger_rule='none_failed_min_one_success')
    def send_notification(**context):
        """学習結果の通知送信"""
        eval_data = context['ti'].xcom_pull(task_ids='evaluate_model', key='return_value')
        print(f"Pipeline completed. Accuracy: {eval_data.get('new_accuracy')}")
        print(f"Improvement: {eval_data.get('improvement')}")
        # Slack / Email通知ロジック

    notification = send_notification()

    # ===== DAG依存性定義 =====
    wait_for_data >> data_prep >> hp_args >> hp_tuning
    hp_tuning >> best_model >> eval_result >> deployment_branch
    deployment_branch >> [deploy, skip] >> notification

このDAGは以下のAirflow機能を統合的に活用している。

  • S3KeySensor:データ到着待機(deferrableモード)
  • KubernetesPodOperator:GPUベースの学習Job実行
  • Dynamic Task Mappingexpand()を通じたハイパーパラメータ並列学習
  • TaskFlow API@taskデコレーターによるPython関数ベースのTask定義
  • XCom:学習メトリクスとモデルパスの転送
  • TaskGroup:視覚的ステージ区分
  • BranchPythonOperator:モデル品質に基づく条件付きデプロイ

実際のプロダクションでは、エラーハンドリング、SLA設定、通知統合などを追加して安定性を高めるべきである。

12. References

本記事で分析した内容は以下のApache Airflow公式ドキュメントおよび関連資料に基づいて作成された。

Apache Airflow公式ドキュメント

関連資料