Numba のコンパイルが通らなかった時の対処
Numba はいいぞ
この記事は何
ふつうの Python なら動くけど Numba では動かないようなコードを列挙して、対処法を書いたもの
主に AtCoder 目的だけどそれ以外でも役に立つはず
Numba のバージョン 0.48.0 くらいの情報なので将来的にいろいろ変わってくると思うので注意(2020 年 8 月現在で AtCoder に入ってるのも 0.48.0)
先に読んでおくといいかもしれない記事
Numba で使えないもの
2 次元以上の ndarray のイテレーション
できない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): array = np.random.rand(5, 2) # 5x2 array for a in array: # コンパイルエラー ...
エラーメッセージ
Direct iteration is not supported for arrays with dimension > 1. Try using indexing instead. [1] During: typing of intrinsic-call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7) File "untitled.py", line 7: def solve(): <source elided> array = np.random.rand(5, 2) # 5x2 array for a in array: # コンパイルエラー ^
対処
range
で回す
@numba.njit("void()", cache=True) def solve(): array = np.random.rand(5, 2) # 5x2 array for i in range(len(array)): a = array[i] ...
変な(?)方法での空のリスト作成
Numba が型を推測できないとエラーが出る
list
以外に dict
と set
でも起こる
結構ありがちでエラーメッセージもわかりにくかったりするので、とりあえず型を明示しておくのがいいかもしれない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): lst = [[] for _ in range(10)] # コンパイルエラー lst[0].append(0)
エラーメッセージ
Undecided type $26load_method.11 := <undecided> [1] During: resolving caller type: $26load_method.11 [2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7) File "untitled.py", line 7: def solve(): <source elided> lst = [[] for _ in range(10)] # コンパイルエラー lst[0].append(0) ^
対処
どうにかして型を Numba に教える
@numba.njit("void()", cache=True) def solve(): lst = [[0] * 0 for _ in range(10)] # 型を明示する lst[0].append(0)
dict
の場合はいらない要素を入れておくとか
set
の場合は {0}-{0}
とかすれば Numba くんはわかってくれる
辞書の内包表記
dict
と set
は内包表記で生成できない
あと dict
はリストからの生成とかもできない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89] inv_fib = {v: i for i, v in enumerate(fib)} # コンパイルエラー
エラーメッセージ
Use of unsupported opcode (MAP_ADD) found File "untitled.py", line 7: def solve(): <source elided> fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89] inv_fib = {v: i for i, v in enumerate(fib)} # コンパイルエラー ^
対処
ひとつずつ入れる
@numba.njit("void()", cache=True) def solve(): fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89] inv_fib = {} for i, v in enumerate(fib): inv_fib[v] = i
リストを値に取る辞書
エラーになるコード
@numba.njit("void()", cache=True) def solve(): dictionary = {3023: [0, 1, 2], 4006: [3, 4, 5]} # コンパイルエラー d = dictionary[3023]
エラーメッセージ
list(int64) as value is forbidden [1] During: typing of dict at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7) File "untitled.py", line 7: def solve(): dictionary = {3023: [0, 1, 2], 4006: [3, 4, 5]} # コンパイルエラー ^
対処
numpy.ndarray
でもいいならそれを使う
そうでないなら、値は別にリストか何かで持っておいて、辞書にはそのインデックスを入れる
@numba.njit("void()", cache=True) def solve(): dictionary = {3023: 0, 4006: 1} container = [[0, 1, 2], [3, 4, 5]] d = container[dictionary[3023]]
pow の第 3 引数
使えない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): mod = 10 ** 9 + 7 inv10 = pow(10, mod-2, mod) # コンパイルエラー
エラーメッセージ
Invalid use of Function(<built-in function pow>) with argument(s) of type(s): (Literal[int](10), int64, Literal[int](1000000007)) Known signatures: * (int64, int64) -> int64 * (int64, uint64) -> int64 * (uint64, int64) -> int64 * (uint64, uint64) -> uint64 * (float32, int32) -> float32 * (float32, int64) -> float32 * (float32, uint64) -> float32 * (float64, int32) -> float64 * (float64, int64) -> float64 * (float64, uint64) -> float64 * (float32, float32) -> float32 * (float64, float64) -> float64 * (complex64, complex64) -> complex64 * (complex128, complex128) -> complex128 In definition 0: All templates rejected with literals. In definition 1: All templates rejected without literals. This error is usually caused by passing an argument of a type that is unsupported by the named function. [1] During: resolving callee type: Function(<built-in function pow>) [2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7) File "untitled.py", line 7: def solve(): <source elided> mod = 10 ** 9 + 7 inv10 = pow(10, mod-2, mod) # コンパイルエラー ^
対処
代わりのものを作っておく
@numba.njit("i8(i8,i8,i8)", cache=True) def pow_mod(base, exp, mod): exp %= mod - 1 res = 1 while exp: if exp & 1: res = res * base % mod base = base * base % mod exp >>= 1 return res @numba.njit("void()", cache=True) def solve(): mod = 10 ** 9 + 7 inv10 = pow_mod(10, mod-2, mod)
built-in の sum 関数
max
は使えるのに sum
は何故か使えない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): a = np.random.rand(5) s = sum(a) # コンパイルエラー
エラーメッセージ
Untyped global name 'sum': cannot determine Numba type of <class 'builtin_function_or_method'> File "untitled.py", line 7: def solve(): <source elided> a = np.random.rand(5) s = sum(a) # コンパイルエラー ^
対処
numpy.sum
か numpy.ndarray.sum
を使う
リストの場合は numpy.sum
でもエラーになるのでそのときは numpy.ndarray
に変換するとかひとつずつ足すとかする
@numba.njit("void()", cache=True) def solve(): a = np.random.rand(5) s = np.sum(a) # s = a.sum() でもいい
numpy.max とか numpy.ndarray.max とかの axis
numpy.sum
は axis
が使えるのに numpy.max
は axis
が使えない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): array = np.random.rand(4, 5) m = array.max(1)
エラーメッセージ
[1] During: resolving callee type: BoundFunction(array.max for array(float64, 2d, C)) [2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7) Enable logging at debug level for details. File "untitled.py", line 7: def solve(): <source elided> array = np.random.rand(4, 5) m = array.max(1) ^
対処
for を回す
@numba.njit("void()", cache=True) def solve(): array = np.random.rand(4, 5) m = np.empty(4, dtype=array.dtype) for i in range(4): m[i] = array[i].max()
2 次元以上の ndarray の boolean indexing
できない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): array = np.random.rand(4, 5) array[array < 0.5] = 0 # コンパイルエラー
エラーメッセージ
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), Literal[int](0)) * parameterized In definition 0: All templates rejected with literals. In definition 1: All templates rejected without literals. In definition 2: All templates rejected with literals. In definition 3: All templates rejected without literals. In definition 4: All templates rejected with literals. In definition 5: All templates rejected without literals. In definition 6: All templates rejected with literals. In definition 7: All templates rejected without literals. In definition 8: TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)] raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71 In definition 9: TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)] raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71 This error is usually caused by passing an argument of a type that is unsupported by the named function. [1] During: typing of setitem at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7) File "untitled.py", line 7: def solve(): <source elided> array = np.random.rand(4, 5) array[array < 0.5] = 0 ^
対処
numpy.where
を使う
@numba.njit("void()", cache=True) def solve(): array = np.random.rand(4, 5) array = np.where(array < 0.5, 0, array)
ndarray の None による次元の追加
できない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): a = np.random.rand(4, 5) a = a[:, None, :] # コンパイルエラー assert a.shape == (4, 1, 5)
エラーメッセージ
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), Tuple(slice<a:b>, none, slice<a:b>)) * parameterized In definition 0: All templates rejected with literals. In definition 1: All templates rejected without literals. In definition 2: All templates rejected with literals. In definition 3: All templates rejected without literals. In definition 4: All templates rejected with literals. In definition 5: All templates rejected without literals. In definition 6: All templates rejected with literals. In definition 7: All templates rejected without literals. In definition 8: All templates rejected with literals. In definition 9: All templates rejected without literals. In definition 10: All templates rejected with literals. In definition 11: All templates rejected without literals. In definition 12: TypeError: unsupported array index type none in Tuple(slice<a:b>, none, slice<a:b>) raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71 In definition 13: TypeError: unsupported array index type none in Tuple(slice<a:b>, none, slice<a:b>) raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71 This error is usually caused by passing an argument of a type that is unsupported by the named function. [1] During: typing of intrinsic-call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7) [2] During: typing of static-get-item at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7) File "untitled.py", line 7: def solve(): <source elided> a = np.random.rand(4, 5) a = a[:, None, :] # コンパイルエラー ^
対処
reshape
か expand_dims
を使う
ただし expand_dims
の第 2 引数に tuple
は使えない
@numba.njit("void()", cache=True) def solve(): a = np.random.rand(4, 5) a = np.expand_dims(a, 1) assert a.shape == (4, 1, 5)
int.bit_length
使えない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): b = (998244353).bit_length()
エラーメッセージ
Unknown attribute 'bit_length' of type Literal[int](998244353) File "untitled.py", line 6: def solve(): b = (998244353).bit_length() ^ [1] During: typing of get attribute at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (6) File "untitled.py", line 6: def solve(): b = (998244353).bit_length() ^
対処
np.log2
とかでなんとかする(249-1 以上は誤差に注意、あと 0 も正しく動かない)
@numba.njit("void()", cache=True) def solve(): b = int(np.log2(998244353)) + 1 print(b)
collections
ほぼ使えない
対処
defaultdict
-> 値があるか自力で確認する
deque
-> リングバッファみたいなのを適当に実装する
Counter
-> 自力で数える
itertools
使えない
対処
itertools
を使う部分はコンパイルしないように切り分けるか、あらかじめ代わりになりそうなものを用意しておく?
@numba.jit("i8[:,:](i8[:],i8)", cache=True) def combinaions(arr, r): n = len(arr) assert 0 <= r <= n res_length = 1 for i in range(r): res_length = res_length * (n-i) // (1+i) res = np.empty((res_length, r), dtype=arr.dtype) idxs_arr = np.arange(r) for idx_res in range(res_length): res[idx_res] = arr[idxs_arr] i = 1 while idxs_arr[r-i] == n-i: i += 1 idxs_arr[r-i] += 1 for j in range(r-i+1, r): idxs_arr[j] = idxs_arr[j-1] + 1 return res
↑の実装はジェネレータじゃないので少し探索してやめるような場合には効率が悪くなってしまう
これを嫌うなら C++ の next_permutation
みたいなのを用意しておくと汎用性も高くて良さそう
string が返る関数
str
とか bin
とか format
とか "%d" % 42
とかは使えない
エラーになるコード
@numba.njit("void()", cache=True) def solve(): popcnt = bin(4047).count("1") # コンパイルエラー
エラーメッセージ
Untyped global name 'bin': cannot determine Numba type of <class 'builtin_function_or_method'> File "untitled.py", line 6: def solve(): popcnt = bin(4047).count("1") # コンパイルエラー ^
対処
Numba で文字列を扱おうとしない
popcount についてはあらかじめ用意しておく(参考: Python 3でpopcountを計算する - にせねこメモ)
@numba.njit("u8(u8)", cache=True) def popcount(n): n = (n & 0x5555555555555555) + (n>>1 & 0x5555555555555555) n = (n & 0x3333333333333333) + (n>>2 & 0x3333333333333333) n = (n & 0x0f0f0f0f0f0f0f0f) + (n>>4 & 0x0f0f0f0f0f0f0f0f) n = (n & 0x00ff00ff00ff00ff) + (n>>8 & 0x00ff00ff00ff00ff) n = (n & 0x0000ffff0000ffff) + (n>>16 & 0x0000ffff0000ffff) n = (n & 0x00000000ffffffff) + (n>>32 & 0x00000000ffffffff) return n @numba.njit("void()", cache=True) def solve(): popcnt = popcount(4047)
関数外の変数の書き換え
array = np.array([1, 2, 3]) @numba.njit("void()", cache=True) def solve(): array[0] = 4 # コンパイルエラー
エラーメッセージ
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (readonly array(int32, 1d, C), Literal[int](0), Literal[int](4)) * parameterized In definition 0: All templates rejected with literals. In definition 1: All templates rejected without literals. In definition 2: All templates rejected with literals. In definition 3: All templates rejected without literals. In definition 4: All templates rejected with literals. In definition 5: All templates rejected without literals. In definition 6: All templates rejected with literals. In definition 7: All templates rejected without literals. In definition 8: TypeError: Cannot modify value of type readonly array(int32, 1d, C) raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:179 In definition 9: TypeError: Cannot modify value of type readonly array(int32, 1d, C) raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:179 This error is usually caused by passing an argument of a type that is unsupported by the named function. [1] During: typing of staticsetitem at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (8) File "untitled.py", line 8: def solve(): array[0] = 4 # コンパイルエラー ^
対処
引数で渡す
@numba.njit("void(i4[:])", cache=True) def solve(array): array[0] = 4
標準入力
できないので諦める
関数を返す関数
できないので諦める
関数内の関数の再帰
できないので諦める
他色々
できないと思って諦める
まとめ
Numba はいいぞ