![]() |
![]() |
![]() |
![]() |
![]() |
MoveNet 是一款超快速且精準的模型,可偵測人體的 17 個關鍵點。此模型在 TF Hub 上提供兩種變體,分別稱為 Lightning 和 Thunder。Lightning 適用於延遲至關重要的應用程式,而 Thunder 適用於需要高精準度的應用程式。這兩種模型在大多數現代桌上型電腦、筆記型電腦和手機上的執行速度都比即時 (30+ FPS) 還快,這對於即時健身、健康和保健應用程式至關重要。
*圖片下載自 Pexels (https://www.pexels.com/)
這個 Colab 將逐步引導您瞭解如何載入 MoveNet,以及如何在以下輸入圖片和影片上執行推論。
使用 MoveNet 進行人體姿勢估計
視覺化程式庫與匯入
pip install -q imageio
pip install -q opencv-python
pip install -q git+https://github.com/tensorflow/docs
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow_docs.vis import embed
import numpy as np
import cv2
# Import matplotlib libraries
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib.patches as patches
# Some modules to display an animation using imageio.
import imageio
from IPython.display import HTML, display
用於視覺化的輔助函式
從 TF Hub 載入模型
model_name = "movenet_lightning"
if "tflite" in model_name:
if "movenet_lightning_f16" in model_name:
!wget -q -O model.tflite https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/float16/4?lite-format=tflite
input_size = 192
elif "movenet_thunder_f16" in model_name:
!wget -q -O model.tflite https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/float16/4?lite-format=tflite
input_size = 256
elif "movenet_lightning_int8" in model_name:
!wget -q -O model.tflite https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite
input_size = 192
elif "movenet_thunder_int8" in model_name:
!wget -q -O model.tflite https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/int8/4?lite-format=tflite
input_size = 256
else:
raise ValueError("Unsupported model name: %s" % model_name)
# Initialize the TFLite interpreter
interpreter = tf.lite.Interpreter(model_path="model.tflite")
interpreter.allocate_tensors()
def movenet(input_image):
"""Runs detection on an input image.
Args:
input_image: A [1, height, width, 3] tensor represents the input image
pixels. Note that the height/width should already be resized and match the
expected input resolution of the model before passing into this function.
Returns:
A [1, 1, 17, 3] float numpy array representing the predicted keypoint
coordinates and scores.
"""
# TF Lite format expects tensor type of uint8.
input_image = tf.cast(input_image, dtype=tf.uint8)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], input_image.numpy())
# Invoke inference.
interpreter.invoke()
# Get the model prediction.
keypoints_with_scores = interpreter.get_tensor(output_details[0]['index'])
return keypoints_with_scores
else:
if "movenet_lightning" in model_name:
module = hub.load("https://tfhub.dev/google/movenet/singlepose/lightning/4")
input_size = 192
elif "movenet_thunder" in model_name:
module = hub.load("https://tfhub.dev/google/movenet/singlepose/thunder/4")
input_size = 256
else:
raise ValueError("Unsupported model name: %s" % model_name)
def movenet(input_image):
"""Runs detection on an input image.
Args:
input_image: A [1, height, width, 3] tensor represents the input image
pixels. Note that the height/width should already be resized and match the
expected input resolution of the model before passing into this function.
Returns:
A [1, 1, 17, 3] float numpy array representing the predicted keypoint
coordinates and scores.
"""
model = module.signatures['serving_default']
# SavedModel format expects tensor type of int32.
input_image = tf.cast(input_image, dtype=tf.int32)
# Run model inference.
outputs = model(input_image)
# Output is a [1, 1, 17, 3] tensor.
keypoints_with_scores = outputs['output_0'].numpy()
return keypoints_with_scores
2024-03-09 15:01:44.320490: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
單張圖片範例
此環節示範在單張圖片上執行模型以預測 17 個人體關鍵點的最簡工作範例。
載入輸入圖片
curl -o input_image.jpeg https://images.pexels.com/photos/4384679/pexels-photo-4384679.jpeg --silent
# Load the input image.
image_path = 'input_image.jpeg'
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image)
執行推論
# Resize and pad the image to keep the aspect ratio and fit the expected size.
input_image = tf.expand_dims(image, axis=0)
input_image = tf.image.resize_with_pad(input_image, input_size, input_size)
# Run model inference.
keypoints_with_scores = movenet(input_image)
# Visualize the predictions with image.
display_image = tf.expand_dims(image, axis=0)
display_image = tf.cast(tf.image.resize_with_pad(
display_image, 1280, 1280), dtype=tf.int32)
output_overlay = draw_prediction_on_image(
np.squeeze(display_image.numpy(), axis=0), keypoints_with_scores)
plt.figure(figsize=(5, 5))
plt.imshow(output_overlay)
_ = plt.axis('off')
/tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
影片 (圖片序列) 範例
本節示範當輸入為影格序列時,如何根據前一個影格的偵測結果套用智慧型裁剪。這可讓模型將注意力和資源集中在主要主體上,進而在不犧牲速度的情況下,大幅提升預測品質。
裁剪演算法
載入輸入圖片序列
wget -q -O dance.gif https://github.com/tensorflow/tfjs-models/raw/master/pose-detection/assets/dance_input.gif
# Load the input image.
image_path = 'dance.gif'
image = tf.io.read_file(image_path)
image = tf.image.decode_gif(image)
使用裁剪演算法執行推論
# Load the input image.
num_frames, image_height, image_width, _ = image.shape
crop_region = init_crop_region(image_height, image_width)
output_images = []
bar = display(progress(0, num_frames-1), display_id=True)
for frame_idx in range(num_frames):
keypoints_with_scores = run_inference(
movenet, image[frame_idx, :, :, :], crop_region,
crop_size=[input_size, input_size])
output_images.append(draw_prediction_on_image(
image[frame_idx, :, :, :].numpy().astype(np.int32),
keypoints_with_scores, crop_region=None,
close_figure=True, output_image_height=300))
crop_region = determine_crop_region(
keypoints_with_scores, image_height, image_width)
bar.update(progress(frame_idx, num_frames-1))
# Prepare gif visualization.
output = np.stack(output_images, axis=0)
to_gif(output, duration=100)
/tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) /tmpfs/tmp/ipykernel_112701/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)