Flutter, TensorFlow Lite e il riconoscimento della scrittura - Parte 2/2

Nella prima parte di questo articolo abbiamo introdotto gli elementi che ci servivano per scrivere una piccola app di riconoscimento cifre. Ora entriamo nel concreto e vedremo l'implementazione di un modello TensorFlow Lite in un progetto Flutter che permette di riconoscere le cifre da 0 a 9 tracciate a mano su un canvas.

Il progetto base 

Creiamo un nuovo progetto e sostituiamo il codice standard nel file main.dart con questo.


import 'package:flutter/material.dart';
import './screens/main_screen.dart';

void main() {
  runApp(const MyApp());
}

class MyApp extends StatelessWidget {
  const MyApp({super.key});

  @override
  Widget build(BuildContext context) {
    return const MaterialApp(
      debugShowCheckedModeBanner: false,
      home: MainScreen(),
    );
  }
}


Creiamo quindi una cartella screens e in questa un nuovo file main_screen.dart con un widget stateful che permette di mostrare sullo schermo oltre alla solita barra del titolo, un canvas su cui è possibile disegnare e un pulsante per cancellare quanto disegnato:


import 'package:flutter/material.dart';
import 'painter_screen.dart';

class MainScreen extends StatefulWidget {
  const MainScreen({super.key});

  @override
  State<MainScreen> createState() => _MainScreenState();
}

class _MainScreenState extends State<MainScreen> {
  final List<Offset?> _points = [];
  static double _canvasSize = 280;
  static double _borderSize = 2;

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Column(
          children: const [
            Text('Digit AI Recognizer'),
            Text(
              '(using MNIST database of handwritten digits)',
              style: TextStyle(
                fontWeight: FontWeight.bold,
                fontSize: 12,
              ),
            ),
          ],
        ),
        centerTitle: true,
      ),
      body: SingleChildScrollView(
        child: Padding(
          padding: const EdgeInsets.all(8.0),
          child: Column(
            children: [
              Row(
                mainAxisAlignment: MainAxisAlignment.start,
                crossAxisAlignment: CrossAxisAlignment.start,
                children: [
                  _tapWidget(),
                  const SizedBox(
                    width: 8,
                  ),
                  ElevatedButton(
                      onPressed: () {
                        _points.clear();
                        setState(() {});
                      },
                      child: const Text("Clear"))
                ],
              ),
            ],
          ),
        ),
      ),
    );
  }

  // qui la definizione del widget  _tapWidget()  

}


Il widget che si occupa di realizzare il canvas su cui disegnare può essere realizzato con questo codice, utilizzando un  GestureDetector inserito in un contenitore più grande appena quanto necessario per contenere anche i bordi:


  Widget _tapWidget() {
    return Container(
      width: _canvasSize + _borderSize * 2,
      height: _canvasSize + _borderSize * 2,
      decoration: BoxDecoration(
        borderRadius: BorderRadius.circular(5),
        border: Border.all(
          color: Colors.black,
          width: _borderSize,
        ),
      ),
      child: GestureDetector(
        onPanUpdate: (DragUpdateDetails details) {
          Offset pos = details.localPosition;
          if (pos.dx.clamp(0, _canvasSize) == pos.dx &&
              pos.dy.clamp(0, _canvasSize) == pos.dy) {
            _points.add(pos);
            setState(() {});
          }
        },
        onPanEnd: (DragEndDetails details) {
          _points.add(null);
        },
        child: CustomPaint(
          painter: Painter(_points),
        ),
      ),
    );
  }


Il widget GestureDetector con i metodi onPanUpdate e onPanEnd permette di riconoscere quando si inizia a disegnare e quando si termina, rilevando le coordinate dei punti interessati. Il codice proposto non fa altro che verificare se il punto è all'interno del riquadro e in caso affermativo aggiunge le coordinate ad una lista di punti _points. Quando il disegno termina, aggiunge un valore null per indicare la fine della linea. Per cancellare il disegno sarà sufficiente svuotare la lista dei punti ed aggiornare lo stato dell'applicazione.

E' quindi ora sufficiente definire un widget personalizzato Painter che si occupi di unire i punti per tracciare il disegno. Nella stessa cartella creiamo un nuovo file painter_screen.dart con questo codice:


import 'package:flutter/material.dart';

class Painter extends CustomPainter {
  final List<Offset?> points;
  static double _strokeWidth = 10;

  Painter(this.points);

  final Paint _paint = Paint()
    ..strokeCap = StrokeCap.round
    ..color = Colors.red
    ..strokeWidth = _strokeWidth;

  @override
  void paint(Canvas canvas, Size size) {
    for (int i = 0; i < points.length - 1; i++) {
      if (points[i] != null && points[i + 1] != null) {
        canvas.drawLine(points[i]!, points[i + 1]!, _paint);
      }
    }
  }

  @override
  bool shouldRepaint(CustomPainter oldDelegate) {
    return true;
  }
}

Questo widget, unisce ogni punto con quello successivo, disegnando quindi una polilinea, fino a quando non trova il terminatore null.

Salvo errori di digitazione, se facciamo girare la nostra app, otteniamo qualcosa del genere:


su cui è possibile disegnare a mano libera nel riquadro e cancellarne il contenuto con il pulsante Clear

Il package e il modello  

E' arrivato il momento di aggiungere un package che semplifica l'utilizzo di TensorFlow Lite. Dal terminale digitiamo il seguente comando:


 flutter pub add flutter_tflite

Creiamo la cartella assets e copiamoci dentro il due file associali al modello MINST introdotto nella prima parte di questo articolo: il modello vero e proprio mnist.tflite e il file delle etichette mnist.txt (li potete trovare nella mia repo GitHub qui). Modifichiamo il file pubspec.yaml per rendere disponibili gli asset appena scaricati nel progetto:


  # To add assets to your application, add an assets section, like this:
  assets:
    - assets/

Il riconoscitore   

Creiamo quindi una cartella services e in questa un nuovo file recognizer.dart in cui, dopo aver importato la libreria scriveremo la logica di gestione del modello MINST: una funzione asincrona che carica il modello (quindi potrebbe impiegare un po' di tempo), un metodo che rilascia il modello quando non occorre più e una funzione (anch'essa asincrona) che applica il modello ad una lista di bytes e ne restituisce le previsioni (lo vedremo meglio in seguito): 


import 'package:flutter/material.dart';
import 'package:flutter_tflite/flutter_tflite.dart';
import 'dart:typed_data';
import 'dart:ui';

class Recognizer {
  Future loadModel() async {
    Tflite.close();

    return Tflite.loadModel(
        model: "assets/mnist.tflite", labels: "assets/mnist.txt");
  }

  dispose() {
    Tflite.close();
  }

  Future _predict(Uint8List bytes) async {
    return Tflite.runModelOnBinary(binary: bytes);
  }

}

Come anticipato nella prima parte, la lista di dati da sottoporre al modello altro non sono che i valori normalizzati dei pixel di una versione ridotta dell'immagine disegnata (ridotta alle dimensioni di 28x28 pixel con linee bianche su fondo nero, come i campioni su cui il modello è stato addestrato). Definiamo quindi una funzione che a  partire dalla lista di punti, nei ricavi un'immagine nella scala 28 a 280 a colori invertiti:


  Picture _pointsToPicture(List<Offset?> points) {
    final recorder = PictureRecorder();
    final canvas = Canvas(
        recorder,
        Rect.fromPoints(
          const Offset(0, 0),
          const Offset(280, 280),
        ))
      ..scale(28 / 280);

    canvas.drawRect(
        const Rect.fromLTWH(0, 0, 280, 280), Paint()..color = Colors.black);

    for (int i = 0; i < points.length - 1; i++) {
      if (points[i] != null && points[i + 1] != null) {
        canvas.drawLine(
            points[i]!,
            points[i + 1]!,
            Paint()
              ..strokeCap = StrokeCap.round
              ..color = Colors.white
              ..strokeWidth = 2);
      }
    }
    return recorder.endRecording();
  }}

Aggiungiamo anche una funzione che, lavorando sull'immagine ridotta, permetta di linearizzarla (da una griglia 2D, passeremo ad una lista di bytes) normalizzando il colore del pixel in un range [0..255]:


  Future<Uint8List> _imageToByteListUint8(Picture pic, int size) async {
    final img = await pic.toImage(size, size);
    final imgBytes = await img.toByteData();
    final resultBytes = Float32List(size * size);
    final buffer = Float32List.view(resultBytes.buffer);

    int index = 0;

    for (int i = 0; i < imgBytes!.lengthInBytes; i += 4) {
      final r = imgBytes.getUint8(i);
      final g = imgBytes.getUint8(i + 1);
      final b = imgBytes.getUint8(i + 2);
      buffer[index++] = (r + g + b) / 3.0 / 255.0;
    }

    return resultBytes.buffer.asUint8List();
  }

Fatto questo, possiamo quindi scrivere una funzione che elabora la lista di punti, la converte in una immagine 28x28 e ricavi la lista di bytes su cui applicare il modello:


  Future recognize(List<Offset?> points) async {
    final Picture picture = _pointsToPicture(points);

    Uint8List bytes =
        await _imageToByteListUint8(picture, 28);
    return _predict(bytes);
  }

Una classe per i risultati   

Creiamo quindi una cartella models e in questa un nuovo file classification.dart in cui definire il modello dei dati di risposta, caratterizzati da un'etichetta, da un valore di confidenza (probabilità) e da un indice :


class Classification {
  final double confidence;
  final int index;
  final String label;

  Classification(
      {required this.confidence, required this.index, required this.label});

  factory Classification.fromJson(Map<dynamic, dynamic> json) {
    return Classification(
      confidence: json['confidence'],
      index: json['index'],
      label: json['label'],
    );
  }
}

Abbiamo aggiunto anche un metodo .fromJson per semplificare la gestione dei valori tornati dal modello MINT.

Utilizziamo il modello    

Torniamo ora sul file main_screen.dart e aggiungiamo i riferimenti agli altri componenti creati, importando quanto necessario


import "../services/recognizer.dart";
import "../models/classification.dart";


aggiungendo il metodo di inizializzazione e di rilascio


class _MainScreenState extends State<MainScreen> {
  final List<Offset?> _points = [];
  List<Classification> _classification = [];
  final _recognizer = Recognizer();
  static double _canvasSize = 280;
  static double _borderSize = 2;

  @override
  void initState() {
    super.initState();
    _initModel();
  }

  @override
  void dispose() {
    _recognizer.dispose();
    super.dispose();
  }

  void _initModel() async {
    await _recognizer.loadModel();
  }

  void _recognize() async {
    List<dynamic> pred = await _recognizer.recognize(_points);
    _classification =
        pred.map((json) => Classification.fromJson(json)).toList();
    setState(() {});
  }
....


e chiamando la funzione di riconoscimento quando si smette di disegnare


...
        onPanEnd: (DragEndDetails details) {
          _points.add(null);
          _recognize();
        },
...

Mostriamo il risultati    

Nella cartella screens aggiungiamo il file result_screen.dart con il widget che useremo per mostrare i risultati. Sarà un semplice pannello con le cifre da 0 a 9. Dopo il riconoscimento, ogni cifra avrà un colore e uno spostamento verso l'alto in base alla classificazione ottenuta:


import 'package:flutter/material.dart';
import '../models/classification.dart';

class ResultWidget extends StatelessWidget {
  final List<Classification> classifications;

  const ResultWidget({required this.classifications, super.key});

  @override
  Widget build(BuildContext context) {
    var values = _getClassificationValues(classifications);

    return Column(
      children: [
        const Text("Classification",
            style: TextStyle(
              fontWeight: FontWeight.bold,
              fontSize: 16,
            )),
        Container(
          decoration: BoxDecoration(
              color: Colors.orange[200],
              borderRadius: BorderRadius.circular(10)),
          child: Column(
            children: [
              Row(
                mainAxisAlignment: MainAxisAlignment.spaceAround,
                children: [
                  for (var i = 0; i < 10; i++)
                    _classificationWidget(i, values[i])
                ],
              ),
            ],
          ),
        ),
      ],
    );
  }

  Widget _classificationWidget(
      int label, Classification? currentClassification) {
    Color classificationColor = _getClassificationColor(currentClassification);
    double classificationConfidence =
        _getClassificationConfidence(currentClassification);

    return Column(
      children: [
        Stack(
          alignment: AlignmentDirectional.bottomEnd,
          children: [
            const SizedBox(
              height: 180,
              width: 5,
            ),
            Column(
              children: [
                Text(
                  '$label',
                  style: TextStyle(
                    fontSize: 50,
                    fontWeight: FontWeight.bold,
                    color: classificationColor,
                  ),
                ),
                Container(
                  height: classificationConfidence,
                  width: 5,
                  decoration: BoxDecoration(
                    borderRadius: BorderRadius.circular(2),
                    color: classificationColor,
                  ),
                ),
                Padding(
                  padding: const EdgeInsets.only(bottom: 2),
                  child: Text(
                    (classificationConfidence == 0)
                        ? ""
                        : "${classificationConfidence.toStringAsFixed(0)}%",
                    style: const TextStyle(
                      fontSize: 12,
                    ),
                  ),
                ),
              ],
            ),
          ],
        ),
      ],
    );
  }

  Color _getClassificationColor(Classification? classification) {
    return classification == null
        ? Colors.black
        : Colors.blue.withOpacity(classification.confidence);
  }

  double _getClassificationConfidence(Classification? classification) {
    return (classification == null) ? 0 : classification.confidence * 100;
  }

  List<dynamic> _getClassificationValues(
      List<Classification>? classifications) {
    List<dynamic> data = List.generate(10, (_) => null);

    if (classifications != null) {
      for (var classification in classifications) {
        data[classification.index] = classification;
      }
    }

    return data;
  }
}

e utilizziamo questo widget nel file main_screen.dart:

...
                      child: const Text("Clear"))
                ],
              ),
              const SizedBox(
                height: 10,
              ),
              ResultWidget(
                classifications: _classification,
              ),
...

ottenendo questo:

Conclusioni    

Il progetto è terminato e la nostra applicazione di esempio è completa. Come per gli articoli precedenti potete trovare questo codice tra le mie repo GitHub (anche con qualche piccola funzione aggiuntiva, come una piccola miniatura visibile sul pannello di disegno): https://github.com/luigimicco/flutter_ai_digit.

Commenti