Processingで正規分布の(手動)最尤推定ツール

はじめに

Processing Advent Calendar 2014 17日目の記事です。

最近は便利な統計解析ツールがたくさん出てきているので、最尤推定によってパラメータを求めるのはとても手軽にできるようになりましたが、確率分布の形をグリグリ動かしながら対数尤度が変化する様を見ることができたら面白いんじゃないかと思い、こんなツールを作ってみました。

(手動)最尤推定ツール

以下のデータが観測されたとします。

1.364 0.235 -0.846 -0.285 -1.646

lで平均(mu)が+0.01
jで平均(mu)が-0.01
k標準偏差(sigma)が+0.01
i標準偏差(sigma)が-0.01
正規分布のパラメータが変化するので、対数尤度(Log likelihood)が一番大きくなるポイントを探ってみましょう。
(グラフ上をクリックしてからキーボード操作してみて下さい)

コード

コードは以下のようになります。

ArrayList myPoints;
float[] myData = {1.364, 0.235, -0.846, -0.285, -1.646};

float mu = 0;
float sigma = 1;
float margin = 20;

void setup(){
    size(540, 350);
    smooth();
    noStroke();

    myPoints = new ArrayList();

    for (int i=0; i<myData.length;i++){
        myPoint mp = new myPoint(myData[i]);
        myPoints.add(mp);
    }
}

void draw(){
    background(255);
    stroke(0);
    noFill();
    rect(margin, margin, width - margin * 2, height - margin * 2);

    textSize(12);
    for(int i=-5; i<=5; i++){
        float xl = map(i, -5, 5, margin, width - margin);
        float yl = height - 5;
        text(str(i), xl, yl);
    }

    text("0", 5, height - margin);
    text("0.5", 1, height / 2);
    text("1", 5, margin);

    for(int i=0; i<width*10; i++){
        float mi1 = map(i, 0, width*10, -5, 5);
        float l = (1 / (sqrt(2 * PI) * sigma)) * exp(-1/(2 * pow(sigma, 2)) * pow(mi1 - mu, 2));
        float ml = map(l, 0, 1, height - margin , 0 + margin);
        float mi2 = map(mi1, -5, 5, 0 + margin, width - margin);
        noStroke();
        fill(100);
        ellipse(mi2, ml, 1, 1);
    }

    float lh = 0;
    for(int i=0; i<myPoints.size();i++){
        myPoint mp = (myPoint) myPoints.get(i);
        mp.update();
        lh += mp.lh;
    }

    fill(100);
    textSize(15);
    text("mu = " +  nf(mu, 1, 2), 420, 50);
    text("sigma = " +  nf(sigma, 1, 2), 400, 70);
    text("Log likelihood = " +  nf(lh, 1, 2), 350, 90);
}

class myPoint{
    float xPos, yPos, lh;

    myPoint(float x){
        xPos = x;
    }

    void update(){
        yPos = (1 / (sqrt(2 * PI) * sigma)) * exp(-1/(2 * pow(sigma, 2)) * pow(xPos - mu, 2));
        float mx = map(xPos, -5, 5, 0 + margin, width - margin);
        float my = map(yPos, 0, 1, height - margin , 0 + margin);
        stroke(80, 180, 90);
        line(mx, my, mx, height - margin);
        noStroke();
        fill(80, 180, 90);
        ellipse(mx, my, 4, 4);
        ellipse(mx, height - margin, 6, 6);
        lh = log(yPos);
    }
}

void keyPressed(){
    if(key=='i'){
        sigma -= 0.01;
        if(sigma<=0){
            sigma = 0;
        }
    }else if(key=='k'){
        sigma += 0.01;
    }else if(key=='j'){
        mu -= 0.01;
    }else if(key=='l'){
        mu += 0.01;
    }
}