Python競プロライブラリの整理をする
(19/11/22 追記)
一応最新のはgithubに上げてる
そのうち整理されるかもしれないので個別のページへのリンクを貼ったりするのはやめといたほうがいいかも
この記事は何
PyCharmに常に貼ってたライブラリが長くなりすぎて整理する必要が出てきた
使わなくなったライブラリを消すのが何となくもったいないので公開してから消す
ついでに競プロライブラリを共有する
前置き
この記事のコードは公開を前提に書いたわけじゃないので、Python競プロライブラリを探しているならまず↓のサイトを見るといいと思う
アルゴリズム [いかたこのたこつぼ] DTMでもよくお世話になりました
ライブラリ整理
拡張ユークリッド互除法・中国剰余定理
拡張ユークリッド互除法
# 拡張ユークリッド互除法 # gcd(a,b) と ax + by = gcd(a,b) の最小整数解を返す def egcd(a, b): if a == 0: return b, 0, 1 else: g, y, x = egcd(b % a, a) return g, x - (b // a) * y, y
AtCoderを初めて間もない頃に誰だったかの解答からコピペしたもの
m1 = 10 m2 = 6 gcd, x, y = egcd(m1, m2) # 10で割ると1余り6で割ると3余る数を求める b1 = 1 b2 = 3 s = (b2 - b1) // gcd # これは必ず割り切れる ans = b1 + m1 * s * x print(ans)
使い方のメモがあったけどこれはいらない、消す
中国剰余定理
def chineseRem(b1, m1, b2, m2): # 中国剰余定理 # x ≡ b1 (mod m1) ∧ x ≡ b2 (mod m2) <=> x ≡ r (mod m) # となる(r. m)を返す # 解無しのとき(0, -1) d, p, q = egcd(m1, m2) if (b2 - b1) % d != 0: return 0, -1 m = m1 * (m2 // d) # m = lcm(m1, m2) tmp = (b2-b1) // d * p % (m2 // d) r = (b1 + m1 * tmp) % m return r, m
なんか効率悪いことしてる気がするけどまあこのままでいいか
乗法のmod逆元・組み合わせ
乗法のmod逆元(mod-2乗)
def modinv(a, mod=10**9+7): return pow(a, mod-2, mod)
こっちを普段使ってる
乗法のmod逆元(拡張ユークリッド互除法)
# mを法とするaの乗法的逆元 def modinv(a, m): g, x, y = egcd(a, m) if g != 1: raise Exception('modular inverse does not exist') else: return x % m
egcdと一緒にパクったものを残してるけどpowでいいこと知ってから使われていない
組み合わせ
# nCr mod m # modinvが必要 # rがn/2に近いと非常に重くなる def combination(n, r, mod=10**9+7): r = min(r, n-r) res = 1 for i in range(r): res = res * (n - i) * modinv(i+1, mod) % mod return res
愚直に求めるものだけどたまに使う
組み合わせ(2)
# nCrをすべてのr(0<=r<=n)について求める # nC0, nC1, nC2, ... , nCn を求める # modinvが必要 def combination_list(n, mod=10**9+7): lst = [1] for i in range(1, n+1): lst.append(lst[-1] * (n+1-i) % mod * modinv(i, mod) % mod) return lst
これは使ってない、消す
階乗のmod逆元のリスト
# 階乗のmod逆元のリストを返す O(n) def facinv_list(n, mod=10**9+7): L = [1] for i in range(1, n+1): L.append(L[i-1] * modinv(i, mod) % mod) return L
これO(n)って書いてるけどO(nlog(mod))では?効率悪いし使ってないので消す
重複組み合わせ
# nHr mod m def H(n, r, mod=10**9+7): return combination(n+r-1, r, mod)
combination()
の部分は必要に応じて下記のCombinationオブジェクトに書き換えて使う
組み合わせ(何回も使う方)
class Combination: """ O(n)の前計算を1回行うことで,O(1)でnCr mod mを求められる n_max = 10**6のとき前処理は約950ms (PyPyなら約340ms, 10**7で約1800ms) 使用例: comb = Combination(1000000) print(comb(5, 3)) # 10 """ def __init__(self, n_max, mod=10**9+7): self.mod = mod self.modinv = self.make_modinv_list(n_max) self.fac, self.facinv = self.make_factorial_list(n_max) def __call__(self, n, r): return self.fac[n] * self.facinv[r] % self.mod * self.facinv[n-r] % self.mod def make_factorial_list(self, n): # 階乗のリストと階乗のmod逆元のリストを返す O(n) # self.make_modinv_list()が先に実行されている必要がある fac = [1] facinv = [1] for i in range(1, n+1): fac.append(fac[i-1] * i % self.mod) facinv.append(facinv[i-1] * self.modinv[i] % self.mod) return fac, facinv def make_modinv_list(self, n): # 0からnまでのmod逆元のリストを返す O(n) modinv = [0] * (n+1) modinv[1] = 1 for i in range(2, n+1): modinv[i] = self.mod - self.mod//i * modinv[self.mod%i] % self.mod return modinv
(久々に中身見たけどこれどうやって動いてるんだっけ)
下2つのメソッドはクラス内から呼び出されることを前提にしてるので_
をつけたほうがいいのかもしれない、知らんけど
素数判定、素因数分解
エラトステネスの篩
# nまでの自然数が素数かどうかを表すリストを返す def makePrimeChecker(n): isPrime = [True] * (n + 1) isPrime[0] = False isPrime[1] = False for i in range(2, int(n ** 0.5) + 1): if isPrime[i]: for j in range(i * i, n + 1, i): isPrime[j] = False return isPrime
若干効率悪く書いてたのに気付いたので修正した
素因数分解
# 素因数分解 def prime_decomposition(n): i = 2 table = [] while i * i <= n: while n % i == 0: n //= i table.append(i) i += 1 if n > 1: table.append(n) return table
素因数分解はなんか高速にやる方法もあったと思うけどそれは持ってない
確率的素数判定もあった方がいい気がしたのでyaketakeさんのものを参考に導入することにした
順列
def full(L): # 並び替えをすべて挙げる,全探索用 # これitertools.permutationsでええやん!!!!!!! if len(L) == 1: return [L] else: L2 = [] for i in range(len(L)): L2.extend([[L[i]] + Lc for Lc in full(L[:i] + L[i+1:])]) return L2
itertools.permutations
を知るまで使っていたもの
思い出的に残してた(消す)
転倒数
# 転倒数 def mergecount(A): cnt = 0 n = len(A) if n>1: A1 = A[:n>>1] A2 = A[n>>1:] cnt += mergecount(A1) cnt += mergecount(A2) i1=0 i2=0 for i in range(n): if i2 == len(A2): A[i] = A1[i1] i1 += 1 elif i1 == len(A1): A[i] = A2[i2] i2 += 1 elif A1[i1] <= A2[i2]: A[i] = A1[i1] i1 += 1 else: A[i] = A2[i2] i2 += 1 cnt += n//2 - i1 return cnt
Spaghetti Source - バブルソートの交換回数 をPythonで書き換えたもの
動作原理は理解してない
転倒数はBinary Indexed Treeで求められるし使ってないけど…一応残しておく
Binary Indexed Tree
一点加算・区間和取得
class Bit: def __init__(self, n): self.size = n self.tree = [0]*(n+1) def __iter__(self): psum = 0 for i in range(self.size): csum = self.sum(i+1) yield csum - psum psum = csum raise StopIteration() def __str__(self): # O(nlogn) return str(list(self)) def sum(self, i): # [0, i) の要素の総和を返す if not (0 <= i <= self.size): raise ValueError("error!") s = 0 while i>0: s += self.tree[i] i -= i & -i return s def add(self, i, x): if not (0 <= i < self.size): raise ValueError("error!") i += 1 while i <= self.size: self.tree[i] += x i += i & -i def __getitem__(self, key): if not (0 <= key < self.size): raise IndexError("error!") return self.sum(key+1) - self.sum(key) def __setitem__(self, key, value): # 足し算と引き算にはaddを使うべき if not (0 <= key < self.size): raise IndexError("error!") self.add(key, value - self[key])
初期値をO(n)で入れられるようにするべきかもしれない
和以外の演算もできるようにすべきかもしれない
# BITで転倒数を求められる A = [3, 10, 1, 8, 5, 5, 1] bit = Bit(max(A)+1) ans = 0 for i, a in enumerate(A): ans += i - bit.sum(a+1) bit.add(a, 1) print(ans)
Binary Indexed Treeの使い方のメモ
区間加算・一点取得
class BitImos: """ ・範囲すべての要素に加算 ・ひとつの値を取得 の2種類のクエリをO(logn)で処理 """ def __init__(self, n): self.bit = Bit(n+1) def add(self, s, t, x): # [s, t)にxを加算 self.bit.add(s, x) self.bit.add(t, -x) def get(self, i): return self[i] def __getitem__(self, key): # 位置iの値を取得 return self.bit.sum(key+1)
いもす法みたいな使い方をしやすいようなBIT
区間加算・区間取得
# 未検証 class Bit2: def __init__(self, n): self.bit0 = Bit(n) self.bit1 = Bit(n) def add(self, l, r, x): # [l, r) に x を足す self.bit0.add(l, -x * (l-1)) self.bit1.add(l, x) self.bit0.add(r, x * (r-1)) self.bit1.add(r, -x) def sum(self, l, r): res = 0 res += self.bit0.sum(r) + self.bit1.sum(r) * (r-1) res -= self.bit0.sum(l) + self.bit1.sum(l) * (l-1) return res
原理わかってない
セグメント木
RMQ
# セグ木 class SegTree: # 根は 1 # 親は node>>1 # 子は node<<1, node<<1 | 1 # 兄弟は node^1 # 値が入ってるのは offset: def __init__(self, n): self.n = n # 要素数 self.INF = float("inf") self.seg = [self.INF] * (n * 4) self.depth = len(bin(n))-2 self.offset = 1<<self.depth def initialize(self, lst): # あとで書く pass def update(self, i, x): node = self.offset + i self.seg[node] = x while node > 1: # 親を更新 self.seg[node >> 1] = min(self.seg[node], self.seg[node ^ 1]) node >>= 1 def getmin(self, a, b, k=1, l=0, r=None): if r is None: r = self.offset # a, b は再帰で変化しない値 # 要求区間[a,b)と今見ている区間[l,r)が交わらない if r <= a or b <= l: return self.INF # 今見ている区間は要求区間に完全に含まれている if a <= l and r <= b: return self.seg[k] # どちらでもなければ子ノードを見る vl = self.getmin(a, b, k << 1, l, (l + r) >> 1) vr = self.getmin(a, b, (k << 1) | 1, (l + r) >> 1, r) return min(vl, vr)
確か プログラミングコンテストでのデータ構造 をPythonで書き換えたもののはず
コピペしてきたもののほうがクオリティが高いので供養
一点更新・区間取得
# https://atcoder.jp/contests/abc014/submissions/3935971 class SegmentTree(object): __slots__ = ["elem_size", "tree", "default", "op"] def __init__(self, a: list, default: int, op): from math import ceil, log real_size = len(a) self.elem_size = elem_size = 1 << ceil(log(real_size, 2)) self.tree = tree = [default] * (elem_size * 2) tree[elem_size:elem_size + real_size] = a self.default = default self.op = op for i in range(elem_size - 1, 0, -1): tree[i] = op(tree[i << 1], tree[(i << 1) + 1]) def get_value(self, x: int, y: int) -> int: # 半開区間 l, r = x + self.elem_size, y + self.elem_size tree, result, op = self.tree, self.default, self.op while l < r: if l & 1: result = op(tree[l], result) l += 1 if r & 1: r -= 1 result = op(tree[r], result) l, r = l >> 1, r >> 1 return result def set_value(self, i: int, value: int) -> None: k = self.elem_size + i self.tree[k] = value self.update(k) def update(self, i: int) -> None: op, tree = self.op, self.tree while i > 1: i >>= 1 tree[i] = op(tree[i << 1], tree[(i << 1) + 1]) """ C = [int(input()) for _ in range(N)] idx = [0] * N for i, c in enumerate(C): idx[c-1] = i seg = SegmentTree([0]*(N+1), 0, min) for i in range(N): idx_ = idx[i] seg.set_value(idx_, seg.get_value(0, idx_)-1) print(seg.get_value(0, N+1)+N) """
htkbさんのセグメント木
神
平衡二分探索木
Treap
# Treap import random class Treap: def __init__(self): self.node = None # root def __str__(self): L = [] def recursiveSearch(node): if node.left is not None: recursiveSearch(node.left) L.append((node.val, node.priority)) if node.right is not None: recursiveSearch(node.right) recursiveSearch(self.node) return str(L) def add(self, x): i = self.bisect(x) self.node = self.insert(self.node, i, x) def remove(self, x): i = self.bisect(x) self.node = self.erase(self.node, i) def bisect(self, x): # bisect_left def recursiveSearch(node): if node is None: return 0 res = 0 if x <= node.val: if node.left is not None: return recursiveSearch(node.left) else: return 0 else: res += self.count(node.left) + 1 if node.right is not None: return res + recursiveSearch(node.right) else: return res return recursiveSearch(self.node) def __getitem__(self, k, t=None): if t is None: t = self.node if k < self.count(t.left): if t.left is None: return t.val return self.__getitem__(k, t.left) elif k == self.count(t.left): return t.val else: if t.right is None: return t.val return self.__getitem__(k - self.count(t.left) - 1, t.right) class Node: def __init__(self, x): self.val = x self.priority = random.randint(0, 1<<30) self.left = None self.right = None self.cnt = 1 def __str__(self): return str(self.val) def count(self, t: Node): return 0 if t is None else t.cnt def update(self, t: Node): t.cnt = self.count(t.left) + self.count(t.right) + 1 return t def merge(self, l: Node, r: Node): if l is None: return r if r is None: return l if l.priority > r.priority: l.right = self.merge(l.right, r) return self.update(l) else: r.left = self.merge(l, r.left) return self.update(r) def split(self, t: Node, k: int) -> (Node, Node): # [0,k),[k,n) if t is None: return None, None if k <= self.count(t.left): s = self.split(t.left, k) t.left = s[1] return s[0], self.update(t) else: s = self.split(t.right, k - self.count(t.left) - 1) t.right = s[0] return self.update(t), s[1] def insert(self, t: Node, k: int, x): l, r = self.split(t, k) return self.merge(self.merge(l, self.Node(x)), r) def erase(self, t: Node, k): l, r = self.split(t, k) _, r = self.split(r, 1) return self.merge(l, r) """ treap = Treap() for i in range(100000): treap.add(i) s.add(i) print(treap[50]) """
重すぎて使い物にならない、供養
AVL木
これ
PythonでAVL木を競プロ用に実装した(誰か高速化してくれ) - 菜
↑はPythonにC++のsetに相当するものが無いから実装したものだけど、クエリ先読みできるならBITでなんとかなるらしい(わかってない)
オフラインクエリのset実装 - Tallfallの日記
ダイクストラ法
from collections import defaultdict import heapq class Dijkstra: # 計算量 O((E+V)logV) # adjはdefaultdictのリスト def dijkstra(self, adj, start, goal=None): num = len(adj) # グラフのノード数 self.dist = [float('inf') for i in range(num)] # 始点から各頂点までの最短距離を格納する self.prev = [float('inf') for i in range(num)] # 最短経路における,その頂点の前の頂点のIDを格納する self.dist[start] = 0 q = [(0, start)] # プライオリティキュー.各要素は,(startからある頂点vまでの仮の距離, 頂点vのID)からなるタプル while len(q) != 0: prov_cost, src = heapq.heappop(q) # pop # プライオリティキューに格納されている最短距離が,現在計算できている最短距離より大きければ,distの更新をする必要はない if self.dist[src] < prov_cost: continue # 探索で辺を見つける場合ここに書く # 他の頂点の探索 for dest, cost in adj[src].items(): if self.dist[dest] > self.dist[src] + cost: self.dist[dest] = self.dist[src] + cost # distの更新 heapq.heappush(q, (self.dist[dest], dest)) # キューに新たな仮の距離の情報をpush self.prev[dest] = src # 前の頂点を記録 if goal is not None: return self.get_path(goal, self.prev) else: return self.dist def get_path(self, goal, prev): path = [goal] # 最短経路 dest = goal # 終点から最短経路を逆順に辿る while prev[dest] != float('inf'): path.append(prev[dest]) dest = prev[dest] # 経路をreverseして出力 return list(reversed(path))
Pythonでダイクストラアルゴリズムを実装した - フツーって言うなぁ! を改造したもの
クラスになってるけどそのまま使うことはほとんどしていなくて、
実装の参考にするために置いている
経路が必要がないならself.prev
は必要ない
Union Find木
# unionfind class Uf: def __init__(self, N): self.p = list(range(N)) self.rank = [0] * N self.size = [1] * N def root(self, x): if self.p[x] != x: self.p[x] = self.root(self.p[x]) return self.p[x] def same(self, x, y): return self.root(x) == self.root(y) def unite(self, x, y): u = self.root(x) v = self.root(y) if u == v: return if self.rank[u] < self.rank[v]: self.p[u] = v self.size[v] += self.size[u] self.size[u] = 0 else: self.p[v] = u self.size[u] += self.size[v] self.size[v] = 0 if self.rank[u] == self.rank[v]: self.rank[u] += 1 def count(self, x): return self.size[self.root(x)]
これどこかからコピペしたんだっけ…?
マス目の幅優先探索
# マス目の幅優先探索 for y, s in enumerate(S): for x, c in enumerate(s): if c=="S": start = (x, y) elif c=="G": goal = (x, y) ans = 0 Open = [start] Close = {start} a = 0 flg = True while flg: a += 1 OpenNext = [] for x, y in Open: for dx, dy in [(1,0), (0,1), (-1,0), (0,-1)]: if (x+dx, y+dy) not in Close and S[y+dy][x+dx] != "#": if (x+dx, y+dy) == goal: flg = False OpenNext.append((x+dx, y+dy)) Close.add((x+dx, y+dy)) Open = OpenNext if len(Open) == 0: a = float("inf") break ans += a
だいぶ前にメモとして置いておいた気がするけど使ってないし微妙に効率悪い、消す
高速ゼータ変換・高速メビウス変換
# 高速ゼータ変換 # 自身を含む集合を全て挙げる方 N = 3 f = [{i} for i in range(1<<N)] for i in range(N): for j in range(1<<N): if not (j & 1<<i): f[j] |= f[j | (1<<i)] # 総和は += # -=にすると逆変換になる print(f) # 部分集合をすべて挙げる方 f = [{i} for i in range(1<<N)] for i in range(N): for j in range(1<<N): if j & 1<<i: f[j] |= f[j ^ (1<<i)] print(f)
メモだけどわかりにくいと思うしどうやればわかりやすく書けるのかわからない
[{0, 1, 2, 3, 4, 5, 6, 7}, {1, 3, 5, 7}, {2, 3, 6, 7}, {3, 7}, {4, 5, 6, 7}, {5, 7}, {6, 7}, {7}] [{0}, {0, 1}, {0, 2}, {0, 1, 2, 3}, {0, 4}, {0, 1, 4, 5}, {0, 2, 4, 6}, {0, 1, 2, 3, 4, 5, 6, 7}]
実行結果
最長回文
# Manacherのアルゴリズム def man(S): i = 0 j = 0 n = len(S) R = [0]*n while i < n: while i-j >= 0 and i+j < n and S[i-j] == S[i+j]: j+=1 R[i] = j k = 1 while i-k >= 0 and i+k < n and k+R[i-k] < j: R[i+k] = R[i-k] k += 1 i += k j -= k return R
フロー
Dinic法
# 最大流問題 from collections import deque INF = float("inf") TO = 0; CAP = 1; REV = 2 class Dinic: def __init__(self, N): self.N = N self.V = [[] for _ in range(N)] # to, cap, rev # 辺 e = V[n][m] の逆辺は V[e[TO]][e[REV]] self.level = [0] * N def add_edge(self, u, v, cap): self.V[u].append([v, cap, len(self.V[v])]) self.V[v].append([u, 0, len(self.V[u])-1]) def add_edge_undirected(self, u, v, cap): # 未検証 self.V[u].append([v, cap, len(self.V[v])]) self.V[v].append([u, cap, len(self.V[u])-1]) def bfs(self, s: int) -> bool: self.level = [-1] * self.N self.level[s] = 0 q = deque() q.append(s) while len(q) != 0: v = q.popleft() for e in self.V[v]: if e[CAP] > 0 and self.level[e[TO]] == -1: # capが1以上で未探索の辺 self.level[e[TO]] = self.level[v] + 1 q.append(e[TO]) return True if self.level[self.g] != -1 else False # 到達可能 def dfs(self, v: int, f) -> int: if v == self.g: return f for i in range(self.ite[v], len(self.V[v])): self.ite[v] = i e = self.V[v][i] if e[CAP] > 0 and self.level[v] < self.level[e[TO]]: d = self.dfs(e[TO], min(f, e[CAP])) if d > 0: # 増加路 e[CAP] -= d # cap を減らす self.V[e[TO]][e[REV]][CAP] += d # 反対方向の cap を増やす return d return 0 def solve(self, s, g): self.g = g flow = 0 while self.bfs(s): # 到達可能な間 self.ite = [0] * self.N f = self.dfs(s, INF) while f > 0: flow += f f = self.dfs(s, INF) return flow
これ何を参考に実装したのか全く思い出せないけど気付いたらあった
最長増加部分列(LIS)
いかたこさんのものを使っています
最近共通祖先(LCA)
class Lca: # 最近共通祖先 def __init__(self, E, root): import sys sys.setrecursionlimit(500000) self.root = root self.E = E # V<V> self.n = len(E) # 頂点数 self.logn = 1 # n < 1<<logn ぴったりはだめ while self.n >= (1<<self.logn): self.logn += 1 # parent[n][v] = ノード v から 1<<n 個親をたどったノード self.parent = [[-1]*self.n for _ in range(self.logn)] self.depth = [0] * self.n self.dfs(root, -1, 0) for k in range(self.logn-1): for v in range(self.n): p_ = self.parent[k][v] if p_ >= 0: self.parent[k+1][v] = self.parent[k][p_] def dfs(self, v, p, dep): # ノード番号、親のノード番号、深さ self.parent[0][v] = p self.depth[v] = dep for e in self.E[v]: if e != p: self.dfs(e, v, dep+1) def get(self, u, v): if self.depth[u] > self.depth[v]: u, v = v, u # self.depth[u] <= self.depth[v] dep_diff = self.depth[v]-self.depth[u] for k in range(self.logn): if dep_diff >> k & 1: v = self.parent[k][v] if u==v: return u for k in range(self.logn-1, -1, -1): if self.parent[k][u] != self.parent[k][v]: u = self.parent[k][u] v = self.parent[k][v] return self.parent[0][u]
これパクったんじゃないっけ…?AOJ見たら自分がverifyしてる痕跡があったけど…
http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=3526639#1
幾何
いくつかあるけどAOJからコピペしたものが多いので割愛
その他細かいメモ
import sys def input(): return sys.stdin.readline()[:-1] input = sys.stdin.readline from functools import lru_cache @lru_cache(maxsize=None) # メモ化再帰したい関数の前につける import sys sys.setrecursionlimit(500000) from operator import itemgetter from collections import defaultdict from itertools import product # 直積 ord("a") - 97 # chr N = int(input()) N, K = map(int, input().split()) L = [int(input()) for _ in range(N)] A = list(map(int, input().split())) S = [list(map(int, input().split())) for _ in range(H)]
itemgetter
とかがどこにあったのか昔は覚えてなかったけど今は大丈夫なのでそのへんは消す
input = sys.stdin.readline は読み込むものが文字列じゃないときに使ってる
まとめ
コピペ多いね…