From fc52401379b7a698e326df69807daf4225eb339b Mon Sep 17 00:00:00 2001
From: ivideogameboss <zain_faiz@hotmail.com>
Date: Fri, 16 Aug 2024 10:43:03 -0500
Subject: [PATCH] add support for two faces

---
 modules/face_analyser.py                 | 14 +++++++++
 modules/processors/frame/face_swapper.py | 38 +++++++++++++++++++-----
 modules/ui.py                            | 26 +++++++++++-----
 3 files changed, 63 insertions(+), 15 deletions(-)

diff --git a/modules/face_analyser.py b/modules/face_analyser.py
index f2d46bf..a89f8da 100644
--- a/modules/face_analyser.py
+++ b/modules/face_analyser.py
@@ -22,6 +22,20 @@ def get_one_face(frame: Frame) -> Any:
         return min(face, key=lambda x: x.bbox[0])
     except ValueError:
         return None
+    
+def get_one_face_left(frame: Frame) -> Any:
+    face = get_face_analyser().get(frame)
+    try:
+        return min(face, key=lambda x: x.bbox[0])
+    except ValueError:
+        return None
+    
+def get_one_face_right(frame: Frame) -> Any:
+    face = get_face_analyser().get(frame)
+    try:
+        return max(face, key=lambda x: x.bbox[0])
+    except ValueError:
+        return None
 
 
 def get_many_faces(frame: Frame) -> Any:
diff --git a/modules/processors/frame/face_swapper.py b/modules/processors/frame/face_swapper.py
index 4b4a222..148c292 100644
--- a/modules/processors/frame/face_swapper.py
+++ b/modules/processors/frame/face_swapper.py
@@ -6,7 +6,7 @@ import threading
 import modules.globals
 import modules.processors.frame.core
 from modules.core import update_status
-from modules.face_analyser import get_one_face, get_many_faces
+from modules.face_analyser import get_one_face, get_many_faces,get_one_face_left,get_one_face_right
 from modules.typing import Face, Frame
 from modules.utilities import conditional_download, resolve_relative_path, is_image, is_video
 
@@ -47,26 +47,50 @@ def get_face_swapper() -> Any:
 def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
     return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True)
 
+def get_two_faces(frame: Frame) -> List[Face]:
+    faces = get_many_faces(frame)
+    if faces:
+        # Sort faces from left to right based on the x-coordinate of the bounding box
+        sorted_faces = sorted(faces, key=lambda x: x.bbox[0])
+        return sorted_faces[:2]  # Return up to two faces, leftmost and rightmost
+    return []
 
-def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
+def process_frame(source_face: List[Face], temp_frame: Frame) -> Frame:
     if modules.globals.many_faces:
         many_faces = get_many_faces(temp_frame)
         if many_faces:
             for target_face in many_faces:
-                temp_frame = swap_face(source_face, target_face, temp_frame)
+                temp_frame = swap_face(source_face[0], target_face, temp_frame)
     else:
-        target_face = get_one_face(temp_frame)
-        if target_face:
-            temp_frame = swap_face(source_face, target_face, temp_frame)
+        target_faces = get_two_faces(temp_frame)
+        if len(target_faces) >= 2:
+            # Swap the first face
+            temp_frame = swap_face(source_face[0], target_faces[0], temp_frame)
+            # Swap the second face
+            temp_frame = swap_face(source_face[1], target_faces[1], temp_frame)
+        elif len(target_faces) == 1:
+            # If only one face is found, swap with the first source face
+            temp_frame = swap_face(source_face[0], target_faces[0], temp_frame)
+
     return temp_frame
 
 
 def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None:
+    
+    source_image_left = None  # Initialize variable for the selected face image
+    source_image_right = None  # Initialize variable for the selected face image
+
+    if source_image_left is None and source_path:
+        source_image_left = get_one_face_left(cv2.imread(source_path))
+    if source_image_right is None and source_path:
+        source_image_right = get_one_face_right(cv2.imread(source_path))
+
+
     source_face = get_one_face(cv2.imread(source_path))
     for temp_frame_path in temp_frame_paths:
         temp_frame = cv2.imread(temp_frame_path)
         try:
-            result = process_frame(source_face, temp_frame)
+            result = process_frame([source_image_left,source_image_right], temp_frame)
             cv2.imwrite(temp_frame_path, result)
         except Exception as exception:
             print(exception)
diff --git a/modules/ui.py b/modules/ui.py
index ab18e95..643fd37 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -7,7 +7,7 @@ from PIL import Image, ImageOps
 
 import modules.globals
 import modules.metadata
-from modules.face_analyser import get_one_face
+from modules.face_analyser import get_one_face,get_one_face_left,get_one_face_right
 from modules.capturer import get_video_frame, get_video_frame_total
 from modules.processors.frame.core import get_frame_processors_modules
 from modules.utilities import is_image, is_video, resolve_relative_path
@@ -61,7 +61,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
     target_label = ctk.CTkLabel(root, text=None)
     target_label.place(relx=0.6, rely=0.1, relwidth=0.3, relheight=0.25)
 
-    select_face_button = ctk.CTkButton(root, text='Select a face', cursor='hand2', command=lambda: select_source_path())
+    select_face_button = ctk.CTkButton(root, text='Select a face/s \n(left face)(right face)', cursor='hand2', command=lambda: select_source_path())
     select_face_button.place(relx=0.1, rely=0.4, relwidth=0.3, relheight=0.1)
 
     select_target_button = ctk.CTkButton(root, text='Select a target', cursor='hand2', command=lambda: select_target_path())
@@ -239,9 +239,16 @@ def update_preview(frame_number: int = 0) -> None:
             from modules.predicter import predict_frame
             if predict_frame(temp_frame):
                 quit()
+        source_image_left = None  # Initialize variable for the selected face image
+        source_image_right = None  # Initialize variable for the selected face image
+        
+        if source_image_left is None and modules.globals.source_path:
+            source_image_left = get_one_face_left(cv2.imread(modules.globals.source_path))
+        if source_image_right is None and modules.globals.source_path:
+            source_image_right = get_one_face_right(cv2.imread(modules.globals.source_path))
+
         for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
-            temp_frame = frame_processor.process_frame(
-                get_one_face(cv2.imread(modules.globals.source_path)),
+            temp_frame = frame_processor.process_frame([source_image_left,source_image_right],
                 temp_frame
             )
         image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
@@ -269,7 +276,8 @@ def webcam_preview():
 
     frame_processors = get_frame_processors_modules(modules.globals.frame_processors)
 
-    source_image = None  # Initialize variable for the selected face image
+    source_image_left = None  # Initialize variable for the selected face image
+    source_image_right = None  # Initialize variable for the selected face image
 
     while True:
         ret, frame = cap.read()
@@ -277,13 +285,15 @@ def webcam_preview():
             break
 
         # Select and save face image only once
-        if source_image is None and modules.globals.source_path:
-            source_image = get_one_face(cv2.imread(modules.globals.source_path))
+        if source_image_left is None and modules.globals.source_path:
+            source_image_left = get_one_face_left(cv2.imread(modules.globals.source_path))
+        if source_image_right is None and modules.globals.source_path:
+            source_image_right = get_one_face_right(cv2.imread(modules.globals.source_path))
 
         temp_frame = frame.copy()  #Create a copy of the frame
 
         for frame_processor in frame_processors:
-            temp_frame = frame_processor.process_frame(source_image, temp_frame)
+            temp_frame = frame_processor.process_frame([source_image_left,source_image_right], temp_frame)
 
         image = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)  # Convert the image to RGB format to display it with Tkinter
         image = Image.fromarray(image)