| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  | from typing import Any, List | 
					
						
							|  |  |  | import cv2 | 
					
						
							|  |  |  | import threading | 
					
						
							|  |  |  | import gfpgan | 
					
						
							| 
									
										
										
										
											2024-08-07 23:26:47 +08:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import modules.globals | 
					
						
							|  |  |  | import modules.processors.frame.core | 
					
						
							|  |  |  | from modules.core import update_status | 
					
						
							|  |  |  | from modules.face_analyser import get_one_face | 
					
						
							|  |  |  | from modules.typing import Frame, Face | 
					
						
							| 
									
										
										
										
											2024-12-19 23:48:28 +08:00
										 |  |  | import platform | 
					
						
							|  |  |  | import torch | 
					
						
							| 
									
										
										
										
											2024-10-15 15:38:03 +08:00
										 |  |  | from modules.utilities import ( | 
					
						
							|  |  |  |     conditional_download, | 
					
						
							|  |  |  |     is_image, | 
					
						
							|  |  |  |     is_video, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | FACE_ENHANCER = None | 
					
						
							|  |  |  | THREAD_SEMAPHORE = threading.Semaphore() | 
					
						
							|  |  |  | THREAD_LOCK = threading.Lock() | 
					
						
							| 
									
										
										
										
											2024-10-15 15:38:03 +08:00
										 |  |  | NAME = "DLC.FACE-ENHANCER" | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 21:43:13 +08:00
										 |  |  | abs_dir = os.path.dirname(os.path.abspath(__file__)) | 
					
						
							| 
									
										
										
										
											2024-12-19 23:48:28 +08:00
										 |  |  | models_dir = os.path.join( | 
					
						
							|  |  |  |     os.path.dirname(os.path.dirname(os.path.dirname(abs_dir))), "models" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def pre_check() -> bool: | 
					
						
							| 
									
										
										
										
											2024-11-09 21:43:13 +08:00
										 |  |  |     download_directory_path = models_dir | 
					
						
							| 
									
										
										
										
											2024-10-15 15:38:03 +08:00
										 |  |  |     conditional_download( | 
					
						
							|  |  |  |         download_directory_path, | 
					
						
							|  |  |  |         [ | 
					
						
							|  |  |  |             "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth" | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  |     return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def pre_start() -> bool: | 
					
						
							| 
									
										
										
										
											2024-10-15 15:38:03 +08:00
										 |  |  |     if not is_image(modules.globals.target_path) and not is_video( | 
					
						
							|  |  |  |         modules.globals.target_path | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         update_status("Select an image or video for target path.", NAME) | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  |         return False | 
					
						
							|  |  |  |     return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-19 21:03:49 +08:00
										 |  |  | TENSORRT_AVAILABLE = False | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     import torch_tensorrt | 
					
						
							|  |  |  |     TENSORRT_AVAILABLE = True | 
					
						
							|  |  |  | except ImportError as im: | 
					
						
							|  |  |  |     print(f"TensorRT is not available: {im}") | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | except Exception as e: | 
					
						
							|  |  |  |     print(f"TensorRT is not available: {e}") | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  | def get_face_enhancer() -> Any: | 
					
						
							|  |  |  |     global FACE_ENHANCER | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with THREAD_LOCK: | 
					
						
							|  |  |  |         if FACE_ENHANCER is None: | 
					
						
							| 
									
										
										
										
											2024-12-19 23:48:28 +08:00
										 |  |  |             model_path = os.path.join(models_dir, "GFPGANv1.4.pth") | 
					
						
							| 
									
										
										
										
											2024-12-23 14:29:36 +08:00
										 |  |  |              | 
					
						
							| 
									
										
										
										
											2025-04-19 21:03:49 +08:00
										 |  |  |             selected_device = None | 
					
						
							|  |  |  |             device_priority = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if TENSORRT_AVAILABLE and torch.cuda.is_available(): | 
					
						
							|  |  |  |                 selected_device = torch.device("cuda") | 
					
						
							|  |  |  |                 device_priority.append("TensorRT+CUDA") | 
					
						
							|  |  |  |             elif torch.cuda.is_available(): | 
					
						
							|  |  |  |                 selected_device = torch.device("cuda") | 
					
						
							|  |  |  |                 device_priority.append("CUDA") | 
					
						
							|  |  |  |             elif torch.backends.mps.is_available() and platform.system() == "Darwin": | 
					
						
							|  |  |  |                 selected_device = torch.device("mps") | 
					
						
							|  |  |  |                 device_priority.append("MPS") | 
					
						
							|  |  |  |             elif not torch.cuda.is_available(): | 
					
						
							|  |  |  |                 selected_device = torch.device("cpu") | 
					
						
							|  |  |  |                 device_priority.append("CPU") | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=selected_device) | 
					
						
							| 
									
										
										
										
											2024-12-23 14:29:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-19 21:03:49 +08:00
										 |  |  |             # for debug: | 
					
						
							|  |  |  |             print(f"Selected device: {selected_device} and device priority: {device_priority}") | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  |     return FACE_ENHANCER | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def enhance_face(temp_frame: Frame) -> Frame: | 
					
						
							|  |  |  |     with THREAD_SEMAPHORE: | 
					
						
							| 
									
										
										
										
											2024-10-15 15:38:03 +08:00
										 |  |  |         _, _, temp_frame = get_face_enhancer().enhance(temp_frame, paste_back=True) | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  |     return temp_frame | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def process_frame(source_face: Face, temp_frame: Frame) -> Frame: | 
					
						
							|  |  |  |     target_face = get_one_face(temp_frame) | 
					
						
							|  |  |  |     if target_face: | 
					
						
							|  |  |  |         temp_frame = enhance_face(temp_frame) | 
					
						
							|  |  |  |     return temp_frame | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-15 15:38:03 +08:00
										 |  |  | def process_frames( | 
					
						
							|  |  |  |     source_path: str, temp_frame_paths: List[str], progress: Any = None | 
					
						
							|  |  |  | ) -> None: | 
					
						
							| 
									
										
										
										
											2023-09-24 21:36:57 +08:00
										 |  |  |     for temp_frame_path in temp_frame_paths: | 
					
						
							|  |  |  |         temp_frame = cv2.imread(temp_frame_path) | 
					
						
							|  |  |  |         result = process_frame(None, temp_frame) | 
					
						
							|  |  |  |         cv2.imwrite(temp_frame_path, result) | 
					
						
							|  |  |  |         if progress: | 
					
						
							|  |  |  |             progress.update(1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def process_image(source_path: str, target_path: str, output_path: str) -> None: | 
					
						
							|  |  |  |     target_frame = cv2.imread(target_path) | 
					
						
							|  |  |  |     result = process_frame(None, target_frame) | 
					
						
							|  |  |  |     cv2.imwrite(output_path, result) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def process_video(source_path: str, temp_frame_paths: List[str]) -> None: | 
					
						
							|  |  |  |     modules.processors.frame.core.process_video(None, temp_frame_paths, process_frames) | 
					
						
							| 
									
										
										
										
											2024-10-15 15:38:03 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def process_frame_v2(temp_frame: Frame) -> Frame: | 
					
						
							|  |  |  |     target_face = get_one_face(temp_frame) | 
					
						
							|  |  |  |     if target_face: | 
					
						
							|  |  |  |         temp_frame = enhance_face(temp_frame) | 
					
						
							|  |  |  |     return temp_frame |