Haskellでトリボナッチ数列を実装する

背景

Atcoderのこの問題をHaskellでACしたい。

atcoder.jp

問題概要

トリボナッチ数列は、3つ前までの数字を足したもの。

 a_{1} = 0, a_{2} = 0, a_{3} = 1
 a_{n} = a_{n-1} + a_{n-2} + a_{n-3}

この数列の 第n項 mod 10007  (1≦n≦10^{6}) を求める。

愚直に実装

まずは愚直に再帰で実装する。
(問題だと数列が第1項から始まっているが、実装は第0項からにしている)

tribonacchi :: Int -> Int
tribonacchi n
  | n <= 1 = 0
  | n == 2 = 1
  | otherwise = (tribonacchi (n-1) + tribonacchi (n-2) + tribonacchi (n-3)) `mod` 10007

動かしてみるが、案の定めちゃくちゃ遅い。。
このままでは確実にTLEとなってしまうので、改良する必要がある。

Prelude> :set +s
Prelude> tribonacci 30
1468
(39.10 secs, 16,806,765,608 bytes)

リストを利用して関数の呼び出し回数を減らす

愚直な実装では、同じ引数でtribonacci関数が何度も呼び出されてしまう。

tribonacci 5
= tribonacci 2 + tribonacci 3 + tribonacci 4
= 1 + (tribonacci 0 + tribonacci 1 + tribonacci 2) + (tribonacci 1 + tribonacci 2 + tribonacci 3)
= 1 + (0 + 0 + 1) + (0 + 1 + (tribonacci 0 + tribonacci 1 + tribonacci 2))
= 1 + (0 + 0 + 1) + (0 + 1 + (0 + 0 + 1))
= 4

tribonacci関数を何度も呼び出さなくていいように、計算結果をリストにしておき、tribonacci関数を再度計算するコストを減らすことで高速化を図る。

tribonacci :: Int -> Int
tribonacci n
  | n <= 1 = 0
  | n == 2 = 1
  | otherwise = (trib !! (n-1) + trib !! (n-2) + trib !! (n-3)) `mod` 10007

trib :: [Int]
trib = map tribonacci [0..]

だいぶ速くはなったが、 n≦10^{6} の条件をクリアするためにはまだ高速化が必要。

Prelude> tribonacci 100
8082
(0.00 secs, 116,080 bytes)

Prelude> tribonacci 100000
2580
(61.92 secs, 64,960,496 bytes)

zipWithを使う

これ以上は自力ではどうしようもなくなってしまったので、[ Haskell フィボナッチ数列 ] で検索すると、以下のようなコードが見つかった。

fib = 0 : 1 : zipWith (+) fib (tail fib)

このコードでは、フィボナッチ数列の最初の2つの要素に、それをずらした数列2つの要素を足し合わせたリストを連結することで、フィボナッチ数列を生成している。

          [0, 1, 1, 2, 3, ..]
           +  +  +  +  +
          [1, 1, 2, 3, 5, ..]
           ↓  ↓  ↓  ↓  ↓
[0, 1] ++ [1, 2, 3, 5, 8, ..]

これと同じように、トリボナッチ数列の最初の3つの要素に、それをずらした数列3つの要素を足し合わせたリストを連結することで、トリボナッチ数列を生成することができるはず。

             [0, 0, 1, 1, 2, ..]
              +  +  +  +  +
             [0, 1, 1, 2, 4, ..]
              +  +  +  +  +
             [1, 1, 2, 4, 7, ..]
              ↓  ↓  ↓  ↓  ↓
[0, 0, 1] ++ [1, 2, 4, 7, 13, ..]

3つのリストを扱うため、zipWith3 関数を利用して実装する。

tribonacci :: Int -> Int
tribonacci n = trib !! n

trib :: [Int]
trib = 0 : 0 : 1 : zipWith3 (\x y z -> mod (x+y+z) 10007) trib (tail trib) (tail (tail trib))

 10^{6} で試してみる。これなら間に合いそう。

Prelude> tribonacci 1000000
2576
(1.12 secs, 321,735,104 bytes)

ACできたコード

入出力部分を合わせて提出して無事AC。
(この問題では数列が a_{1}から始まることに注意する)

tribonacci :: Int -> Int
tribonacci n = trib !! (n-1)

trib :: [Int]
trib = 0 : 0 : 1 : zipWith3 (\x y z -> mod (x+y+z) 10007) trib (tail trib) (tail (tail trib))

main = do
  n <- readLn
  print $ tribonacci n