Back To Top
Experience the Future of Intelligent Investing Today
Fine-tuning BLIP2, a state-of-the-art open-source visual language model, can be a game-changer for various business applications. BLIP2 stands out as one of the most powerful models in its class. It excels at tasks involving both visual and textual understanding.
This article shows how to fine-tune BLIP2 for the specific task of generating captions for the Flickr30k dataset. We leverage a technique called Low-Rank Adaptation (LoRA) and will use HuggingFace’s PEFT for this process.
We can adapt BLIP2’s capabilities with minimal computational resources. Employing PEFT allows us to fine-tune the model efficiently and significantly reduce the extensive resources that traditional methods typically require.
This article is structured as follows:
To begin, ensure your system has Python installed. Here’s what you need to install to get started:
pip install torch torchvision
. This will be our primary framework for model training.pip install transformers datasets
. They provide the pre-trained models and easy access to the dataset.pip install git+https://github.com/huggingface/peft.git bitsandbytes
.
# Install libraries
!pip install -q git+https://github.com/huggingface/peft.git transformers bitsandbytes datasets accelerate
For this project, we start by loading the Flickr30K dataset using the HuggingFace datasets
library. We then select a small subset of the data for both training and testing purposes. Specifically, we choose only the first 6 images for training — this limited selection is for illustration purposes only.
Although we’re using just a few images here, the entire dataset, containing 30,000 images, could be utilized for a more comprehensive training phase. This small split ensures that we have a manageable amount of data for initial tests and quick adjustments.
First, we transform each sample to extract the first caption from each image. This step simplifies the model’s training task, allowing it to focus on one caption per image. Next, we clearly define the features of our datasets.
from datasets import load_dataset, Dataset, Features, Value, Image
# Load the dataset
dataset = load_dataset("nlphuji/flickr30k")
# Select the first 6 entries from the training dataset
sample_train_dataset = dataset['test'].select(range(6))
# Select the next 6 entries for the testing dataset
sample_test_dataset = dataset['test'].select(range(7, 13))
# Function to transform the data into the desired format, extracting only the first caption
def transform_sample(sample):
return {
"image": sample["image"],
"text": sample["caption"][0] # Taking only the first caption
}
# Apply the transformation on train and test
transformed_dataset_train = sample_train_dataset.map(transform_sample, remove_columns=sample_train_dataset.column_names)
transformed_dataset_test = sample_test_dataset.map(transform_sample, remove_columns=sample_test_dataset.column_names)
# Define features explicitly
features = Features({
'image': Image(decode=True), # Assuming the images are encoded and need to be decoded
'text': Value('string')
})
# Set structure
transformed_dataset_train = transformed_dataset_train.cast(features)
transformed_dataset_test = transformed_dataset_test.cast(features)
print(transformed_dataset_train["text"][1])
transformed_dataset_train["image"][1]
“Several men in hard hats are operating a giant pulley system.”
To effectively manage the complexities of handling both image and text data, a custom dataset class is developed. It ensures that each piece of data is appropriately formatted for the neural network.
Here’s how it works:
__len__
method allows us to quickly query the total number of items in the dataset, facilitating operations that require knowledge of the dataset size, such as batching during training.__getitem__
method retrieves a data item by its index. It first extracts the image and the first caption from the dataset. The processor then formats the text and applies necessary transformations to the image. These transformations typically include resizing and normalizing the image and tokenizing the text. collate_fn
function. This function is needed for batch processing as it aggregates individual data points into a batch. It handles diverse data types within a batch, ensuring that all image data are correctly stacked and all text data are padded to uniform length.Newsletter