Prototypical Networks for Few-Shot Image Classification and its Use in the Retail Industry
This project implements a few-shot image classification system using Prototypical Networks. The goal is to enable the model to recognize new classes with only a few examples by leveraging metric-based learning. I then explore its application in the retail industry.
Overview
Few-shot learning aims to train models that can adapt to new tasks with a minimal amount of labeled data. Prototypical Networks achieve this by learning a metric space where classification can be performed by computing distances to prototype representations of each class.
This implementation uses a pre-trained ResNet18 model as the backbone feature extractor and leverages the EasyFSL library for streamlined few-shot learning processes.
Features
- Custom Dataset Support: Easily integrate your own datasets by following simple formatting guidelines.
- Flexible Hyperparameters: Adjust the number of classes (
N_WAY), support samples (N_SHOT), and query samples (N_QUERY) per task. - Easy Training and Evaluation: Train and evaluate the model using simple function calls.
- Inference Support: Perform inference on individual images or live video streams.
- Visualization: Visualize support and query images for better understanding of tasks.
- Modular Design: Clean and modular code structure for easy customization and extension.
Dataset Preparation
To use your own dataset, organize your data as follows:
-
Images Directory:
- Structure your images in subdirectories where each subdirectory corresponds to a class and contains images of that class.
- Example:
dataset/ βββ class1/ β βββ image1.jpg β βββ image2.jpg βββ class2/ β βββ image3.jpg β βββ image4.jpg ...
-
CSV File:
- Create a CSV file mapping class IDs to class names.
- Example (
classes.csv):Class ID (int),Class Name (str) 0,Cat 1,Dog 2,Bird
-
Text Files:
- Create text files listing image paths and corresponding class IDs for training and testing.
- Example (
train.txt):class1/image1.jpg, 0, - class2/image3.jpg, 1, -
Ensure all paths in text files are relative to the root directory of your dataset.
Installation
-
Install EasyFSL
pip install easyfsl -
(Optional) Setup Google Colab
- If using Google Colab, mount your Google Drive to access datasets and save models.
from google.colab import drive drive.mount('/content/drive')
Usage
Training
-
Configure Paths and Hyperparameters
- Edit the placeholders in the code with actual paths to your dataset and desired hyperparameters.
csv_file = 'path/to/classes.csv' root_dir = 'path/to/dataset/' train_txt_file = 'path/to/train.txt' test_txt_file = 'path/to/test.txt' N_WAY = 5 N_SHOT = 5 N_QUERY = 5 N_TRAINING_EPISODES = 10000 -
Run Training
- Execute the training section of the code to start training.
# Training loop with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train: for episode_index, (...) in tqdm_train: loss_value = fit(...) -
Save Model
- After training, save the model for future use.
torch.save(model.state_dict(), 'path/to/save/model.pth')
Evaluation
-
Load Trained Model
model.load_state_dict(torch.load('path/to/saved/model.pth')) -
Run Evaluation
- Evaluate the model's performance on test tasks.
evaluate(test_loader)
Inference on Single Image
-
Prepare Support Set
- Organize a support set directory with subfolders for each class containing a few images.
support_set = SupportSetFolder( root='path/to/support/set/', transform=transform, device=device ) -
Load and Preprocess Query Image
query_image_path = 'path/to/query/image.jpg' query_image_PIL = Image.open(query_image_path) query_image = transform_tensor(query_image_PIL).unsqueeze(0).to(device) -
Perform Inference
with torch.no_grad(): model.process_support_set(support_set.get_images(), support_set.get_labels()) scores = model(None, None, query_image) predicted_label = scores.argmax(dim=1).item() predicted_class = support_set.classes[predicted_label] print(f"Predicted Class: {predicted_class}")
Real-time Video Classification
- Run Live Classification
- Ensure you have a webcam connected.
cap = cv2.VideoCapture(0) while True: ret, frame = cap.read() class_name = classify_frame(frame) cv2.putText(frame, f"Class: {class_name}", (10, 30), ...) cv2.imshow('Live Classification', frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows()
Applications in Retail Industry
Few-shot learning is particularly advantageous in the retail industry, where it can be used to efficiently train models to classify items in a store with minimal labeled data. Traditional machine learning models require large datasets of labeled images for each item, which can be time-consuming and expensive to collect. However, few-shot learning models like Prototypical Networks can recognize new items with only a few examples.
Benefits:
- Efficient Training: Store items can be classified with just a handful of images per class, reducing the overhead of data collection.
- Real-Time Classification: The trained model can be deployed in real-time applications, such as self-checkout systems.
- Theft Reduction: By accurately classifying items at self-checkout kiosks, few-shot learning can help prevent theft by ensuring that items are correctly scanned and billed.
- Scalability: New items can be added to the inventory with minimal additional data collection, making the system scalable and adaptable to changing inventories.
In summary, few-shot learning provides a powerful tool for retailers to improve their operational efficiency and security by enabling rapid, accurate classification of products with minimal data requirements.
Customization
- Backbone Network: Replace
resnet18with other architectures likeresnet50or custom models as needed. - Transformations: Modify image transformations to include data augmentation techniques such as rotation, flipping, or normalization.
- Hyperparameters: Experiment with different values for
N_WAY,N_SHOT,N_QUERY, and learning rates to improve performance. - Loss Functions and Optimizers: Explore other loss functions (e.g.,
TripletLoss) and optimizers (e.g.,SGD) based on your use case.
Dependencies
- Python 3.6+
- PyTorch 1.7+
- torchvision
- EasyFSL
- numpy
- pandas
- Pillow
- OpenCV
- tqdm
Results
After training for a sufficient number of episodes, you should expect reasonable accuracy on novel classes with limited examples. The exact performance will depend on the complexity of the dataset and chosen hyperparameters.
Example Evaluation Output:
Evaluation on 100 tasks completed. Accuracy: 85.40%