[プログラミング]分割数

数論には「分割数」という概念があって、与えられた任意の自然数NをN以下の自然数の和で表現したものを言うんだそうだ。
例えば、4の分割数は下記の5種類となる。

4 = 4 
  = 3 + 1 
  = 2 + 2 
  = 2 + 1 + 1 
  = 1 + 1 + 1 + 1

ここで自然数nを分割する関数をp(n)とした場合、突っ込みを承知で我流の記法を用いて書くと、次のような再帰関数として定義できる。

p(n) = { n,
         concat(n - 1, p(1)),
         concat(n - 2, p(2)),
         concat(n - 3, p(3)),
         ...
         concat(2, p(n - 2)),
         concat(1, p(n - 1)) }

ちなみにconcat(a, S)は集合Sに要素aを追加した集合を返す関数。
ちなみにconcat(a, S)は自然数の並びSの先頭に要素aを追加した並びを返す関数。ただし、Sの任意の先頭の要素bが a < b の場合、空集合を返す。(上手い説明の仕方が分からん...)このp(n)、Schemeで書くとこんな感じになる。

(use gauche.collection)
(use srfi-1)

(define (partition-f n)
  (if (<= n 0)
      '()
      (cons (list n)
	    (fold (lambda (k s1)
		    (append (fold (lambda (ls s2)
				    (if (>= k (car ls))
					(cons (cons k ls) s2)
					s2))
				  '()
				  (partition-f (- n k)))
			    s1))
		  '()
		  (iota (- n 1) 1)))))

見ての通り、p(n)の結果はp(1), p(2) ... p(n - 1)によって決まる。なので、正直に実行すると、Nが大きくなるに従って同じ計算を何回も繰り返すことになり非常に遅い。そこで、計算結果をメモ化すれば計算量を大幅に削減することができるんじゃないかと思い実験してみた。次が上のpartition-fを最適化したfast-partition-f。

(define (memoize proc)
  (let ((cache (make-hash-table 'equal?)))
    (lambda (arg . rest)
      (let ((args (cons arg rest)))
	(let ((computed-result (hash-table-get cache args #f)))
	  (or computed-result
	      (let ((result (apply proc args)))
		(hash-table-put! cache args result)
		result)))))))

(define fast-partition-f
  (memoize
   (lambda (n)
     (if (<= n 0)
	 '()
	 (cons (list n)
	       (fold (lambda (k s1)
		       (append (fold (lambda (ls s2)
				       (if (>= k (car ls))
					   (cons (cons k ls) s2)
					   s2))
				     '()
				     (fast-partition-f (- n k)))
			       s1))
		     '()
		     (iota (- n 1) 1)))))))

(define (main args)
  (let ((n (if (pair? (cdr args))
	       (string->number (cadr args))
	       10)))
    (begin
      (time (partition-f n))
      (time (fast-partition-f n)))))

N < 10 だと両者の差はあまりないけど、N >= 10 あたりから計算時間が大きく開いてくる。N = 20 になると結果は一目瞭然。1000倍以上の速度差が出た。ちょっと驚き。

% gosh partition-numbers.scm 20
;(time (partition-f n))
; real  14.433
; user  14.060
; sys    0.210
;(time (fast-partition-f n))
; real   0.011
; user   0.010
; sys    0.000

以下、余談。

fast-partition-f の戻り値の表示がなぜかおかしい。ハッシュテーブルにキャッシュされた値が関係してるっぽいのだけど、はっきりとした原因は分からず。

gosh> (fast-partition-f 3)
CALL fast-partition-f 3
 CALL fast-partition-f 2
  CALL fast-partition-f 1
  RETN fast-partition-f ((1))
 RETN fast-partition-f ((2) (1 1))
 CALL fast-partition-f 1
 RETN fast-partition-f ((1))
RETN fast-partition-f ((3) (2 1) (1 1 1))
((3) (2 . #0=(1)) (1 1 . #0#))

memoize に渡す関数って、再帰呼び出し部分が違うだけで元の関数(partition-f)とほとんど変わらないんだけど重複を回避する方法ないかなあ。もしかしてY-Combinatorって、こういう場合に使うのか?今までそのありがたみがよく分からなかったけど。