Back To Top

April 26, 2024

Fine-tuning BLIP2 for Image Caption Generation with PEFT

Customizing BLIP2 with LoRA and HuggingFace on the Flickr30k Dataset

Fine-tuning BLIP2, a state-of-the-art open-source visual language model, can be a game-changer for various business applications. BLIP2 stands out as one of the most powerful models in its class. It excels at tasks involving both visual and textual understanding. 

This article shows how to fine-tune BLIP2 for the specific task of generating captions for the Flickr30k dataset. We leverage a technique called Low-Rank Adaptation (LoRA) and will use HuggingFace’s PEFT for this process. 

We can adapt BLIP2’s capabilities with minimal computational resources. Employing PEFT allows us to fine-tune the model efficiently and significantly reduce the extensive resources that traditional methods typically require.

fine-tuning blip2 demo entreprenerdly

This article is structured as follows:

  • Environment and Libraries Setup
  • Data Acquisition and Preparation
  • Model Fine-Tuning with LoRA
  • Using Model for Inference

1. Environment Setup

To begin, ensure your system has Python installed. Here’s what you need to install to get started:

  1. PyTorch: Install PyTorch by running pip install torch torchvision. This will be our primary framework for model training.
  2. Transformers and Datasets Libraries: Install these HuggingFace libraries with pip install transformers datasets. They provide the pre-trained models and easy access to the dataset.
  3. PEFT and bitsandbytes: To manage memory efficiently and fine-tune effectively, install PEFT and bitsandbytes by running pip install git+https://github.com/huggingface/peft.git bitsandbytes.
				
					# Install libraries
!pip install -q git+https://github.com/huggingface/peft.git transformers bitsandbytes datasets accelerate
				
			

2. Fine-tuning Data Preparation

2.1 Load Data and Split

For this project, we start by loading the Flickr30K dataset using the HuggingFace datasets library. We then select a small subset of the data for both training and testing purposes. Specifically, we choose only the first 6 images for training — this limited selection is for illustration purposes only. 

Although we’re using just a few images here, the entire dataset, containing 30,000 images, could be utilized for a more comprehensive training phase. This small split ensures that we have a manageable amount of data for initial tests and quick adjustments.

First, we transform each sample to extract the first caption from each image. This step simplifies the model’s training task, allowing it to focus on one caption per image. Next, we clearly define the features of our datasets. 

				
					from datasets import load_dataset, Dataset, Features, Value, Image

# Load the dataset
dataset = load_dataset("nlphuji/flickr30k")

# Select the first 6 entries from the training dataset
sample_train_dataset = dataset['test'].select(range(6))

# Select the next 6 entries for the testing dataset
sample_test_dataset = dataset['test'].select(range(7, 13))

# Function to transform the data into the desired format, extracting only the first caption
def transform_sample(sample):
    return {
        "image": sample["image"],
        "text": sample["caption"][0]  # Taking only the first caption
    }

# Apply the transformation on train and test
transformed_dataset_train = sample_train_dataset.map(transform_sample, remove_columns=sample_train_dataset.column_names)
transformed_dataset_test = sample_test_dataset.map(transform_sample, remove_columns=sample_test_dataset.column_names)

# Define features explicitly
features = Features({
    'image': Image(decode=True),  # Assuming the images are encoded and need to be decoded
    'text': Value('string')
})

# Set structure 
transformed_dataset_train = transformed_dataset_train.cast(features)
transformed_dataset_test = transformed_dataset_test.cast(features)
				
			
				
					print(transformed_dataset_train["text"][1])
transformed_dataset_train["image"][1]
				
			

 “Several men in hard hats are operating a giant pulley system.”

ftraining image example for fine-tuning blip2

2.2 Helper Function for Data Transformation

To effectively manage the complexities of handling both image and text data, a custom dataset class is developed. It ensures that each piece of data is appropriately formatted for the neural network.

Here’s how it works:

  1. Initialization: The constructor of the class takes two primary inputs: the dataset itself and the processor. The processor prepares the dataset’s images and their corresponding captions for efficient processing by the model.
  2. Length Determination: Implementing the __len__ method allows us to quickly query the total number of items in the dataset, facilitating operations that require knowledge of the dataset size, such as batching during training.
  3. Item Retrieval and Processing: The __getitem__ method retrieves a data item by its index. It first extracts the image and the first caption from the dataset. The processor then formats the text and applies necessary transformations to the image. These transformations typically include resizing and normalizing the image and tokenizing the text. 
  4. Batch Collation: To facilitate the training process, we also implement a collate_fn function. This function is needed for batch processing as it aggregates individual data points into a batch. It handles diverse data types within a batch, ensuring that all image data are correctly stacked and all text data are padded to uniform length.

Also worth reading:

Fine-Tuning StarCoder To Customize A Coding Assistant

A Comprehensive Guide Fine-tune a Code LLM on Private Code using a Single GPU to Enhance its Contextual Awareness
Prev Post

Fine-Tuning StarCoder to Customize a Coding Assistant

Next Post

Fine-Tuning LLaMA 3 at 2x Speed with Unsloth and ALPACA

post-bars
Mail Icon

Newsletter

Get Every Weekly Update & Insights

[mc4wp_form id=]

Leave a Comment