def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf): ''' Model Inference and results visualization ''' if self.camera == '0': print("开始调用摄像头...") cap = cv2.VideoCapture(0) while True: f, img_src = cap.read() image = letterbox(img_src, self.img_size, stride=self.stride)[0] txt_path = osp.join(save_dir, 'labels', osp.splitext(osp.basename(img_path))[0]) image = image.transpose((2, 0, 1))[::-1] image = torch.from_numpy(np.ascontiguousarray(image)) image = image.half() if self.half else image.float() image /= 255 img = image img = img.to(self.device) if len(img.shape) == 3: img = img[None] pred_results = self.model(img) det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] img_ori = img_src
assert img_ori.data.contiguous, 'Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im).' self.font_check() if len(det): det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round() for *xyxy, conf, cls in reversed(det): class_num = int(cls) label = None if hide_labels else (self.class_names[class_num] if hide_conf else f'{self.class_names[class_num]} {conf:.2f}')
self.plot_box_and_label(img_ori, max(round(sum(img_ori.shape) / 2 * 0.003), 2), xyxy, label, color=self.generate_colors(class_num, True))
img_src = np.asarray(img_ori) cv2.namedWindow('test', cv2.WINDOW_AUTOSIZE) cv2.imshow('test', img_src) if cv2.waitKey(1) & 0xFF == ord('q'): break
cap.release() cv2.destroyAllWindows() else: for img_path in tqdm(self.img_paths): img, img_src = self.precess_image(img_path, self.img_size, self.stride, self.half) img = img.to(self.device) if len(img.shape) == 3: img = img[None] pred_results = self.model(img) det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
save_path = osp.join(save_dir, osp.basename(img_path)) txt_path = osp.join(save_dir, 'labels', osp.splitext(osp.basename(img_path))[0])
gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] img_ori = img_src
assert img_ori.data.contiguous, 'Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im).' self.font_check()
if len(det): det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
for *xyxy, conf, cls in reversed(det): if save_txt: xywh = (self.box_convert(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() line = (cls, *xywh, conf) with open(txt_path + '.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img: class_num = int(cls) label = None if hide_labels else (self.class_names[class_num] if hide_conf else f'{self.class_names[class_num]} {conf:.2f}')
self.plot_box_and_label(img_ori, max(round(sum(img_ori.shape) / 2 * 0.003), 2), xyxy, label, color=self.generate_colors(class_num, True))
img_src = np.asarray(img_ori)
if save_img: cv2.imwrite(save_path, img_src)
|