How to use TensorFlow Lite for Text Classification in Jetpack Compose

TensorFlow Lite Text Classifier Android App

Overview

This Android app uses a TensorFlow Lite model to classify social media posts into 11 categories like technology, sports, and finance, all on-device for privacy and speed. Built with Kotlin and Jetpack Compose, it delivers a modern, efficient, and server-independent experience, perfect for real-time text classification.


Key Features

  • On-device text classification using TensorFlow Lite
  • Modern Jetpack Compose UI for seamless user interaction
  • Multi-line text input for flexible post entry
  • Displays probability distribution across categories
  • Fast, lightweight inference without server calls
  • Manual tokenizer with vocab.txt for consistent results

Key Features Table

Feature Description
 On-Device Classification Classifies text locally using model_with_softmax.tflite.
 Jetpack Compose UI Modern, responsive interface for text input and results.
 Probability Output Shows confidence scores for each category (e.g., "Sports: 92.34%").

Implementation Steps

This guide walks you through building an Android app that classifies text input using a pre-trained TensorFlow Lite model, styled with Jetpack Compose for a sleek user experience.

Step 1: Set Up Dependencies

Add TensorFlow Lite dependencies to your build.gradle file.


implementation 'org.tensorflow:tensorflow-lite:2.12.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.3.1'
    

Step 2: Add Assets

Place the following files in app/src/main/assets/:

  • model_with_softmax.tflite – Pre-trained classification model
  • vocab.txt – Tokenizer vocabulary
  • labels.txt – List of category labels

Step 3: Initialize Text Classifier

Load the model, vocabulary, and labels in Kotlin.


package com.example.tfliteapp

import android.content.Context
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil

object TextClassifier {
    private lateinit var interpreter: Interpreter
    private lateinit var vocab: Map
    private lateinit var labels: List

    fun initialize(context: Context) {
        val model = FileUtil.loadMappedFile(context, "model_with_softmax.tflite")
        interpreter = Interpreter(model)
        vocab = loadVocab(context)
        labels = loadLabels(context)
    }

    private fun loadVocab(context: Context): Map {
        return context.assets.open("vocab.txt").bufferedReader().useLines { lines ->
            lines.withIndex().associate { (index, word) -> word to index }
        }
    }

    private fun loadLabels(context: Context): List {
        return context.assets.open("labels.txt").bufferedReader().readLines()
    }
}
    

Step 4: Build Jetpack Compose UI

Create a UI with a text input field, classify button, and result display.


package com.example.tfliteapp

import android.os.Bundle
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import com.example.tfliteapp.ui.theme.TFliteAppTheme

class MainActivity : ComponentActivity() {
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        enableEdgeToEdge()
        TextClassifier.initialize(this)
        setContent {
            TFliteAppTheme {
                Scaffold(modifier = Modifier.fillMaxSize()) { innerPadding ->
                    HomeScreen(modifier = Modifier.padding(innerPadding))
                }
            }
        }
    }
}

@Composable
fun HomeScreen(modifier: Modifier = Modifier) {
    var text by remember { mutableStateOf("") }
    var res by remember { mutableStateOf("") }
    var isLoading by remember { mutableStateOf(false) }
    var error by remember { mutableStateOf("") }

    Column(
        modifier = modifier
            .fillMaxSize()
            .padding(16.dp)
            .verticalScroll(rememberScrollState()),
        horizontalAlignment = Alignment.CenterHorizontally,
        verticalArrangement = Arrangement.Center
    ) {
        Card(
            modifier = Modifier
                .fillMaxWidth()
                .padding(8.dp)
        ) {
            TextField(
                value = text,
                onValueChange = { text = it },
                placeholder = { Text("Enter your post here...") },
                modifier = Modifier
                    .fillMaxWidth()
                    .height(150.dp)
                    .padding(8.dp),
                maxLines = 10,
                textStyle = androidx.compose.ui.text.TextStyle(textAlign = TextAlign.Start)
            )
        }

        Spacer(modifier = Modifier.height(16.dp))

        Button(
            onClick = {
                if (text.isBlank()) {
                    error = "Please enter some text."
                    return@Button
                }
                error = ""
                isLoading = true
                val result = TextClassifier.predict(text)
                res = result
                isLoading = false
            },
            modifier = Modifier
                .fillMaxWidth(0.8f)
                .height(50.dp)
        ) {
            Text("Predict", style = MaterialTheme.typography.bodyLarge)
        }

        Spacer(modifier = Modifier.height(16.dp))

        if (isLoading) {
            CircularProgressIndicator()
        }

        if (error.isNotEmpty()) {
            Text(
                text = error,
                color = MaterialTheme.colorScheme.error,
                style = MaterialTheme.typography.bodyMedium,
                modifier = Modifier.fillMaxWidth()
            )
        }

        Spacer(modifier = Modifier.height(8.dp))

        Text(
            text = "Prediction Result:",
            style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.Bold),
            modifier = Modifier.fillMaxWidth()
        )
        Text(
            text = res,
            style = MaterialTheme.typography.bodyLarge,
            modifier = Modifier.fillMaxWidth()
        )
    }
}
    

Step 5: Implement Prediction Logic

Preprocess text and run inference with the TensorFlow Lite model.


package com.example.tfliteapp

object TextClassifier {
    fun predict(text: String): String {
        val input = preprocess(text, vocab)
        val output = Array(1) { FloatArray(labels.size) }
        interpreter.run(arrayOf(input), output)
        val probs = output[0]
        val sorted = labels.zip(probs.toList())
            .sortedByDescending { it.second }
        return sorted.joinToString(separator = "\n") { (label, prob) ->
            "$label: ${"%.2f".format(prob.coerceIn(0f, 1f) * 100)}%"
        }
    }

    private fun preprocess(text: String, vocab: Map, maxLen: Int = 50): IntArray {
        val tokens = text.lowercase()
            .replace(Regex("[^a-z0-9 ]"), "")
            .split(" ")
        val unk = vocab["[UNK]"] ?: 0
        val indices = tokens.map { vocab[it] ?: unk }
        val padded = IntArray(maxLen)
        for (i in indices.indices.take(maxLen)) {
            padded[i] = indices[i]
        }
        return padded
    }
}
    

FAQ

Which Android versions support TensorFlow Lite?

TensorFlow Lite works on Android 4.1 (API 16) and above, covering most modern devices.


Do I need internet for classification?

No, the app runs fully on-device with no server dependency.


Can I add more categories?

Yes, retrain the model using the provided Android.ipynb notebook and update labels.txt.


Why use a manual tokenizer?

A manual tokenizer with vocab.txt ensures consistency between training and inference.

Quick Notes

TFLite Text Classifier Jetpack Compose App

1. Training the model (from text to tflite file)


2. Model Overview 
(what the model expects)

3. Plugin Setup (Use latest version)

4. Project Assets (Place model & label files)

5. Kotlin Initialization (Load the classifier)

7. Jetpack-Compose-UI (Input Text + Run prediction)

8.Input & Output 

Source Code

* Github : TFLite Text Classifier Android App

* Model: Colab for training the classifier



Post a Comment

Previous Post Next Post