🦩 OpenFlamingo

PyPI version

Blog post | Paper (coming soon)

欢迎使用我们的 DeepMind Flamingo 模型的开源版本!在此存储库中,我们提供了用于训练和评估 OpenFlamingo 模型的 PyTorch 实现。我们还提供了一个初始的 OpenFlamingo 9B 模型,该模型在一个新的 Multimodal C4 数据集上进行了训练。请参阅我们的博客文章了解更多详情。

这个 repo 仍在开发中,我们希望尽快发布性能更好、更大的 OpenFlamingo 模型。如果您有任何问题,请随时打开一个问题。我们也欢迎投稿!

目录

安装

要在现有环境中安装包,请运行

pip install open-flamingo

或者创建运行 OpenFlamingo 的 conda 环境,运行

conda env create -f environment.yml

用法

我们使用 CLIP ViT-Large 视觉编码器和 LLaMA-7B 语言模型提供初始 OpenFlamingo 9B 模型。一般来说,我们支持任何 CLIP 视觉编码器。对于语言模型,我们支持 LLaMA , OPT , [ GPT-Neo ] (https://huggingface.co/models?search=gpt-neo)、[GPT-J](https://huggingface.co/models?search=gptj )和[ Pythia ](https://huggingface . co/models?search=pythia) 模型。

注意:要使用 LLaMA 模型,您需要通过安装最新版本的变压器

pip install git+https://github.com/huggingface/transformers

使用此 脚本 将 LLaMA 权重转换为 HuggingFace 格式。

初始化 OpenFlamingo 模型

from open_flamingo import create_model_and_transforms
model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="<path to llama weights in HuggingFace format>",
    tokenizer_path="<path to llama tokenizer in HuggingFace format>",
    cross_attn_every_n_layers=4
)
# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)

生成文本

这是一个以交错图像/文本为条件生成文本的示例,在这种情况下,我们将进行少镜头图像字幕。

from PIL import Image
import requests
"""
第 1 步:加载图像
"""
demo_image_one = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
    ).raw
)
demo_image_two = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
        stream=True
    ).raw
)
query_image = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 
        stream=True
    ).raw
)
"""
第 2 步:预处理图像
详细信息:对于 OpenFlamingo,我们希望图像是形状为火炬的张量
batch_size x num_media x num_frames x channels x height x width。
在这种情况下 batch_size = 1, num_media = 3, num_frames = 1
(对于我们尚不支持的视频,这始终是一种期望),
通道 = 3,高度 = 224,宽度 = 224。
"""
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
"""
第 3 步:预处理文本
详细信息:在文本中,我们希望使用 <image> 特殊标记来指示图像的位置。
我们还期望一个 <|endofchunk|> 特殊标记来指示文本的结尾
与图像关联的部分。
"""
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
    ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
    return_tensors="pt",
)
"""
第 4 步:生成文本
"""
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)
print("Generated text: ", tokenizer.decode(generated_text[0]))

方法

OpenFlamingo 是一种多模态语言模型,可用于多种任务。它在大型多模式数据集(例如 Multimodal C4 )上进行训练,可用于生成以交错图像/文本为条件的文本。例如,OpenFlamingo 可用于为图像生成标题,或根据图像和文本段落生成问题。这种方法的好处是我们能够使用上下文训练快速适应新任务。

模型架构

OpenFlamingo 试图融合预训练视觉编码器和使用交叉注意力层的语言模型。模型架构如下图所示。

 OpenFlamingo 架构
图片来源: Flamingo

训练

要训练模型,请修改以下示例命令,该命令使用 OPT 1.3B 作为示例 LM:

torchrun --nnodes=1 --nproc_per_node=4 train.py \
--run_name flamingo3B \
--lm_path facebook/opt-1.3b \
--tokenizer_path facebook/opt-1.3b \
--dataset_resampled \
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
--batch_size_mmc4 4 \
--batch_size_laion 8 \
--train_num_samples_mmc4 125000 \
--train_num_samples_laion 250000 \
--loss_multiplier_laion 0.2 \
--workers=6 \
--num_epochs 250 \
--lr_scheduler constant \
--warmup_steps 5000 \
--use_media_placement_augmentation \
--mmc4_textsim_threshold 30

数据集

我们希望我们所有的训练数据集都是 WebDataset 碎片。
我们在 LAION 2B Multimodal C4 数据集上训练我们的模型。默认情况下,如果使用 img2dataset 工具 下载 LAION 2B 数据集,则其为 WebDataset 格式,并且 Multimodal C4 以 WebDataset 格式打包。

评估

我们目前支持在 COCO VQAv2 OKVQA 上运行评估.org)、 Flickr30k ImageNet 。请注意,目前这些评估是在验证模式下运行的(如 Flamingo 论文中所指定)。我们将在未来添加对在测试模式下运行评估的支持。

在评估模型之前,您需要通过运行以下命令来安装 coco 评估包:

pip install pycocoevalcap

要在 OKVQA 上运行评估,您需要运行以下命令:

import nltk
nltk.download('wordnet')

要评估模型,请运行位于open_flamingo/scripts/run_eval.sh的脚本

作者:Jeebiz  创建时间:2023-12-12 12:36
最后编辑:Jeebiz  更新时间:2025-05-12 09:20