TECHSCORE BLOG

クラウドCRMを提供するシナジーマーケティングのエンジニアブログです。

PythonでASL(Amazon States Language)を自動生成するアイデア


generated by DALL-E3; "Illustration: Two distinct flowcharts in labeled boxes. The 'Python' flowchart on the left has various coding structures connected with arrows. The 'ASL' flowchart on the right depicts a task sequence. A prominent conversion arrow links the two charts."

はじめに

AWS LambdaAWS Step Functions によるサーバレスなバッチアプリケーション開発は素晴らしいのですが、ワークフローを定義するASL: Amazon States Languageのデバッグには苦労させられます。本記事では、Python コードから ASL を自動生成する方法を考察します。

結論

この記事で紹介したいアイデアは以下の3つです。私たちは AWS Lambda の開発言語にPython を選んでいるため、Pythonで実現できる方法を検討しました。

  1. Pydanticによる ASL のスキーマと Python コードの相互変換
  2. Pydantic を用いた JSON ペイロードの型検証
  3. 関数のAST(抽象構文木)解析による ASL 生成

各アイデアの詳細についてはこの後説明します。

Python で ASL(Amazon States Language) を自動生成する

最初に、自動生成しない場合の開発がどういうものかを説明します。

1. Lambda と Step Functions によるサーバレスバッチアプリケーション開発

開発方法

AWS SAM (Serverless Application Model) は、サーバレスアプリケーションの開発を効率化するためのフレームワークです。AWS Lambda と AWS Step Functions を用いたサーバレスなバッチ処理の設定が、CloudFormation のテンプレートを拡張した書式で簡単に行えます。

AWS Step Functions を使用することで、複雑なワークフローを ASL(Amazon States Language) というドメイン固有言語で定義できます。ASL を利用することで、Lambda を逐次的に呼び出す処理や、条件分岐、ループ構造の実装が可能となります。この際、Lambda とASL のTask状態を紐付けるため、ASL 内にはプレースホルダー(例: ${FunctionArn} )を書き、State Machine の DefinitionSubstitutions 属性に FunctionArn: !GetAtt Function.Arn を書きます。

Lambda と Step Functions 開発のつらいところ

AWS SAM によるアプリケーション開発では ASL の記述が重要ですが、ASL は JSON または YAML で記述するものであり検証やテストが難しいという課題があります。例えば以下のようなエラーによく遭遇しますが、こうしたエラーはアプリケーションをデプロイあるいは実行するまで発見されないので、修正するのに時間を要します。

  • ASL内で Lambda ARNの プレースホルダーを書き、SAMのテンプレートで実際の ARN を代入する際に指定漏れやスペルミスがある
  • Step Functions のタスク間で受け渡される JSON 構造と Lambda の入出力が一致しない

ASL をローカルでテストする方法は一応提供されているものの、Lambda と結合したテストを行うのは今のところ難しいと言えます。

2. Python から ASL を自動生成するアイデア

発想

上述した問題を解決するために、テストしにくい ASL を書く代わりにテストしやすいコードから ASL を生成できないものでしょうか。冒頭でも述べましたが、それを実現するのが以下3つのアイデアです。特に3つ目の抽象構文木を使うことが本記事を執筆するきっかけとなりました。

  1. Pydanticによる ASL のスキーマと Python コードの相互変換
  2. Pydantic を用いた JSON ペイロードの型検証
    • Pydantic Model を入出力にとる Python 関数を Lambda に変換すると、Step Functions と Lambda の間で交換されるJSON ペイロードを型検証できる
  3. 関数のAST(抽象構文木)解析による ASL 生成
    • Python の抽象構文木を解析することで、2で説明した関数呼び出し、関数間の制御構造や変数の流れを把握し、ASL に変換する

ステートマシン生成のイメージ

以下の単純なフローを考えます。

これを以下のような Python の関数から生成することを考えます。

class StateMachineInput(BaseModel):
    tenant_id: int
    process_date: date

def main(params: StateMachineInput) -> None:
    fetched_key = Fetch.run(params)
    Process.run(params, fetched_key)
  • このコードはローカルでも実行できるようにします
  • paramsは ステートマシンへの入力と対応します。StateMachineInput は Pydantic のBaseModel を拡張しており、JSON からの変換が可能です
  • FetchProcess は以下のコードでLambda 関数ハンドラーに変換され、Step Functions の Task状態に対応するものとします
handler_fetch = generate_handler(Fetch)
handler_process = generate_handler(Process)

今回考慮しないもの

今回の記事では非常に単純な State Machine の実装を目指しますが、Step Functions ではもっと複雑なワークフローを書くことができます。以下については今回の記事では取り上げません。

  • 条件分岐やループの実装
    • Choice状態を使うと if 文や matchに相当する条件分岐が実現でき、Map状態はfor文に対応するループ構造を書けるはずです
  • エラーハンドリング
    • ASL では特定のエラータイプに基づいて異なる処理を行うことが可能です
  • SQS や SNS、EventBridge によるオーケストレーション
    • Step Functions が直接 Lambda を呼び出す以外に、SQS や SNS を介して Lambda を起動し、また Step Functions に戻ってくるようなワークフローを組むこともできます
  • AWS SDKサービス統合
    • Step Functions は Lambda 以外にも様々な AWS サービスの API、例えば別の State Machine を呼び出すこともできます

3. 実装方法

以下のコードは Python3.10 と Pydantic2.3.0 で動作します。今回は実験なので Google Colaboratory 上で確認しています。

!python --version
!pip install --upgrade pydantic
!pip show pydantic

Python 3.10.12

Name: pydantic Version: 2.3.0 Summary: Data validation using Python type hints

3.1. PydanticによるASLのスキーマとPythonコードの相互変換

まず、Step Functions の ASL を表現する基本的な Pydantic モデルを作成します。Choice, Fail, Map, Pass, Succeed, Task, Wait状態に対応するモデルと State Machine 全体を表すモデルが必要です。今回はTask状態のみを考えます。以下は State MachineとTask状態の実装例です。

from typing import Iterator, Literal, Tuple

from pydantic import BaseModel, ConfigDict, Field


class TaskState(BaseModel, Generic[P]):
    Type: Literal["Task"] = Field(alias="Type", default="Task")
    Resource: str  # Lambda の ARN を書く。プレースホルダー ${Function} を挿入してデプロイ時に置換する方法がよく取られる
    # comment
    # timeoutSeconds
    # heartbeatSeconds
    ResultPath: str | None
    # parameters
    # retry
    # catch
    Next: str | None
    End: bool | None


class StateMachine(BaseModel):
    StartAt: str
    # comment
    States: dict[str, TaskState]

3.2. Pydanticを用いたJSONペイロードの型検証

Lambda 関数ハンドラーの引数および戻り値は JSON と相互変換できるように一定の決まりがあり、そのままでは型の検証が難しいです。Pydantic を使って型検証できるモデルから Lambda 関数ハンドラーを生成することを考えます。以下のコードはその基底クラスです。

from abc import abstractmethod, ABC
from collections import ChainMap

from pydantic import BaseModel

TypedLambdaResult = BaseModel | None


class LambdaHandler(ABC, BaseModel):
    @abstractmethod
    def __call__(self) -> TypedLambdaResult:
        pass

    @classmethod
    def run(cls, *args: BaseModel) -> TypedLambdaResult:
        """ 引数にとった複数のモデルを記述順にマージしてこのハンドラへ入力し実行する
        ※記述順:フィールド名が衝突する場合は、引数の記述順に後に来たもので上書きします。
        """
        # [ChainMap](https://docs.python.org/3.10/library/collections.html#collections.ChainMap)
        # は引数の先頭からキーを検索するので入力を逆順にしておく
        params = dict(ChainMap(*[arg.model_dump() for arg in reversed(args.reverse())))
        handler = cls.model_validate(params)
        return handler()

LambdaHandler のプロパティは Lambda への入力と対応し、__call__メソッドに処理内容を書きます。その戻り値もPydantic のモデルで型宣言することにします。Lambda 関数ハンドラーは本来単なる文字列やlistを入出力できるのですが、これはイベントオブジェクトと戻り値がともにdictであるという制限を加えることになります。

なお、__call__LambdaHandlerのインスタンスを関数のように使うための特別なメソッドです。また、runは無くても良いのですが、State Machine を生成するためのmain関数の見栄えを整える目的で追加しています。

LambdaHandlerの実装例は以下のようになるでしょう。

from datetime import date

from pydantic import BaseModel


class Fetch(LambdaHandler):
    tenant_id: int
    process_date: date

    class Result(BaseModel):
        key: str

    def __call__(self) -> "Fetch.Result":
        print(self.model_dump_json())
        return Result(key="test.csv")


class Process(LambdaHandler):
    tenant_id: int
    process_date: date
    key: str

    def __call__(self) -> None:
        print(self.model_dump_json())

次にこれを Lambda 関数ハンドラーに変換するための関数を作りましょう。入出力との変換は Pydantic で処理できます。

import json
from typing import Any, Callable, Type

from pydantic import BaseModel

LambdaEvent = dict[str, Any]
LambdaContext = Any
LambdaResult = dict[str, Any]


def generate_handler(handler_type: Type[LambdaHandler]) -> Callable[[LambdaEvent, LambdaContext], LambdaResult]:
    def wrapper(event: LambdaEvent, context: LambdaContext) -> LambdaResult:
        handler = handler_type.model_validate(event)
        result = handler()
        if result is None:
            return
        # Pydantic モデルを dict に変換するには通常`model_dump`メソッドを使いますが、
        # モデルのフィールドが `datetime` や `UUID` などの型を持つ場合、
        # Lambda の実行時に "Unable to marshal response" エラーがでます。
        # これを回避するために、一度 JSON を経由します。
        json_ = result.model_dump_json()
        return json.loads(json_)
    return wrapper


## 以下のhandler_fetch や handler_process を SAM テンプレートで`AWS::Serverless::Function` の`Handler`に指定
handler_fetch = generate_handler(Fetch)
handler_process = generate_handler(Process)
※Lambda関数ハンドラーを生成する別案

本記事の例では、main関数の中で処理を呼び出すためにFetch(...).run() のようにオブジェクトを作成してrunメソッドを呼ぶという規約になっていますが、直観ではオブジェクトなしに普通の関数を呼び出すように書きたいと考えていました。例えば普通の Python 関数にデコレータを付けて Lambda 関数ハンドラーに変換する方法が考えられます。試してみるとデコレータ付きの関数からはせっかく用意した型情報が消えてしまい、エディタの入力支援を受けられなくなるためmain関数を書きにくいという印象を持ちました。デコレータを付与した関数から元の関数を取り出せると良いのですが、その方法は見つかりませんでした。

3.3. 関数のAST(抽象構文木)解析によるASL生成

いよいよ、ASL の生成方法を検討します。先に示した例、

def main(params: StateMachineInput) -> None:
    fetched_key = Fetch.run(params)
    Process.run(params, fetched_key)

を解析する方法を考えます。

Task呼び出しを順番に記録する

以下のコードは、main関数から以下の情報を読み取ります。行番号は、Lambda の出現順を保存するためと、同じクラスの Lambda を複数回呼び出した場合に区別するためです。

  • calls
    • *.run を呼び出している行番号
    • FetchProcessなどの名前
    • FetchProcessに入力される Python 変数名
      • TODO: 型情報を読み取る必要がある
  • assignments
    • Fetchを呼び出している行番号
    • Fetchの出力結果を格納する Python 変数名
import ast
import inspect

def extract_tasks(function):
    # ノートブック上で定義した関数定義を文字列として得る
    # 実際にはファイルから読み取る
    source = inspect.getsource(function)
    # print(source)
    assignments = {}
    calls = {}

    class TaskVisitor(ast.NodeVisitor):
        def visit_Call(self, node):
            # print(ast.dump(node))
            method = node.func
            assert isinstance(method, ast.Attribute) and method.attr == "run"
            task = method.value.id
            calls[node.lineno] = (task, [arg.id for arg in node.args])
            # 他ノードの巡回を続ける
            self.generic_visit(node)

        def visit_Assign(self, node):
            # print(ast.dump(node))
            variable = node.targets[0].id
            assert isinstance(node.value, ast.Call)
            method = node.value.func
            assert isinstance(method, ast.Attribute) and method.attr == "run"
            task = method.value.id
            assignments[node.lineno] = (variable, task)
            # 他ノードの巡回を続ける
            self.generic_visit(node)

    visitor = TaskVisitor()
    visitor.visit(ast.parse(source))

    return calls, assignments

calls, assignments = extract_tasks(main)
print(f"{calls=}")
print(f"{assignments=}")

実行結果は以下の通りです。

calls={2: ('Fetch', ['params']), 3: ('Process', ['params', 'fetched_key'])}
assignments={2: ('fetched_key', 'Fetch')}
Task の呼び出し順を ASL に再現するコード例

先程の解析結果を使って ASL を生成してみます。

def generate(calls: dict[int, Tuple[str, list[str]]], assignments: dict[int, Tuple[str, str]]) -> StateMachine:

    def generate_states() -> Iterator[TaskState]:

        next = None
        for line_no, call in sorted(calls.items(), reverse=True):
            task_name, args = call
            result_path = assignments.get(line_no, [None])[0]
            print(result_path, task_name, args)
            yield task_name, TaskState(
                Resource = "${" + task_name + "}",
                # ResultPath のデフォルトは"$"だが、これは入力を消去して出力で置き換える意味。
                # 出力がない場合は null を指定しなければいけない
                ResultPath = "$." + result_path if result_path is not None else None,
                Next = next,
                End = next is None
            )
            next = task_name

    states = list(generate_states())
    states.reverse()
    return StateMachine(
        StartAt = states[0][0],
        States = {name: task for name, task in states}
    )

print(
    generate(calls, assignments)
    .model_dump_json(exclude_none=False, indent=4)
)

出力結果は以下のようになります。Resource は後で Lambda の ARN に置換することを想定しています。各Task状態への入力は暗黙的に決まっており、callsに含まれる引数の情報を使っていません。今回はたまたま出力するASLとコードの意味が一致しますが、もう少し複雑な StateMachine では Parameters を使ってTask状態への入力形式を変換する必要があるでしょう。また、Fetchの出力をfetched_keyというローカル変数に代入しているコードはResultPathの指定に対応させています。

{
    "StartAt": "Fetch",
    "States": {
        "Fetch": {
            "Type": "Task",
            "Resource": "${Fetch}",
            "ResultPath": "$.fetched_key",
            "Next": "Process",
            "End": false
        },
        "Process": {
            "Type": "Task",
            "Resource": "${Process}",
            "ResultPath": null,
            "Next": null,
            "End": true
        }
    }
}

4. 今後の課題

今回の記事では非常に単純な例を取り上げましたが、実用的なバッチ処理を実装するには多くの課題が残っています。

生成ロジックの課題

  1. AST 解析の拡張
    1. Lambda に入力する JSON ペイロードの操作
    2. 条件分岐やループなどの制御構造
    3. 並行処理(Parallel状態Map状態に対応)
    4. 例外処理
  2. ASL 生成ロジックのエラー検知
    • Python コードとして正しいが、ASL に変換できない場合の扱い
  3. 生成前の Python コード(main関数)をより自然に書けるようにする
  4. Lambda 以外のサービスを呼び出せるように生成ロジックの拡張

テスト方法について

main, Fetch, Processは単なる関数とクラスなので、テストコードの中でそのまま実行することでステートマシン全体の動作をシミュレートできます。mainのテスト時はFetchProcessをモックに置き換えても良いでしょう。

デプロイ方法の課題

SAM や CDK を使って Lambda とステートマシンをデプロイできます。この時に注意すべき点が2つあります。ただし筆者は CDK を使ったことがないので CDK では課題といえないかもしれません。

デプロイ課題1. Lambda の ARN と Task のひもづけ

生成した ASL に、Task状態のResourceFetchProcessから生成する Lambda の ARN を設定しなければいけません。2つの方法がありますが、今回のコードでは後者を想定しました。

  1. ASL を生成する時に ARN を決め打ちし、それに合わせて各 LambdaのFunctionName属性を決める
    • FunctionName を決めるとリソースの更新ができなくなる場合があるので注意が必要
  2. ASL にプレースホルダー変数を設定し、DefinitionSubstitution属性にて、生成された Lambda の ARN へマッピングする
    • DefinitionSubstitution を誤って設定してしまうことが多いが、デプロイしないと誤りに気づかないので開発体験が悪い。検査するプログラムを書くか、State Machine のリソース定義もコードで生成したほうが良いかもしれない
デプロイ課題2. ASL 更新の自動化

SAM は当然ながら生成したASLをデプロイするので、生成元のコードを修正した時にはASLを再生成する必要があります。これを何らかの方法で自動化しなければなりません。

おわりに

本記事で検討した方法は、AWS Lambda と Step Functions の実行モデルを Python コードでシミュレートしながら、型安全性を保持することを意図したものです。これにより複数の Lambda を組み合わせるバッチ処理の開発体験とテストしやすさが改善できることが期待できます。純粋なコードでワークフローを書けることから、Lambda では実行困難な処理をAWS Batch などに移行することも容易でしょう。ただし、このアイデアを実用的なツールとするためにはまだまだ多くの課題が残っています。

シナジーマーケティング株式会社では一緒に働く仲間を募集しています。

西尾 義英(ニシオ ヨシヒデ)
開発から離れていた時期もありましたが最近プロダクトづくりに戻ってきました。 データを活用できる製品・基盤づくりがテーマです。