茨城エンジニアのPython開発日記

茨城のITベンチャー企業ではたらく2年目エンジニア。Pythonで色々なものを作成中。

AIをARに応用するぞ会 → (第5回) ~スマホでAI作動4~


ブログから記事を見つけたい場合はこちら

ブログ地図 - 茨城エンジニアのPython開発日記


こんにちは松原です。
ウマ娘二期をみて泣きました。トウカイテイオーすき。

さて、先週はダミー画像で識別するところまで行きました。
今週は、実際の画像を識別していきます。


結論から言うと、出来上がった画面がこれ。
f:id:tottorisnow33:20210626103601p:plain

後ろにうつってるのが識別をかけた画像です。
ひしゃげてる。

で、肝心の識別処理のMainActivity.javaの中身が下記です。

package com.example.myapplication_test;

import androidx.appcompat.app.AppCompatActivity;


import android.graphics.BitmapFactory;
import android.graphics.Color;
import android.graphics.drawable.BitmapDrawable;
import android.media.MediaPlayer;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.TextView;
import android.widget.ImageView;
import android.graphics.Bitmap;


import org.tensorflow.lite.Interpreter;

import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;



public class MainActivity extends AppCompatActivity {

    private TextView textView;
    private boolean buttonTap = false;
    final int CLASS_NUM = 2;
    final int IMG_WIDTH = 224;
    final int IMG_HEIGHT = 224;
    MediaPlayer mp = null;

    private MappedByteBuffer loadModel(Context context, String modelPath) {
        try {
            AssetFileDescriptor fd = context.getAssets().openFd(modelPath);
            FileInputStream in = new FileInputStream(fd.getFileDescriptor());
            FileChannel fileChannel = in.getChannel();
            return fileChannel.map(FileChannel.MapMode.READ_ONLY,
                    fd.getStartOffset(), fd.getDeclaredLength());
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // ボタンを設定
        Button button1 = findViewById(R.id.button);

        button1.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                /* ここがメイン関数 */

                /*モデル取得*/
//                String ModelPath = "mobilenet_v1_1.0_224_quant.tflite";
                String ModelPath = "model_img_recog_ramen_pan_bn.tflite";
                Context context = MainActivity.this;
                MappedByteBuffer myModel = loadModel(context, ModelPath);


                try (Interpreter interpreter = new Interpreter(myModel)){
                    /*****************************************/
                    /*********** 画像の取得 *******************/
                    /*****************************************/
                    /*画像を配列に格納(128*128*3)*/
                    byte[][][][] byteTestImg = new byte[1][IMG_HEIGHT][IMG_WIDTH][3];

                    InputStream inputStream;
                    Bitmap      bitmap = null;
                    try{
                        inputStream = getAssets().open("test.jpg");
                        bitmap = BitmapFactory.decodeStream(inputStream);
                    }
                    catch (IOException openErr){
                        openErr.printStackTrace();
                    }


                    /* 画像サイズ取得 */
                    int w = bitmap.getWidth();
                    int h = bitmap.getHeight();

                    /* 画像を配列に格納 */
                    for(int i=0;i < w;i++) {
                        for (int j = 0; j < h; j++) {

                            /* 今見てるピクセルが128ピクセルとしては何ピクセル目かを調査 */
                            /* 計算量無駄になってるけど許してほしい */
                            int x = (int)(i*((float)IMG_WIDTH/(float)w));
                            int y = (int)(j*((float)IMG_HEIGHT/(float)h));


                            int color = bitmap.getPixel(i,j);
                            int r = Color.red(color);
                            int b = Color.blue(color);
                            int g = Color.green(color);


                            byteTestImg[0][y][x][0] = (byte) r; // R
                            byteTestImg[0][y][x][1] = (byte) g; // G
                            byteTestImg[0][y][x][2] = (byte) b; // B

                        }
                    }


                    /*****************************************/
                    /*********** 画像の表示 *******************/
                    /*****************************************/
                    ImageView inferenceImageView = (ImageView) findViewById (R.id.image_view_01);

                    /* pixelごとのRGBに値を格納 */
                    int[] pixels = new int [IMG_WIDTH * IMG_HEIGHT];
                    for (int i = 0; i < IMG_WIDTH; i++) {
                        for (int j = 0; j < IMG_HEIGHT; j++) {
                            int c = j + i * IMG_WIDTH;
                            int red   = (int)byteTestImg[0][i][j][0];
                            int green = (int)byteTestImg[0][i][j][1];
                            int blue  = (int)byteTestImg[0][i][j][2];;
                            pixels [c] = Color.argb (255, red, green, blue);
                        }
                    }

                    /* bitmap配列にrgb値を格納 */
                    Bitmap bitmapImage = Bitmap.createBitmap (IMG_WIDTH, IMG_HEIGHT, Bitmap.Config.ARGB_8888);
                    bitmapImage.setPixels (pixels, 0, IMG_WIDTH, 0, 0, IMG_WIDTH, IMG_HEIGHT);

                    /* 画像の表示 */
                    inferenceImageView.setImageBitmap (bitmapImage);

                    /*****************************************/
                    /*********** 推論の実施 *******************/
                    /*****************************************/
                    /* [0,255]の画像を[0.0, 1.0]に正規化 */
                    float[][][][] testImg = new float[1][IMG_HEIGHT][IMG_WIDTH][3];
                    for (int i = 0; i < IMG_WIDTH; i++) {
                        for (int j = 0; j < IMG_HEIGHT; j++) {
                            for (int ch = 0; ch < 3; ch++) {
//                                testImg[0][i][j][ch] = (float) 0.1;
                                testImg[0][i][j][ch] = (float)byteTestImg[0][i][j][ch] / 255.0f;
                            }
                        }
                    }

                    /*推論実施*/
                    float[][] output = new float[1][CLASS_NUM];
                    interpreter.run(testImg, output);


                    /*****************************************/
                    /*********** 結果の出力 *******************/
                    /*****************************************/
                    /* 結果の描画 */
                    String cate1 = "PanCake: " + String.valueOf((int)(output[0][0]*100)) + "%";
                    String cate2 = "Ramen: " + String.valueOf((int)(output[0][1]*100)) + "%";
                    ((TextView) findViewById(R.id.category1)).setText(cate1);
                    ((TextView) findViewById(R.id.category2)).setText(cate2);

                }
                catch (IllegalArgumentException e){
                    System.out.println("IllegalArgumentException!!!!!!!!!!!!!!!!!!!!!!!");
                    System.out.println(e);
                }

            }

        });
    }
}

使っている識別器は下の記事で作ったやつ。
論文読んでAIつくるぞ会(第9.5回) ~ResNetを作ってみた~ - 茨城エンジニアのPython開発日記
今回識別実施してるのは学習に使った画像なので、識別器も自信満々にラーメンと言っています。



上記ニューラルネットワークは入力が224×224なので、今回の画像も224×224にリサイズしています。
リサイズ、かなり無理やりやってるけど意外とうまくいった。

という訳でひとまず、スマホ上で画像識別を行うのは問題なくできました。
うれしい。




次回は「商品化できそうな見栄えのスマホアプリ」を目指して要件定義してみます。
これをそのままARにしていく予定。

それではまた次回。