Skip to content

Prototypical Networks for Few-Shot Image Classification and its Use in the Retail Industry

This project implements a few-shot image classification system using Prototypical Networks. The goal is to enable the model to recognize new classes with only a few examples by leveraging metric-based learning. I then explore its application in the retail industry.

Overview

Few-shot learning aims to train models that can adapt to new tasks with a minimal amount of labeled data. Prototypical Networks achieve this by learning a metric space where classification can be performed by computing distances to prototype representations of each class.

This implementation uses a pre-trained ResNet18 model as the backbone feature extractor and leverages the EasyFSL library for streamlined few-shot learning processes.

Features

  • Custom Dataset Support: Easily integrate your own datasets by following simple formatting guidelines.
  • Flexible Hyperparameters: Adjust the number of classes (N_WAY), support samples (N_SHOT), and query samples (N_QUERY) per task.
  • Easy Training and Evaluation: Train and evaluate the model using simple function calls.
  • Inference Support: Perform inference on individual images or live video streams.
  • Visualization: Visualize support and query images for better understanding of tasks.
  • Modular Design: Clean and modular code structure for easy customization and extension.

Dataset Preparation

To use your own dataset, organize your data as follows:

  1. Images Directory:

    • Structure your images in subdirectories where each subdirectory corresponds to a class and contains images of that class.
    • Example:
      dataset/ β”œβ”€β”€ class1/ β”‚ β”œβ”€β”€ image1.jpg β”‚ β”œβ”€β”€ image2.jpg β”œβ”€β”€ class2/ β”‚ β”œβ”€β”€ image3.jpg β”‚ β”œβ”€β”€ image4.jpg ...
  2. CSV File:

    • Create a CSV file mapping class IDs to class names.
    • Example (classes.csv):
      Class ID (int),Class Name (str) 0,Cat 1,Dog 2,Bird
  3. Text Files:

    • Create text files listing image paths and corresponding class IDs for training and testing.
    • Example (train.txt):
      class1/image1.jpg, 0, - class2/image3.jpg, 1, -

Ensure all paths in text files are relative to the root directory of your dataset.

Installation

  1. Install EasyFSL

    pip install easyfsl
  2. (Optional) Setup Google Colab

    • If using Google Colab, mount your Google Drive to access datasets and save models.
    from google.colab import drive drive.mount('/content/drive')

Usage

Training

  1. Configure Paths and Hyperparameters

    • Edit the placeholders in the code with actual paths to your dataset and desired hyperparameters.
    csv_file = 'path/to/classes.csv' root_dir = 'path/to/dataset/' train_txt_file = 'path/to/train.txt' test_txt_file = 'path/to/test.txt' N_WAY = 5 N_SHOT = 5 N_QUERY = 5 N_TRAINING_EPISODES = 10000
  2. Run Training

    • Execute the training section of the code to start training.
    # Training loop with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train: for episode_index, (...) in tqdm_train: loss_value = fit(...)
  3. Save Model

    • After training, save the model for future use.
    torch.save(model.state_dict(), 'path/to/save/model.pth')

Evaluation

  1. Load Trained Model

    model.load_state_dict(torch.load('path/to/saved/model.pth'))
  2. Run Evaluation

    • Evaluate the model's performance on test tasks.
    evaluate(test_loader)

Inference on Single Image

  1. Prepare Support Set

    • Organize a support set directory with subfolders for each class containing a few images.
    support_set = SupportSetFolder( root='path/to/support/set/', transform=transform, device=device )
  2. Load and Preprocess Query Image

    query_image_path = 'path/to/query/image.jpg' query_image_PIL = Image.open(query_image_path) query_image = transform_tensor(query_image_PIL).unsqueeze(0).to(device)
  3. Perform Inference

    with torch.no_grad(): model.process_support_set(support_set.get_images(), support_set.get_labels()) scores = model(None, None, query_image) predicted_label = scores.argmax(dim=1).item() predicted_class = support_set.classes[predicted_label] print(f"Predicted Class: {predicted_class}")

Real-time Video Classification

  1. Run Live Classification
    • Ensure you have a webcam connected.
    cap = cv2.VideoCapture(0) while True: ret, frame = cap.read() class_name = classify_frame(frame) cv2.putText(frame, f"Class: {class_name}", (10, 30), ...) cv2.imshow('Live Classification', frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows()

Applications in Retail Industry

Few-shot learning is particularly advantageous in the retail industry, where it can be used to efficiently train models to classify items in a store with minimal labeled data. Traditional machine learning models require large datasets of labeled images for each item, which can be time-consuming and expensive to collect. However, few-shot learning models like Prototypical Networks can recognize new items with only a few examples.

Benefits:

  • Efficient Training: Store items can be classified with just a handful of images per class, reducing the overhead of data collection.
  • Real-Time Classification: The trained model can be deployed in real-time applications, such as self-checkout systems.
  • Theft Reduction: By accurately classifying items at self-checkout kiosks, few-shot learning can help prevent theft by ensuring that items are correctly scanned and billed.
  • Scalability: New items can be added to the inventory with minimal additional data collection, making the system scalable and adaptable to changing inventories.

In summary, few-shot learning provides a powerful tool for retailers to improve their operational efficiency and security by enabling rapid, accurate classification of products with minimal data requirements.

Customization

  • Backbone Network: Replace resnet18 with other architectures like resnet50 or custom models as needed.
  • Transformations: Modify image transformations to include data augmentation techniques such as rotation, flipping, or normalization.
  • Hyperparameters: Experiment with different values for N_WAY, N_SHOT, N_QUERY, and learning rates to improve performance.
  • Loss Functions and Optimizers: Explore other loss functions (e.g., TripletLoss) and optimizers (e.g., SGD) based on your use case.

Dependencies

  • Python 3.6+
  • PyTorch 1.7+
  • torchvision
  • EasyFSL
  • numpy
  • pandas
  • Pillow
  • OpenCV
  • tqdm

Results

After training for a sufficient number of episodes, you should expect reasonable accuracy on novel classes with limited examples. The exact performance will depend on the complexity of the dataset and chosen hyperparameters.

Example Evaluation Output:

Evaluation on 100 tasks completed. Accuracy: 85.40%

References

  • Prototypical Networks for Few-shot Learning
  • EasyFSL: A Few-Shot Learning Library
  • PyTorch Documentation