読者です 読者をやめる 読者になる 読者になる

ヒープソート

はじめに

昔勉強したけどすっかり忘れてたヒープソート
思い出していきましょう。

C言語による最新アルゴリズム事典 (ソフトウェアテクノロジー)

C言語による最新アルゴリズム事典 (ソフトウェアテクノロジー)

ヒープソートは数あるソートアルゴリズムの一つで、 特徴としては以下のことがいえます。

  • 安定ではない(同順位のものの順序関係が保たれるとは限らない)
  • クイックソートより約2倍くらい遅い
  • 最悪の場合でも計算量は  O(n \log{n})
  • 制御用データ(一時配列等)を必要としない

ヒープソートは、
1. ソート対象の配列をヒープ構造に変換(ヒープ構造については後述)
2. ヒープ構造の特性を利用してソートを完了
という2段階のソート手法です。

ヒープ構造

ヒープソートで言うヒープ構造とは、2分木を使った構造で、
「子要素は親要素より常に大きいか等しい(または常に小さいか等しい)」
という制約を持つ木構造です。

f:id:hades-netherworld-service:20160821171152p:plain

上の図で言う添字1の接点が($ a_1 $とする)、5~9の点が($ a_5 ~ a_9 $とする)と呼ばれます。 ヒープ構造では、根に近いほど値が大きい(または小さい)という条件を満たします。

ここで、接点$ a_i $の子は$ a_{2i} $と$ a_{2i + 1} $で、逆に親は$ a_{\lfloor i / 2 \rfloor} $と出来ます。 このとき、ヒープ構造の条件としては$ \eqref{eq:heap} $とすることができます。 $$ \begin{eqnarray} \begin{split} a_i \leqq a_{\lfloor i / 2 \rfloor} (i = 2, 3, 4, \dots, n) \\ (または、a_i \geqq a_{\lfloor i / 2 \rfloor}) \end{split} \label{eq:heap} \end{eqnarray} $$

実際に昇順の(降順の)ヒープ構造に変換するためには、
要素$ a_i $とその子要素 $ a_{2i} $と$ a_{2i + 1} $について、子要素のうち大きい(小さい)方($ a_j $とする)と$ a_i $を比較し、$ a_i $のほうが小さい(大きい)場合、 $ a_i $と$ a_j $を交換します。
そして位置を交換した場合、交換した場所で同様の比較処理をしていき、位置の交換が起こらなくなるか子要素がなくなるまで繰り返します。
この処理を ではない 接点$ a_{\lfloor n / 2 \rfloor}, \dots, a_1 $について行うと、 どの要素も子より大きい(小さい) という関係ができあがり、$ a_1 $が最大値(最小値)となります。

ヒープ構造はその制約上、が最大値(最小値)になるので、その性質を利用してソートを行っていくわけです)

そしてヒープソート

ヒープ構造が得られたらソートしていきます。
まず、ヒープ構造のの要素を配列の最後尾と交換します。
そして、第1要素から第n-1要素目までを対象とした部分木を考えると、第1要素以外はヒープ構造の制約を満たすヒープ構造まで後一歩の状態の木が得られます。
この部分木の第1要素に対してヒープ構造化の処理を行うと、部分木に関するヒープ構造が出来上がります。
つまり、第1要素から第n-1要素までの最大値が部分木の(第1要素)に来ていることがわかります。
あとは処理を部分木の要素数が1になるまで繰り返せば、ソートが完了しています。

コードにしてみる

上記のアルゴリズムをコードにすると、以下のようになります。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import math

def makeHeap(vals, asc=True):
  n = len(vals)
  for i in range(int(n / 2))[::-1]:
    heap(vals, i, n - 1, asc)

def sort(vals, asc=True):
  matrix = vals.copy()
  makeHeap(matrix, asc)
  for n in range(1, len(vals))[::-1]:
    i = 0
    swap(matrix, i, n)
    heap(matrix, i, n, asc)
  return matrix

def heap(vals, start, end, asc=True):
  i = start
  n = end
  x = vals[i]
  while 2 * i + 1 < n:
    j = 2 * i + 1
    if j + 1 < n and ((asc and vals[j] < vals[j + 1]) or (not asc and vals[j] > vals[j + 1])):
      j += 1
    if (asc and x >= vals[j]) or (not asc and x <= vals[j]):
      break
    vals[i] = vals[j]
    i = j
  vals[i] = x

def swap(vals, i, j):
  vals[j] = vals[i] - vals[j]
  vals[i] -= vals[j]
  vals[j] += vals[i]

昇順降順をフラグで切り替えるために、途中のifが若干見づらくなっていますが、動きます。 実際に実行してみると、

a = np.random.randint(0, 100, 10)
print(a)
print(sort(a))
print(sort(a, asc=False))
>>>[61 31 80 91 50 64 47  4 69 11]
>>>[ 4 11 31 47 50 61 64 69 80 91]
>>>[91 80 69 64 61 50 47 31 11  4]

計算量

ヒープソートは、大きく分けて「最初のヒープ構造化」と「木の根を最後尾と交換し、最後尾の要素以外で再度ヒープ構造化」という2つの処理を行っていきます。
後半の処理は、木の根に持ってきた要素に対してヒープ構造化の処理を行いますが、その際、最悪でも木の高さ数分の比較しか起こりませんので、1回のループで高々$ \log{n} $回の計算量となります。
それが$ \log{n - 1}, \log{n - 2}, \dots, \log{1} $と計算量が減っていきますので、 $$ \begin{eqnarray} \begin{split} & \log{n} + \log{n - 1} + \dots + \log{2} + \log{1} \\ &= \bigl(\log{n} + \log{1}\bigr) \frac{n}{2} \\ &= \frac{1}{2} n \log{n} \end{split} \label{eq:order} \end{eqnarray} $$ より、計算量は$ O(n \log{n}) $となります。
また、最初のヒープ構造化に関してですが、この処理に関しては、比較回数は明らかに後半の処理よりも少ないため、計算量を求める上では無視されます。
よって、ヒープソートの計算量は$ O(n \log{n}) $ということができます。

おわりに

上のコードでは、ソート元のデータを破壊しないように関数内でコピーしているので、ここがオーバーヘッドになりますが、 それでも10万個のデータソートでも数秒で終わります。
(後で破壊的な関数にして試してみましたが、あまり変化が見られませんでしたw)

プログラムで書く際、配列の先頭は0スタートなので、子要素のindexの計算を間違えないようにご注意くださいmm