Numba のコンパイルが通らなかった時の対処

Numba はいいぞ

この記事は何

ふつうの Python なら動くけど Numba では動かないようなコードを列挙して、対処法を書いたもの
主に AtCoder 目的だけどそれ以外でも役に立つはず

Numba のバージョン 0.48.0 くらいの情報なので将来的にいろいろ変わってくると思うので注意(2020 年 8 月現在で AtCoder に入ってるのも 0.48.0)

先に読んでおくといいかもしれない記事

qiita.com

ikatakos.com

Numba で使えないもの

2 次元以上の ndarray のイテレーション

できない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    array = np.random.rand(5, 2)  # 5x2 array
    for a in array:  # コンパイルエラー
        ...
エラーメッセージ
Direct iteration is not supported for arrays with dimension > 1. Try using indexing instead.
[1] During: typing of intrinsic-call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)

File "untitled.py", line 7:
def solve():
    <source elided>
    array = np.random.rand(5, 2)  # 5x2 array
    for a in array:  # コンパイルエラー
    ^
対処

range で回す

@numba.njit("void()", cache=True)
def solve():
    array = np.random.rand(5, 2)  # 5x2 array
    for i in range(len(array)):
        a = array[i]
        ...

変な(?)方法での空のリスト作成

Numba が型を推測できないとエラーが出る
list 以外に dictset でも起こる
結構ありがちでエラーメッセージもわかりにくかったりするので、とりあえず型を明示しておくのがいいかもしれない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    lst = [[] for _ in range(10)]  # コンパイルエラー
    lst[0].append(0)
エラーメッセージ
Undecided type $26load_method.11 := <undecided>
[1] During: resolving caller type: $26load_method.11
[2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)


File "untitled.py", line 7:
def solve():
    <source elided>
    lst = [[] for _ in range(10)]  # コンパイルエラー
    lst[0].append(0)
    ^
対処

どうにかして型を Numba に教える

@numba.njit("void()", cache=True)
def solve():
    lst = [[0] * 0 for _ in range(10)]  # 型を明示する
    lst[0].append(0)

dict の場合はいらない要素を入れておくとか
set の場合は {0}-{0} とかすれば Numba くんはわかってくれる

辞書の内包表記

dictset は内包表記で生成できない
あと dict はリストからの生成とかもできない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]
    inv_fib = {v: i for i, v in enumerate(fib)}  # コンパイルエラー
エラーメッセージ
Use of unsupported opcode (MAP_ADD) found

File "untitled.py", line 7:
def solve():
    <source elided>
    fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]
    inv_fib = {v: i for i, v in enumerate(fib)}  # コンパイルエラー
    ^
対処

ひとつずつ入れる

@numba.njit("void()", cache=True)
def solve():
    fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]
    inv_fib = {}
    for i, v in enumerate(fib):
        inv_fib[v] = i

リストを値に取る辞書

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    dictionary = {3023: [0, 1, 2], 4006: [3, 4, 5]}  # コンパイルエラー
    d = dictionary[3023]
エラーメッセージ
list(int64) as value is forbidden
[1] During: typing of dict at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)

File "untitled.py", line 7:
def solve():
    dictionary = {3023: [0, 1, 2], 4006: [3, 4, 5]}  # コンパイルエラー
    ^
対処

numpy.ndarray でもいいならそれを使う
そうでないなら、値は別にリストか何かで持っておいて、辞書にはそのインデックスを入れる

@numba.njit("void()", cache=True)
def solve():
    dictionary = {3023: 0, 4006: 1}
    container = [[0, 1, 2], [3, 4, 5]]
    d = container[dictionary[3023]]

pow の第 3 引数

使えない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    mod = 10 ** 9 + 7
    inv10 = pow(10, mod-2, mod)  # コンパイルエラー
エラーメッセージ
Invalid use of Function(<built-in function pow>) with argument(s) of type(s): (Literal[int](10), int64, Literal[int](1000000007))
Known signatures:
 * (int64, int64) -> int64
 * (int64, uint64) -> int64
 * (uint64, int64) -> int64
 * (uint64, uint64) -> uint64
 * (float32, int32) -> float32
 * (float32, int64) -> float32
 * (float32, uint64) -> float32
 * (float64, int32) -> float64
 * (float64, int64) -> float64
 * (float64, uint64) -> float64
 * (float32, float32) -> float32
 * (float64, float64) -> float64
 * (complex64, complex64) -> complex64
 * (complex128, complex128) -> complex128
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function pow>)
[2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)


File "untitled.py", line 7:
def solve():
    <source elided>
    mod = 10 ** 9 + 7
    inv10 = pow(10, mod-2, mod)  # コンパイルエラー
    ^
対処

代わりのものを作っておく

@numba.njit("i8(i8,i8,i8)", cache=True)
def pow_mod(base, exp, mod):
    exp %= mod - 1
    res = 1
    while exp:
        if exp & 1:
            res = res * base % mod
        base = base * base % mod
        exp >>= 1
    return res

@numba.njit("void()", cache=True)
def solve():
    mod = 10 ** 9 + 7
    inv10 = pow_mod(10, mod-2, mod)

built-in の sum 関数

max は使えるのに sum は何故か使えない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    a = np.random.rand(5)
    s = sum(a)  # コンパイルエラー
エラーメッセージ
Untyped global name 'sum': cannot determine Numba type of <class 'builtin_function_or_method'>

File "untitled.py", line 7:
def solve():
    <source elided>
    a = np.random.rand(5)
    s = sum(a)  # コンパイルエラー
    ^
対処

numpy.sumnumpy.ndarray.sum を使う
リストの場合は numpy.sum でもエラーになるのでそのときは numpy.ndarray に変換するとかひとつずつ足すとかする

@numba.njit("void()", cache=True)
def solve():
    a = np.random.rand(5)
    s = np.sum(a)  # s = a.sum() でもいい

numpy.max とか numpy.ndarray.max とかの axis

numpy.sumaxis が使えるのに numpy.maxaxis が使えない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    array = np.random.rand(4, 5)
    m = array.max(1)
エラーメッセージ
[1] During: resolving callee type: BoundFunction(array.max for array(float64, 2d, C))
[2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)

Enable logging at debug level for details.

File "untitled.py", line 7:
def solve():
    <source elided>
    array = np.random.rand(4, 5)
    m = array.max(1)
    ^
対処

for を回す

@numba.njit("void()", cache=True)
def solve():
    array = np.random.rand(4, 5)
    m = np.empty(4, dtype=array.dtype)
    for i in range(4):
        m[i] = array[i].max()

2 次元以上の ndarray の boolean indexing

できない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    array = np.random.rand(4, 5)
    array[array < 0.5] = 0  # コンパイルエラー
エラーメッセージ
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), Literal[int](0))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
    raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 9:
    TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
    raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of setitem at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)

File "untitled.py", line 7:
def solve():
    <source elided>
    array = np.random.rand(4, 5)
    array[array < 0.5] = 0
    ^
対処

numpy.where を使う

@numba.njit("void()", cache=True)
def solve():
    array = np.random.rand(4, 5)
    array = np.where(array < 0.5, 0, array)

ndarray の None による次元の追加

できない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    a = np.random.rand(4, 5)
    a = a[:, None, :]  # コンパイルエラー
    assert a.shape == (4, 1, 5)
エラーメッセージ
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), Tuple(slice<a:b>, none, slice<a:b>))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
In definition 10:
    All templates rejected with literals.
In definition 11:
    All templates rejected without literals.
In definition 12:
    TypeError: unsupported array index type none in Tuple(slice<a:b>, none, slice<a:b>)
    raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 13:
    TypeError: unsupported array index type none in Tuple(slice<a:b>, none, slice<a:b>)
    raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
[2] During: typing of static-get-item at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)

File "untitled.py", line 7:
def solve():
    <source elided>
    a = np.random.rand(4, 5)
    a = a[:, None, :]  # コンパイルエラー
    ^
対処

reshapeexpand_dims を使う
ただし expand_dims の第 2 引数に tuple は使えない

@numba.njit("void()", cache=True)
def solve():
    a = np.random.rand(4, 5)
    a = np.expand_dims(a, 1)
    assert a.shape == (4, 1, 5)

int.bit_length

使えない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    b = (998244353).bit_length()
エラーメッセージ
Unknown attribute 'bit_length' of type Literal[int](998244353)

File "untitled.py", line 6:
def solve():
    b = (998244353).bit_length()
    ^

[1] During: typing of get attribute at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (6)

File "untitled.py", line 6:
def solve():
    b = (998244353).bit_length()
    ^
対処

np.log2 とかでなんとかする(249-1 以上は誤差に注意、あと 0 も正しく動かない)

@numba.njit("void()", cache=True)
def solve():
    b = int(np.log2(998244353)) + 1
    print(b)

collections

ほぼ使えない

対処

defaultdict -> 値があるか自力で確認する
deque -> リングバッファみたいなのを適当に実装する
Counter -> 自力で数える

itertools

使えない

対処

itertools を使う部分はコンパイルしないように切り分けるか、あらかじめ代わりになりそうなものを用意しておく?

@numba.jit("i8[:,:](i8[:],i8)", cache=True)
def combinaions(arr, r):
    n = len(arr)
    assert 0 <= r <= n
    res_length = 1
    for i in range(r):
        res_length = res_length * (n-i) // (1+i)
    res = np.empty((res_length, r), dtype=arr.dtype)
    idxs_arr = np.arange(r)
    for idx_res in range(res_length):
        res[idx_res] = arr[idxs_arr]
        i = 1
        while idxs_arr[r-i] == n-i:
            i += 1
        idxs_arr[r-i] += 1
        for j in range(r-i+1, r):
            idxs_arr[j] = idxs_arr[j-1] + 1
    return res

↑の実装はジェネレータじゃないので少し探索してやめるような場合には効率が悪くなってしまう
これを嫌うなら C++next_permutation みたいなのを用意しておくと汎用性も高くて良さそう

string が返る関数

str とか bin とか format とか "%d" % 42 とかは使えない

エラーになるコード
@numba.njit("void()", cache=True)
def solve():
    popcnt = bin(4047).count("1")  # コンパイルエラー
エラーメッセージ
Untyped global name 'bin': cannot determine Numba type of <class 'builtin_function_or_method'>

File "untitled.py", line 6:
def solve():
    popcnt = bin(4047).count("1")  # コンパイルエラー
    ^
対処

Numba で文字列を扱おうとしない

popcount についてはあらかじめ用意しておく(参考: Python 3でpopcountを計算する - にせねこメモ

@numba.njit("u8(u8)", cache=True)
def popcount(n):
    n = (n & 0x5555555555555555) + (n>>1 & 0x5555555555555555)
    n = (n & 0x3333333333333333) + (n>>2 & 0x3333333333333333)
    n = (n & 0x0f0f0f0f0f0f0f0f) + (n>>4 & 0x0f0f0f0f0f0f0f0f)
    n = (n & 0x00ff00ff00ff00ff) + (n>>8 & 0x00ff00ff00ff00ff)
    n = (n & 0x0000ffff0000ffff) + (n>>16 & 0x0000ffff0000ffff)
    n = (n & 0x00000000ffffffff) + (n>>32 & 0x00000000ffffffff)
    return n

@numba.njit("void()", cache=True)
def solve():
    popcnt = popcount(4047)

関数外の変数の書き換え

array = np.array([1, 2, 3])

@numba.njit("void()", cache=True)
def solve():
    array[0] = 4  # コンパイルエラー
エラーメッセージ
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (readonly array(int32, 1d, C), Literal[int](0), Literal[int](4))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    TypeError: Cannot modify value of type readonly array(int32, 1d, C)
    raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:179
In definition 9:
    TypeError: Cannot modify value of type readonly array(int32, 1d, C)
    raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:179
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (8)

File "untitled.py", line 8:
def solve():
    array[0] = 4  # コンパイルエラー
    ^
対処

引数で渡す

@numba.njit("void(i4[:])", cache=True)
def solve(array):
    array[0] = 4

標準入力

できないので諦める

関数を返す関数

できないので諦める

関数内の関数の再帰

できないので諦める

他色々

できないと思って諦める

まとめ

Numba はいいぞ