Back To Top

February 7, 2024

Fine-Tuning Your Own Custom Stable Diffusion Model with just 4 Images

End-to-End Python Guide For Giving a Stable Diffusion Model Your Own Images for Training and Making Inferences from Text

Stable Diffusion models have gained significant attention for their ability to generate high-quality, diverse images from textual descriptions. However, the one-size-fits-all nature of these pre-trained models might fall short when tasked with generating images of highly specific or personalized subjects.

This is where fine-tuning comes into play. Fine-tuning allows you to teach the model about new, unique subjects based on a relatively small set of images. DreamBooth takes this a step further by enabling the fine-tuning process without compromising the model’s original capabilities.

Consequently, this guide presents a clear, step-by-step approach for you to fine-tune your Stable Diffusion model with DreamBooth. By the end of this guide, you will have a fine-tuned model capable of generating images that resonate with your vision and requirements.

Furthermore, we will also cover the theoretical underpinnings of generative models and fine-tuning, delve into the specifics of setting up and environment and give you the end-to-end code on Google Colab.

This guide is structured as follows:

1. Conceptual Foundations

1.1 Generative Models and Text-to-Image Synthesis

As a class of machine learning models, generative models actively create new data instances that resemble a given dataset. They capture the underlying data distribution to produce novel samples. 

Text-to-image models, a subset of generative models, intrigue users by translating textual descriptions into corresponding visual representations with great accuracy and fidelity.

One prevalent architecture for such models is the Transformer, initially introduced in “Attention Is All You Need” by Vaswani et al. (2017). Additionally, the model uses self-attention mechanisms actively to process input sequences, effectively handling complex dependencies.

The Stable Diffusion model, a type of text-to-image model, leverages a variant of the Transformer architecture to generate images conditioned on textual input.

The generative process in such models can be described as follows:

Given a textual description T, the model aims to generate an image I such that the joint probability P(I,T) is maximized. Moreover, training the model to maximize the conditional probability P(I∣T) achieves this, ensuring coherence between the generated images and the textual descriptions.

1.2 Fine-Tuning Generative Models

Fine-tuning requires you to adjust a pre-trained model on a new, typically smaller dataset, tailoring the model to specific needs without losing the original dataset’s generalizability. This proactive approach becomes critical in applications with scarce data or required customization.

Mathematically, fine-tuning adjusts the parameters θ of the model to optimize a loss function L on the new dataset Dnew​, while preventing significant deviation from the original parameters θorig​. This can be formulated as a regularization problem:

Stable Diffusion Minimizing the fine-tuning loss function to balance new data fitting and original model retention entreprenerdly.com

Equation 1. Minimizing the fine-tuning loss function to balance new data fitting and original model retention.

where λ is a regularization parameter controlling the trade-off between fitting the new data and staying close to the original model.

1.3 Introducing DreamBooth

DreamBooth proposes a novel fine-tuning methodology that allows the generation of images with specific subjects or objects while maintaining the model’s ability to generate diverse images. 

Unlike traditional fine-tuning, which might lead to overfitting or catastrophic forgetting (forgetting the original data distribution), DreamBooth ensures the model retains its general capabilities.

Moreover, you need to train subject-specific tokens actively alongside the original model parameters in this process.. This is conceptually similar to adding a new “word” to the model’s vocabulary that represents the new subject. The training objective can be described as:

stable diffusion Objective function for DreamBooth fine-tuning, optimizing subject-specific parameters while regularizing towards the original model parameters entreprenerdly.com

Equation 2. Objective function for DreamBooth fine-tuning, optimizing subject-specific parameters while regularizing towards the original model parameters.

where ϕ represents the subject-specific parameters, and 1,2λ1​,λ2​ are regularization parameters.

For further reading and a deeper understanding, refer to the original papers: “DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation” by Rombach et al., and the foundational paper on diffusion models, “Denoising Diffusion Probabilistic Models” by Sohl-Dickstein et al.

1.4 Examples and Practical Considerations

In practical terms, consider a scenario where a Stable Diffusion model, trained on a diverse dataset, generates generic images of cats. To personalize this model using DreamBooth, you would provide a set of images of your specific cat. 

Moreover, the fine-tuning process adapts the model to generate images of your cat in various contexts, as described by textual prompts, while still being able to generate a wide range of other images.

ou must actively select and tune hyperparameters carefully to avoid overfitting. Techniques like Min-SNR weighting and prior preservation loss can be employed for stable training and to encourage diversity in generated images, respectively. 

For scenarios demanding high-quality outputs, particularly in generating human faces, it’s beneficial to train the text encoder alongside the model. This additional training step, though resource-intensive, significantly enhances the fidelity and detail of the generated images.

2. Technical Prerequisites

Fine-tuning a generative model like Stable Diffusion using DreamBooth is computationally intensive and requires substantial memory and processing power. 

To ensure the training process is efficient and to avoid potential bottlenecks, it is highly recommended to use a high-performance GPU. Here are the two primary recommendations for the computing environment:

  • High GPU Local Machine → Ensure that your GPU has sufficient VRAM (Video RAM) to handle the model’s requirements. High-end GPUs from NVIDIA, like those from the Tesla, Titan, or RTX series, are well-suited for this task. Remember to install appropriate drivers and CUDA toolkit
  • Google Colab with High-RAM GPU Runtime → Runtimes like V100 and A100 (the fastest to date), would definetely provide enough compute power to run the workloads.

2.1. Library Installations

To get started, the following libraries need to be installed. These libraries provide the necessary functions and classes for model fine-tuning, image processing, and data handling:

  • Diffusers: A library for diffusion models, specifically for fine-tuning and leveraging pre-trained models.
  • Accelerate: A library for distributed training and mixed precision.
  • TensorBoard: For visualizing training progress and metrics.
  • Transformers, FTFY, and Gradio: For model components, text processing, and creating web UIs for model interaction.
  • Bitsandbytes: For memory-efficient and fast training, particularly for optimizing model training on specific GPU architectures.
				
					!pip install -U -qq git+https://github.com/huggingface/diffusers.git
!pip install -qq accelerate tensorboard transformers ftfy gradio
!pip install -qq "ipywidgets>=7,<8"
!pip install -qq bitsandbytes
!pip install huggingface_hub
				
			

2.2 Optional Installations

For faster and memory-efficient training, especially if using specific types of GPUs (T4, P100, V100, A100), you may optionally install the following:

  • Xformers: A library that provides efficient transformers components.
  • Triton: For programming GPUs.

2.3 Code Dependencies

The code utilizes several libraries for its operations. Make sure you import the following libraries:

				
					import argparse
import itertools
import math
import os
from contextlib import nullcontext
import random

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset

import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

import bitsandbytes as bnb

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
				
			

3. Dataset Preparation

Proper dataset preparation is a critical step in fine-tuning a Stable Diffusion model with DreamBooth. 

This process requires you to select representative images actively, preprocess them, and organize them into a structured training format.

In this guide, we will be using the following 4 training images as an example.

quartet of images showcasing a striped cat figurine, used as training data for model fine-tuning entreprenerdly

Figure. 1: A quartet of images showcasing a striped cat figurine, used as training data for model fine-tuning.

3.1 Download and Visualize Training Images

We start by downloading the training images. The download_image function is defined to retrieve images from a specified list of URLs, which are then downloaded and converted into an RGB format for consistency. 

Lastly, each image is processed through this function and stored in a list. The system creates a dedicated directory to save these images locally for future reference and training purposes.

				
					urls = [
        "https://huggingface.co/datasets/Entreprenerdly/finetunestablediffusion/resolve/main/2.jpeg",
        "https://huggingface.co/datasets/Entreprenerdly/finetunestablediffusion/resolve/main/3.jpeg",
        "https://huggingface.co/datasets/Entreprenerdly/finetunestablediffusion/resolve/main/5.jpeg",
        "https://huggingface.co/datasets/Entreprenerdly/finetunestablediffusion/resolve/main/6.jpeg",
        ## Add additional images here 
       ]

import requests
import glob
from io import BytesIO

def download_image(url):
  try:
    response = requests.get(url)
  except:
    return None
  return Image.open(BytesIO(response.content)).convert("RGB")

images = list(filter(None,[download_image(url) for url in urls]))
save_path = "./my_concept"
if not os.path.exists(save_path):
  os.mkdir(save_path)
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
image_grid(images, 1, len(images))
				
			
The four training images of a cat figurine downloaded and processed for AI fine-tuning.

Figure. 2: The four training images of a cat figurine downloaded and processed for AI fine-tuning.

3.2 Settings for Newly Created Images

For fine-tuning a generative model with DreamBooth, specific settings are configured to define the new concept effectively. The instance_prompt is crucial as it includes a descriptive identifier that the model uses to recognize and generate the new concept—in this case, cat_toy

Next, the prior_preservation flag indicates whether the model should retain the broader class attributes during training. This helps in quality and generalization but may extend the training duration. 

Furthermore, accompanying this setting are parameters that determine the number of class images to generate, the batch size for sampling, and the weight of the preservation loss. 

These settings collectively ensure the model learns both the unique features of the new concept and the general characteristics of its category.

Here is the complete code snippet incorporating these settings:

				
					instance_prompt = "<cat-toy> toy"  # Descriptive prompt with unique identifier
prior_preservation = False  # Flag for enabling class characteristics preservation
prior_preservation_class_prompt = "a photo of a cat clay toy"  # Prompt for the class of the concept

# Parameters for class image generation and loss weighting
num_class_images = 12
sample_batch_size = 2
prior_loss_weight = 0.5
prior_preservation_class_folder = "./class_images"

# Directories for storing class images
class_data_root = prior_preservation_class_folder
class_prompt = prior_preservation_class_prompt
				
			

3.3 Setup Classes and Image Functions

We define the necessary Python classes and functions that will handle the data loading and preprocessing for fine-tuning the model. Moreover, the DreamBoothDataset class extends PyTorch’s Dataset and is tailored to manage the image data used to train the model. 

It takes care of loading images from the specified directories, applying the required transformations, and encoding the prompts using the provided tokenizer.

Key functionalities of this class include:

  • Initializing with paths to image data and prompts.
  • Transforming images to the required size and applying center or random cropping.
  • Normalizing images for model compatibility.
  • Handling the retrieval of instance and class images and associated prompts when accessed during training.

Another class, PromptDataset, is set up to handle the generation of prompts for the class images. This simple dataset structure stores the prompts and the number of samples to generate.

Furthermore, these classes are instrumental in structuring the training data for the DreamBooth fine-tuning process, ensuring that the model receives data in the expected format and with the necessary augmentations.

The relevant snippet from the Colab notebook is as follows:

				
					# Initialization of DreamBoothDataset with directory paths and settings
class DreamBoothDataset(Dataset):
    ...
    def __init__(self, instance_data_root, instance_prompt, tokenizer, class_data_root=None, class_prompt=None, size=512, center_crop=False):
        ...
        self.image_transforms = transforms.Compose([...])

    def __getitem__(self, index):
        ...
        return example

# Class for prompt dataset
class PromptDataset(Dataset):
    ...
				
			

The complete code can be found in the Google Colab Notebook here:

Related Articles

AI Singing Voice Cloning With AI In Python

End-to-End Python Guide for Data Processing, Training and Inference of AI Cloned voices. From Voice Data to using Pre-trained and Custom Models

Fine-Tuning StarCoder To Customize A Coding Assistant

A Comprehensive Guide Fine-tune a Code LLM on Private Code using a Single GPU to Enhance its Contextual Awareness

Finetuning LayoutLMv2 For Document Question Answering

A Step-by-Step Guide to Optimizing LayoutLMv2 for Enhanced Domain-Specific Document Question Answering Efficiency
Prev Post

AI Face Restoration in Images with CodeFormer

Next Post

Analyzing Rolling Z-Score in Stock Trading with Python

post-bars
Mail Icon

Newsletter

Get Every Weekly Update & Insights

[mc4wp_form id=]

One thought on “Fine-Tuning Your Own Custom Stable Diffusion Model with just 4 Images

Leave a Comment