PyO3 で rust-decimal を C API 化する

Python
スポンサーリンク

あけましておめでとうございます。キリンです。今年もよろしくお願いいたします。

作ったもの

rust_decimal::Decimal を Python ライブラリ化

レポジトリ上の pyo3_decimalrust_decimal::Decimal の Python 上の ラッパーライブラリで、Python pyo3_decimal.Decimal というクラスを作成しています。

pyo3_decimal_apirust-numpy のような仕組みのヘッダーファイルのようなライブラリです。

実際の外部からの pyo3_decimal.Decimal の使用例が pyo3_decimal_user になります。

なぜ作ったか

本来であれば Python の decimal.Decimal を Rust 上から使いたいのですが、Python 3.10 で導入予定だった C-API が revert されてしまっています。3.11 でも導入がなされず、今後の API の導入も見通せないので、仕方なく Decimal 系のライブラリを自分で用意することにしました。その過程を少し紹介します。

PyO3 をそのまま使うと、そのライブラリで作成したクラスしか Python と受け渡しができないので、カプセル化を用いて numpy C API のように受け渡しができるようになることを目標とします。

参考リンク

導入予定だった Decimal C-API

Revert された プルリク

numpy のような C API 作成

このあたりは numpy 自体の実装と、revert されてしまった Decimal の実装、 rust-numpy のソースコードを読み込みました。

Python の Native 系のオブジェクトの API は Python 自体に組み込まれているため、 API が用意されていなければ PyO3 のような C Extention 側で直接受け取る術がありません。もちろん、無理やりメモリを参照して受け取るような荒業ができないことはないのですが、メンテナンス性を考えてもやるべきではないでしょう。最初はそれで実装しようとしていたのは内緒….

Python 自体をカスタマイズして、decimal.Decimal の型を C Extention 側で受け取れるような API を作成し、自前でビルドして使うことはできます。このあたりは revert されてしまった _decimal.c の実装例が参考になります。

ただ、そこまでするならば、もう自前で rust-decimal をラップしてしまえというのが私の結論でした。

numpy も Native 系と同様の方法でカプセル化を用いて C API を提供しており、その方法を踏襲したいと思います。

カプセル化

C API の作成にあたり、勉強したのがこのカプセル化です。

ここでのカプセル化とはオブジェクト指向のカプセル化とは意味が異なり、C 側で確保したメモリを他の外部ライブラリから参照できるようにすることです。メモリを Python のカプセルで包むイメージのようです。具体的には、Python のオブジェクトを介さずに C の構造体の受け渡しを可能にしたり、関数のポインタを直接渡しています。Python の C-API はこのようにして作られています。

たとえば、 int クラスや float クラスのような Python の Native のクラスはクラスの定義である C 側で確保されているそれぞれの PyTypeObject を API を介して操作できるようにしています。これで、 C-API で渡されたときの Python のオブジェクトが intfloat であることを確認できるようにしています。

それらのカプセル化されたメモリは Python のモジュールシステムの一部に組み込まれます。カプセルの使用者はそれらのモジュールをインポートすることで使うことができます。

revert された _decimal.c と pydecimal.h の例

#define PyDec_CheckExact(v) Py_IS_TYPE(v, &PyDec_Type)

上の例ではこの PyDec_Type というのが Python 側で確保されている Decimal クラスで使われている PyTypeObject で、このポインタを Py_IS_TYPE という API に渡せば、型が PyDec_Type であることを確認できます。これを使うような API が公開されることになります。

static PyObject *
init_api(void)
{
    /* Simple API */
    _decimal_api[PyDec_TypeCheck_INDEX] = (void *)PyDec_TypeCheck;
    ...
    return PyCapsule_New(_decimal_api, "_decimal._API", NULL);
}

上の例では、カプセル化を経て PyDec_TypeCheck を外部に API として公開しています。これで、decimal.Decimal の型チェックが行えるようになります。API は _decimal._API というモジュール上の場所に格納されます。

このように、カプセル化を使って PyTypeObject を直接的に確認する手段を提供しています。

static int
import_decimal(void)
{
    _decimal_api = (void **)PyCapsule_Import("_decimal._API", 0);
    if (_decimal_api == NULL) {
        return -1;
    }

    return 0;
}

使用者側の pydecimal.h では PyCapsule_Import を使って、カプセル化した C API 一式を取得していました。

rust-decimal の API 化

今回は rust_decimal::Decimal を Python に直接受け渡しすることが目的なので、rust_decimal::Decimal をラップした PyO3 クラスpyo3_decimal.Decimal を作成し、API化してみます。

Python 用のライブラリの作成

use rust_decimal::Decimal;
...
static mut PYO3_DECIMAL_CAPI: *const PyO3Decimal_CAPI = std::ptr::null();
...
#[pyclass(module = "pyo3_decimal", name = "Decimal")]
#[repr(C)]
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
struct PyDecimal(Decimal);
...
/// This module is a python module implemented in Rust.
#[pymodule]
fn pyo3_decimal(py: Python, m: &PyModule) -> PyResult<()> {
    let decimal_type = PyDecimal::type_object_raw(py);
    let mut _decimal_api = PyO3Decimal_CAPI {
        DecimalType: decimal_type,
    };
    let mut _decimal_api = Box::new(_decimal_api);
    unsafe {
        // leak the value, so it will never be dropped or freed
        PYO3_DECIMAL_CAPI = Box::leak(_decimal_api) as *const PyO3Decimal_CAPI;
    }
    unsafe {
        let cap_ptr = PyCapsule_New(
            PYO3_DECIMAL_CAPI as *mut c_void,
            (*PYO3_CAPSULE_API_NAME).as_ptr(),
            None,
        );
        let capsule: &PyCapsule = py.from_owned_ptr_or_err(cap_ptr)?;
        m.add("_API", capsule)?;
    }

    m.add_class::<PyDecimal>()?;

    Ok(())
}

これに関しては PyO3 の参考になる例が見つからなかったので、試行錯誤しながら実装しています。最低限の型の確認する手段と型で受け取る手段だけ欲しかったので、 rust_decimal::Decimal のラッパーである DecimalTypePyTypeObject を外部に公開することにしました。

PyCapsule_NewDecimalType をカプセル化することになるのですが、カプセル化する中身を Rust がメモリ開放してしまわないように Box::leak しています。

この rust_decimal::Decimal のラッパークラスを Python 上では pyo3_decimal.Decimal クラスとして使えるようにしています。

他ライブラリでの rust_decimal::Decimal の受取

あとは、他の PyO3 ライブラリ上で、 pyo3_decimal.DecimalPyTypeObject を適切に使えば良いだけです。

カプセル化した API の受取

unsafe fn PyDecimal_IMPORT() {
    let py_decimal_c_api = {
        let module = CString::new("pyo3_decimal").unwrap();
        let capsule = CString::new("_API").unwrap();
        unsafe {
            let module = ffi::PyImport_ImportModule(module.as_ptr());
            assert!(!module.is_null(), "Failed to import pyo3_decimal module");
            let capsule = ffi::PyObject_GetAttrString(module as _, capsule.as_ptr());
            assert!(
                !capsule.is_null(),
                "Failed to get pyo3_decimal.Decimal API capsule"
            );
            ffi::PyCapsule_GetPointer(capsule, PYO3_CAPSULE_API_NAME.as_ptr())
                as *mut PyO3Decimal_CAPI
        }
    };
    *PyDecimalAPI_impl.0.get() = py_decimal_c_api;
}

上記で紹介した PyCapsule_Import を使った方法では今回の pyo3_decimal の実装ではなぜか動かなかったため、rust-numpy の実装を参考に PyImport_ImportModulePyObject_GetAttrString を使う例を採用しています。

受取用の PyDecimal の実装

...
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct PyDecimal(pub Decimal);

unsafe impl pyo3::type_object::PyTypeInfo for PyDecimal {
    type AsRefTarget = pyo3::PyCell<Self>;
    const NAME: &'static str = "Decimal";
    const MODULE: ::std::option::Option<&'static str> = Some("pyo3_decimal");
    #[inline]
    fn type_object_raw(py: pyo3::Python<'_>) -> *mut pyo3::ffi::PyTypeObject {
        ensure_decimal_api(py).DecimalType
    }
}
...

関係のない pyo3 上のマクロのパースによって得られた結果を説明しているととても大変なので、具体的説明は割愛します。

実際に行ったことは、 cargo expand を使って #[pyclass(module = "pyo3_decimal", name = "Decimal")] のマクロのパースした結果を pyo3-decimal-api 側に貼り付けて、type_object_raw の関数の中身を API で公開した DecimalType を渡すようにするだけです。

この type_object_raw が適切に渡されていれば、この pyo3_decimal::PyDecimal は Python 上で確保した pyo3_decimal.Decimal として今後も内部的に処理されます。

use pyo3;
use pyo3::prelude::*;
use pyo3_decimal_api::PyDecimal;
use rust_decimal::Decimal;

#[pyfunction]
/// Formats the sum of two numbers as string
fn decimal_test(a: PyDecimal) -> PyResult<PyDecimal> {
    Ok(a)
}

#[pyfunction]
/// Formats the sum of two numbers as string
fn cast_decimal(a: &mut PyDecimal) -> PyResult<PyDecimal> {
    a.0 = a.0 + Decimal::new(1, 0);
    Ok(a.0.into())
}

/// This module is a python module implemented in Rust.
#[pymodule]
fn rust_binding(py: Python, m: &PyModule) -> PyResult<()> {
    m.add_wrapped(wrap_pyfunction!(decimal_test))?;

    Ok(())
}

ここまでくれば、使うのは簡単です。受取も受け渡しも問題なく動いていることを確認しました。

あとがき

コツとしては type_object_raw だけ カプセル化 API を使ってしっかり定義してやればいい感じに動いてくれます。慣れてしまえばかなり簡単に思えます。

rust_decimal::Decimal の構造体のメモリだけしっかり受け渡しされていれば、問題なく動いてくれるはずです。rust_decimal::Decimal バージョンの変更で構造体の構成が変わってしまった場合は問題ですが、その問題を避ける方法は難しくありません。

pyo3 的には pyo3 で確保したクラスを他の pyo3 のライブラリで使うことはできないと言っているのをフォーラムで見かけたので、今回の方法も推奨されるものではないかもしれません。

個人的に知って良かった点をまとめたつもりです。内容が結構複雑なので、ここですべてを説明しきるのは難しいですね。わかる人はコードを見てもらったほうが良いかもしれません。

個人的な近況

ブログも久しいですが、生きています。持病の腰痛が悪化し、なかなか厳しい日々を過ごしています。健康が第一と思わされますね。ただ、どうやら生まれつきのものだったようで、どう頑張っても避けられなかった事態なので受け入れるしかないです。最終手段の腰椎固定術を 11 月に受け、回復を祈るばかりなのですが、自分が想定したような回復計画とはいかず、2ヶ月立ってもまともに動けない、長時間座れないなどの状態です。術後の回復療養期間は3ヶ月なので、その時までに状態が回復していればなぁと期待するしかありません。

コメント

タイトルとURLをコピーしました