Nella prima parte di questo articolo abbiamo introdotto gli elementi che ci servivano per scrivere una piccola app di riconoscimento cifre. Ora entriamo nel concreto e vedremo l'implementazione di un modello TensorFlow Lite in un progetto Flutter che permette di riconoscere le cifre da 0 a 9 tracciate a mano su un canvas.
Il progetto base
Creiamo un nuovo progetto e sostituiamo il codice standard nel file main.dart con questo.
import 'package:flutter/material.dart';import 'painter_screen.dart';
class MainScreen extends StatefulWidget { const MainScreen({super.key});
@override State<MainScreen> createState() => _MainScreenState();}
class _MainScreenState extends State<MainScreen> { final List<Offset?> _points = []; static double _canvasSize = 280; static double _borderSize = 2;
@override Widget build(BuildContext context) { return Scaffold( appBar: AppBar( title: Column( children: const [ Text('Digit AI Recognizer'), Text( '(using MNIST database of handwritten digits)', style: TextStyle( fontWeight: FontWeight.bold, fontSize: 12, ), ), ], ), centerTitle: true, ), body: SingleChildScrollView( child: Padding( padding: const EdgeInsets.all(8.0), child: Column( children: [ Row( mainAxisAlignment: MainAxisAlignment.start, crossAxisAlignment: CrossAxisAlignment.start, children: [ _tapWidget(), const SizedBox( width: 8, ), ElevatedButton( onPressed: () { _points.clear(); setState(() {}); }, child: const Text("Clear")) ], ), ], ), ), ), ); }
// qui la definizione del widget _tapWidget()
}
Il widget che si occupa di realizzare il canvas su cui disegnare può essere realizzato con questo codice, utilizzando un GestureDetector inserito in un contenitore più grande appena quanto necessario per contenere anche i bordi:
Widget _tapWidget() { return Container( width: _canvasSize + _borderSize * 2, height: _canvasSize + _borderSize * 2, decoration: BoxDecoration( borderRadius: BorderRadius.circular(5), border: Border.all( color: Colors.black, width: _borderSize, ), ), child: GestureDetector( onPanUpdate: (DragUpdateDetails details) { Offset pos = details.localPosition; if (pos.dx.clamp(0, _canvasSize) == pos.dx && pos.dy.clamp(0, _canvasSize) == pos.dy) { _points.add(pos); setState(() {}); } }, onPanEnd: (DragEndDetails details) { _points.add(null); }, child: CustomPaint( painter: Painter(_points), ), ), ); }
Il widget GestureDetector con i metodi onPanUpdate e onPanEnd permette di riconoscere quando si inizia a disegnare e quando si termina, rilevando le coordinate dei punti interessati. Il codice proposto non fa altro che verificare se il punto è all'interno del riquadro e in caso affermativo aggiunge le coordinate ad una lista di punti _points. Quando il disegno termina, aggiunge un valore null per indicare la fine della linea. Per cancellare il disegno sarà sufficiente svuotare la lista dei punti ed aggiornare lo stato dell'applicazione.
E' quindi ora sufficiente definire un widget personalizzato Painter che si occupi di unire i punti per tracciare il disegno. Nella stessa cartella creiamo un nuovo file painter_screen.dart con questo codice:
import 'package:flutter/material.dart';
class Painter extends CustomPainter { final List<Offset?> points; static double _strokeWidth = 10;
Painter(this.points);
final Paint _paint = Paint() ..strokeCap = StrokeCap.round ..color = Colors.red ..strokeWidth = _strokeWidth;
@override void paint(Canvas canvas, Size size) { for (int i = 0; i < points.length - 1; i++) { if (points[i] != null && points[i + 1] != null) { canvas.drawLine(points[i]!, points[i + 1]!, _paint); } } }
@override bool shouldRepaint(CustomPainter oldDelegate) { return true; }}
Questo widget, unisce ogni punto con quello successivo, disegnando quindi una polilinea, fino a quando non trova il terminatore null.
Salvo errori di digitazione, se facciamo girare la nostra app, otteniamo qualcosa del genere:
su cui è possibile disegnare a mano libera nel riquadro e cancellarne il contenuto con il pulsante Clear.
Il package e il modello
E' arrivato il momento di aggiungere un package che semplifica l'utilizzo di TensorFlow Lite. Dal terminale digitiamo il seguente comando:
Creiamo la cartella assets e copiamoci dentro il due file associali al modello MINST introdotto nella prima parte di questo articolo: il modello vero e proprio mnist.tflite e il file delle etichette mnist.txt (li potete trovare nella mia repo GitHub qui). Modifichiamo il file pubspec.yaml per rendere disponibili gli asset appena scaricati nel progetto:
Il riconoscitore
Creiamo quindi una cartella services e in questa un nuovo file recognizer.dart in cui, dopo aver importato la libreria scriveremo la logica di gestione del modello MINST: una funzione asincrona che carica il modello (quindi potrebbe impiegare un po' di tempo), un metodo che rilascia il modello quando non occorre più e una funzione (anch'essa asincrona) che applica il modello ad una lista di bytes e ne restituisce le previsioni (lo vedremo meglio in seguito):
import 'package:flutter/material.dart';import 'package:flutter_tflite/flutter_tflite.dart';import 'dart:typed_data';import 'dart:ui';
class Recognizer { Future loadModel() async { Tflite.close();
return Tflite.loadModel( model: "assets/mnist.tflite", labels: "assets/mnist.txt"); }
dispose() { Tflite.close(); }
Future _predict(Uint8List bytes) async { return Tflite.runModelOnBinary(binary: bytes); }
}
Come anticipato nella prima parte, la lista di dati da sottoporre al modello altro non sono che i valori normalizzati dei pixel di una versione ridotta dell'immagine disegnata (ridotta alle dimensioni di 28x28 pixel con linee bianche su fondo nero, come i campioni su cui il modello è stato addestrato). Definiamo quindi una funzione che a partire dalla lista di punti, nei ricavi un'immagine nella scala 28 a 280 a colori invertiti:
Aggiungiamo anche una funzione che, lavorando sull'immagine ridotta, permetta di linearizzarla (da una griglia 2D, passeremo ad una lista di bytes) normalizzando il colore del pixel in un range [0..255]:
Fatto questo, possiamo quindi scrivere una funzione che elabora la lista di punti, la converte in una immagine 28x28 e ricavi la lista di bytes su cui applicare il modello:
Una classe per i risultati
Creiamo quindi una cartella models e in questa un nuovo file classification.dart in cui definire il modello dei dati di risposta, caratterizzati da un'etichetta, da un valore di confidenza (probabilità) e da un indice :
class Classification { final double confidence; final int index; final String label;
Classification( {required this.confidence, required this.index, required this.label});
factory Classification.fromJson(Map<dynamic, dynamic> json) { return Classification( confidence: json['confidence'], index: json['index'], label: json['label'], ); }}
Abbiamo aggiunto anche un metodo .fromJson per semplificare la gestione dei valori tornati dal modello MINT.
Commenti
Posta un commento