文本反转(Textual Inversion)

文本反转是一种训练技术,仅需少量示例图像即可个性化图像生成模型。该技术通过学习和更新文本嵌入(新嵌入会绑定到提示中必须使用的特殊词汇)来匹配您提供的示例图像。

如果在显存有限的GPU上训练,建议在训练命令中启用gradient_checkpointingmixed_precision参数。您还可以通过xFormers使用内存高效注意力机制来减少内存占用。JAX/Flax训练也支持在TPU和GPU上进行高效训练,但不支持梯度检查点或xFormers。在配置与PyTorch相同的情况下,Flax训练脚本的速度至少应快70%!

本指南将探索textual_inversion.py脚本,帮助您更熟悉其工作原理,并了解如何根据自身需求进行调整。

运行脚本前,请确保从源码安装库:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

进入包含训练脚本的示例目录,并安装所需依赖:

```bash cd examples/textual_inversion pip install -r requirements.txt ``` ```bash cd examples/textual_inversion pip install -r requirements_flax.txt ```

[!TIP] 🤗 Accelerate 是一个帮助您在多GPU/TPU或混合精度环境下训练的工具库。它会根据硬件和环境自动配置训练设置。查看🤗 Accelerate 快速入门了解更多。

初始化🤗 Accelerate环境:

accelerate config

要设置默认的🤗 Accelerate环境(不选择任何配置):

accelerate config default

如果您的环境不支持交互式shell(如notebook),可以使用:

from accelerate.utils import write_basic_config

write_basic_config()

最后,如果想在自定义数据集上训练模型,请参阅创建训练数据集指南,了解如何创建适用于训练脚本的数据集。

[!TIP] 以下部分重点介绍训练脚本中需要理解的关键修改点,但未涵盖脚本所有细节。如需深入了解,可随时查阅脚本源码,如有疑问欢迎反馈。

脚本参数

训练脚本包含众多参数,便于您定制训练过程。所有参数及其说明都列在parse_args()函数中。Diffusers为每个参数提供了默认值(如训练批次大小和学习率),但您可以通过训练命令自由调整这些值。

例如,将梯度累积步数增加到默认值1以上:

accelerate launch textual_inversion.py \
  --gradient_accumulation_steps=4

其他需要指定的基础重要参数包括:

训练脚本

与其他训练脚本不同,textual_inversion.py包含自定义数据集类TextualInversionDataset,用于创建数据集。您可以自定义图像尺寸、占位符词汇、插值方法、是否裁剪图像等。如需修改数据集创建方式,可调整TextualInversionDataset类。

接下来,在main()函数中可找到数据集预处理代码和训练循环。

脚本首先加载tokenizerscheduler和模型

# 加载tokenizer
if args.tokenizer_name:
    tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")

# 加载scheduler和模型
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

随后将特殊占位符词汇加入tokenizer,并调整嵌入层以适配新词汇。

接着,脚本通过TextualInversionDataset创建数据集

train_dataset = TextualInversionDataset(
    data_root=args.train_data_dir,
    tokenizer=tokenizer,
    size=args.resolution,
    placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
    repeats=args.repeats,
    learnable_property=args.learnable_property,
    center_crop=args.center_crop,
    set="train",
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)

最后,训练循环处理从预测噪声残差到更新特殊占位符词汇嵌入权重的所有流程。

如需深入了解训练循环工作原理,请参阅理解管道、模型与调度器教程,该教程解析了去噪过程的基本模式。

启动脚本

完成所有修改或确认默认配置后,即可启动训练脚本!🚀

本指南将下载猫玩具的示例图像并存储在目录中。当然,您也可以创建和使用自己的数据集(参见创建训练数据集指南)。

from huggingface_hub import snapshot_download

local_dir = "./cat"
snapshot_download(
    "diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
)

设置环境变量MODEL_NAME为Hub上的模型ID或本地模型路径,DATA_DIR为刚下载的猫图像路径。脚本会将以下文件保存至您的仓库:

[!WARNING] 在单块V100 GPU上完整训练约需1小时。

启动脚本前还有最后一步。如果想实时观察训练过程,可以定期保存生成图像。在训练命令中添加以下参数:

--validation_prompt="A <cat-toy> train"
--num_validation_images=4
--validation_steps=100
```bash export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5" export DATA_DIR="./cat" accelerate launch textual_inversion.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATA_DIR \ --learnable_property="object" \ --placeholder_token="" \ --initializer_token="toy" \ --resolution=512 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --max_train_steps=3000 \ --learning_rate=5.0e-04 \ --scale_lr \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --output_dir="textual_inversion_cat" \ --push_to_hub ``` </hfoption> ```bash export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export DATA_DIR="./cat" python textual_inversion_flax.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATA_DIR \ --learnable_property="object" \ --placeholder_token="" \ --initializer_token="toy" \ --resolution=512 \ --train_batch_size=1 \ --max_train_steps=3000 \ --learning_rate=5.0e-04 \ --scale_lr \ --output_dir="textual_inversion_cat" \ --push_to_hub ``` </hfoption> </hfoptions> 训练完成后,可以像这样使用新模型进行推理: ```py from diffusers import StableDiffusionPipeline import torch pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") pipeline.load_textual_inversion("sd-concepts-library/cat-toy") image = pipeline("A train", num_inference_steps=50).images[0] image.save("cat-train.png") ``` </hfoption> Flax不支持[`~loaders.TextualInversionLoaderMixin.load_textual_inversion`]方法,但textual_inversion_flax.py脚本会在训练后[保存](https://github.com/huggingface/diffusers/blob/c0f058265161178f2a88849e92b37ffdc81f1dcc/examples/textual_inversion/textual_inversion_flax.py#L636C2-L636C2)学习到的嵌入作为模型的一部分。这意味着您可以像使用其他Flax模型一样进行推理: ```py import jax import numpy as np from flax.jax_utils import replicate from flax.training.common_utils import shard from diffusers import FlaxStableDiffusionPipeline model_path = "path-to-your-trained-model" pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) prompt = "A train" prng_seed = jax.random.PRNGKey(0) num_inference_steps = 50 num_samples = jax.device_count() prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) # 分片输入和随机数生成器 params = replicate(params) prng_seed = jax.random.split(prng_seed, jax.device_count()) prompt_ids = shard(prompt_ids) images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) image.save("cat-train.png") ``` </hfoption> </hfoptions> ## 后续步骤 恭喜您成功训练了自己的文本反转模型!🎉 如需了解更多使用技巧,以下指南可能会有所帮助: - 学习如何[加载文本反转嵌入](../using-diffusers/loading_adapters),并将其用作负面嵌入 - 学习如何将[文本反转](textual_inversion_inference)应用于Stable Diffusion 1/2和Stable Diffusion XL的推理