[!NOTE]
This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
This example shows how to fine-tune Flux.1 Dev with LoRA and quantization. We show this by using the Norod78/Yarn-art-style dataset. Steps below summarize the workflow:
compute_embeddings.py and serialize them into a parquet file.
train_dreambooth_lora_flux_miniature.py takes care of training:
bitsandbytes, prepare it for 4bit training.To run training in a memory-optimized manner, we additionally use:
We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it’s extremely slow.
Ensure you have installed the required libraries:
pip install -U transformers accelerate bitsandbytes peft datasets
pip install git+https://github.com/huggingface/diffusers -U
Now, compute the text embeddings:
python compute_embeddings.py
It should create a file named embeddings.parquet. We’re then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
hf auth login
Then launch:
accelerate launch --config_file=accelerate.yaml \
train_dreambooth_lora_flux_miniature.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--data_df_path="embeddings.parquet" \
--output_dir="yarn_art_lora_flux_nf4" \
--mixed_precision="fp16" \
--use_8bit_adam \
--weighting_scheme="none" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--learning_rate=1e-4 \
--guidance_scale=1 \
--report_to="wandb" \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--cache_latents \
--rank=4 \
--max_train_steps=700 \
--seed="0"
We can directly pass a quantized checkpoint path, too:
+ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using torch.bfloat16.
We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:
pip install -Uq deepspeed
And then launch:
accelerate launch --config_file=ds2.yaml \
train_dreambooth_lora_flux_miniature.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--data_df_path="embeddings.parquet" \
--output_dir="yarn_art_lora_flux_nf4" \
--mixed_precision="no" \
--use_8bit_adam \
--weighting_scheme="none" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--learning_rate=1e-4 \
--guidance_scale=1 \
--report_to="wandb" \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--cache_latents \
--rank=4 \
--max_train_steps=700 \
--seed="0"
When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:
from diffusers import FluxPipeline
import torch
ckpt_id = "black-forest-labs/FLUX.1-dev"
pipeline = FluxPipeline.from_pretrained(
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
)
pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline.transformer.save_pretrained("fused_transformer")
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
import torch
ckpt_id = "black-forest-labs/FLUX.1-dev"
bnb_4bit_compute_dtype = torch.float16
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
)
transformer = FluxTransformer2DModel.from_pretrained(
"fused_transformer",
quantization_config=nf4_config,
torch_dtype=bnb_4bit_compute_dtype,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
)
pipeline.enable_model_cpu_offload()
image = pipeline(
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
).images[0]
image.save("yarn_merged.png")
| Dequantize, merge, quantize | Merging directly into quantized model |
|---|---|
![]() |
![]() |
As we can notice the first column result follows the style more closely.