ABC306 回想・解説(Python)

前提

AtCoder Biginner Contest306 の Python の解答コードです。

雑談

ABC306 も原因不明であるが、ジャッジが正常に動かず、unrated となった。 SortedMultiset を実践で始めて使い、初の水パフォだっただけに残念である。

A. Echo

問題文の通り実装する。

N = int(input())
S = input()

ans = ""
for s in S:
ans += s \* 2
print(ans)

B. Base 2

問題文の通り実装する。

A = list(map(int, input().split()))

ans = 0
for i, a in enumerate(A):
    ans += a * 2**i
print(ans)

C - Centers

愚直になると TLE なので、工夫が必要。単純に 2 回目に現れた通りに数を出力したらよい。

from collections import defaultdict

N = int(input())
A = list(map(int, input().split()))

d = defaultdict(int)
ans = []
for a in A:
d[a] += 1
if d[a] == 2:
ans.append(a)
print(\*ans)

D - Poisonous Full-Course

dpを使う。もらうDPで解こうとしたら、毒の状態をうまく持てなかったので、汚い実装になった。 解説の配るDPはスッと理解できたので、そちらのほうがよさそう。

N = int(input())
XY = [0] + [list(map(int, input().split())) for _ in range(N)]

dp = [[0] * (N + 2) for _ in range(4)]
INF = 10**18

for i in range(1, N + 1):
    x, y = XY[i]
    if x == 1:
        dp[0][i] = -INF
        dp[1][i] = max(dp[0][i - 1], dp[1][i - 1])
        dp[2][i] = max(dp[0][i - 1], dp[1][i - 1]) + y
        dp[3][i] = max(dp[2][i - 1], dp[3][i - 1])
    else:
        dp[0][i] = max(dp[0][i - 1], dp[1][i - 1], dp[2][i - 1], dp[3][i - 1]) + y
        dp[1][i] = max(dp[0][i - 1], dp[1][i - 1])
        dp[2][i] = -INF
        dp[3][i] = max(dp[2][i - 1], dp[3][i - 1])
print(max(dp[0][N], dp[1][N], dp[2][N], dp[3][N]))


E - Best Performances

いちいちソートしていると日がくれるので、工夫をする。 K 個の値を持つ SortedMultiset と N-K 個の値を持つ SortedMultiset を用意する。 更新前の値を更新後の値がどちらに入るかを判定し、それぞれ場合によって値を入れ替える。 合計値は別でもっておき、K 個の値を持つ SortedMultiset に操作をした場合に計算しなおす。

# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
import math
from bisect import bisect_left, bisect_right, insort
from typing import Generic, Iterable, Iterator, TypeVar, Union, List

T = TypeVar("T")

class SortedMultiset(Generic[T]):
BUCKET_RATIO = 50
REBUILD_RATIO = 170

    def _build(self, a=None) -> None:
        """Evenly divide `a` into buckets."""
        if a is None:
            a = list(self)
        size = self.size = len(a)
        bucket_size = int(math.ceil(math.sqrt(size / self.BUCKET_RATIO)))
        self.a = [
            a[size * i // bucket_size : size * (i + 1) // bucket_size]
            for i in range(bucket_size)
        ]

    def __init__(self, a: Iterable[T] = []) -> None:
        """Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"""
        a = list(a)
        if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
            a = sorted(a)
        self._build(a)

    def __iter__(self) -> Iterator[T]:
        for i in self.a:
            for j in i:
                yield j

    def __reversed__(self) -> Iterator[T]:
        for i in reversed(self.a):
            for j in reversed(i):
                yield j

    def __len__(self) -> int:
        return self.size

    def __repr__(self) -> str:
        return "SortedMultiset" + str(self.a)

    def __str__(self) -> str:
        s = str(list(self))
        return "[" + s[1 : len(s) - 1] + "]"

    def _find_bucket(self, x: T) -> List[T]:
        """Find the bucket which should contain x. self must not be empty."""
        for a in self.a:
            if x <= a[-1]:
                return a
        return a

    def __contains__(self, x: T) -> bool:
        if self.size == 0:
            return False
        a = self._find_bucket(x)
        i = bisect_left(a, x)
        return i != len(a) and a[i] == x

    def count(self, x: T) -> int:
        """Count the number of x."""
        return self.index_right(x) - self.index(x)

    def add(self, x: T) -> None:
        """Add an element. / O(√N)"""
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return
        a = self._find_bucket(x)
        insort(a, x)
        self.size += 1
        if len(a) > len(self.a) * self.REBUILD_RATIO:
            self._build()

    def discard(self, x: T) -> bool:
        """Remove an element and return True if removed. / O(√N)"""
        if self.size == 0:
            return False
        a = self._find_bucket(x)
        i = bisect_left(a, x)
        if i == len(a) or a[i] != x:
            return False
        a.pop(i)
        self.size -= 1
        if len(a) == 0:
            self._build()
        return True

    def lt(self, x: T) -> Union[T, None]:
        """Find the largest element < x, or None if it doesn't exist."""
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x: T) -> Union[T, None]:
        """Find the largest element <= x, or None if it doesn't exist."""
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x: T) -> Union[T, None]:
        """Find the smallest element > x, or None if it doesn't exist."""
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x: T) -> Union[T, None]:
        """Find the smallest element >= x, or None if it doesn't exist."""
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]

    def __getitem__(self, x: int) -> T:
        """Return the x-th element, or IndexError if it doesn't exist."""
        if x < 0:
            x += self.size
        if x < 0:
            raise IndexError
        for a in self.a:
            if x < len(a):
                return a[x]
            x -= len(a)
        raise IndexError

    def index(self, x: T) -> int:
        """Count the number of elements < x."""
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x: T) -> int:
        """Count the number of elements <= x."""
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans

N, K, Q = map(int, input().split())
XY = [list(map(int, input().split())) for _ in range(Q)]
A = [0] _ N
high = SortedMultiset([0] _ K)
if N == K:
low = SortedMultiset([0])
else:
low = SortedMultiset([0] \* (N - K))

ans = 0
for x, y in XY:
bf = A[x - 1]
A[x - 1] = y
hl = high.**getitem**(0)
lh = low.**getitem**(N - K - 1)
if bf >= hl:
high.discard(bf)
if y >= lh:
high.add(y)
ans += y - bf
else:
low.discard(lh)
high.add(lh)
low.add(y)
ans += lh - bf
else:
low.discard(bf)
if y <= hl:
low.add(y)
ans = ans
else:
high.discard(hl)
low.add(hl)
high.add(y)
ans += y - hl
print(ans)