Batch processing

#12
by cora-17 - opened

Can anyone share the batch processing code or refer me to it.

Thanks

interested as well!

inputs = self.processor.apply_chat_template(
promote_t,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
padding=True
)
input_ids = inputs.input_ids.squeeze(0)
attention_mask = inputs.attention_mask.squeeze(0)
pixel_values = inputs.pixel_values.squeeze(0) if 'pixel_values' in inputs else None
image_grid_thw = inputs.image_grid_thw.squeeze(0) if 'image_grid_thw' in inputs else None

        yield {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "image_grid_thw": image_grid_thw,
            "meta": (rawid, name, pic_url)
        }

以上处理 数据,返回一个迭代器, promote_t 跟官网 单条格式一样就可以,注意:图片一定要resize 一下,大小要一样
下面批量合并函数
def collate_fn(batch, pad_token_id):
"""
左填充实现:把每条样本右对齐(填充在左侧)。
batch: list of items from Dataset.getitem
返回 dict(input_ids, attention_mask, metas)
"""
batch_size = len(batch)
lengths = [item["input_ids"].size(0) for item in batch]
max_len = max(lengths)
input_ids_padded = torch.full((batch_size, max_len), pad_token_id, dtype=torch.long)
attention_mask_padded = torch.zeros((batch_size, max_len), dtype=torch.long)
pixel_values_list = []
image_grid_thw_list = []
metas = []
for i, item in enumerate(batch):
seq = item["input_ids"]
L = seq.size(0)
input_ids_padded[i, max_len - L:] = seq # 右对齐,左侧为 pad
attention_mask_padded[i, max_len - L:] = item["attention_mask"]
pixel_values_list.append(item["pixel_values"])
image_grid_thw_list.append(item["image_grid_thw"])
metas.append(item["meta"])
pixel_values = torch.stack(pixel_values_list, dim=0) # B x C x H x W
image_grid_thw = torch.stack(image_grid_thw_list, dim=0)
# print()
return {
"input_ids": input_ids_padded,
"attention_mask": attention_mask_padded,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
"metas": metas
}

MyIterableDataset 关键代码上面以提供

dataset = MyIterableDataset(INPUT_FILE, processor, rank, world_size)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    # sampler=sampler,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: collate_fn(b, pad_token_id),
    persistent_workers=PERSISTENT_WORKERS if NUM_WORKERS > 0 else False,
)

预测代码
for batch in tqdm(dataloader, desc=f"Rank {rank} infer"):
# 把 inputs 移到 GPU,使用 non_blocking=True(需要 pin_memory=True)
input_ids = batch["input_ids"].to(device_str, non_blocking=True)
attention_mask = batch["attention_mask"].to(device_str, non_blocking=True)
pixel_values = batch['pixel_values'].to(device_str, non_blocking=True)
image_grid_thw = batch['image_grid_thw'].to(device_str, non_blocking=True)
metas = batch["metas"]

    seq_len = input_ids.size(1)  # 输入序列的长度(包含左 pad)

    with torch.no_grad():
        generated = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
            max_new_tokens=MAX_NEW_TOKENS,
            pad_token_id=pad_token_id,
            # use_cache=False,
        )

    # 按 seq_len 裁剪生成的后缀部分
    batch_texts = []
    for i in range(generated.size(0)):
        gen_ids = generated[i, seq_len:].cpu().numpy().tolist()
        batch_texts.append(gen_ids)
    # 删除结果
    # del generated

    # decode(processor 在每个进程中已经可用)
    decoded_texts = processor.batch_decode(
        batch_texts, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

Sign up or log in to comment