PythonでAVL木を競プロ用に実装した(誰か高速化してくれ)

(19/11/22 追記)

このAVL木どうも遅いっぽい

順序付き集合はPythonだと平方分割が速いことが熨斗袋さんによって判明したのでそれを使うといいと思う

熨斗袋さんによる実装
Submission #7482671 - AtCoder Beginner Contest 140

自分の実装(SkipListベース)
Submission #7488620 - AtCoder Regular Contest 033


Pythonで競プロをやるとAtCoderのレートが1970上がることが知られている(※個人の感想であり、効果・効能には個人差がある)が、Pythonには平衡二分探索木がなくて悲しい。なのでAVL木を実装した。

writerの助言もあり、なんとかPyPyでCPSCO2019 Session1 E - Exclusive OR Queriesを通せたが、かなり制限時間ぎりぎり。自分は諦めたので誰かこれをもっと高速にしてほしい(他力本願寺建立)

atcoder.jp

コピペだらけのコードだけど実行速度優先だから多少はね?

class AvlTree:  # std::set
    def __init__(self, values=None, sorted_=False, n=0):
        # values: 初期値のリスト
        # sorted_: 初期値がソート済みであるか
        # n: add メソッドを使う回数の最大値

        # sorted_==True であれば、初期値の数の線形時間で木を構築する
        # 値を追加するときは必ず n を設定する
        if values is None:
            self.left = [-1] * (n + 1)
            self.right = [-1] * (n + 1)
            self.values = [-float("inf")]
            self.diff = [0] * (n + 1)  # left - right
            self.size_l = [0] * (n + 1)
            self.idx_new_val = 0
        else:
            if not sorted_:
                values.sort()
            len_ = self.idx_new_val = len(values)
            n += len_
            self_left = self.left = [-1] * (n + 1)
            self_right = self.right = [-1] * (n + 1)
            self_values = self.values = [-float("inf")] + values
            self_diff = self.diff = [0] * (n + 1)  # left - right
            self_size_l = self.size_l = [0] * (n + 1)

            st = [[1, len_ + 1, 0]]
            while len(st) > 0:  # dfs っぽく木を構築
                l, r, idx_par = st.pop()  # 半開区間
                c = (l + r) >> 1  # python -> //2  pypy -> >>1
                if self_values[c] < self_values[idx_par]:
                    self_left[idx_par] = c
                else:
                    self_right[idx_par] = c
                siz = r - l
                if siz & -siz == siz != 1:  # 2 冪だったら
                    self_diff[c] = 1
                self_size_l[c] = siz_l = c - l
                if siz_l > 0:
                    st.append([l, c, c])
                    c1 = c + 1
                    if c1 < r:  # 左にノードがなければ右には必ず無いので
                        st.append([c1, r, c])

    def rotate_right(self, idx_par, lr):  # lr: 親の左なら 0
        self_left = self.left
        self_right = self.right
        self_diff = self.diff
        self_size_l = self.size_l

        lr_container = self_right if lr else self_left
        idx = lr_container[idx_par]
        #assert self_diff[idx] == 2
        idx_l = self_left[idx]
        diff_l = self_diff[idx_l]

        if diff_l == -1:  # 複回転
            idx_lr = self_right[idx_l]
            diff_lr = self_diff[idx_lr]
            if diff_lr == 0:
                self_diff[idx] = 0
                self_diff[idx_l] = 0
            elif diff_lr == 1:
                self_diff[idx] = -1
                self_diff[idx_l] = 0
                self_diff[idx_lr] = 0
            else:  # diff_lr == -1
                self_diff[idx] = 0
                self_diff[idx_l] = 1
                self_diff[idx_lr] = 0

            # 部分木の大きさの計算
            self_size_l[idx_lr] += self_size_l[idx_l] + 1
            self_size_l[idx] -= self_size_l[idx_lr] + 1

            # 回転
            self_right[idx_l] = self_left[idx_lr]
            self_left[idx] = self_right[idx_lr]
            self_left[idx_lr] = idx_l
            self_right[idx_lr] = idx
            lr_container[idx_par] = idx_lr

            return 0

        else:  # 単回転
            if diff_l == 0:
                self_diff[idx] = 1
                nb = self_diff[idx_l] = -1
            else:  # diff_l == 1
                self_diff[idx] = 0
                nb = self_diff[idx_l] = 0

            # 部分木の大きさの計算
            self_size_l[idx] -= self_size_l[idx_l] + 1

            # 回転
            self_left[idx] = self_right[idx_l]
            self_right[idx_l] = idx
            lr_container[idx_par] = idx_l

            return nb  # 新しい根の diff を返す

    def rotate_left(self, idx_par, lr):  # lr: 親の左なら 0
        self_left = self.left
        self_right = self.right
        self_diff = self.diff
        self_size_l = self.size_l

        lr_container = self_right if lr else self_left
        idx = lr_container[idx_par]
        #assert self_diff[idx] == -2
        idx_r = self_right[idx]
        diff_l = self_diff[idx_r]

        if diff_l == 1:  # 複回転
            idx_rl = self_left[idx_r]
            diff_rl = self_diff[idx_rl]
            if diff_rl == 0:
                self_diff[idx] = 0
                self_diff[idx_r] = 0
            elif diff_rl == -1:
                self_diff[idx] = 1
                self_diff[idx_r] = 0
                self_diff[idx_rl] = 0
            else:  # diff_lr == 1
                self_diff[idx] = 0
                self_diff[idx_r] = -1
                self_diff[idx_rl] = 0

            # 部分木の大きさの計算
            self_size_l[idx_r] -= self_size_l[idx_rl] + 1
            self_size_l[idx_rl] += self_size_l[idx] + 1

            # 回転
            self_left[idx_r] = self_right[idx_rl]
            self_right[idx] = self_left[idx_rl]
            self_right[idx_rl] = idx_r
            self_left[idx_rl] = idx
            lr_container[idx_par] = idx_rl

            return 0

        else:  # 単回転
            if diff_l == 0:
                self_diff[idx] = -1
                nb = self_diff[idx_r] = 1
            else:  # diff_l == 1
                self_diff[idx] = 0
                nb = self_diff[idx_r] = 0

            # 部分木の大きさの計算
            self_size_l[idx_r] += self_size_l[idx] + 1

            # 回転
            self_right[idx] = self_left[idx_r]
            self_left[idx_r] = idx
            lr_container[idx_par] = idx_r

            return nb  # 新しい根の diff を返す

    def add(self, x):  # insert
        # x を加える
        # x が既に入ってる場合は False を、
        # そうでなければ True を返す

        idx = 0
        path = []
        path_left = []

        self_values = self.values
        self_left = self.left
        self_right = self.right

        while idx != -1:
            path.append(idx)
            value = self_values[idx]
            if x < value:
                path_left.append(idx)  # 重複を許さないので処理を後にする必要がある
                idx = self_left[idx]
            elif value < x:
                idx = self_right[idx]
            else:  # x == value
                return False  # 重複を許さない

        self.idx_new_val += 1
        self_diff = self.diff
        self_size_l = self.size_l

        idx = path[-1]
        if x < value:
            self_left[idx] = self.idx_new_val
        else:
            self_right[idx] = self.idx_new_val

        self_values.append(x)

        for idx_ in path_left:
            self_size_l[idx_] += 1

        self_diff[idx] += 1 if x < value else -1
        for idx_par in path[-2::-1]:
            diff = self_diff[idx]
            if diff == 0:
                return True
            elif diff == 2:  # 右回転
                self.rotate_right(idx_par, self_right[idx_par] == idx)
                return True
            elif diff == -2:  # 左回転
                self.rotate_left(idx_par, self_right[idx_par] == idx)
                return True
            else:
                self_diff[idx_par] += 1 if self_left[idx_par] == idx else -1
            idx = idx_par
        return True

    def remove(self, x):  # erase
        # x を削除する
        # x の存在が保証されている必要がある

        idx = 0
        path = []
        idx_x = -1

        self_values = self.values
        self_left = self.left
        self_right = self.right
        self_diff = self.diff
        self_size_l = self.size_l

        while idx != -1:
            path.append(idx)
            value = self_values[idx]
            if value < x:
                idx = self_right[idx]
            elif x < value:
                self_size_l[idx] -= 1  # 値の存在を保証しているので
                idx = self_left[idx]
            else:  # x == value
                idx_x = idx
                self_size_l[idx] -= 1
                idx = self_left[idx]

        idx_last_par, idx_last = path[-2:]

        if idx_last == idx_x:  # x に左の子が存在しない
            # 親の idx を付け替える
            if self_left[idx_last_par] == idx_x:
                self_left[idx_last_par] = self_right[idx_x]
                self_diff[idx_last_par] -= 1
            else:
                self_right[idx_last_par] = self_right[idx_x]
                self_diff[idx_last_par] += 1
        else:  # x に左の子が存在する
            # 自身の value を付け替える
            self_values[idx_x] = self_values[idx_last]
            if idx_last_par == idx_x:  # x 左 idx_last (左 _)?
                self_left[idx_last_par] = self_left[idx_last]
                self_diff[idx_last_par] -= 1
            else:  # x 左 _ 右 ... 右 idx_last (左 _)?
                self_right[idx_last_par] = self_left[idx_last]
                self_diff[idx_last_par] += 1

        self_rotate_left = self.rotate_left
        self_rotate_right = self.rotate_right
        diff = self_diff[idx_last_par]
        idx = idx_last_par
        for idx_par in path[-3::-1]:
            # assert diff == self_diff[idx]
            lr = self_right[idx_par] == idx
            if diff == 0:
                pass
            elif diff == 2:  # 右回転
                diff_ = self_rotate_right(idx_par, lr)
                if diff_ != 0:
                    return True
            elif diff == -2:  # 左回転
                diff_ = self_rotate_left(idx_par, lr)
                if diff_ != 0:
                    return True
            else:
                return True
            diff = self_diff[idx_par] = self_diff[idx_par] + (1 if lr else -1)
            idx = idx_par
        return True

    def pop(self, idx_):
        # 小さい方から idx_ 番目の要素を削除してその要素を返す(0-indexed)
        # idx_ 番目の値の存在が保証されている必要がある

        path = [0]
        idx_x = -1

        self_values = self.values
        self_left = self.left
        self_right = self.right
        self_diff = self.diff
        self_size_l = self.size_l

        sum_left = 0
        idx = self_right[0]
        while idx != -1:
            path.append(idx)
            c = sum_left + self_size_l[idx]
            if idx_ < c:
                self_size_l[idx] -= 1  # 値の存在が保証されているので
                idx = self_left[idx]
            elif c < idx_:
                idx = self_right[idx]
                sum_left = c + 1
            else:
                idx_x = idx
                x = self_values[idx]
                self_size_l[idx] -= 1  # なんで?
                idx = self_left[idx]

        idx_last_par, idx_last = path[-2:]

        if idx_last == idx_x:  # x に左の子が存在しない
            # 親の idx を付け替える
            if self_left[idx_last_par] == idx_x:
                self_left[idx_last_par] = self_right[idx_x]
                self_diff[idx_last_par] -= 1
            else:
                self_right[idx_last_par] = self_right[idx_x]
                self_diff[idx_last_par] += 1
        else:  # x に左の子が存在する
            # 自身の value を付け替える
            self_values[idx_x] = self_values[idx_last]
            if idx_last_par == idx_x:  # x 左 idx_last (左 _)?
                self_left[idx_last_par] = self_left[idx_last]
                self_diff[idx_last_par] -= 1
            else:  # x 左 _ 右 ... 右 idx_last (左 _)?
                self_right[idx_last_par] = self_left[idx_last]
                self_diff[idx_last_par] += 1

        self_rotate_left = self.rotate_left
        self_rotate_right = self.rotate_right
        diff = self_diff[idx_last_par]
        idx = idx_last_par
        for idx_par in path[-3::-1]:
            # assert diff == self_diff[idx]
            lr = self_right[idx_par] == idx
            if diff == 0:
                pass
            elif diff == 2:  # 右回転
                diff_ = self_rotate_right(idx_par, lr)
                if diff_ != 0:
                    return x
            elif diff == -2:  # 左回転
                diff_ = self_rotate_left(idx_par, lr)
                if diff_ != 0:
                    return x
            else:
                return x
            diff = self_diff[idx_par] = self_diff[idx_par] + (1 if lr else -1)
            idx = idx_par
        return x

    def __getitem__(self, idx_):
        # 小さい方から idx_ 番目の要素を返す

        self_left = self.left
        self_right = self.right
        self_size_l = self.size_l

        sum_left = 0
        idx = self_right[0]
        while idx != -1:
            c = sum_left + self_size_l[idx]
            if idx_ < c:
                idx = self_left[idx]
            elif c < idx_:
                idx = self_right[idx]
                sum_left = c + 1
            else:  # c == idx_
                return self.values[idx]
        raise IndexError

    def __contains__(self, x):  # count
        # 値 x があるか

        self_left = self.left
        self_right = self.right
        self_values = self.values
        self_size_l = self.size_l

        idx = self_right[0]
        res = 0
        while idx != -1:
            value = self_values[idx]
            if value < x:
                res += self_size_l[idx] + 1
                idx = self_right[idx]
            elif x < value:
                idx = self_left[idx]
            else:
                return True  # res + self_size_l[idx]
        return False

    def bisect_left(self, x):  # lower_bound
        self_left = self.left
        self_right = self.right
        self_values = self.values
        self_size_l = self.size_l

        idx = self_right[0]
        res = 0
        while idx != -1:
            value = self_values[idx]
            if value < x:
                res += self_size_l[idx] + 1
                idx = self_right[idx]
            elif x < value:
                idx = self_left[idx]
            else:  # value == x
                return res + self_size_l[idx]
        return res

    def bisect_right(self, x):  # upper_bound
        self_left = self.left
        self_right = self.right
        self_values = self.values
        self_size_l = self.size_l

        idx = self_right[0]
        res = 0
        while idx != -1:
            value = self_values[idx]
            if value < x:
                res += self_size_l[idx] + 1
                idx = self_right[idx]
            elif x < value:
                idx = self_left[idx]
            else:  # value == x:
                return res + self_size_l[idx] + 1
        return res

    def print_tree(self, idx=0, depth=0, from_="・"):
        if idx == 0:
            idx = self.right[idx]
        if idx == -1:
            return
        self.print_tree(self.left[idx], depth + 1, "┏")
        print("\t\t" * depth + from_ + " val=[" + str(self.values[idx]) +
              "] diff=[" + str(self.diff[idx]) +
              "] size_l=[" + str(self.size_l[idx]) + "]")
        self.print_tree(self.right[idx], depth + 1, "┗")

まとめ

# なんで? をつけたコードをブログで公開するやつがいるらしい