Streamlit で multiprocessing を使う際の注意点

2024年1月7日 engineering

あけましておめでとうございます。 @kz_morita です。

年末年始休暇に入り面白そうだと思っていた Streamlit で遊んでいました。 その中でコストの高い計算をする際に並列処理を行おうと思い、 multiprocessing で実装を行っていたのですが色々とはまりどころがあったため、まとめます。

やりたいこと

やりたいことは以下のようなコードになります。

(以下は動かないコードです)

import streamlit as st
import time
from multiprocessing import Pool 

# 2倍を計算
def double(x):
    print(f"x = {x}")
    time.sleep(0.01)
    return x * 2

# multiprocess で動かすための関数
def calcurate(count):
    with Pool() as p:
        result = p.map(double, range(count))

    return result

st.write("Hello world")
COUNT = 10000
with st.spinner("Calcurating..."):
    result = calcurate(COUNT)

st.write(result)

streamlit 上で 10000 回計算を並列で回して結果を取得しています。 ローディングを出したい為、spinner なども表示していますが基本的にはシンプルなコードです。

上記のコードはエラーがでて動かないのですがこちらをベースにしてハマった点を書きます。

__main__ 内で実行しないと動かない

上記を実行すると以下のようなエラーが発生します。

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

これは __main__ 内で実行することで解消できます。

import streamlit as st
import time
from multiprocessing import Pool

def double(x):
    print(f"x = {x}")
    time.sleep(0.01)
    return x * 2

def calcurate(count):
    with Pool() as p:
        result = p.map(double, range(count))

    return result

def main():
    # main 関数の中へ
    st.write("Hello world")
    COUNT = 10000
    with st.spinner("Calcurating..."):
        result = calcurate(COUNT)

    st.write(result)

if __name__ == '__main__':
    main()

これで実際に動かすことができます。

結果を dataclass で返したい

ので、以下のようにコードを変更しました。

(以下エラーが出るコードです。)

import streamlit as st
import time
from multiprocessing import Pool
from dataclasses import dataclass

@dataclass(frozen=True)
class Result:
    original: int
    calcurated: int

def double_return_dataclass(x):
    print(f"x = {x}")
    time.sleep(0.01)
    return Result(x, 2 * x)

def calcurate(count):
    with Pool() as p:
        result = p.map(double_return_dataclass, range(count))

    return result

def main():
    st.write("Hello world")
    COUNT = 10000
    with st.spinner("Calcurating..."):
        result = calcurate(COUNT)

    st.write(result)

if __name__ == '__main__':
    main()

Result という dataclass を定義して double_return_dataclass というメソッドに変更しています。

実行してみると以下のようなエラーがでます。

Exception in thread Thread-27 (_handle_results):
Traceback (most recent call last):
  File "/Users/{username}/.pyenv/versions/3.10.13/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/Users/{username}/.pyenv/versions/3.10.13/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/{username}/.pyenv/versions/3.10.13/lib/python3.10/multiprocessing/pool.py", line 579, in _handle_results
    task = get()
  File "/Users/{username}/.pyenv/versions/3.10.13/lib/python3.10/multiprocessing/connection.py", line 251, in recv
    return _ForkingPickler.loads(buf.getbuffer())
AttributeError: Can't get attribute 'Result' on <module '__main__' from '/Users/{username}/Library/Caches/pypoetry/virtualenvs/data-usage-monitor-xA2XThov-py3.10/bin/streamlit'>

エラーをみる限りだと、Result dataclass にアクセスできなくてエラーになっていそうです。

このエラーの治し方としては、2 通り確認できました。

dataclass を別ファイルとして切り出す

spawn した別プロセスから、Result にアクセスできていないのが問題っぽいのですが、app.py は __main__ で実行されないのか別ファイルとして dataclass を切り出すをうまく動きました。

メインの処理が、app.py で新たに result.py を作成して import するようにしています。

app.py

import streamlit as st
import time
from multiprocessing import Pool

from result import Result

def double_return_dataclass(x):
    print(f"x = {x}")
    time.sleep(0.01)
    return Result(x, 2 * x)

def calcurate(count):
    with Pool() as p:
        result = p.map(double_return_dataclass, range(count))
        p.close()
        p.join()

    return result

def main():
    st.write("Hello world")
    COUNT = 10000
    with st.spinner("Calcurating..."):
        result = calcurate(COUNT)

    st.write(result)

if __name__ == '__main__':
    main()

result.py

from dataclasses import dataclass

@dataclass(frozen=True)
class Result:
    original: int
    calcurated: int

ちなみに、streamlit でマルチページの実装をしていると、pages/ ディレクトリ以下に python ファイルをおくことになるかと思いますが、親階層の python ファイルを読み込むにも多少 Tips がありました。

例えば以下のような階層構造を想定しています。

- pages/
  -- page.py
- data/
  -- result.py
- app.py

page.py から result.py を読むためには以下のように sys.path に page.py の親階層を追加する必要があります。

page.py

sys.path.append(str(Path(__file__).resolve().parent.parent))
from data.result import Result

multiprocessing を fork で動かす

他の方法としては、fork で動かすことでも動作確認できました。 set_start_method で fork を指定することができます。

import streamlit as st
import time
from multiprocessing import Pool, set_start_method

from dataclasses import dataclass

@dataclass(frozen=True)
class Result:
    original: int
    calcurated: int

def double_return_dataclass(x):
    print(f"x = {x}")
    time.sleep(0.01)
    return Result(x, 2 * x)

def calcurate(count):
    set_start_method("fork", force=True) # fork にかえる
    with Pool() as p:
        result = p.map(double_return_dataclass, range(count))
        p.close()
        p.join()

    return result

def main():
    st.write("Hello world")
    COUNT = 10000
    with st.spinner("Calcurating..."):
        result = calcurate(COUNT)

    st.write(result)

if __name__ == '__main__':
    main()

spawn でできなかったのは、子プロセスに dataclass の情報を渡せなかったからで、fork にして親プロセスをコピーすることで dataclass へ参照できたということだと理解してます。

まとめ

今回は、streamlit 上で multiprocessing モジュールを使用した時につまづいた点についてまとめました。 streamlit 上というより、multiprocessing モジュール、もっと言えば 並列処理に対しての知識が足りなくつまづいたりしました。 が、重たい処理をする上では避けては通れないので今回良い勉強になりました。

また、streamlit 非常に簡単に UI を作れてとても便利なのでどんどん活用していきたいです。

この記事をシェア