Exécutez des modèles de machine learning dans le navigateur avec TensorFlow.js : modèles pré-entraînés, prédictions temps réel et webcam interactive.
Pourquoi TensorFlow.js dans le navigateur ?
TensorFlow.js permet d'exécuter des modèles de machine learning directement dans le navigateur ou sous Node.js, sans serveur backend. Les avantages sont nombreux :
- Confidentialité : les données ne quittent jamais l'appareil de l'utilisateur.
- Latence zéro : pas d'aller-retour serveur, les prédictions sont immédiates.
- Pas de coût serveur : le GPU/CPU du client fait le travail.
- Fonctionne hors ligne après le chargement du modèle.
Installation et premiers pas
Via npm pour un projet Node.js ou bundler (Webpack, Vite) :
npm install @tensorflow/tfjs
Ou directement via CDN dans le navigateur :
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
Vérification que TensorFlow.js est opérationnel :
import * as tf from '@tensorflow/tfjs';
// Informations sur le backend utilisé (WebGL, CPU, WASM...)
console.info('Backend:', tf.getBackend());
// Exemple de tenseur simple
const a = tf.tensor([1, 2, 3, 4]);
const b = tf.scalar(2);
const result = a.mul(b);
result.print(); // [2, 4, 6, 8]
Backends disponibles
- WebGL — utilise le GPU du navigateur (le plus rapide).
- WASM — WebAssembly, bon compromis CPU/performance.
- CPU — fallback pur JavaScript, le plus lent.
Utiliser des modèles pré-entraînés
TensorFlow.js propose des modèles pré-entraînés prêts à l'emploi via @tensorflow-models :
npm install @tensorflow-models/mobilenet
npm install @tensorflow-models/coco-ssd
npm install @tensorflow-models/pose-detection
Classification d'image avec MobileNet :
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
async function classifierImage(imageElement) {
// Charger le modèle (mis en cache après le premier chargement)
const model = await mobilenet.load();
// Classer l'image
const predictions = await model.classify(imageElement);
// predictions = [{ className: 'cat', probability: 0.93 }, ...]
return predictions;
}
// Utilisation
const img = document.getElementById('mon-image');
const predictions = await classifierImage(img);
console.info(predictions[0]); // { className: 'tabby cat', probability: 0.932 }
Prédictions en temps réel
Détection d'objets en temps réel depuis la webcam avec COCO-SSD :
import * as cocoSsd from '@tensorflow-models/coco-ssd';
async function demarrerDetection() {
const video = document.getElementById('webcam');
const canvas = document.getElementById('overlay');
const ctx = canvas.getContext('2d');
// Accéder à la webcam
const stream = await navigator.mediaDevices.getUserMedia({ video: true });
video.srcObject = stream;
await video.play();
// Charger le modèle
const model = await cocoSsd.load();
// Boucle de détection
async function detecter() {
const predictions = await model.detect(video);
ctx.clearRect(0, 0, canvas.width, canvas.height);
predictions.forEach(pred => {
const [x, y, width, height] = pred.bbox;
ctx.strokeStyle = '#dd0031';
ctx.lineWidth = 2;
ctx.strokeRect(x, y, width, height);
ctx.fillStyle = '#dd0031';
ctx.fillText(
`${pred.class} (${Math.round(pred.score * 100)}%)`,
x, y - 5
);
});
requestAnimationFrame(detecter);
}
detecter();
}
Entraîner un modèle simple
Entraîner un réseau de neurones pour prédire une valeur (exemple : y = 2x - 1) :
import * as tf from '@tensorflow/tfjs';
// Créer un modèle séquentiel simple
const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
// Compiler le modèle
model.compile({
optimizer: 'sgd',
loss: 'meanSquaredError'
});
// Données d'entraînement
const xs = tf.tensor1d([-1, 0, 1, 2, 3, 4]);
const ys = tf.tensor1d([-3, -1, 1, 3, 5, 7]); // y = 2x - 1
// Entraîner
await model.fit(xs, ys, {
epochs: 250,
callbacks: {
onEpochEnd: (epoch, logs) => {
if (epoch % 50 === 0) {
console.info(`Epoch ${epoch}: loss = ${logs.loss.toFixed(4)}`);
}
}
}
});
// Prédire
const prediction = model.predict(tf.tensor1d([5]));
prediction.print(); // ~9 (2*5 - 1)
Transfer learning : classifier vos propres images
Le transfer learning consiste à réutiliser les couches d'un modèle pré-entraîné (ex: MobileNet) comme extracteur de features, et à ajouter uniquement une nouvelle tête de classification entraînée sur vos données.
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
// 1. Charger MobileNet comme extracteur de features
const mobileNet = await mobilenet.load({ version: 2, alpha: 1.0 });
// 2. Extraire la couche d'embedding (avant la classification finale)
const layer = mobileNet.model.getLayer('conv_pw_13_relu');
const featureExtractor = tf.model({
inputs: mobileNet.model.inputs,
outputs: layer.output
});
featureExtractor.trainable = false; // Geler les poids MobileNet
// 3. Ajouter votre propre classifieur
const numClasses = 3; // Ex: chat, chien, oiseau
const classifier = tf.sequential({
layers: [
tf.layers.flatten({ inputShape: featureExtractor.outputShape.slice(1) }),
tf.layers.dense({ units: 128, activation: 'relu' }),
tf.layers.dropout({ rate: 0.3 }),
tf.layers.dense({ units: numClasses, activation: 'softmax' })
]
});
// 4. Modèle complet = extractor + classifier
const fullModel = tf.sequential();
// Note : en pratique, chaîner les prédictions manuellement
// 5. Compiler et entraîner
classifier.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
// Données d'entraînement : extraire les features via MobileNet
async function extractFeatures(imageElement) {
const tensor = tf.browser.fromPixels(imageElement)
.resizeNearestNeighbor([224, 224])
.toFloat()
.div(tf.scalar(127.5))
.sub(tf.scalar(1))
.expandDims(0);
return featureExtractor.predict(tensor);
}
Sauvegarder et charger un modèle
TensorFlow.js permet de sauvegarder un modèle dans le localStorage, IndexedDB, ou comme téléchargement de fichiers, puis de le recharger.
// Sauvegarder le modèle après entraînement
await model.save('localstorage://my-model'); // localStorage (limité à 5MB)
await model.save('indexeddb://my-model'); // IndexedDB (recommandé)
await model.save('downloads://my-model'); // Télécharger les fichiers
// Charger un modèle sauvegardé
const loadedModel = await tf.loadLayersModel('indexeddb://my-model');
// Ou depuis un serveur (fichiers .json + .bin)
const serverModel = await tf.loadLayersModel('https://cdn.example.com/model/model.json');
// Sauvegarder avec métadonnées personnalisées
await model.save({
path: 'indexeddb://emotion-detector-v2',
trainingConfig: {
optimizer: 'adam',
loss: 'categoricalCrossentropy',
epochs: 100
}
});
Service Angular avec TensorFlow.js et Signals
Intégrer TensorFlow.js dans Angular avec un service gérant le cycle de vie du modèle et exposant des signaux réactifs.
// image-classifier.service.ts
import { Injectable, signal, computed } from '@angular/core';
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
export interface Prediction {
className: string;
probability: number;
}
@Injectable({ providedIn: 'root' })
export class ImageClassifierService {
private model: mobilenet.MobileNet | null = null;
private _loading = signal(false);
private _ready = signal(false);
private _predictions = signal<Prediction[]>([]);
private _error = signal<string | null>(null);
readonly loading = this._loading.asReadonly();
readonly ready = this._ready.asReadonly();
readonly predictions = this._predictions.asReadonly();
readonly error = this._error.asReadonly();
readonly topPrediction = computed(() =>
this._predictions()[0] ?? null
);
async initialize(): Promise<void> {
if (this.model) return; // Déjà chargé
this._loading.set(true);
this._error.set(null);
try {
// Forcer WebGL pour les performances
await tf.setBackend('webgl');
await tf.ready();
this.model = await mobilenet.load({ version: 2, alpha: 1.0 });
this._ready.set(true);
} catch (err) {
this._error.set(`Erreur de chargement: ${err}`);
} finally {
this._loading.set(false);
}
}
async classify(imageElement: HTMLImageElement | HTMLVideoElement): Promise<void> {
if (!this.model) throw new Error('Modèle non initialisé');
const results = await tf.tidy(() =>
this.model!.classify(imageElement)
);
this._predictions.set(results as Prediction[]);
}
dispose(): void {
if (this.model) {
this.model = null;
this._ready.set(false);
}
}
}
// Composant d'utilisation
@Component({
selector: 'app-image-classifier',
standalone: true,
template: `
@if (classifier.loading()) {
<p>Chargement du modèle TensorFlow.js...</p>
}
@if (classifier.ready()) {
<input type="file" accept="image/*" (change)="onFileChange($event)">
@if (classifier.topPrediction(); as pred) {
<p class="fw-bold">{{ pred.className }}</p>
<p>Confiance: {{ (pred.probability * 100).toFixed(1) }}%</p>
}
}
@if (classifier.error(); as err) {
<p class="text-danger">{{ err }}</p>
}
`
})
export class ImageClassifierComponent {
classifier = inject(ImageClassifierService);
constructor() {
this.classifier.initialize();
}
async onFileChange(event: Event) {
const file = (event.target as HTMLInputElement).files?.[0];
if (!file) return;
const img = new Image();
img.onload = () => this.classifier.classify(img);
img.src = URL.createObjectURL(file);
}
}
Performances et bonnes pratiques
- Appeler
tf.dispose()outf.tidy()pour libérer la mémoire GPU après chaque prédiction. - Charger le modèle une seule fois et le réutiliser — ne pas le recharger à chaque prédiction.
- Forcer le backend WebGL explicitement avec
await tf.setBackend('webgl')pour les performances. - Utiliser un Web Worker pour les calculs CPU-intensifs et éviter de bloquer le thread principal.
- Afficher un indicateur de chargement pendant le téléchargement du modèle (20-100 MB).
- Mettre le modèle en cache via
indexeddb://model-namepour éviter le re-téléchargement. - Monitorer la mémoire GPU avec
tf.memory()en développement.
// Bonne pratique : utiliser tidy() pour éviter les fuites mémoire
const prediction = tf.tidy(() => {
const input = tf.browser.fromPixels(imageElement)
.resizeNearestNeighbor([224, 224])
.toFloat()
.div(tf.scalar(255.0))
.expandDims(0);
return model.predict(input);
});
prediction.print();
prediction.dispose(); // Libérer explicitement après usage
// Monitorer la mémoire (debugging)
console.info('Tenseurs actifs:', tf.memory().numTensors);
console.info('Bytes GPU:', tf.memory().numBytes);
| Modèle | Taille | Vitesse | Usage principal |
|---|---|---|---|
| MobileNet v2 | ~20 MB | Rapide | Classification images 1000 classes |
| COCO-SSD | ~25 MB | Rapide | Détection d'objets en temps réel |
| BlazeFace | ~1 MB | Très rapide | Détection de visages |
| PoseNet / MoveNet | ~10 MB | Rapide | Estimation de pose humaine |
| Universal Sentence Encoder | ~30 MB | Moyen | Embeddings de texte, similarité |
| Handpose | ~5 MB | Rapide | Détection main et doigts |
tf.tidy() ou dispose(), vous créez des fuites qui ralentissent ou bloquent le navigateur. Vérifiez tf.memory().numTensors avant/après une opération pour détecter les fuites.