Spaces:
Build error
Build error
| from random import uniform | |
| import torch | |
| import os | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import time | |
| from IPython.core.debugger import set_trace | |
| from dataloader.creatidesign_dataset_benchmark import DesignDataset,visualize_bbox,collate_fn,tensor_to_pil,make_image_grid_RGB | |
| import numpy as np | |
| from PIL import Image | |
| from safetensors.torch import save_file, load_file | |
| from accelerate import load_checkpoint_and_dispatch | |
| from modules.flux.transformer_flux_creatidesign import FluxTransformer2DModel | |
| from pipeline.pipeline_flux_creatidesign import FluxPipeline | |
| import json | |
| from huggingface_hub import snapshot_download | |
| from datasets import load_dataset | |
| if __name__ == "__main__": | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| weight_dtype = torch.bfloat16 | |
| resolution = 1024 | |
| condition_resolution = 512 | |
| neg_condition_image = 'same' | |
| background_color = 'gray' | |
| use_bucket = True | |
| condition_resolution_scale_ratio=0.5 | |
| benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark | |
| datasets = DesignDataset(dataset_name=benchmark_repo, | |
| resolution=resolution, | |
| condition_resolution=condition_resolution, | |
| neg_condition_image =neg_condition_image, | |
| background_color=background_color, | |
| use_bucket=use_bucket, | |
| condition_resolution_scale_ratio=condition_resolution_scale_ratio | |
| ) | |
| test_dataloader = DataLoader(datasets, batch_size=1, shuffle=False, num_workers=4,collate_fn=collate_fn) | |
| model_path = "black-forest-labs/FLUX.1-dev" | |
| ckpt_repo = "HuiZhang0812/CreatiDesign" # huggingface repo of ckpt | |
| ckpt_path = snapshot_download( | |
| repo_id=ckpt_repo, | |
| repo_type="model", | |
| local_dir="./CreatiDesign_checkpoint", | |
| local_dir_use_symlinks=False | |
| ) | |
| # Load transformer config from checkpoint | |
| with open(os.path.join(ckpt_path, "transformer", "config.json"), 'r') as f: | |
| config = json.load(f) | |
| transformer = FluxTransformer2DModel(**config) | |
| transformer = load_checkpoint_and_dispatch(transformer, checkpoint=os.path.join(model_path,"transformer"), device_map=None) | |
| # Load lora parameters using safetensors | |
| state_dict = load_file(os.path.join(ckpt_path, "transformer","model.safetensors")) | |
| # Load parameters, allow partial loading | |
| missing_keys, unexpected_keys = transformer.load_state_dict(state_dict, strict=False) | |
| print(f"Loaded parameters: {len(state_dict)}",state_dict.keys()) | |
| print(f"Missing keys: {len(missing_keys)}",missing_keys) | |
| print(f"Unexpected keys: {len(unexpected_keys)}",unexpected_keys) | |
| transformer = transformer.to(dtype=torch.bfloat16) | |
| pipe = FluxPipeline.from_pretrained(model_path, transformer=transformer,torch_dtype=torch.bfloat16) | |
| pipe = pipe.to("cuda") | |
| seed=42 | |
| num_samples = 1 | |
| true_cfg_scale=3.5 | |
| guidance_scale=1.0 | |
| if resolution == 512: | |
| position_delta=[0,-32] | |
| else: | |
| position_delta=[0,-64] | |
| if use_bucket: | |
| scale_h = 1/condition_resolution_scale_ratio | |
| scale_w = 1/condition_resolution_scale_ratio | |
| else: | |
| scale_h = resolution/condition_resolution | |
| scale_w = resolution/condition_resolution | |
| num_inference_steps = 28 | |
| # Create save directory based on benchmark directory name | |
| save_root =os.path.join("outputs",benchmark_repo.split("/")[-1]) | |
| os.makedirs(save_root,exist_ok=True) | |
| img_save_root = os.path.join(save_root,"images") | |
| os.makedirs(img_save_root,exist_ok=True) | |
| img_withgt_save_root = os.path.join(save_root,"images_with_gt") | |
| os.makedirs(img_withgt_save_root,exist_ok=True) | |
| total_time = 0 | |
| for i, batch in enumerate(tqdm(test_dataloader)): | |
| prompts = batch["caption"] | |
| imgs_id = batch['id'] | |
| objects_boxes = batch["objects_boxes"] | |
| objects_caption = batch['objects_caption'] | |
| objects_masks = batch['objects_masks'] | |
| condition_img = batch['condition_img'] | |
| neg_condtion_img = batch['neg_condtion_img'] | |
| objects_masks_maps= batch['objects_masks_maps'] | |
| subject_masks_maps = batch['condition_img_masks_maps'] | |
| target_width=batch['target_width'][0] | |
| target_height=batch['target_height'][0] | |
| img_info = batch["img_info"][0] | |
| filename = img_info["img_id"]+'.jpg' | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| images = pipe(prompt=prompts*num_samples, | |
| generator=torch.Generator(device="cuda").manual_seed(seed), | |
| num_inference_steps = num_inference_steps, | |
| objects_boxes=objects_boxes, | |
| objects_caption=objects_caption, | |
| objects_masks = objects_masks, | |
| objects_masks_maps=objects_masks_maps, | |
| condition_img = condition_img, | |
| subject_masks_maps = subject_masks_maps, | |
| neg_condtion_img = neg_condtion_img, | |
| height= target_height, | |
| width = target_width, | |
| true_cfg_scale = true_cfg_scale, | |
| position_delta=position_delta, | |
| guidance_scale=guidance_scale, | |
| scale_h = scale_h, | |
| scale_w = scale_w, | |
| use_bucket=use_bucket | |
| ) | |
| images=images.images | |
| use_time = time.time() - start_time | |
| total_time +=use_time | |
| make_image_grid_RGB(images, rows=1, cols=num_samples).save(os.path.join(img_save_root,filename)) | |
| use_time = time.time() - start_time | |
| total_time +=use_time | |
| # Process original image and bounding boxes | |
| ori_image = tensor_to_pil(batch['img'][0]) | |
| orig_width, orig_height = ori_image.size | |
| normalized_boxes = batch['objects_boxes'][0].cpu().numpy() | |
| denormalized_boxes = [] | |
| for box in normalized_boxes: | |
| x1, y1, x2, y2 = box | |
| denorm_box = [ | |
| x1 * orig_width, # x1 | |
| y1 * orig_height, # y1 | |
| x2 * orig_width, # x2 | |
| y2 * orig_height # y2 | |
| ] | |
| denormalized_boxes.append(denorm_box) | |
| objects_result = { | |
| "boxes": denormalized_boxes, | |
| "labels": batch['objects_caption'][0], | |
| "masks": [] | |
| } | |
| # Only keep boxes and captions where mask is 1 | |
| valid_boxes = [] | |
| valid_labels = [] | |
| for box, label, mask in zip(objects_result['boxes'], | |
| objects_result['labels'], | |
| batch['objects_masks'][0]): | |
| if mask: | |
| valid_boxes.append(box) | |
| valid_labels.append(label) | |
| objects_result['boxes'] = valid_boxes | |
| objects_result['labels'] = valid_labels | |
| ori_image_with_bbox = visualize_bbox(ori_image ,objects_result) | |
| # Concatenate images | |
| total_width = ori_image.width + ori_image.width+ num_samples*ori_image.width | |
| max_height = ori_image.height | |
| # Create a new blank image to hold the concatenated images | |
| new_image = Image.new('RGB', (total_width, max_height)) | |
| new_image.paste(ori_image_with_bbox, (0, 0)) | |
| # Process condition image | |
| condition_img = tensor_to_pil(batch['original_size_condition_img'][0]) | |
| subject_canvas_with_bbox = visualize_bbox(condition_img ,objects_result) | |
| new_image.paste(subject_canvas_with_bbox, (ori_image.width, 0)) | |
| # Paste generated images | |
| for j, image in enumerate(images): | |
| save_name=os.path.join(img_withgt_save_root,filename) | |
| image_with_bbox = visualize_bbox(image ,objects_result) | |
| new_image.paste(image_with_bbox, (ori_image.width*(j+2), 0)) | |
| new_image.save(save_name) | |
| print(f"Total inference time: {total_time:.2f} seconds") | |
| print(f"Average time per image: {total_time/len(test_dataloader):.2f} seconds") |