トップページ -> AOJの解答例 -> ALDS1_14の解答例

ALDS1_14の解答例(Python)

文字列の検索に関する問題です. アルゴリズムについて詳しく知りたい方は参考文献などをご覧ください. C,Dに関しては解き方がよく分からなかったので中途半端な解答になっています.

1. ALDS1_14_A: Naive String Search

与えられた文字列Tの中に文字列Pが含まれる場合 その位置を出力する問題です. 文字列が短いので問題名通り単純に先頭から調べました.

# ALDS1_14_A
T = input()
P = input()

for i in range(len(T)-len(P)+1):
    if T[i:i+len(P)] == P:
        print(i)

2. ALDS1_14_B: String Search

Aの文字列の長さが1000000に伸びたものです. Aと同じコードで通ります.

# ALDS1_14_B
T = input()
P = input()

for i in range(len(T)-len(P)+1):
    if T[i:i+len(P)] == P:
        print(i)

3. ALDS1_14_C: Pattern Search

縦 H × 横 W の文字フィールドの中から、縦 R × 横 C の文字パターンを探しだす問題です. 色々考えても,計算量が膨大になってしまい正攻法で解くことができませんでした. 提出されている解答を見るとどうやらローリングハッシュを二次元に拡張して解けるそうですが,よく分かりませんでした. 色々試した結果,Acceptはされましたがスマートな解答ではありません.

参考サイト

各行に対してローリングハッシュが使えるように下準備をした後に縦方向にR行のハッシュ値が一致しているかを調べます. この際に1行ずつずらすと効率が悪いのでBM法の考え方を利用して,一致しないと分かる場所はスキップするようにします. H×WのパターンからR×Cのパターンを見つけるために最悪で(H-R)×(W-C)回一致するかを判定する必要があります. 一度の判定に縦方向にR回調べる必要があるので計算量は(H-R)×(W-C)×Rです. R=1000, C=1などの場合はそのまま検索するよりもRとCを入れ替えたほうが少なくなるので,RがCより小さい場合は行と列を入れ替えて検索を行うことにすれば, 計算量は(H-R)×(W-C)×min(R,C)になります. 問題の設定の範囲内でH,W=1000 R,C=300で計算量が1.47×10^8となってしまうためこれでは全然間に合いません. しかし,計算量が最悪に(近く)なる場合はBM法によるスキップが(ほとんど)起こらない場合のみ(極端な例は0で埋め尽くされ全てが一致するパターン)です. Pに2つ以上のハッシュ値が含まれていれば最低でもBM法で1つ飛ばしで検索できるため,この場合に限っては7×10^7程度で済みます. Pが1つの文字しか含まない場合は別の処理をしました.

# 参考 https://daeudaeu.com/c-bm-search/#i-6
# ----------------------------ここから---------------------------------- #
def makeTable1(pattern, pattern_len):
    table1 = {}
    for pos in range(pattern_len):
        table1[pattern[pos]] = pattern_len - 1 - pos
    return table1

def makeTable2(pattern, pattern_len):
    table2 = [-1] * pattern_len

    for tail_pos in range(pattern_len - 1):
        eq_len = 0
        while eq_len < tail_pos and pattern[tail_pos - eq_len] == pattern[pattern_len - 1 - eq_len]:
            eq_len += 1

        if eq_len == 0:
            continue

        if pattern[tail_pos - eq_len] != pattern[pattern_len - 1 - eq_len]:
            table2[pattern_len - 1 - eq_len] = pattern_len - 1 - tail_pos + eq_len

    tail_pos = -1
    for pattern_pos in range(pattern_len - 2, -1, -1):
        eq_len = pattern_len - 1 - pattern_pos
        i = 0
        while i < eq_len and pattern[i] == pattern[pattern_pos + 1 + i]:
            i += 1

        if eq_len == i:
            tail_pos = eq_len - 1

        if table2[pattern_pos] == -1:
            if tail_pos != -1:
                table2[pattern_pos] = pattern_len - 1 - tail_pos + pattern_len - 1 - pattern_pos

    for pattern_pos in range(pattern_len - 2, -1, -1):
        if table2[pattern_pos] == -1:
            table2[pattern_pos] = pattern_len + (pattern_len - 1 - pattern_pos)

    return table2

def bmSearch(text, pattern,col):
    text_len = len(text)
    pattern_len = len(pattern)

    table1 = makeTable1(pattern, pattern_len)
    table2 = makeTable2(pattern, pattern_len)

    pattern_pos = pattern_len - 1
    text_pos = pattern_len - 1
    
    pos_list = []

    while text_pos < text_len:
        if text[text_pos] == pattern[pattern_pos]:
            if pattern_pos == 0:
                pos_list.append([text_pos,col])
                text_pos += len(pattern)
                pattern_pos = pattern_len - 1
                continue
            text_pos -= 1
            pattern_pos -= 1
        else:
            if text[text_pos] in table1.keys():
                t1 = table1[text[text_pos]]
            else:
                t1 = len(pattern)
                
            if t1 > table2[pattern_pos]:
                text_pos += t1
            else:
                text_pos += table2[pattern_pos]
            pattern_pos = pattern_len - 1

    return pos_list

# --------------------------------ここまで------------------------------------- #

H,W = map(int,input().split(" "))
str_set = set()
T = []
for _ in range(H):
    s = input()
    str_set = str_set | set(s)
    T.append(s)
    
R,C = map(int,input().split(" "))
P = []
str_set2 = set() # Pに含まれる文字の種類を把握する
for _ in range(R):
    s = input()
    str_set = str_set | set(s)
    str_set2 = str_set2 | set(s)
    P.append(s)
    
def get_value(row,start,end):
    return (rolling_hash[row][end]-rolling_hash[row][start]*num_dict[end-start]) % mod
    
# 1つの文字しか含まない極端な場合は別の処理をする
if len(set(str_set2))==1:
    # 1つの文字以外が含まれている場合 そこは一致しない
    key = list(str_set2)[0]
    ng_list = []
    for i in range(H):
        for j in range(W):
            if T[i][j] != key:
                for k in range(max(i+1-R,0),i+1):
                    for l in range(max(j+1-C,0),j+1):
                        ng_list.append(str(k)+","+str(l))

    ng_set = set(ng_list)
    for i in range(H-R+1):
        for j in range(W-C+1):
            if str(i)+","+str(j) not in ng_set:
                print(i,j)
                
else:
    inversed = False
    
    # 行と列を比較し,列の方が多ければ入れ替える
    if R<=C:
        pass
    else:
        T2 = ["" for i in range(W)]
        for t in T:
            for i in range(W):
                T2[i] += t[i]

        P2 = ["" for i in range(C)]
        for p in P:
            for i in range(C):
                P2[i] += p[i]

        H,W = W,H
        R,C = C,R
        T = T2
        P = P2
        inversed = True


    # ローリングハッシュ用の変数
    base_num = len(str_set)+1
    mod = 2**32-1
    num_dict = {i:(base_num**i)%mod for i in range(W+1)}
    str_dict = {s:i+1 for i,s in enumerate(str_set)} # ordの代わりに使う辞書

    # H×Wの文字列をローリングハッシュに
    rolling_hash = [[0 for _ in range(W+1)] for _ in range(H)]
    for i in range(H):
        hash_value = 0
        for j in range(W):
            num = num_dict[j]
            hash_value = str_dict[T[i][j]]
            rolling_hash[i][j+1] = (rolling_hash[i][j]*base_num + hash_value) % mod

    # R×Cの文字列ををハッシュ化
    hash_value0 = [0 for i in range(R)]
    for i in range(R):
        for j in range(C):
            hash_value0[i] += (str_dict[P[i][j]] * num_dict[C-j-1])

        hash_value0[i] = hash_value0[i] % mod

    # 各列のどこで一致するかをBM法で調べる
    result = []
    pattern = hash_value0
    for col in range(W-C+1):
        text =[get_value(row,col,col+C) for row in range(H)]
        result += bmSearch(text, pattern,col)

    # 答えを行が小さい順に並び替える
    if inversed:
        result.sort(key=lambda x:x[1])
    else:
        result.sort()

    for ans in result:
        if inversed:
            print(ans[1],ans[0])
        else:
            print(*ans)

4. ALDS1_14_D: Multiple String Matching

与えられた文字列Tの中に文字列P_i(i<=10000)が含まれるかどうかを判定する問題です. 文字列の長さが長い上に判定すべき文字列も多いので,なにか特別なことをしないと通らないのだろうと思っていたのですが, 1000文字区切りの文字列をソートして二分探索するだけで間に合います. 真面目(?)に解きたい場合はSuffix Arrayを用いてSA-IS法というものを使うと高速で文字列の検索が行えるようですが,よく分かりませんでした.

# ALDS1_14_D
from bisect import bisect_left
S = input()
n = int(input())

S_list = [S[i:1000+i] for i in range(len(S))]
S_list.sort() # 二分探索が使えるようにソートする

for _ in range(n):
    s = input() 
    idx = bisect_left(S_list,s) # 二分探索で文字列が含まれるかもしれない位置を調べる
    if idx == len(S_list):
        print(0)
    elif s not in S_list[idx]:
        print(0)
    else:
        print(1)

<- 前へ戻る 【目次に戻る】 次へ進む ->