トップページ -> AOJの解答例 -> DSL_2の解答例

DSL_2の解答例(Python)

ある範囲に対して更新と検索を行う問題です. セグメント木を使って効率的に検索・更新できるようにします.
F,G,H,Iは遅延評価セグメント木に関する問題です.

1. DSL_2_A: Range Minimum Query (RMQ)

完全二分木のセグメント木を用意して検索・更新をします.
最下層に元のデータを持っておき,それより上では効率よくクエリの質問に答えられるように範囲の最小値や和を記憶しておくようにします. 以下の外部サイトが非常に分かりやすいです.
【外部サイト】セグメント木を徹底解説!0から遅延評価やモノイドまで

# DSL_2_A: Range Minimum Query (RMQ)
class SegmentTree:
    def __init__(self,n):
        # サイズが2のべき乗になるようにする(綺麗な二分木にしたい)
        self.size = 1
        while self.size < n:
            self.size *= 2
        # 2**31-1で初期化
        self.tree = [2**31-1]*(2*self.size-1)

    def update(self, i, x):
        # 一番下の階層にiを合わせる(i番目のデータの場所を探す)
        i += self.size - 1
        # iを更新
        self.tree[i] = x
        # 親を更新していく
        while i > 0:
            i = (i-1)//2
            self.tree[i] = min(self.tree[2*i+1], self.tree[2*i+2])

    def find_min(self, start, end, idx=0, left=0,right=None):
        if right == None:
            right = self.size
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return float("inf")
        # 区間が範囲内の場合
        if start <= left and end >= right:
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.find_min(start,end,2*idx+1,left,mid)
        vr = self.find_min(start,end,2*idx+2,mid,right)
        return min(vl,vr)


n, q = map(int, input().split())
segment_tree = SegmentTree(n)

for _ in range(q):
    com,x,y = map(int,input().split())
    if com == 0:
        segment_tree.update(x,y)
    else:
        result = segment_tree.find_min(x,y+1)  # yも範囲に含めるので+1
        print(result)

2. DSL_2_B: Range Sum Query (RSQ)

Aを和を求めるように変えただけです.

# DSL_2_B
class SegmentTree:
    def __init__(self,n):
        # サイズが2のべき乗になるようにする(綺麗な二分木にしたい)
        self.size = 1
        while self.size < n:
            self.size *= 2
        # 0で初期化
        self.tree = [0]*(2*self.size-1)

    def update(self, i, x):
        # 一番下の階層にiを合わせる(i番目のデータの場所を探す)
        i += self.size - 1
        # iを更新
        self.tree[i] += x
        # 親を更新していく
        while i > 0:
            i = (i-1)//2
            self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2] # 和に変えた

    # グローバル変数 ans を使って和を計算した
    def find_sum(self, start, end, idx=0, left=0,right=None):
        if right == None:
            right = self.size
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return 0 
        # 区間が範囲内の場合
        if start <= left and end >= right:
            return self.tree[idx] 

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.find_sum(start,end,2*idx+1,left,mid)
        vr = self.find_sum(start,end,2*idx+2,mid,right)
        return vl+vr # 和を返すようにする


n, q = map(int, input().split())
segment_tree = SegmentTree(n)

for _ in range(q):
    com,x,y = map(int,input().split())
    if com == 0:
        segment_tree.update(x-1,y)
    else:
        result = segment_tree.find_sum(x-1,y)  # indexがAとは違うことに注意
        print(result)

3. DSL_2_C: Range Search (kD Tree)


# DSL_2_C
# 編集中

4. DSL_2_D: Range Update Query (RUQ)

RUQに関する問題です. 本来は先ほど紹介したサイトのように遅延評価セグメント木を使うのですが,最後の更新がいつだったかを記録して答えるようにしています. 遅延評価セグメント木についてはF以降で書いています.

# DSL_2_D
class SegmentTree:
    def __init__(self,n):
        # サイズが2のべき乗になるようにする(綺麗な二分木にしたい)
        self.size = 1
        while self.size < n:
            self.size *= 2
        # 2**31-1で初期化
        self.tree = [2**31-1]*(2*self.size-1)
        self.delay = [0]*(2*self.size-1) # 最後の更新を記録

    # left ~ right をnumに変える
    def update(self, start, end, num, delay, idx=0, left=0,right=None):
        if right == None:
            right = self.size
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return float("inf")
        # 区間が範囲内の場合
        if start <= left and end >= right:
            self.tree[idx] = num # ここで更新している
            self.delay[idx] = delay # ここで更新している
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.update(start,end,num,delay,2*idx+1,left,mid)
        vr = self.update(start,end,num,delay,2*idx+2,mid,right)

    # iとiの親で更新が最後のものを探す
    def find(self, i):
        i += self.size - 1
        result = self.tree[i]
        max_delay = self.delay[i]
        while i > 0:
            # 後に更新されていればそれが結果
            i = (i-1)//2
            if self.delay[i] > max_delay:
                result = self.tree[i]
                max_delay = self.delay[i]
        return result

n, q = map(int, input().split())
segment_tree = SegmentTree(n)
delay = 0

for _ in range(q):
    query = list(map(int,input().split()))
    if query[0] == 0:
        _,s,t,x = query
        delay += 1
        segment_tree.update(s,t+1,x,delay)
    else:
        _,i = query 
        result = segment_tree.find(i)  # yも範囲に含めるので+1
        print(result)

5. DSL_2_E: Range Add Query (RAQ)

RAQに関する問題です.

# DSL_2_E
class SegmentTree:
    def __init__(self,n):
        # サイズが2のべき乗になるようにする(綺麗な二分木にしたい)
        self.size = 1
        while self.size < n:
            self.size *= 2
        # 0で初期化
        self.tree = [0]*(2*self.size-1)

    # left ~ right にnumを足す
    def update(self, start, end, num, idx=0, left=0,right=None):
        if right == None:
            right = self.size
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return float("inf")
        # 区間が範囲内の場合
        if start <= left and end >= right:
            self.tree[idx] += num # ここで更新している
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.update(start,end,num,2*idx+1,left,mid)
        vr = self.update(start,end,num,2*idx+2,mid,right)
        
    # 自身から親までを足し合わせればいい
    def find(self, i):
        i += self.size - 1
        result = self.tree[i]
        while i > 0:
            i = (i-1)//2
            result += self.tree[i]
                
        return result
        
n, q = map(int, input().split())
segment_tree = SegmentTree(n)
delay = 0

for _ in range(q):
    query = list(map(int,input().split()))
    if query[0] == 0:
        _,s,t,x = query
        delay += 1
        segment_tree.update(s-1,t,x)
    else:
        _,i = query 
        result = segment_tree.find(i-1)  # yも範囲に含めるので+1
        print(result)

6. DSL_2_F: RMQ and RUQ

遅延評価無しでは誤魔化しが効かないので遅延評価セグメント木を使います. 求めるものが和なのか最小値なのか,更新の方法が加算なのか代入なのかの組み合わせを変えたG,H,Iの問題が出ます.
問題に応じて変更するべき行に # !!! を付けておきました.

# DSL_2_F
# 遅延評価セグメント木
class SegmentTree:
    def __init__(self,n):
        # サイズが2のべき乗になるようにする(綺麗な二分木にしたい)
        self.size = 1
        while self.size < n:
            self.size *= 2
        # 2**31-1で初期化
        self.tree = [2**31-1]*(2*self.size-1) # !!!
        # 子は find のときに lazy を使って一気に更新する(遅延評価)
        self.lazy = [None]*(2*self.size-1) # !!!

    # left ~ right をnumに変える
    def update(self, start, end, num, idx=0, left=0,right=None):
        if right == None:
            right = self.size
            
        # 更新しておく(遅延評価)
        if self.lazy[idx] != None: # !!!
            self.tree[idx] = self.lazy[idx] # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] = self.lazy[idx] # !!!
                self.lazy[2*idx+2] = self.lazy[idx] # !!!
        self.lazy[idx] = None # !!!
            
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return float("inf") # !!!
        # 区間が範囲内の場合
        if start <= left and end >= right:
            self.tree[idx] = num # !!!
            # 遅延評価
            self.lazy[idx] = None # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] = num # !!!
                self.lazy[2*idx+2] = num # !!!
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.update(start,end,num,2*idx+1,left,mid)
        vr = self.update(start,end,num,2*idx+2,mid,right)
        self.tree[idx] = min(self.tree[2 * idx + 1], self.tree[2 * idx + 2]) # !!!

    def find_min(self, start, end, idx=0, left=0,right=None):
        if right == None:
            right = self.size
            
        # 更新しておく(遅延評価)
        if self.lazy[idx] != None: # !!! 
            self.tree[idx] = self.lazy[idx] # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] = self.lazy[idx]
                self.lazy[2*idx+2] = self.lazy[idx]
            self.lazy[idx] = None # !!!
            
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return float("inf") # !!!
        # 区間が範囲内の場合
        if start <= left and end >= right:
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.find_min(start,end,2*idx+1,left,mid)
        vr = self.find_min(start,end,2*idx+2,mid,right)
        return min(vl,vr) # !!!

n, q = map(int, input().split())
segment_tree = SegmentTree(n)

for _ in range(q):
    # print(segment_tree.tree)
    query = list(map(int,input().split()))
    if query[0] == 0:
        _,s,t,x = query
        segment_tree.update(s,t+1,x)
    else:
        _,s,t = query 
        result = segment_tree.find_min(s,t+1)  # yも範囲に含めるので+1
        print(result)

7. DSL_2_G: RSQ and RAQ

遅延評価セグメント木を使います. Fの # !!! が付いている行を変えました. 和を考えるときに(right-left)で幅を考えなければいけないことに注意します.

# DSL_2_G
# 遅延評価セグメント木
class SegmentTree:
    def __init__(self,n):
        # サイズが2のべき乗になるようにする(綺麗な二分木にしたい)
        self.size = 1
        while self.size < n:
            self.size *= 2
        # 0で初期化
        self.tree = [0]*(2*self.size-1) # !!!
        # 子は find のときに lazy を使って一気に更新する(遅延評価)
        self.lazy = [0]*(2*self.size-1) # !!!

    # left ~ right を num を加える 
    def update(self, start, end, num, idx=0, left=0,right=None):
        if right == None:
            right = self.size
            
        # 更新しておく(遅延評価)
        if self.lazy[idx] != 0: # !!!
            self.tree[idx] += self.lazy[idx]*(right-left) # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] += self.lazy[idx] # !!!
                self.lazy[2*idx+2] += self.lazy[idx] # !!!
        self.lazy[idx] = 0 # !!!
            
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return 0 # !!!
        # 区間が範囲内の場合
        if start <= left and end >= right:
            self.tree[idx] += num*(right-left) # !!!
            # 遅延評価
            self.lazy[idx] = 0 # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] += num # !!!
                self.lazy[2*idx+2] += num # !!!
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.update(start,end,num,2*idx+1,left,mid)
        vr = self.update(start,end,num,2*idx+2,mid,right)
        self.tree[idx] = self.tree[2 * idx + 1] + self.tree[2 * idx + 2] # !!!

    def find_sum(self, start, end, idx=0, left=0,right=None):
        if right == None:
            right = self.size
            
        # 更新しておく(遅延評価)
        if self.lazy[idx] != 0: # !!! 
            self.tree[idx] += self.lazy[idx]*(right-left) # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] += self.lazy[idx] # !!!
                self.lazy[2*idx+2] += self.lazy[idx] # !!!
            self.lazy[idx] = 0 # !!!
            
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return 0 # !!!
        # 区間が範囲内の場合
        if start <= left and end >= right:
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.find_sum(start,end,2*idx+1,left,mid)
        vr = self.find_sum(start,end,2*idx+2,mid,right)
        return vl+vr # !!!

n, q = map(int, input().split())
segment_tree = SegmentTree(n)

for _ in range(q):
    query = list(map(int,input().split()))
    if query[0] == 0:
        _,s,t,x = query
        segment_tree.update(s-1,t,x) # 範囲に注意
    else:
        _,s,t = query 
        result = segment_tree.find_sum(s-1,t)  # 範囲に注意
        print(result)

8. DSL_2_H: RMQ and RAQ

遅延評価セグメント木を使います. Fの # !!! が付いている行を変えました.

# DSL_2_H
# 遅延評価セグメント木
class SegmentTree:
    def __init__(self,n):
        # サイズが2のべき乗になるようにする(綺麗な二分木にしたい)
        self.size = 1
        while self.size < n:
            self.size *= 2
        # 0で初期化
        self.tree = [0]*(2*self.size-1) # !!!
        # 子は find のときに lazy を使って一気に更新する(遅延評価)
        self.lazy = [0]*(2*self.size-1) # !!!

    # left ~ right を num を加える 
    def update(self, start, end, num, idx=0, left=0,right=None):
        if right == None:
            right = self.size
            
        # 更新しておく(遅延評価)
        if self.lazy[idx] != 0: # !!!
            self.tree[idx] += self.lazy[idx] # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] += self.lazy[idx] # !!!
                self.lazy[2*idx+2] += self.lazy[idx] # !!!
        self.lazy[idx] = 0 # !!!
            
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return 0 # !!!
        # 区間が範囲内の場合
        if start <= left and end >= right:
            self.tree[idx] += num # !!!
            # 遅延評価
            self.lazy[idx] = 0 # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] += num # !!!
                self.lazy[2*idx+2] += num # !!!
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.update(start,end,num,2*idx+1,left,mid)
        vr = self.update(start,end,num,2*idx+2,mid,right)
        self.tree[idx] = min(self.tree[2 * idx + 1],self.tree[2 * idx + 2]) # !!!

    def find_min(self, start, end, idx=0, left=0,right=None):
        if right == None:
            right = self.size
            
        # 更新しておく(遅延評価)
        if self.lazy[idx] != 0: # !!! 
            self.tree[idx] += self.lazy[idx] # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] += self.lazy[idx] # !!!
                self.lazy[2*idx+2] += self.lazy[idx] # !!!
            self.lazy[idx] = 0 # !!!
            
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return float("inf") # !!!
        # 区間が範囲内の場合
        if start <= left and end >= right:
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.find_min(start,end,2*idx+1,left,mid)
        vr = self.find_min(start,end,2*idx+2,mid,right)
        return min(vl,vr) # !!!

n, q = map(int, input().split())
segment_tree = SegmentTree(n)

for _ in range(q):
    query = list(map(int,input().split()))
    if query[0] == 0:
        _,s,t,x = query
        segment_tree.update(s,t+1,x)
    else:
        _,s,t = query 
        result = segment_tree.find_min(s,t+1)  # yも範囲に含めるので+1
        print(result)

8. DSL_2_I: RSQ and RUQ

遅延評価セグメント木を使います. Fの # !!! が付いている行を変えました.

# DSL_2_I
# 遅延評価セグメント木
class SegmentTree:
    def __init__(self,n):
        # サイズが2のべき乗になるようにする(綺麗な二分木にしたい)
        self.size = 1
        while self.size < n:
            self.size *= 2
        # 0で初期化
        self.tree = [0]*(2*self.size-1) # !!!
        # 子は find のときに lazy を使って一気に更新する(遅延評価)
        self.lazy = [None]*(2*self.size-1) # !!!

    # left ~ right をnumに変える
    def update(self, start, end, num, idx=0, left=0,right=None):
        if right == None:
            right = self.size
            
        # 更新しておく(遅延評価)
        if self.lazy[idx] != None: # !!!
            self.tree[idx] = self.lazy[idx]*(right-left) # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] = self.lazy[idx]
                self.lazy[2*idx+2] = self.lazy[idx]
        self.lazy[idx] = None # !!!
            
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return 0 # !!!
        # 区間が範囲内の場合
        if start <= left and end >= right:
            self.tree[idx] = num*(right-left) # !!!
            # 遅延評価
            self.lazy[idx] = None # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] = num # !!!
                self.lazy[2*idx+2] = num # !!!
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.update(start,end,num,2*idx+1,left,mid)
        vr = self.update(start,end,num,2*idx+2,mid,right)
        self.tree[idx] = self.tree[2 * idx + 1] + self.tree[2 * idx + 2] # !!!

    def find_sum(self, start, end, idx=0, left=0,right=None):
        if right == None:
            right = self.size
            
        # 更新しておく(遅延評価)
        if self.lazy[idx] != None: # !!! 
            self.tree[idx] = self.lazy[idx]*(right-left) # !!!
            if 2*idx+2 < len(self.tree):
                self.lazy[2*idx+1] = self.lazy[idx] # !!!
                self.lazy[2*idx+2] = self.lazy[idx] # !!!
            self.lazy[idx] = None # !!!
            
        # 区間が範囲外の場合
        if end <= left or start >= right:
            return 0 # !!!
        # 区間が範囲内の場合
        if start <= left and end >= right:
            return self.tree[idx]

        # 一部が含まれている場合は区間を二分割して調べる
        mid = (left+right)//2
        vl = self.find_sum(start,end,2*idx+1,left,mid)
        vr = self.find_sum(start,end,2*idx+2,mid,right)
        return vl+vr # !!!

n, q = map(int, input().split())
segment_tree = SegmentTree(n)

for _ in range(q):
    # print(segment_tree.tree)
    query = list(map(int,input().split()))
    if query[0] == 0:
        _,s,t,x = query
        segment_tree.update(s,t+1,x)
    else:
        _,s,t = query 
        result = segment_tree.find_sum(s,t+1)  # yも範囲に含めるので+1
        print(result)

<- 前へ戻る 【目次に戻る】 次へ進む ->