AniDoc / scripts_infer /anidoc_inference.py
fffiloni's picture
Update scripts_infer/anidoc_inference.py
df77222 verified
import sys
import os
import types
from pyparsing import col
sys.path.insert(0, ".")
# ---- huggingface_hub compatibility for old diffusers ----
import huggingface_hub
if not hasattr(huggingface_hub, "cached_download"):
from huggingface_hub import hf_hub_download
huggingface_hub.cached_download = hf_hub_download
# ---- torchvision compatibility for old basicsr ----
import torchvision.transforms.functional as TVF
if "torchvision.transforms.functional_tensor" not in sys.modules:
functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor")
functional_tensor.rgb_to_grayscale = TVF.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
# ---------------------------------------------------------
import argparse
from packaging import version
import glob
from LightGlue.lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
from LightGlue.lightglue.utils import load_image, rbd
from cotracker.predictor import CoTrackerPredictor, sample_trajectories, generate_gassian_heatmap, sample_trajectories_with_ref
import torch
from diffusers.utils.import_utils import is_xformers_available
from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from pipelines.AniDoc import AniDocPipeline
from models_diffusers.controlnet_svd import ControlNetSVDModel
from diffusers.utils import load_image, export_to_video, export_to_gif
import time
from lineart_extractor.annotator.lineart import LineartDetector
import numpy as np
from PIL import Image
from utils import load_images_from_folder,export_gif_with_ref,export_gif_side_by_side,extract_frames_from_video,safe_round,select_multiple_points,generate_point_map,generate_point_map_frames,export_gif_side_by_side_complete,export_gif_side_by_side_complete_ablation
import random
import torchvision.transforms as T
from LightGlue.lightglue import viz2d
import matplotlib.pyplot as plt
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from torchvision.transforms import PILToTensor
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_name_or_path", type=str, default="pretrained_weights/stable-video-diffusion-img2vid-xt", help="Path to the input image.")
parser.add_argument(
"--pretrained_unet", type=str, help="Path to the input image.",
default="pretrained_weights/anidoc"
)
parser.add_argument(
"--controlnet_model_name_or_path", type=str, help="Path to the input image.",
default="pretrained_weights/anidoc/controlnet"
)
parser.add_argument("--output_dir", type=str, default=None, help="Path to the output video.")
parser.add_argument("--seed", type=int, default=42, help="random seed.")
parser.add_argument("--noise_aug", type=float, default=0.02)
parser.add_argument("--num_frames", type=int, default=14)
parser.add_argument("--width", type=int, default=512)
parser.add_argument("--height", type=int, default=320)
parser.add_argument("--all_sketch",action="store_true",help="all_sketch")
parser.add_argument("--not_quant_sketch",action="store_true",help="not_quant_sketch")
parser.add_argument("--repeat_sketch",action="store_true",help="not_quant_sketch")
parser.add_argument("--matching",action="store_true",help="add keypoint matching")
parser.add_argument("--tracking",action="store_true",help="tracking keypoint")
parser.add_argument("--repeat_matching",action="store_true",help="not tracking, but just simply repeat")
parser.add_argument("--tracker_point_init", type=str, default='gaussion', choices=['dift', 'gaussion', 'both'], help="Regular grid size")
parser.add_argument(
"--tracker_shift_grid",
type=int, default=0, choices=[0, 1],
help="shift the grid for the tracker")
parser.add_argument("--tracker_grid_size", type=int, default=8, help="Regular grid size")
parser.add_argument(
"--tracker_grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame",
)
parser.add_argument(
"--tracker_backward_tracking",
action="store_true",
help="Compute tracks in both directions, not only forward",
)
parser.add_argument("--control_image", type=str, default=None, help="Path to the output video.")
parser.add_argument("--ref_image", type=str, default=None, help="Path to the output video.")
parser.add_argument("--max_points", "--max_point", dest="max_points", type=int, default=10)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
dtype = torch.float16
unet = UNetSpatioTemporalConditionModel.from_pretrained(
args.pretrained_unet,
subfolder="unet",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
custom_resume=True,
)
unet.to("cuda",dtype)
if args.controlnet_model_name_or_path:
controlnet = ControlNetSVDModel.from_pretrained(
args.controlnet_model_name_or_path,
)
else:
controlnet = ControlNetSVDModel.from_unet(
unet,
conditioning_channels=8
)
controlnet.to("cuda",dtype)
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly")
pipe = AniDocPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
controlnet=controlnet,
low_cpu_mem_usage=False,
torch_dtype=torch.float16, variant="fp16"
)
pipe.to("cuda")
device = "cuda"
detector = LineartDetector(device)
extractor = SuperPoint(max_num_keypoints=2000).eval().to(device) # load the extractor
matcher = LightGlue(features='superpoint').eval().to(device) # load the matcher
tracker = CoTrackerPredictor(
checkpoint="pretrained_weights/cotracker2.pth",
shift_grid=args.tracker_shift_grid,
)
tracker.requires_grad_(False)
tracker.to(device, dtype=torch.float32)
width, height = args.width, args.height
# image = load_image('dalle3_cat.jpg')
if args.output_dir is None:
args.output_dir = "results"
os.makedirs(args.output_dir, exist_ok=True)
image_folder_list=[
'data_test/sample1.mp4',
]
ref_image_list=[
"data_test/sample1.png",
]
if args.ref_image is not None and args.control_image is not None:
ref_image_list=[args.ref_image]
image_folder_list=[args.control_image]
for val_id ,each_sample in enumerate(image_folder_list):
if os.path.isdir(each_sample):
control_images=load_images_from_folder(each_sample)
elif each_sample.endswith(".mp4"):
control_images = extract_frames_from_video(each_sample)
ref_image=load_image(ref_image_list[val_id]).resize((width, height))
#resize:
for j, each in enumerate(control_images):
control_images[j]=control_images[j].resize((width, height))
# load image from folder
if args.all_sketch:
controlnet_image=[]
for k in range(len(control_images)):
sketch=control_images[k]
sketch = np.array(sketch)
sketch=detector(sketch,coarse=False)
sketch=np.repeat(sketch[:, :, np.newaxis], 3, axis=2)
if args.not_quant_sketch:
pass
else:
sketch= (sketch > 200).astype(np.uint8)*255
sketch = Image.fromarray(sketch).resize((width, height))
controlnet_image.append(sketch)
controlnet_sketch_condition = [T.ToTensor()(img).unsqueeze(0) for img in controlnet_image]
controlnet_sketch_condition = torch.cat(controlnet_sketch_condition, dim=0).unsqueeze(0).to(device, dtype=torch.float16)
controlnet_sketch_condition = (controlnet_sketch_condition - 0.5) / 0.5 #(1,14,3,h,w)
# matching condition
with torch.no_grad():
ref_img_value = T.ToTensor()(ref_image).to(device, dtype=torch.float16) #(0,1)
ref_img_value = ref_img_value.to(torch.float32)
current_img= T.ToTensor()(controlnet_image[0]).to(device, dtype=torch.float16) #(0,1)
current_img = current_img.to(torch.float32)
feats0 = extractor.extract(ref_img_value)
feats1 = extractor.extract(current_img)
matches01 = matcher({'image0': feats0, 'image1': feats1})
feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]]
matches = matches01['matches']
points0 = feats0['keypoints'][matches[..., 0]]
points1 = feats1['keypoints'][matches[..., 1]]
points0 = points0.cpu().numpy()
# points0_org=points0.copy()
points1 = points1.cpu().numpy()
points0 = safe_round(points0, current_img.shape)
points1 = safe_round(points1, current_img.shape)
num_points = min(50, points0.shape[0])
points0,points1 = select_multiple_points(points0, points1, num_points)
mask1, mask2 = generate_point_map(size=current_img.shape, coords0=points0, coords1=points1)
# import ipdb;ipdb.set_trace()
point_map1=torch.from_numpy(mask1)
point_map2=torch.from_numpy(mask2)
point_map1 = point_map1.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device, dtype=torch.float16)
point_map2 = point_map2.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device, dtype=torch.float16)
point_map=torch.cat([point_map1,point_map2],dim=2)
conditional_pixel_values=ref_img_value.unsqueeze(0).unsqueeze(0)
conditional_pixel_values = (conditional_pixel_values - 0.5) / 0.5
point_map_with_ref= torch.cat([point_map,conditional_pixel_values],dim=2)
original_shape = list(point_map_with_ref.shape)
new_shape = original_shape.copy()
new_shape[1] = args.num_frames-1
if args.repeat_matching:
matching_controlnet_image=point_map_with_ref.repeat(1,args.num_frames,1,1,1)
controlnet_condition=torch.cat([controlnet_sketch_condition, matching_controlnet_image], dim=2)
elif args.tracking:
with torch.no_grad():
video_for_tracker = (controlnet_sketch_condition * 0.5 + 0.5) * 255.
queries = np.insert(points1,0,0,axis=1)
queries =torch.from_numpy(queries).to(device,torch.float).unsqueeze(0)
if queries.shape[1]==0:
pred_tracks_sampled=None
points0_sampled = None
else:
pred_tracks, pred_visibility = tracker(
video_for_tracker.to(dtype=torch.float32),
queries=queries,
grid_size=args.tracker_grid_size, # 8
grid_query_frame=args.tracker_grid_query_frame, # 0
backward_tracking=args.tracker_backward_tracking, # False
# segm_mask=segm_mask,
)
pred_tracks_sampled, pred_visibility_sampled,points0_sampled = sample_trajectories_with_ref(
pred_tracks.cpu(), pred_visibility.cpu(), torch.from_numpy(points0).unsqueeze(0).cpu(),
max_points=args.max_points,
motion_threshold=1,
vis_threshold=3,
)
if pred_tracks_sampled is None:
mask1 = np.zeros((args.height, args.width), dtype=np.uint8)
mask2 = np.zeros((args.num_frames,args.height, args.width), dtype=np.uint8)
else:
pred_tracks_sampled = pred_tracks_sampled.squeeze(0).cpu().numpy()
pred_visibility_sampled =pred_visibility_sampled.squeeze(0).cpu().numpy()
points0_sampled =points0_sampled.squeeze(0).cpu().numpy()
for frame_id in range(args.num_frames):
pred_tracks_sampled[frame_id] = safe_round(pred_tracks_sampled[frame_id],current_img.shape)
points0_sampled = safe_round(points0_sampled,current_img.shape)
mask1, mask2 = generate_point_map_frames(size=current_img.shape, coords0=points0_sampled,coords1=pred_tracks_sampled,visibility=pred_visibility_sampled)
point_map1=torch.from_numpy(mask1)
point_map2=torch.from_numpy(mask2)
point_map1 = point_map1.unsqueeze(0).unsqueeze(0).repeat(1,args.num_frames,1,1,1).to(device, dtype=torch.float16)
point_map2 = point_map2.unsqueeze(0).unsqueeze(2).to(device, dtype=torch.float16)
point_map=torch.cat([point_map1,point_map2],dim=2)
conditional_pixel_values_repeat=conditional_pixel_values.repeat(1,14,1,1,1)
point_map_with_ref= torch.cat([point_map,conditional_pixel_values_repeat],dim=2)
controlnet_condition= torch.cat([controlnet_sketch_condition, point_map_with_ref], dim=2)
else:
zero_tensor = torch.zeros(new_shape).to(device, dtype=torch.float16)
matching_controlnet_image=torch.cat((point_map_with_ref,zero_tensor),dim=1)
controlnet_condition = torch.cat([controlnet_sketch_condition, matching_controlnet_image], dim=2)
ref_base_name=os.path.splitext(os.path.basename(ref_image_list[val_id]))[0]
sketch_base_name=os.path.splitext(os.path.basename(each_sample))[0]
supp_dir=os.path.join(args.output_dir,ref_base_name+"_"+sketch_base_name)
os.makedirs(supp_dir, exist_ok=True)
elif args.repeat_sketch:
controlnet_image=[]
for i_2 in range(int(len(control_images)/2)):
sketch=control_images[0]
sketch = np.array(sketch)
sketch=detector(sketch,coarse=False)
sketch=np.repeat(sketch[:, :, np.newaxis], 3, axis=2)
if args.not_quant_sketch:
pass
else:
sketch= (sketch > 200).astype(np.uint8)*255
sketch = Image.fromarray(sketch)
controlnet_image.append(sketch)
for i_3 in range(int(len(control_images)/2)):
sketch=control_images[-1]
sketch = np.array(sketch)
sketch=detector(sketch,coarse=False)
sketch=np.repeat(sketch[:, :, np.newaxis], 3, axis=2)
if args.not_quant_sketch:
pass
else:
sketch= (sketch > 200).astype(np.uint8)*255
sketch = Image.fromarray(sketch)
controlnet_image.append(sketch)
generator = torch.manual_seed(args.seed)
with torch.inference_mode():
video_frames = pipe(
ref_image,
controlnet_condition,
height=args.height,
width=args.width,
num_frames=14,
decode_chunk_size=8,
motion_bucket_id=127,
fps=7,
noise_aug_strength=0.02,
generator=generator,
).frames[0]
out_file = supp_dir+'.mp4'
if args.all_sketch:
export_gif_side_by_side_complete_ablation(ref_image,controlnet_image,video_frames,out_file.replace('.mp4','.gif'),supp_dir,6)
elif args.repeat_sketch:
export_gif_with_ref(control_images[0],video_frames,controlnet_image[-1],controlnet_image[0],out_file.replace('.mp4','.gif'),6)