TensorFlowの手書き数字認識チュートリアルからざっくりディープラーニングを勉強してみました

手書き数字画像の認識を行うTensorFlowのチュートリアルであるMNIST For ML Beginners[日本語訳]およびDeep MNIST for Experts[日本語訳]を通してディープラーニングについて勉強したので理解した内容をまとめてみました。的な話です。

f:id:nihma:20160702155302p:plain

TensorFlowとディープラーニングの概要

TensorFlow

TensorFlowとはGoogleが公開しているディープラーニングに対応した機械学習のライブラリです。Tensor(多次元配列/行列)のFlow(計算処理)をグラフ構造で定義し、それを元に演算を行います。演算処理が抽象化されているので演算デバイス変更などの対応が比較的容易と思われます。

f:id:nihma:20160702155622p:plain

実装イメージ

TensorFlowでy=x2+bにx=2,b=3を指定し計算する例です。変数はplaceholderとして定義しSessionに演算を指示するタイミングでfeed_dictで実際の値を指定します。

f:id:nihma:20160702155756p:plain

ディープラーニング

ディープラーニングとは、中間層を多層化して深く(Deep)したニューラルネットワークの機械学習(Learning)の事です。中間層が多いほど認識の精度が上がりますが学習が困難になります。学習の手法が考案され、ベンチマークテストで高い性能を示した事から注目されるようになりました。

f:id:nihma:20160702155924p:plain

チュートリアル1:MNIST For ML Beginners

まずはMNIST For ML Beginners[日本語訳]です。強引にまとめるとやりたいのことは次な感じと思います。

f:id:nihma:20160702160524p:plain

なんのこっちゃ??という感じです。

前提知識のキャッチアップ

なので、まずは前提知識のキャッチアップです。(ざっくりイメージだけでも)

f:id:nihma:20160702161132p:plain

単純パーセプトロン

単純パーセプトロンとは、入力層と出力層だけの順伝播型ニューラルネットワークの事です。中間層が無い(深く無い)ネットワークですので、このチュートリアルディープラーニングではありません。(ディープラーニングは次のチュートリアルにお預けです。)

f:id:nihma:20160702161738p:plain

単純パーセプトロンの出力層の値(y)は、結合している入力層の値(x)と重み(W)を掛けた値の総和にバイアス/閾値(b)を加算し活性化関数(f)を適用する事で求めます。

f:id:nihma:20160702161925p:plain

MNIST

MNISTとは画像認識アルゴリズムベンチマークに使われる手書き数字画像データセットです。このデータセットを認識できるように単純パーセプトロンを教師あり学習します。

f:id:nihma:20160702162105p:plain

MNISTには画像と画像が表す数値の組が含まれます。各画像のサイズは28 x 28 ピクセルで各ピクセル値は 0 (白) ~ 255 (黒) です。60000枚の訓練データと10000枚のテストデータに分かれています。

今回は、入力層の値を「ベクトルに変換した画像(28 x 28 = 784次元)」に、出力層の値を「画像が表す数値次元目のみ1で他が0となる10次元ベクトル」として進めます。

f:id:nihma:20160702162314p:plain

ソフトマックス関数

ソフトマックス関数とは「複数ある事象」のうち「ある事象」が起きる確率を求める関数です。多クラス分類問題の場合、出力層の結果を確率分布にしたいため活性化関数として用いられることが多いです。

f:id:nihma:20160702162530p:plain

ソフトマックス関数は、「ある事象」の場合の数をexp(場合の数)として計算するため、場合の数が大きい事象ほど確率を高く際立たせる事ができます。

出力層の値は、出力層全体を事象全体としてソフトマックス関数で求める事により確率値となります。

f:id:nihma:20160702195941p:plain

交差エントロピー

交差エントロピーとは、確率分布間のエントロピーの距離の事です。単純パーセプトロンが導いた結果と正解との差を求めるための損失関数として用います。損失関数は教師あり学習で用います。

f:id:nihma:20160702162945p:plain

この関数は出力層で予測した確率分布と正解の分布が遠ければ差が大きくなります。

単純パーセプトロンの学習では、この損失関数の結果をペナルティとして小さくなるように重み/バイアスを調整します。

f:id:nihma:20160702163217p:plain

確率的勾配降下法

確率的勾配降下法とは、ある関数の極小値を算出する手法です。単純パーセプトロンの重みやバイアスを調整する教師あり学習のために用います。

f:id:nihma:20160702163335p:plain

勾配降下法は、学習データに対するペナルティの総和(E)が小さくなる方向に重み(W)を更新し徐々に理想のWへと近づけていきます。動かす方向は傾き(ΔE/ΔW)が負となる方向になります。傾きに掛け合わせる学習係数(η)により動かす大きさが決まります。(バイアスも同様)

Eを全ての学習データを対象とせずに一部のデータに限定して計算量を抑えた方式確率的勾配降下法です。

f:id:nihma:20160702163536p:plain

TensorFlowで動かしてみる

前提知識を(なんとなく)キャッチアップしたところでいよいよTensorFlowで動かしてみます!

f:id:nihma:20160702163742p:plain

計算グラフ

まず、単純パーセプトロンをTensorFlowで扱えるように計算グラフで表現します。ソフトマックス関数や交差エントロピー確率的勾配降下法などの計算はTensorFlowが用意している関数で行う事ができます。

f:id:nihma:20160702200116p:plain

学習した結果を評価する部分です。

f:id:nihma:20160702200220p:plain

実装内容と実行結果

TensorFlowで数字画像の認識をします。まずはパーセプトロンの計算グラフを構築し初期化します。

f:id:nihma:20160702164151p:plain

次にMNISTデータセットを教師あり学習します。訓練データの100サンプルを1バッチとして1000バッチ分の学習を行います。

f:id:nihma:20160702164307p:plain

学習したモデルにテストデータを与えて結果を評価します。

f:id:nihma:20160702164448p:plain
f:id:nihma:20160702200405p:plain

精度は約91%となりあまり良くないです。(今の最高水準は99.79%くらい?)

f:id:nihma:20160702164639p:plain

コードの全容

下記のとおりです。

import tensorflow as tf

x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = - tf.reduce_sum(y_ * tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

print sess.run(accuracy,
           feed_dict={x: mnist.test.images, y_: mnist.test.labels})

チュートリアル2:Deep MNIST for Experts

次はDeep MNIST for Experts[日本語訳]です。強引にまとめるとやりたいのことは次な感じと思います。

f:id:nihma:20160702165006p:plain

良くワカリマセン

前提知識のキャッチアップ

なので、まずは前提知識のキャッチアップです。(ざっくりイメージだけでも)

f:id:nihma:20160702165153p:plain

畳み込みニューラルネットワーク

畳込みニューラルネットワークとは、中間層として畳込み層とプーリング層を有する多層の順伝播型ニューラルネットワークの事です。画像認識の分野でよく用いられます。中間層を持った深いネットワークですのでようやくディープラーニングです。

f:id:nihma:20160702170211p:plain

畳み込み層

畳み込み層とは、入力(x)に対して重みフィルタ(W)をスライドさせながら適用(畳み込み)した結果の集まりである特徴マップ(c)を抽出する層です。cはxの局所的な部分を抽象化した特徴量です。

結果に適用する活性化関数にはReLUを使います。

f:id:nihma:20160702170338p:plain

プーリング層

プーリング層とは、畳み込み層の結果である特徴マップ(c)を縮小する層です。局所的な部分の特徴を維持するような縮小を行うことにより位置変更に対する結果の変化を(若干ですが)抑えることができます。

最大値のみを取り出し縮小する「最大プーリング」などがあります。

f:id:nihma:20160702170533p:plain

全結合層

全結合層とは、隣接の層とユニットが全結合した層です。畳み込みニューラルネットワークでは畳み込み層やプーリング層の結果である2次元の特徴マップを1次元に展開します。

活性化関数にはReLUを使います。

f:id:nihma:20160702170648p:plain

Adam

Adamとは、確率的勾配降下法の更新量を調整し学習の収束性能を高めた手法です。

確率的勾配降下法の(傾き)の部分を(傾きの平均値)/(傾きの標準偏差)とすることにより、始めの更新量は大きく学習が速く進み、理想の値に近づくほど更新量が減少し学習を収束させられる性質があります。

f:id:nihma:20160702170856p:plain

誤差逆伝播

誤差逆伝播法とは、多層ニューラルネットワークにおいて各層の勾配を効率的に求める手法です。出力層から入力層に向かって誤差(ペナルティ)を逆伝播させながら求めていきます。

この勾配を用いて確率的勾配降下法やAdamによる重みの調整を行います。

f:id:nihma:20160702171055p:plain

ReLU

ReLUとは、入力が0以下ならば0を出力し入力が0より大きいならば入力と同じ値を出力する非線形関数です。単純で計算量が小さく、微分すると活性状態なら1となるので誤差逆伝播法で活性状態の勾配が消えない性質がある事から活性化関数としてよく用いられます。

f:id:nihma:20160702200546p:plain

ドロップアウト

ドロップアウトとは、確率的勾配降下法等で多層ネットワークのユニットを確率的に選別して学習する手法です。ユニットの選別確率pを決めておき重み更新のたびに対象のユニットをランダムで選出します。

学習対象に過剰適合する過学習の状態を抑止する効果があります。

f:id:nihma:20160702171342p:plain

TensorFlowで動かしてみる

前提知識を(なんとなく)キャッチアップしたところでいよいよTensorFlowで動かしてみます!

f:id:nihma:20160702171451p:plain

計算グラフ

畳み込みニューラルネットワークを計算グラフで表現します。

f:id:nihma:20160702200703p:plain

ReLUやドロップアウト、Adamや誤差逆伝播法などの計算はTensorFlowが用意している関数で行う事ができます。

f:id:nihma:20160702200757p:plain

学習した結果を評価する部分です。(チュートリアル1と同じです)

f:id:nihma:20160702200857p:plain

実装内容と実行結果

まず共通で使う処理を関数にしておきます。重みとバイアスの初期化、畳み込みとプーリングです。

f:id:nihma:20160702171941p:plain

計算グラフを定義していきます。

f:id:nihma:20160702172100p:plain

引き続き、計算グラフを定義していきます。

f:id:nihma:20160702172159p:plain

各変数のサイズの変化は下図のイメージです。

f:id:nihma:20160702201005p:plain

次にMNIST訓練データの50サンプルを1バッチとして20000バッチ分の学習を行います。ドロップアウト率は0.5にしています。

f:id:nihma:20160702172446p:plain

学習したモデルにテストデータで与えて結果を評価します。精度は約99.2%となりまずまずです。(今の最高水準の99.79%には及ばないけどチュートリアル1の約91%に比べたら)

f:id:nihma:20160702172813p:plain

コードの全容

下記のとおりです。

import tensorflow as tf

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

sess = tf.InteractiveSession()

x = tf.placeholder("float", shape=[None, 784])
x_image = tf.reshape(x, [-1,28,28,1])
W_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
W_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
W_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
W_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
y_ = tf.placeholder("float", shape=[None, 10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

sess.run(tf.initialize_all_variables())

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

for i in range(20000):
    batch = mnist.train.next_batch(50)
    if i % 100 == 0:
        feed_dict = {x:batch[0],y_:batch[1],keep_prob:1.0}
        train_accuracy = accuracy.eval(feed_dict=feed_dict)
        print("step %d, training accuracy %g" % (i, train_accuracy))
    train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})

feed_dict={x:mnist.test.images, y_: mnist.test.labels,keep_prob:1.0}
print("test accuracy %g" % accuracy.eval(feed_dict=feed_dict))

良く分からなかったところ

Deep MNIST for Experts[日本語訳]の重みの初期化の説明が良くわかりませんでした。なぜ微量のノイズで初期化するのか?対称性の破れとは何なのか?誰か詳しい人...

f:id:nihma:20160702173342p:plain

参考情報

次の情報を参考にさせていただきました。

情報源 メモ
TensorFlowを算数で理解する y=x2+bの例
What is the class of this image ? MNISTの最高水準の認識率
誤差逆伝播法のノート 勾配消失問題が分かりやすかったです
30分でわかるAdam ざっくり理解できました
初めてのディープラーニング --オープンソース"Caffe"による演習付き とっかかりに良かったです
深層学習 (機械学習プロフェッショナルシリーズ) 数式分かりやすかったです
イラストで学ぶ ディープラーニング (KS情報科学専門書) 深層学習 (機械学習プロフェッショナルシリーズ) で分からなくなった時に参照しました
深層学習 Deep Learning (監修:人工知能学会) 深層学習 (機械学習プロフェッショナルシリーズ) で分からなくなった時に参照しました

まとめ

TensorFlowのMNISTを認識するチュートリアルを通して単純なパーセプトロンと畳み込みニューラルネットワークの概要について調べてディープラーニングに対する理解を少しだけ深める事ができました。

とはいえ、まだまだ数式等の理解が足りないところも多く仕組みを完全に分かるところまでには達する事ができませんでした。もう少し学習が必要に思います。理解が誤っているところも多そうです。

今後は、画像認識だけでなく自然言語処理のためのディープラーニング技術などについても調査して何かに使ってみたいです。

PHPで「1時間以内に解けなければプログラマ失格となってしまう5つの問題」の問題4と5を改めて解いてみました

前回、解いた

1時間以内に解けなければプログラマ失格となってしまう5つの問題が話題に

ですが、さすがにあんまりな内容だったので問題4と5の答えを改めてそれっぽい感じにつくってみました。

主に次のページを参考にしました。
http://qiita.com/tanakh/items/b4069a6d3485ef4278ce
http://techblog.mindpl.co.jp/2014/09/array_combination/

つまりカンニングしました。

問題4

正の整数のリストを与えられたとき、数を並び替えて可能な最大数を返す関数を記述せよ。例えば、[50, 2, 1, 9]が与えられた時、95021が答えとなる(解答例)。

<?php
require_once 'util.php';

function p4(array $xs)
{
    $xs = array_sort($xs, function($a, $b) { return "{$b}{$a}" < "{$a}{$b}"; });
    return implode($xs);
}

echo p4([50, 2, 1, 9]). "\n";
$ php 4.php 
95021

問題5

1,2,…,9の数をこの順序で、”+”、”-“、またはななにもせず結果が100となるあらゆる組合せを出力するプログラムを記述せよ。例えば、1 + 2 + 34 – 5 + 67 – 8 + 9 = 100となる(解答例)

<?php
require_once 'util.php';

function calc($expr)
{
    $f = function(array $acc, $term) {
        if ($term === '+') {
            $acc['x'] = $acc['f']($acc['x'], $acc['y']);
            list($acc['f'], $acc['y']) = ['add', 0];
        }
        else if ($term === '-') {
            $acc['x'] = $acc['f']($acc['x'], $acc['y']);
            list($acc['f'], $acc['y']) = ['sub', 0];
        }
        else if ($term === '=') {
            $acc['x'] = $acc['f']($acc['x'], $acc['y']);
        }
        else if ($term !== '') {
            $acc['y'] = $acc['y'] * 10 + $term;
        }
        return $acc;
    };

    $expr[] = '=';
    $acc = array_reduce($expr, $f, ['x' => 0, 'f' => 'add', 'y' => 0]);
    return $acc['x'];
}

function is_n($n)
{
    return function($expr) use($n) { return calc($expr) === $n; };
}

function p5()
{
    $nums = array_map('wrap_array', range(1, 9));
    $opes = replicate(8, ['+', '-', '']);
    $exprs = array_filter(direct_product(alternate($nums, $opes)), is_n(100));
    return array_map('implode', $exprs);
}

foreach (p5() as $expr) {
    echo $expr. "\n";
}
$ php 5.php 
1+23-4+56+7+8+9
12+3-4+5+67+8+9
1+2+34-5+67-8+9
1+2+3-4+5+6+78+9
123-4-5-6-7+8-9
123+45-67+8-9
1+23-4+5+6+78-9
12-3-4+5-6+7+89
12+3+4+5-6-7+89
123-45-67+89
123+4-5+67-89

util.phpはこんな感じです。

<?php

function array_sort(array $xs, $isAsc)
{
    usort($xs, function($a, $b) use($isAsc) { return $isAsc($a, $b) ? -1 : 1; });
    return $xs;
}

function alternate(array $xs, array $ys)
{
    return $xs === [] ? $ys : array_merge([$xs[0]], alternate($ys, array_slice($xs, 1)));
}

function div_qr($n, $d)
{
    return [intval($n / $d), $n % $d];
}


function replicate($n, $x)
{
    return array_map(function($n) use($x) { return $x; }, range(0, $n - 1));
}

function add($x, $y)
{
    return $x + $y;
}

function sub($x, $y)
{
    return $x - $y;
}

function count_direct_product(array $xss)
{
    $f = function(array $acc, array $xs) {
        $n = count($xs);
        $acc['direct_product'] *= $n;
        $acc['own']++;
        $acc['each'][] = $n;
        return $acc;
    };
    return array_reduce($xss, $f, ['direct_product' => 1, 'own' => 0, 'each' => []]);
}

function direct_product(array $xss)
{
    $count = count_direct_product($xss);

    $direct_product = [];
    for ($i = 0; $i < $count['direct_product']; $i++) {
        $combination = [];
        for ($q = $i, $j = 0; $j < $count['own']; $j++) {
            list($q, $r) = div_qr($q, $count['each'][$j]);
            $combination[] = $xss[$j][$r];
        }
        $direct_product[] = $combination;
    }
    return $direct_product;
}

function wrap_array($n)
{
    return [$n];
}

プログラマになりたいなぁ。

プログラマ失格になりました

プログラマになりたいなぁと思って、

1時間以内に解けなければプログラマ失格となってしまう5つの問題が話題に

解いてみました、PHPで。
トータル1時間半ちょいかかってしまいプログラマ失格になってしまいました。
残念。

解き方が全くスマートで無く、関数名、変数名もむちゃくちゃ。
ちゃんとしたプログラマになりたいなぁ。
向いてないんだろうなぁ。

問題1

forループ、whileループ、および再帰を使用して、リスト内の数字の合計を計算する3つの関数を記述せよ。

<?php

function sum1(array $list)
{
  $length = count($list);
  $sum = 0;
  for ($i = 0; $i < $length; $i++) {
      $sum += $list[$i];
  }
  return $sum;
}

function sum2(array $list)
{
  $length = count($list);
  $sum = 0;
  $i = 0;
  while ($i < $length) {
      $sum += $list[$i];
      $i++;
  }
  return $sum;
}

function sum3($sum, array $list)
{
    if ($list == []) {
        return $sum;
    }
    else {
        return sum3($sum + $list[0], array_slice($list, 1));
    }
}

echo sum1([1,2,3,4,5]). "\n";
echo sum2([1,2,3,4,5]). "\n";
echo sum3(0, [1,2,3,4,5]). "\n";
$ php 1.php 
15
15
15

問題2

交互に要素を取ることで、2つのリストを結合する関数を記述せよ。例えば [a, b, c]と[1, 2, 3]という2つのリストを与えると、関数は [a, 1, b, 2, c, 3]を返す。

<?php

// $list1の長さ分だけやる前提
function hoge(array $list1, array $list2)
{
    $r = [];
    $length1 = count($list1);
    for ($i = 0; $i < $length1; $i++) {
        $r[] = $list1[$i];
        $r[] = $list2[$i];
    }
    return $r;
}

print_r(hoge(['a', 'b', 'c'], [1,2,3]));
$ php 2.php 
Array
(
    [0] => a
    [1] => 1
    [2] => b
    [3] => 2
    [4] => c
    [5] => 3
)

問題3

最初の100個のフィボナッチ数のリストを計算する関数を記述せよ。定義では、フィボナッチ数列の最初の2つの数字は0と1で、次の数は前の2つの合計となる。例えば最初の10個のフィボナッチ数列は、0, 1, 1, 2, 3, 5, 8, 13, 21, 34となる。

<?php

function fib(array $acc, $n, $m)
{
    if (count($acc) == 100) {
        return $acc;
    }
    else {
        $acc[] = $n + $m;
        return fib($acc, $m, $n + $m);
    }
}

print_r(fib([0, 1], 0, 1));
$ php 3.php 
Array
(
    [0] => 0
    [1] => 1
    [2] => 1
    [3] => 2
    [4] => 3
    [5] => 5
    [6] => 8
    [7] => 13
    [8] => 21
    [9] => 34
    [10] => 55
    [11] => 89
    [12] => 144
    [13] => 233
    [14] => 377
    [15] => 610
    [16] => 987
    [17] => 1597
    [18] => 2584
    [19] => 4181
    [20] => 6765
    [21] => 10946
    [22] => 17711
    [23] => 28657
    [24] => 46368
    [25] => 75025
    [26] => 121393
    [27] => 196418
    [28] => 317811
    [29] => 514229
    [30] => 832040
    [31] => 1346269
    [32] => 2178309
    [33] => 3524578
    [34] => 5702887
    [35] => 9227465
    [36] => 14930352
    [37] => 24157817
    [38] => 39088169
    [39] => 63245986
    [40] => 102334155
    [41] => 165580141
    [42] => 267914296
    [43] => 433494437
    [44] => 701408733
    [45] => 1134903170
    [46] => 1836311903
    [47] => 2971215073
    [48] => 4807526976
    [49] => 7778742049
    [50] => 12586269025
    [51] => 20365011074
    [52] => 32951280099
    [53] => 53316291173
    [54] => 86267571272
    [55] => 139583862445
    [56] => 225851433717
    [57] => 365435296162
    [58] => 591286729879
    [59] => 956722026041
    [60] => 1548008755920
    [61] => 2504730781961
    [62] => 4052739537881
    [63] => 6557470319842
    [64] => 10610209857723
    [65] => 17167680177565
    [66] => 27777890035288
    [67] => 44945570212853
    [68] => 72723460248141
    [69] => 117669030460994
    [70] => 190392490709135
    [71] => 308061521170129
    [72] => 498454011879264
    [73] => 806515533049393
    [74] => 1304969544928657
    [75] => 2111485077978050
    [76] => 3416454622906707
    [77] => 5527939700884757
    [78] => 8944394323791464
    [79] => 14472334024676221
    [80] => 23416728348467685
    [81] => 37889062373143906
    [82] => 61305790721611591
    [83] => 99194853094755497
    [84] => 160500643816367088
    [85] => 259695496911122585
    [86] => 420196140727489673
    [87] => 679891637638612258
    [88] => 1100087778366101931
    [89] => 1779979416004714189
    [90] => 2880067194370816120
    [91] => 4660046610375530309
    [92] => 7540113804746346429
    [93] => 1.2200160415122E+19
    [94] => 1.9740274219868E+19
    [95] => 3.194043463499E+19
    [96] => 5.1680708854858E+19
    [97] => 8.3621143489848E+19
    [98] => 1.3530185234471E+20
    [99] => 2.1892299583456E+20
)

問題4

正の整数のリストを与えられたとき、数を並び替えて可能な最大数を返す関数を記述せよ。例えば、[50, 2, 1, 9]が与えられた時、95021が答えとなる(解答例)。

<?php

function hoge(array $list)
{
    $max = 0;
    foreach ($list as $n) {
        if ($max < $n) {
            $max = $n;
        }
    }
    $ketanum = 0;
    if ($max == 0) {
        $ketanum = 1;
    }
    else {
      while ($max > 0) {
        $max = intval($max / 10);
        $ketanum++;
      }
    }

    $nums = [];
    foreach ($list as $num) {
        $num2 = $num;

    $ketanum2 = 0;
    if ($num2 == 0) {
        $ketanum2 = 1;
    }
    else {
      while ($num2 > 0) {
        $num2 = intval($num2 / 10);
        $ketanum2++;
      }
    }
    $need = $ketanum - $ketanum2;
          $nums[] = [$num * pow(10, $need), $ketanum2, $num];
    }

    usort($nums, function($a, $b) {
        if ($a[0] == $b[0]) {
            return 0;
        }
        return $a[0] > $b[0] ? -1 : 1;
    });

    $r = 0;
    foreach ($nums as $num) {
        $r *= pow(10, $num[1]);
        $r += $num[2];
    }
    return $r;
}

echo hoge([0, 50, 2, 1, 9]). "\n";
echo hoge([50, 2, 1, 9]). "\n";
$ php 4.php 
950210
95021

問題5

1,2,…,9の数をこの順序で、”+”、”-“、またはななにもせず結果が100となるあらゆる組合せを出力するプログラムを記述せよ。例えば、1 + 2 + 34 – 5 + 67 – 8 + 9 = 100となる(解答例)

<?php

function foo($index)
{
    $arr = [];
    $now = 0;
    for ($i = $index; $i < 10; $i++) {
        $now *= 10;
        $now += $i;
        $arr[] = ['num' => $now, 'last' => $i];
    }
    return $arr;
}

function bar($index)
{
    $acc = [];
    $hoge = foo($index);
    foreach ($hoge as $h) {
        $acc[] = ['car' => $h['num'], 'cdr' => bar($h['last'] + 1)];
    }
    return $acc;
}

function flat($car, $cdr)
{
    if ($cdr == []) {
        return [[$car]];
    }
    else {
        $ret = [];
        foreach ($cdr as $x) {
        $hh = flat($x['car'], $x['cdr']);
            foreach ($hh as $h) {
                if ($h !== null) {
                    $ret[] = array_merge([$car], $h);
                }
            }

        }
        return $ret;
    }
}

function calc($s, $x, array $rest) {
    if ($rest == []) {
        return [['s' => $s, 'r' => $x]];
    }
    else {
        $acc = calc("{$s} + {$rest[0]}", $x + $rest[0], array_slice($rest, 1));
        $acc = array_merge($acc, calc("{$s} - {$rest[0]}", $x - $rest[0], array_slice($rest, 1)));
        return $acc;
    }
}

$nums = [];

$aaa = bar(1);
foreach ($aaa as $a) {
    $iii = flat($a['car'],$a['cdr']);
    foreach ($iii as $i) {
        $nums[] = $i;
    }
}

foreach ($nums as $n) {
    $cs = calc($n[0], $n[0], array_slice($n, 1));
    foreach ($cs as $c) {
        if ($c['r'] == 100) {
            echo $c['s']. "\n";
        }
    }
}
$ php 5.php 
1 + 2 + 3 - 4 + 5 + 6 + 78 + 9
1 + 2 + 34 - 5 + 67 - 8 + 9
1 + 23 - 4 + 5 + 6 + 78 - 9
1 + 23 - 4 + 56 + 7 + 8 + 9
12 + 3 + 4 + 5 - 6 - 7 + 89
12 - 3 - 4 + 5 - 6 + 7 + 89
12 + 3 - 4 + 5 + 67 + 8 + 9
123 - 4 - 5 - 6 - 7 + 8 - 9
123 + 4 - 5 + 67 - 89
123 + 45 - 67 + 8 - 9
123 - 45 - 67 + 89

D言語でMVarしだした

前に参加したHaskellによる並列・並行プログラミング読書会にて、

で、やってみようと思いました。C言語はリソース管理とか色々めんどくなったのでD言語にしました。バージョンはv2.066.1です。mallocさん、さようなら。
まずは、MVar作りました。例外とかはガン無視です。

MVarはProducer-ConsumerパターンのChannelの事です。

f:id:nihma:20150223225540p:plain

以下コードです。増補改訂版Java言語で学ぶデザインパターン入門マルチスレッド編の5章の”Producer-Consumer - わたしが作り、あなたが使う”を見ながら実装しました。

// MVar.d
module MVar;

import core.sync.mutex;
import core.sync.condition;

class MVar(T)
{
  private Mutex m;
  private Condition c;
  private T[] value;

  this()
  {
    value = null;
    m = new Mutex;
    c = new Condition(m);
  }

  void put(T v)
  {
    synchronized(m) {
      while (value != null) {
        c.wait();
      }
      value = [v];
      c.notifyAll();
    }
  }

  T take()
  {
    synchronized(m) {
      while (value == null) {
        c.wait();
      }
      auto v = value[0];
      value = null;
      c.notifyAll();
      return v;
    }
  }
}

とりあえず、Haskellによる並列・並行プログラミングの7.3の”簡単なチャネルとしてのMVar:ログサービス”をやってみました。

// Logger.d
import std.stdio;
import std.variant;
import core.thread;
import MVar;

class Logger
{
  struct Message { string message; }
  struct Stop    { MVar.MVar!(bool) m; }
  alias  Command = Algebraic!(Message, Stop);

  private MVar.MVar!(Command) m;

  this()
  {
    m = new MVar.MVar!(Command);

    auto t = new Thread(
    {
      while (true) {
        auto cmd = m.take();
        if (cmd.type == typeid(Message)) {
          auto msg = cmd.get!Message.message;
          writeln(msg);
        }
        else if (cmd.type == typeid(Stop)) {
          auto s = cmd.get!Stop.m;
          writeln("logger: stop");
          s.put(true);
          break;
        }
      }
    });
    t.start(); 
  }

  void message(string msg)
  {
    Message message = { message:msg };
    Command cmd = message;
    m.put(cmd);
  }

  void stop()
  {
    auto s = new MVar.MVar!(bool);
    Stop stop = { m:s };
    Command cmd = stop;
    m.put(cmd);
    s.take();
  }
}

void main()
{
  auto l = new Logger;
  l.message("hello");
  l.message("bye");
  l.stop();
}

動いたー

$ dmd Logger.d MVar.d
$ ./Logger 
hello
bye
logger: stop

Pythonにロジスティック回帰で画像を学習させてみました

今回はロジスティック回帰で画像とラベルの対応関係を教師付き学習させて画像分類の精度を検証して遊んでみました。

ちなみに前回は画像を教師なし学習のk-means法でカテゴリ分けしました。

データセット

17 Category Flower Datasetで公開されている花の画像を使いました。

$ wget http://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz
$ tar xvfz 17flowers.tgz

イメージファイル名は"image_連番.jpg"の形式で品種との対応は下記みたいです。

連番 (連番 - 1) / 80 品種
0001 - 0080 0 Daffodil ラッパズイセン
0081 - 0160 1 Snowdrop スノードロップ
0161 - 0240 2 LilyValley スズラン
0241 - 0320 3 Bluebell ブルーベル
0321 - 0400 4 Crocus クロッカス
0401 - 0480 5 Iris アヤメ
0481 - 0560 6 Tigerlily オニユリ
0561 - 0640 7 Tulip チューリップ
0641 - 0720 8 Fritillary クロユリ
0721 - 0800 9 Sunflower ヒマワリ
0801 - 0880 10 Daisy ヒナギク
0881 - 0960 11 ColtsFoot フキタンポポ
0961 - 1040 12 Dandelion タンポポ
1041 - 1120 13 Cowslip キバナノクリンザクラ
1121 - 1200 14 Buttercup キンポウゲ
1201 - 1280 15 Windflower アネモネ
1281 - 1360 16 Pansy パンジー

検証

特徴量にはsurfのbag-of-Visual Words, haralick, edginess_sobelを使いました。
それぞれの組み合わせで交差検証して正解率を求めて混合行列ROC曲線で可視化してみました。

surf haralick edginess_sobel 正解率
1 × × 59.0%
2 × × 16.9%
3 × × 9.2%
4 × 60.4%
5 × 59.0%
6 × 18.4%
7 60.4%

60.4%は低い。。。

混合行列とROC曲線1: surfのみ

f:id:nihma:20150119164103p:plain

f:id:nihma:20150119164127p:plain

混合行列とROC曲線2: haralickのみ

f:id:nihma:20150119164403p:plain

f:id:nihma:20150119164425p:plain

混合行列とROC曲線3: edginess_sobelのみ

f:id:nihma:20150119164617p:plain

f:id:nihma:20150119164651p:plain

混合行列とROC曲線4: surf/haralick

f:id:nihma:20150119164725p:plain

f:id:nihma:20150119164757p:plain

混合行列とROC曲線5: surf/edginess_sobel

f:id:nihma:20150119164831p:plain

f:id:nihma:20150119164903p:plain

混合行列とROC曲線6: haralick/edginess_sobel

f:id:nihma:20150119164942p:plain

f:id:nihma:20150119165007p:plain

混合行列とROC曲線7: surf/haralick/edginess_sobel

f:id:nihma:20150119165042p:plain

f:id:nihma:20150119165111p:plain

コード

今回のコードです。

import mahotas as mh
import numpy as np
from glob import glob
from mahotas.features import surf
from sklearn.cluster import KMeans
from sklearn.cross_validation import ShuffleSplit
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.metrics import auc
from collections import defaultdict
from matplotlib import pylab
from sklearn.linear_model import LogisticRegression

def save_plot_confusion_matrix(file_name, cms, y_labels):
  cm_avg = np.mean(cms, axis=0)
  cm_norm = cm_avg / np.sum(cm_avg, axis=0)
  pylab.clf()
  fig = pylab.figure(figsize=(11,10))
  ax = fig.add_subplot(111)
  res = ax.imshow(cm_norm, cmap='Blues', interpolation='nearest')
  for i, cas in enumerate(cm_avg):
      for j, c in enumerate(cas):
          n = c
          if n > 0:
              pylab.text(j-.2, i+.2, n, fontsize=12, color='red')
  cb = fig.colorbar(res)
  ax.set_xticks(range(len(y_labels)))
  ax.xaxis.set_ticks_position("bottom")
  ax.set_xticklabels(y_labels, rotation=270)
  ax.set_yticks(range(len(y_labels)))
  ax.set_yticklabels(y_labels)
  pylab.title('Confusion Matrix')
  pylab.xlabel('Predicted class')
  pylab.ylabel('True class')
  pylab.grid(False)
  pylab.savefig(file_name)

def save_plot_roc(file_name, y_labels, roc_scores, tprs, fprs):
  pylab.clf()
  pylab.figure(figsize=(35, 35))
  column = 4
  row = int(len(y_labels) / column) + 1
  for i, y_label in enumerate(y_labels):
    pylab.subplot(row, column, i+1)
    auc_score = np.mean(roc_scores[y_label])
    label = '%s vs rest' % y_label
    
    pylab.grid(True)
    pylab.plot([0, 1], [0, 1], 'k--')
    # 混合行列に合わせて平均にしたいけどとりあえず全てplotしとく
    for j, tpr in enumerate(tprs[y_label]):
      pylab.plot(fprs[y_label][j], tpr)
      pylab.fill_between(fprs[y_label][j], tpr, alpha=0.5)
    pylab.xlim([0.0, 1.0])
    pylab.ylim([0.0, 1.0])
    pylab.xlabel('False Positive Rate')
    pylab.ylabel('True Positive Rate')
    pylab.title('ROC curve (AUC = %0.2f) / %s' %
                (auc_score, label), verticalalignment="bottom")
    pylab.legend(loc="lower right")
  pylab.savefig(file_name, bbox_inches="tight")

def try_model(X, y, y_labels, confusion_matrix_filename, roc_filename):
  cv = ShuffleSplit(n=len(X), n_iter=10, test_size=0.3, indices=True, random_state=0)
  
  cms = []  # 混合行列
  scores = [] # 正解率
  roc_scores = defaultdict(list)
  tprs = defaultdict(list)
  fprs = defaultdict(list)
  
  for train, test in cv:
    X_train, y_train = X[train], y[train]
    X_test,  y_test  = X[test],  y[test]
    clf = LogisticRegression()
    clf.fit(X_train, y_train)
    scores.append(clf.score(X_test, y_test))
    
    y_pred = clf.predict(X_test)
    cms.append(confusion_matrix(y_test, y_pred, y_labels))
    
    proba = clf.predict_proba(X_test)
    for i, label in enumerate(y_labels):
       y_label_test = np.asarray(y_test == label, dtype=int)
       proba_label = proba[:, i]
       
       fpr, tpr, roc_thresholds = roc_curve(y_label_test, proba_label)
       roc_scores[label].append(auc(fpr, tpr))
       tprs[label].append(tpr)
       fprs[label].append(fpr)
  
  summary = (np.mean(scores), np.std(scores))
  print "accuracy mean:%.3f\tstd:%.3f\t" % summary
  
  # 結果の生成
  save_plot_confusion_matrix(confusion_matrix_filename, np.asarray(cms), y_labels)
  save_plot_roc(roc_filename, y_labels, roc_scores, tprs, fprs)

feature_category_num = 512

images = glob('./image_*.jpg')

# 画像ごとの局所特徴量を取り出す
alldescriptors = []
for im in images:
  im = mh.imread(im, as_grey=True)
  im = im.astype(np.uint8)
  alldescriptors.append(surf.surf(im, descriptor_only=True))

# 局所特徴量からVisual Wordsを決める
concatenated = np.concatenate(alldescriptors)
km = KMeans(feature_category_num)
km.fit(concatenated)

# 各画像のbag-of-Visual Wordsを求める
features = []  # 特徴量1: surf
for d in alldescriptors:
  c = km.predict(d)
  features.append(np.array([np.sum(c == ci) for ci in range(feature_category_num)]))

haralicks = [] # 特徴量2: haralick
edginess_sobels = [] # 特徴量3: edginess_sobel
for im in images:
  im = mh.imread(im, as_grey=True)
  im = im.astype(np.uint8)
  edges = mh.sobel(im, just_filter=True)
  edges = edges.ravel()
  edginess_sobels.append([np.sqrt(np.dot(edges, edges)),1])  # 一応1いれとく
  haralicks.append(mh.features.haralick(im).mean(0))

y_labels = ['Daffodil',   'Snowdrop',  'LilyValley', 'Bluebell',   'Crocus', 
            'Iris',       'Tigerlily', 'Tulip',      'Fritillary', 'Sunflower',
            'Daisy',      'ColtsFoot', 'Dandelion',  'Cowslip',    'Buttercup',
            'Windflower', 'Pansy']
y = []
for image in images:
  n = int(image[len('./image_'):-len('.jpg')])
  n = (n - 1) / 80
  y.append(y_labels[n])

y = np.array(y)

# 1. surf
X = np.array(features)
try_model(np.array(X, dtype=int), y, y_labels, "confusion_matrix1.png", "roc1.png")
# accuracy mean:0.590    std:0.013

# 2. haralick
X = np.array(haralicks)
try_model(np.array(X, dtype=int), y, y_labels, "confusion_matrix2.png", "roc2.png")
# accuracy mean:0.169    std:0.014

# 3. edginess_sobel
X = np.array(edginess_sobels)
try_model(np.array(X, dtype=int), y, y_labels, "confusion_matrix3.png", "roc3.png")
# accuracy mean:0.092    std:0.009

# 4. surf, haralick
X = []
for i, _ in enumerate(y):
  X.append(np.concatenate((features[i], haralicks[i])))
try_model(np.array(X, dtype=int), y, y_labels, "confusion_matrix4.png", "roc4.png")
# accuracy mean:0.604    std:0.014

# 5. surf, edginess_sobel
X = []
for i, _ in enumerate(y):
  X.append(np.concatenate((features[i], edginess_sobels[i])))
try_model(np.array(X, dtype=int), y, y_labels, "confusion_matrix5.png", "roc5.png")
# accuracy mean:0.590    std:0.016

# 6. haralick, edginess_sobel
X = []
for i, _ in enumerate(y):
  X.append(np.concatenate((haralicks[i], edginess_sobels[i])))
try_model(np.array(X, dtype=int), y, y_labels, "confusion_matrix6.png", "roc6.png")
# accuracy mean:0.184    std:0.016

# 7, surf, haralick, edginess_sobel
X = []
for i, _ in enumerate(y):
  X.append(np.concatenate((features[i], haralicks[i], edginess_sobels[i])))
try_model(np.array(X, dtype=int), y, y_labels, "confusion_matrix7.png", "roc7.png")
# accuracy mean:0.604    std:0.013

思ったこととか

もうちょい前処理の工夫や特徴量の追加やパラメータのチューニングなどをすれば精度は上げられるかもしれないです。 あとはロジスティック回帰以外の分類器についても比較したりするのも面白そうに思います。 まだ遊ぶ余地はありそうです。

珠玉のアルゴリズムデザイン19章の数独ソルバー読みました

関数プログラミング 珠玉のアルゴリズムデザインの第19章の"単純な数独ソルバー"のアルゴリズムを追いました。 各効率化の戦略を読み解くのが大変だったので後でまた読むときのために概要だけメモを残します。

お題は数独ソルバー

数独を解くアルゴリズムプログラム運算により効率の良いものにしていきます。言語はHaskellです。
数独は次のようなゲームです。

f:id:nihma:20150111152731p:plain

ルール
・空いているマスに1〜9のいずれかの数字を入れる。
・縦・横の各列及び、太線で囲まれた3×3のブロック内に同じ数字が複数入ってはいけない。

数独 - Wikipedia

最初のアルゴリズム

最初のアルゴリズムは次のような愚直な感じの流れです。
(答えの候補が膨大なので処理が終わらないです。)

  1. 空のマス目に入る候補を1〜9とする。
  2. すべての空のマス目の候補の組み合わせで答えの候補を作る。
  3. 答えの候補から数独ルールに合っているものを取り出して答えとする。

Haskellプログラムの実行イメージは次のような感じです。

solve = filter valid . expand . choices

{-
問題データ:空のマス目は0
["004005700",
 "000009400",
 "360000008",
 "720060000",
 "000402000",
 "000080093",
 "400000056",
 "005300000",
 "006100900"]

↓ choices:空のマス目を候補の数値にする
[["123456789","123456789","4","123456789","123456789","5","7","123456789","123456789"],
 ["123456789","123456789","123456789","123456789","123456789","9","4","123456789","123456789"],
 ["3","6","123456789","123456789","123456789","123456789","123456789","123456789","8"],
 ["7","2","123456789","123456789","6","123456789","123456789","123456789","123456789"],
 ["123456789","123456789","123456789","4","123456789","2","123456789","123456789","123456789"],
 ["123456789","123456789","123456789","123456789","8","123456789","123456789","9","3"],
 ["4","123456789","123456789","123456789","123456789","123456789","123456789","5","6"],
 ["123456789","123456789","5","3","123456789","123456789","123456789","123456789","123456789"],
 ["123456789","123456789","6","1","123456789","123456789","9","123456789","123456789"]]

↓ expand:空のマス目の候補すべて組み合わせで答えの候補を作る
[[["114115711"],  <- 1xxx
  ["111119411"],
  ["361111118"],
  ["721161111"],
  ["111412111"],
  ["111181193"],
  ["411111156"],
  ["115311111"],
  ["116111911"]],
 [["214115711"],  <- 2xxx
  ["111119411"],
  ["361111118"],
  ["721161111"],
  ["111412111"],
  ["111181193"],
  ["411111156"],
  ["115311111"],
  ["116111911"]],
 [["314115711"],   <- 3xxx
  ["111119411"],
  ["361111118"],
  ["721161111"],
  ["111412111"],
  ["111181193"],
  ["411111156"],
  ["115311111"],
  ["116111911"]],
  ・・・,
 [["994995799"],
  ["999999499"],
  ["369999998"],
  ["729969999"],
  ["999492999"],
  ["999989993"],
  ["499999956"],
  ["995399999"],
  ["996199999"]]]
 
↓ (filter valid):数独ルールに合っている答えの候補のみ取り出す
答え:1つとは限らない
[["184625739",
  "572839461",
  "369741528",
  "728963145",
  "953412687",
  "641587293",
  "417298356",
  "295376814",
  "836154972"]]
-}

選択肢行列の枝刈り

最初のアルゴリズムで空のマス目に入る候補を1〜9としていたところから、数独ルールに合わない候補を無くします。
例えば空のマス目の数が30、そこに入る候補の数の平均が4になるとすれば探索範囲数は下記のようにしぼられます。(これでも処理時間は相当かかります。)
 42391158275216203514294433201(9の30乗)→1152921504606846976(4の30乗)

Haskellプログラムの実行イメージは次のような感じです。

solve = filter valid . expand . prune . choices

{-
問題データ:空のマス目は0
↓ choices:空のマス目を候補の数値にする

↓ prune:無効な空のマス目候補を排除する
[["1289","189","4","268","123","5","7","1236","129"],
 ["1258","1578","1278","2678","1237","9","4","1236","125"],
 ["3","6","1279","27","1247","147","125","12","8"],
 ["7","2","1389","59","6","13","158","148","145"],
 ["15689","13589","1389","4","13579","2","1568","1678","157"],
 ["156","145","1","57","8","17","1256","9","3"],
 ["4","13789","123789","2789","279","78","1238","5","6"],
 ["1289","1789","5","3","2479","4678","128","12478","1247"],
 ["28","378","6","1","2457","478","9","23478","247"]]

↓ expand:空のマス目候補すべて組み合わせで答えの候補を作る
↓ (filter valid):数独ルールに合っている答えの候補のみ取り出す
答え:1つとは限らない
-}

単一マス拡張

空のマス目1マスずつ逐次的に"答えの候補の作成"と"無効な空のマス目候補の排除"を行う事で探索範囲数の大幅な削減を期待します。 (相当速くなります。)

Haskellプログラムの実行イメージは次のような感じです。

solve = search . choices
search m | not $ safe m = []
         | complete m' = [map (map head) m']
         | otherwise = concat (map search (expand1 m'))
           where m' = prune m

{-
問題データ:空のマス目は0
↓ choices:空のマス目を候補の数値にする

↓ search:答えを探索する(再帰する)
  ↑↓ prune:無効な空のマス目候補を排除する
  ↑↓ expand1:空のマス目候補の少ないマス目を1つだけ選び答えの候補を作る
  [[["1289","189","4","268","123","5","7","1236","129"],
    ["1258","1578","1278","2678","1237","9","4","1236","125"],
    ["3","6","1279",
                      "2",  <- 
                   "1247","147","125","12","8"],
    ["7","2","1389","59","6","13","158","148","145"],
    ["15689","13589","1389","4","13579","2","1568","1678","157"],
    ["156","145","1","57","8","17","1256","9","3"],
    ["4","13789","123789","2789","279","78","1238","5","6"],
    ["1289","1789","5","3","2479","4678","128","12478","1247"],
    ["28","378","6","1","2457","478","9","23478","247"]],
   [["1289","189","4","268","123","5","7","1236","129"],
    ["1258","1578","1278","2678","1237","9","4","1236","125"],
    ["3","6","1279",
                      "7",  <-
                   "1247","147","125","12","8"],
    ["7","2","1389","59","6","13","158","148","145"],
    ["15689","13589","1389","4","13579","2","1568","1678","157"],
    ["156","145","1","57","8","17","1256","9","3"],
    ["4","13789","123789","2789","279","78","1238","5","6"],
    ["1289","1789","5","3","2479","4678","128","12478","1247"],
    ["28","378","6","1","2457","478","9","23478","247"]]]
  ↓ すべてが not $ safe または complete になったら再帰終了
答え:1つとは限らない
-}

感想等

import Data.List ((\\))

-- 単一要素リストであることを検査する関数
single :: [a] -> Bool
single [_] = True
single _   = False
  • あとp133に下記の誤記があります(って細かい)。(←正誤表に反映していただきました)その他は正誤表が参考になります。
    × expand :: Matrix Choicies -> [Grid]
    ○ expand :: Matrix Choices -> [Grid]

モンティ・ホール問題をGoで確かめてみました

少し前に参加した続・わかりやすいパターン認識読書会モンティ・ホール問題について知ったので”司会が正解を知っているか否かによる結果の違い”について、ちょうど触ってみたかったGoで確かめてみました。(Go触りたかっただけ)

モンティ・ホール問題って?

(1) 3つのドア (A, B, C) に(景品、ヤギ、ヤギ)がランダムに入っている。
(2) プレイヤーはドアを1つ選ぶ。
(3) モンティは残りのドアのうち1つを必ず開ける。
(4) モンティの開けるドアは、必ずヤギの入っているドアである。
(5) モンティはプレーヤーにドアを選びなおしてよいと必ず言う。

モンティ・ホール問題 - Wikipedia

(5)で選び直した方が良いのかどうかという問題です。

ここでは、プレイヤーが最初に選んだドアをA、司会のモンティがあけるドア(はずれ)をB、残っている開けられていないドアをCとします。

司会が正解を知っている場合、Aが正解の確率は1/3、Cが正解の確率は2/3となり、選び直した方が良い事になります。 確率はベイズの定理で求まります。

f:id:nihma:20141229004918j:plain

ここで司会が正解を知らなかった場合、確率が変わります。 Aが正解の確率は1/2、Cが正解の確率は1/2でどちらを選んでも同じ確率です。 こちらもベイズの定理で求まります。

f:id:nihma:20141229005058j:plain

P(Bが選ばれる|Bが当たり)とP(Bが選ばれる|Cが当たり)が違います。

確かめてみました

とりあえずGoで10000000回の試行による確率を求める計算をそれぞれ5回ずつ行ってみました。 大数の法則を信じて10000000回もやれば大丈夫と信じました。

コードは次の通りです。
問題の(4)の条件から、Bが必ずはずれであるのでAorCが当たりとなり、司会が正解を知らない場合にBを選んで当たってしまったケースは試行に含まれないです。
(tryA関数は無い方が速いですがgoroutineしてみたかったのでしました。)

package main

import (
  "fmt"
  "math/rand"
  "time"
  "sync"
  "runtime"
)

func main() {
  // CPU全コアを使いたい
  runtime.GOMAXPROCS(runtime.NumCPU())

  // ランダムの初期化
  rand.Seed(time.Now().UnixNano())

  // それぞれ5回ずつprintしてみる
  label_know := map[string] bool {
    "司会が正解を知っている場合": true,
    "司会が正解を知らない場合":   false,
  }
  n := 5
  printP := func(p_a float64, p_c float64) {
    fmt.Printf("P(A) = %.3f, P(C) = %.3f\n", p_a, p_c)
  }
  for label, know_answer := range label_know {
    fmt.Println(label)
    tryNtimes(know_answer, n, printP)
  }
}

/* n回やってみる
 * 引数1:know_answer 司会が当たりを知っているか?
 * 引数2:n やってみる回数
 * 引数3:f 結果を渡す関数
 */
func tryNtimes(know_answer bool, n int,
               f func(float64, float64)) {
  var wg sync.WaitGroup
  for i := 0; i < n; i++ {
    wg.Add(1)
    go func() {
      var p_a float64 = probabilityA(know_answer)
      var p_c float64 = 1 - p_a
      f(p_a, p_c)
      wg.Done()
    }()
  }
  wg.Wait()
}

/* ドアAが当たりの確率
 * 引数:know_answer 司会が当たりを知っているか?
 */
func probabilityA(know_answer bool) float64 {
  try_num := 10000000  // 試行回数(このくらいやれば収束する?)
  a_num := 0 // ドアAが当たりの回数

  ch := tryA(know_answer)
  for i := 0; i < try_num; i++ {
// こっちの方が速いですが->    if isA(know_answer) {
    if <- ch {
      a_num++
    }
  }
  return float64(a_num) / float64(try_num)
}

/* ドアAが正解であるかの試行を生成
 * 引数:know_answer 司会が当たりを知っているか?
 */
func tryA(know_answer bool) chan bool {
  ch := make(chan bool)
  go func() {
    for {
      ch <- isA(know_answer)
    }
  }()
  return ch
}

/* ドアAが正解であるか?
 * 引数:know_answer 司会が当たりを知っているか?
 */
func isA(know_answer bool) bool {
  /* [答えの候補]
   * 0: ドアA <- 最初に選択
   * 1: ドアB <- 司会が選択(はずれ)
   * 2: ドアC <- こちらに変更できる
   */
  // 最初に答えを決める
  answer := rand.Intn(3)  // 0〜2

  if ! know_answer { // 答えを知らない場合
                     // ドアBが当たりになる試行は無効、捨てる
    for answer == 1 {
      answer = rand.Intn(3)
    }
  }
  return answer == 0 // ドアAが当たりか?ドアCがはずれか?
}

結果は理論通りになったようです。

$ go build MontyHall.go
$ ./MontyHall
司会が正解を知っている場合
P(A) = 0.333, P(C) = 0.667
P(A) = 0.334, P(C) = 0.666
P(A) = 0.334, P(C) = 0.666
P(A) = 0.333, P(C) = 0.667
P(A) = 0.333, P(C) = 0.667
司会が正解を知らない場合
P(A) = 0.500, P(C) = 0.500
P(A) = 0.500, P(C) = 0.500
P(A) = 0.500, P(C) = 0.500
P(A) = 0.500, P(C) = 0.500
P(A) = 0.500, P(C) = 0.500

感想

Goを知った2009年頃に受けたネイティブならD言語でええやんという印象から食わず嫌いでしたが意外と面白かったのでgoroutineでもう少し遊んでみたいです。