Refactor application structure.

pull/800/head
cattodotpy 2024-11-20 22:53:20 +08:00
parent 6171141505
commit b7d9889565
7 changed files with 1076 additions and 1102 deletions

View File

@ -17,7 +17,7 @@ import tensorflow
import modules.globals
import modules.metadata
import modules.ui as ui
from modules.ui import DeepFakeUI
from modules.processors.frame.core import get_frame_processors_modules
from modules.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
@ -27,9 +27,148 @@ if 'ROCMExecutionProvider' in modules.globals.execution_providers:
warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
def parse_args() -> None:
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
def suggest_max_memory() -> int:
if platform.system().lower() == 'darwin':
return 4
return 16
def suggest_execution_providers() -> List[str]:
return encode_execution_providers(onnxruntime.get_available_providers())
def suggest_execution_threads() -> int:
if 'DmlExecutionProvider' in modules.globals.execution_providers:
return 1
if 'ROCMExecutionProvider' in modules.globals.execution_providers:
return 1
return 8
def limit_resources() -> None:
# prevent tensorflow memory leak
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tensorflow.config.experimental.set_memory_growth(gpu, True)
# limit memory usage
if modules.globals.max_memory:
memory = modules.globals.max_memory * 1024 ** 3
if platform.system().lower() == 'darwin':
memory = modules.globals.max_memory * 1024 ** 6
if platform.system().lower() == 'windows':
import ctypes
kernel32 = ctypes.windll.kernel32
kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
else:
import resource
resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
def release_resources() -> None:
if 'CUDAExecutionProvider' in modules.globals.execution_providers:
torch.cuda.empty_cache()
class DeepFakeApp:
def __init__(self):
self.ui = DeepFakeUI(
self.start,
self.destroy
)
def update_status(self, message: str, scope: str = 'DLC.CORE') -> None:
print(f'[{scope}] {message}')
if not modules.globals.headless:
self.ui.update_status(message)
def pre_check(self) -> bool:
if sys.version_info < (3, 9):
update_status('Python version is not supported - please upgrade to 3.9 or higher.')
return False
if not shutil.which('ffmpeg'):
self.update_status('ffmpeg is not installed.')
return False
return True
def start(self) -> None:
for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
if not frame_processor.pre_start():
return
self.update_status('Processing...')
# process image to image
if has_image_extension(modules.globals.target_path):
if modules.globals.nsfw_filter and self.ui.check_and_ignore_nsfw(modules.globals.target_path, self.destroy):
return
try:
shutil.copy2(modules.globals.target_path, modules.globals.output_path)
except Exception as e:
print("Error copying file:", str(e))
for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
self.update_status('Progressing...', frame_processor.NAME)
frame_processor.process_image(modules.globals.source_path, modules.globals.output_path, modules.globals.output_path)
release_resources()
if is_image(modules.globals.target_path):
self.update_status('Processing to image succeed!')
else:
self.update_status('Processing to image failed!')
return
# process image to videos
if modules.globals.nsfw_filter and self.ui.check_and_ignore_nsfw(modules.globals.target_path, self.destroy):
return
if not modules.globals.map_faces:
self.update_status('Creating temp resources...')
create_temp(modules.globals.target_path)
self.update_status('Extracting frames...')
extract_frames(modules.globals.target_path)
temp_frame_paths = get_temp_frame_paths(modules.globals.target_path)
for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
self.update_status('Progressing...', frame_processor.NAME)
frame_processor.process_video(modules.globals.source_path, temp_frame_paths)
release_resources()
# handles fps
if modules.globals.keep_fps:
self.update_status('Detecting fps...')
fps = detect_fps(modules.globals.target_path)
self.update_status(f'Creating video with {fps} fps...')
create_video(modules.globals.target_path, fps)
else:
self.update_status('Creating video with 30.0 fps...')
create_video(modules.globals.target_path)
# handle audio
if modules.globals.keep_audio:
if modules.globals.keep_fps:
self.update_status('Restoring audio...')
else:
self.update_status('Restoring audio might cause issues as fps are not kept...')
restore_audio(modules.globals.target_path, modules.globals.output_path)
else:
move_temp(modules.globals.target_path, modules.globals.output_path)
# clean and validate
clean_temp(modules.globals.target_path)
if is_video(modules.globals.target_path):
self.update_status('Processing to video succeed!')
else:
self.update_status('Processing to video failed!')
def destroy(self, to_quit=True) -> None:
if modules.globals.target_path:
clean_temp(modules.globals.target_path)
if to_quit:
sys.exit(0)
def parse_args(self) -> None:
signal.signal(signal.SIGINT, lambda _: self.destroy())
program = argparse.ArgumentParser()
program.add_argument('-s', '--source', help='select an source image', dest='source_path')
program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
@ -105,151 +244,16 @@ def parse_args() -> None:
modules.globals.execution_threads = args.gpu_threads_deprecated
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
def suggest_max_memory() -> int:
if platform.system().lower() == 'darwin':
return 4
return 16
def suggest_execution_providers() -> List[str]:
return encode_execution_providers(onnxruntime.get_available_providers())
def suggest_execution_threads() -> int:
if 'DmlExecutionProvider' in modules.globals.execution_providers:
return 1
if 'ROCMExecutionProvider' in modules.globals.execution_providers:
return 1
return 8
def limit_resources() -> None:
# prevent tensorflow memory leak
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tensorflow.config.experimental.set_memory_growth(gpu, True)
# limit memory usage
if modules.globals.max_memory:
memory = modules.globals.max_memory * 1024 ** 3
if platform.system().lower() == 'darwin':
memory = modules.globals.max_memory * 1024 ** 6
if platform.system().lower() == 'windows':
import ctypes
kernel32 = ctypes.windll.kernel32
kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
else:
import resource
resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
def release_resources() -> None:
if 'CUDAExecutionProvider' in modules.globals.execution_providers:
torch.cuda.empty_cache()
def pre_check() -> bool:
if sys.version_info < (3, 9):
update_status('Python version is not supported - please upgrade to 3.9 or higher.')
return False
if not shutil.which('ffmpeg'):
update_status('ffmpeg is not installed.')
return False
return True
def update_status(message: str, scope: str = 'DLC.CORE') -> None:
print(f'[{scope}] {message}')
if not modules.globals.headless:
ui.update_status(message)
def start() -> None:
for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
if not frame_processor.pre_start():
return
update_status('Processing...')
# process image to image
if has_image_extension(modules.globals.target_path):
if modules.globals.nsfw_filter and ui.check_and_ignore_nsfw(modules.globals.target_path, destroy):
return
try:
shutil.copy2(modules.globals.target_path, modules.globals.output_path)
except Exception as e:
print("Error copying file:", str(e))
for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
update_status('Progressing...', frame_processor.NAME)
frame_processor.process_image(modules.globals.source_path, modules.globals.output_path, modules.globals.output_path)
release_resources()
if is_image(modules.globals.target_path):
update_status('Processing to image succeed!')
else:
update_status('Processing to image failed!')
return
# process image to videos
if modules.globals.nsfw_filter and ui.check_and_ignore_nsfw(modules.globals.target_path, destroy):
return
if not modules.globals.map_faces:
update_status('Creating temp resources...')
create_temp(modules.globals.target_path)
update_status('Extracting frames...')
extract_frames(modules.globals.target_path)
temp_frame_paths = get_temp_frame_paths(modules.globals.target_path)
for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
update_status('Progressing...', frame_processor.NAME)
frame_processor.process_video(modules.globals.source_path, temp_frame_paths)
release_resources()
# handles fps
if modules.globals.keep_fps:
update_status('Detecting fps...')
fps = detect_fps(modules.globals.target_path)
update_status(f'Creating video with {fps} fps...')
create_video(modules.globals.target_path, fps)
else:
update_status('Creating video with 30.0 fps...')
create_video(modules.globals.target_path)
# handle audio
if modules.globals.keep_audio:
if modules.globals.keep_fps:
update_status('Restoring audio...')
else:
update_status('Restoring audio might cause issues as fps are not kept...')
restore_audio(modules.globals.target_path, modules.globals.output_path)
else:
move_temp(modules.globals.target_path, modules.globals.output_path)
# clean and validate
clean_temp(modules.globals.target_path)
if is_video(modules.globals.target_path):
update_status('Processing to video succeed!')
else:
update_status('Processing to video failed!')
def destroy(to_quit=True) -> None:
if modules.globals.target_path:
clean_temp(modules.globals.target_path)
if to_quit: quit()
def run() -> None:
parse_args()
if not pre_check():
def run(self) -> None:
self.parse_args()
if not self.pre_check():
return
for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
if not frame_processor.pre_check():
return
limit_resources()
if modules.globals.headless:
start()
self.start()
else:
window = ui.init(start, destroy)
window.mainloop()
self.ui.root.mainloop()

View File

@ -19,14 +19,10 @@ FRAME_PROCESSORS_INTERFACE = [
def load_frame_processor_module(frame_processor: str) -> Any:
try:
frame_processor_module = importlib.import_module(f'modules.processors.frame.{frame_processor}')
for method_name in FRAME_PROCESSORS_INTERFACE:
if not hasattr(frame_processor_module, method_name):
sys.exit()
except ImportError:
print(f"Frame processor {frame_processor} not found")
sys.exit()
return frame_processor_module

View File

@ -5,7 +5,7 @@ import threading
import numpy as np
import modules.globals
import modules.processors.frame.core
from modules.core import update_status
from run import app
from modules.face_analyser import get_one_face, get_many_faces, default_source_face
from modules.typing import Face, Frame
from modules.utilities import (
@ -36,17 +36,17 @@ def pre_check() -> bool:
def pre_start() -> bool:
if not modules.globals.map_faces and not is_image(modules.globals.source_path):
update_status("Select an image for source path.", NAME)
app.update_status("Select an image for source path.", NAME)
return False
elif not modules.globals.map_faces and not get_one_face(
cv2.imread(modules.globals.source_path)
):
update_status("No face in source path detected.", NAME)
app.update_status("No face in source path detected.", NAME)
return False
if not is_image(modules.globals.target_path) and not is_video(
modules.globals.target_path
):
update_status("Select an image or video for target path.", NAME)
app.update_status("Select an image or video for target path.", NAME)
return False
return True
@ -236,7 +236,7 @@ def process_image(source_path: str, target_path: str, output_path: str) -> None:
cv2.imwrite(output_path, result)
else:
if modules.globals.many_faces:
update_status(
app.update_status(
"Many faces enabled. Using first source image. Progressing...", NAME
)
target_frame = cv2.imread(output_path)
@ -246,7 +246,7 @@ def process_image(source_path: str, target_path: str, output_path: str) -> None:
def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
if modules.globals.map_faces and modules.globals.many_faces:
update_status(
app.update_status(
"Many faces enabled. Using first source image. Progressing...", NAME
)
modules.processors.frame.core.process_video(
@ -256,7 +256,7 @@ def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
def create_lower_mouth_mask(
face: Face, frame: Frame
) -> (np.ndarray, np.ndarray, tuple, np.ndarray):
) -> tuple[np.ndarray, np.ndarray, tuple, np.ndarray]:
mask = np.zeros(frame.shape[:2], dtype=np.uint8)
mouth_cutout = None
landmarks = face.landmark_2d_106

View File

@ -1,7 +1,6 @@
from typing import Any
from typing import Any, TypeAlias
from insightface.app.common import Face
import numpy
Face = Face
Frame = numpy.ndarray[Any, Any]
Frame: TypeAlias = numpy.ndarray[Any, Any]

View File

@ -7,6 +7,8 @@ from cv2_enumerate_cameras import enumerate_cameras # Add this import
from PIL import Image, ImageOps
import time
import json
from numpy import ndarray
from modules.predicter import predict_image, predict_video, predict_frame
import modules.globals
import modules.metadata
@ -27,13 +29,9 @@ from modules.utilities import (
has_image_extension,
)
ROOT = None
POPUP = None
POPUP_LIVE = None
ROOT_HEIGHT = 700
ROOT_WIDTH = 600
PREVIEW = None
PREVIEW_MAX_HEIGHT = 700
PREVIEW_MAX_WIDTH = 1200
PREVIEW_DEFAULT_WIDTH = 960
@ -55,33 +53,9 @@ MAPPER_PREVIEW_MAX_WIDTH = 100
DEFAULT_BUTTON_WIDTH = 200
DEFAULT_BUTTON_HEIGHT = 40
RECENT_DIRECTORY_SOURCE = None
RECENT_DIRECTORY_TARGET = None
RECENT_DIRECTORY_OUTPUT = None
preview_label = None
preview_slider = None
source_label = None
target_label = None
status_label = None
popup_status_label = None
popup_status_label_live = None
source_label_dict = {}
source_label_dict_live = {}
target_label_dict_live = {}
img_ft, vid_ft = modules.globals.file_types
def init(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
global ROOT, PREVIEW
ROOT = create_root(start, destroy)
PREVIEW = create_preview(ROOT)
return ROOT
def save_switch_states():
switch_states = {
"keep_fps": modules.globals.keep_fps,
@ -123,10 +97,58 @@ def load_switch_states():
# If the file doesn't exist, use default values
pass
def fit_image_to_size(image, width: int, height: int):
if width is None and height is None:
return image
h, w, _ = image.shape
ratio_h = 0.0
ratio_w = 0.0
if width > height:
ratio_h = height / h
else:
ratio_w = width / w
ratio = max(ratio_w, ratio_h)
new_size = (int(ratio * w), int(ratio * h))
return cv2.resize(image, dsize=new_size)
def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
global source_label, target_label, status_label, show_fps_switch
def get_available_cameras():
"""Returns a list of available camera names and indices."""
camera_indices = []
camera_names = []
for camera in enumerate_cameras():
cap = cv2.VideoCapture(camera.index)
if cap.isOpened():
camera_indices.append(camera.index)
camera_names.append(camera.name)
cap.release()
return (camera_indices, camera_names)
class DeepFakeUI:
preview_label: ctk.CTkLabel
source_label: ctk.CTkLabel
target_label: ctk.CTkLabel
status_label: ctk.CTkLabel
popup_status_label: ctk.CTkLabel
popup_status_label_live: ctk.CTkLabel
preview_slider: ctk.CTkSlider
source_label_dict: dict[int, ctk.CTkLabel] = {}
source_label_dict_live: dict[int, ctk.CTkLabel] = {}
target_label_dict_live: dict[int, ctk.CTkLabel] = {}
source_label_dict_live = {}
target_label_dict_live = {}
popup_live: ctk.CTkToplevel
popup: ctk.CTkToplevel = None
recent_directory_source: str = os.path.expanduser("~")
recent_directory_target: str = os.path.expanduser("~")
recent_directory_output: str = os.path.expanduser("~")
def __init__(self, start: Callable[[], None], destroy: Callable[[], None]) -> None:
self.root = self.create_root(start, destroy)
self.preview = self.create_preview(self.root)
def create_root(self, start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
load_switch_states()
ctk.deactivate_automatic_dpi_awareness()
@ -148,12 +170,12 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
target_label.place(relx=0.6, rely=0.1, relwidth=0.3, relheight=0.25)
select_face_button = ctk.CTkButton(
root, text="Select a face", cursor="hand2", command=lambda: select_source_path()
root, text="Select a face", cursor="hand2", command=lambda: self.select_source_path()
)
select_face_button.place(relx=0.1, rely=0.4, relwidth=0.3, relheight=0.1)
swap_faces_button = ctk.CTkButton(
root, text="", cursor="hand2", command=lambda: swap_faces_paths()
root, text="", cursor="hand2", command=lambda: self.swap_faces_paths()
)
swap_faces_button.place(relx=0.45, rely=0.4, relwidth=0.1, relheight=0.1)
@ -161,7 +183,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
root,
text="Select a target",
cursor="hand2",
command=lambda: select_target_path(),
command=lambda: self.select_target_path(),
)
select_target_button.place(relx=0.6, rely=0.4, relwidth=0.3, relheight=0.1)
@ -198,7 +220,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
variable=enhancer_value,
cursor="hand2",
command=lambda: (
update_tumbler("face_enhancer", enhancer_value.get()),
self.update_tumbler("face_enhancer", enhancer_value.get()),
save_switch_states(),
),
)
@ -296,7 +318,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
show_mouth_mask_box_switch.place(relx=0.6, rely=0.55)
start_button = ctk.CTkButton(
root, text="Start", cursor="hand2", command=lambda: analyze_target(start, root)
root, text="Start", cursor="hand2", command=lambda: self.analyze_target(start, root)
)
start_button.place(relx=0.15, rely=0.80, relwidth=0.2, relheight=0.05)
@ -306,7 +328,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
stop_button.place(relx=0.4, rely=0.80, relwidth=0.2, relheight=0.05)
preview_button = ctk.CTkButton(
root, text="Preview", cursor="hand2", command=lambda: toggle_preview()
root, text="Preview", cursor="hand2", command=lambda: self.toggle_preview()
)
preview_button.place(relx=0.65, rely=0.80, relwidth=0.2, relheight=0.05)
@ -333,7 +355,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
root,
text="Live",
cursor="hand2",
command=lambda: webcam_preview(
command=lambda: self.webcam_preview(
root,
available_camera_indices[
available_camera_strings.index(camera_variable.get())
@ -357,56 +379,58 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
"<Button>", lambda event: webbrowser.open("https://paypal.me/hacksider")
)
self.source_label = source_label
self.target_label = target_label
self.status_label = status_label
return root
def analyze_target(start: Callable[[], None], root: ctk.CTk):
if POPUP != None and POPUP.winfo_exists():
update_status("Please complete pop-up or close it.")
def analyze_target(self, start: Callable[[], None], root: ctk.CTk):
if self.popup != None and self.popup.winfo_exists():
self.update_status("Please complete pop-up or close it.")
return
if modules.globals.map_faces:
modules.globals.souce_target_map = []
if is_image(modules.globals.target_path):
update_status("Getting unique faces")
self.update_status("Getting unique faces")
get_unique_faces_from_target_image()
elif is_video(modules.globals.target_path):
update_status("Getting unique faces")
self.update_status("Getting unique faces")
get_unique_faces_from_target_video()
if len(modules.globals.souce_target_map) > 0:
create_source_target_popup(start, root, modules.globals.souce_target_map)
self.create_source_target_popup(start, root, modules.globals.souce_target_map)
else:
update_status("No faces found in target")
self.update_status("No faces found in target")
else:
select_output_path(start)
self.select_output_path(start)
def create_source_target_popup(
start: Callable[[], None], root: ctk.CTk, map: list
) -> None:
global POPUP, popup_status_label
POPUP = ctk.CTkToplevel(root)
POPUP.title("Source x Target Mapper")
POPUP.geometry(f"{POPUP_WIDTH}x{POPUP_HEIGHT}")
POPUP.focus()
def create_source_target_popup(
self, start: Callable[[], None], root: ctk.CTk, map: list
) -> None:
popup = ctk.CTkToplevel(root)
popup.title("Source x Target Mapper")
popup.geometry(f"{POPUP_WIDTH}x{POPUP_HEIGHT}")
popup.focus()
def on_submit_click(start):
if has_valid_map():
POPUP.destroy()
select_output_path(start)
popup.destroy()
self.select_output_path(start)
else:
update_pop_status("Atleast 1 source with target is required!")
self.update_pop_status("Atleast 1 source with target is required!")
scrollable_frame = ctk.CTkScrollableFrame(
POPUP, width=POPUP_SCROLL_WIDTH, height=POPUP_SCROLL_HEIGHT
popup, width=POPUP_SCROLL_WIDTH, height=POPUP_SCROLL_HEIGHT
)
scrollable_frame.grid(row=0, column=0, padx=0, pady=0, sticky="nsew")
def on_button_click(map, button_num):
map = update_popup_source(scrollable_frame, map, button_num)
map = self.update_popup_source(scrollable_frame, map, button_num)
for item in map:
id = item["id"]
@ -443,30 +467,31 @@ def create_source_target_popup(
target_image.grid(row=id, column=3, padx=10, pady=10)
target_image.configure(image=tk_image)
popup_status_label = ctk.CTkLabel(POPUP, text=None, justify="center")
popup_status_label = ctk.CTkLabel(popup, text=None, justify="center")
popup_status_label.grid(row=1, column=0, pady=15)
close_button = ctk.CTkButton(
POPUP, text="Submit", command=lambda: on_submit_click(start)
popup, text="Submit", command=lambda: on_submit_click(start)
)
close_button.grid(row=2, column=0, pady=10)
self.popup_status_label = popup_status_label
self.popup = popup
def update_popup_source(
scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int
) -> list:
global source_label_dict
def update_popup_source(
self, scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int
) -> list:
source_path = ctk.filedialog.askopenfilename(
title="select an source image",
initialdir=RECENT_DIRECTORY_SOURCE,
initialdir=self.recent_directory_source,
filetypes=[img_ft],
)
if "source" in map[button_num]:
map[button_num].pop("source")
source_label_dict[button_num].destroy()
del source_label_dict[button_num]
self.source_label_dict[button_num].destroy()
del self.source_label_dict[button_num]
if source_path == "":
return map
@ -498,78 +523,71 @@ def update_popup_source(
)
source_image.grid(row=button_num, column=1, padx=10, pady=10)
source_image.configure(image=tk_image)
source_label_dict[button_num] = source_image
self.source_label_dict[button_num] = source_image
else:
update_pop_status("Face could not be detected in last upload!")
self.update_pop_status("Face could not be detected in last upload!")
return map
def create_preview(parent: ctk.CTkToplevel) -> ctk.CTkToplevel:
global preview_label, preview_slider
def create_preview(self, parent: ctk.CTkToplevel) -> ctk.CTkToplevel:
preview = ctk.CTkToplevel(parent)
preview.withdraw()
preview.title("Preview")
preview.configure()
preview.protocol("WM_DELETE_WINDOW", lambda: toggle_preview())
preview.protocol("WM_DELETE_WINDOW", lambda: self.toggle_preview())
preview.resizable(width=True, height=True)
preview_label = ctk.CTkLabel(preview, text=None)
preview_label.pack(fill="both", expand=True)
self.preview_label = ctk.CTkLabel(preview, text=None)
self.preview_label.pack(fill="both", expand=True)
preview_slider = ctk.CTkSlider(
preview, from_=0, to=0, command=lambda frame_value: update_preview(frame_value)
self.preview_slider = ctk.CTkSlider(
preview, from_=0, to=0, command=lambda frame_value: self.update_preview(frame_value)
)
return preview
def update_status(text: str) -> None:
status_label.configure(text=text)
ROOT.update()
def update_status(self, text: str) -> None:
self.status_label.configure(text=text)
self.root.update()
def update_pop_status(text: str) -> None:
popup_status_label.configure(text=text)
def update_pop_status(self, text: str) -> None:
self.popup_status_label.configure(text=text)
def update_pop_live_status(text: str) -> None:
popup_status_label_live.configure(text=text)
def update_pop_live_status(self, text: str) -> None:
self.popup_status_label_live.configure(text=text)
def update_tumbler(var: str, value: bool) -> None:
def update_tumbler(self, var: str, value: bool) -> None:
modules.globals.fp_ui[var] = value
save_switch_states()
# If we're currently in a live preview, update the frame processors
if PREVIEW.state() == "normal":
global frame_processors
frame_processors = get_frame_processors_modules(
if self.preview.state() == "normal":
self.frame_processors = get_frame_processors_modules(
modules.globals.frame_processors
)
def select_source_path() -> None:
global RECENT_DIRECTORY_SOURCE, img_ft, vid_ft
PREVIEW.withdraw()
def select_source_path(self) -> None:
self.preview.withdraw()
source_path = ctk.filedialog.askopenfilename(
title="select an source image",
initialdir=RECENT_DIRECTORY_SOURCE,
initialdir=self.recent_directory_source,
filetypes=[img_ft],
)
if is_image(source_path):
modules.globals.source_path = source_path
RECENT_DIRECTORY_SOURCE = os.path.dirname(modules.globals.source_path)
image = render_image_preview(modules.globals.source_path, (200, 200))
source_label.configure(image=image)
self.recent_directory_source = os.path.dirname(modules.globals.source_path)
image = self.render_image_preview(modules.globals.source_path, (200, 200))
self.source_label.configure(image=image)
else:
modules.globals.source_path = None
source_label.configure(image=None)
self.source_label.configure(image=None)
def swap_faces_paths() -> None:
global RECENT_DIRECTORY_SOURCE, RECENT_DIRECTORY_TARGET
def swap_faces_paths(self) -> None:
source_path = modules.globals.source_path
target_path = modules.globals.target_path
@ -579,52 +597,48 @@ def swap_faces_paths() -> None:
modules.globals.source_path = target_path
modules.globals.target_path = source_path
RECENT_DIRECTORY_SOURCE = os.path.dirname(modules.globals.source_path)
RECENT_DIRECTORY_TARGET = os.path.dirname(modules.globals.target_path)
self.recent_directory_source = os.path.dirname(modules.globals.source_path)
self.recent_directory_target = os.path.dirname(modules.globals.target_path)
PREVIEW.withdraw()
self.preview.withdraw()
source_image = render_image_preview(modules.globals.source_path, (200, 200))
source_label.configure(image=source_image)
source_image = self.render_image_preview(modules.globals.source_path, (200, 200))
self.source_label.configure(image=source_image)
target_image = render_image_preview(modules.globals.target_path, (200, 200))
target_label.configure(image=target_image)
target_image = self.render_image_preview(modules.globals.target_path, (200, 200))
self.target_label.configure(image=target_image)
def select_target_path() -> None:
global RECENT_DIRECTORY_TARGET, img_ft, vid_ft
PREVIEW.withdraw()
def select_target_path(self) -> None:
self.preview.withdraw()
target_path = ctk.filedialog.askopenfilename(
title="select an target image or video",
initialdir=RECENT_DIRECTORY_TARGET,
initialdir=self.recent_directory_target,
filetypes=[img_ft, vid_ft],
)
if is_image(target_path):
modules.globals.target_path = target_path
RECENT_DIRECTORY_TARGET = os.path.dirname(modules.globals.target_path)
image = render_image_preview(modules.globals.target_path, (200, 200))
target_label.configure(image=image)
self.recent_directory_target = os.path.dirname(modules.globals.target_path)
image = self.render_image_preview(modules.globals.target_path, (200, 200))
self.target_label.configure(image=image)
elif is_video(target_path):
modules.globals.target_path = target_path
RECENT_DIRECTORY_TARGET = os.path.dirname(modules.globals.target_path)
video_frame = render_video_preview(target_path, (200, 200))
target_label.configure(image=video_frame)
self.recent_directory_target = os.path.dirname(modules.globals.target_path)
video_frame = self.render_video_preview(target_path, (200, 200))
self.target_label.configure(image=video_frame)
else:
modules.globals.target_path = None
target_label.configure(image=None)
self.target_label.configure(image=None)
def select_output_path(start: Callable[[], None]) -> None:
global RECENT_DIRECTORY_OUTPUT, img_ft, vid_ft
def select_output_path(self, start: Callable[[], None]) -> None:
if is_image(modules.globals.target_path):
output_path = ctk.filedialog.asksaveasfilename(
title="save image output file",
filetypes=[img_ft],
defaultextension=".png",
initialfile="output.png",
initialdir=RECENT_DIRECTORY_OUTPUT,
initialdir=self.recent_directory_output,
)
elif is_video(modules.globals.target_path):
output_path = ctk.filedialog.asksaveasfilename(
@ -632,22 +646,19 @@ def select_output_path(start: Callable[[], None]) -> None:
filetypes=[vid_ft],
defaultextension=".mp4",
initialfile="output.mp4",
initialdir=RECENT_DIRECTORY_OUTPUT,
initialdir=self.recent_directory_output,
)
else:
output_path = None
if output_path:
modules.globals.output_path = output_path
RECENT_DIRECTORY_OUTPUT = os.path.dirname(modules.globals.output_path)
self.recent_directory_output = os.path.dirname(modules.globals.output_path)
start()
def check_and_ignore_nsfw(target, destroy: Callable = None) -> bool:
def check_and_ignore_nsfw(self, target: str | ndarray, destroy: Callable | None = None) -> bool:
"""Check if the target is NSFW.
TODO: Consider to make blur the target.
"""
from numpy import ndarray
from modules.predicter import predict_image, predict_video, predict_frame
if type(target) is str: # image/video file path
check_nsfw = predict_image if has_image_extension(target) else predict_video
@ -658,37 +669,21 @@ def check_and_ignore_nsfw(target, destroy: Callable = None) -> bool:
destroy(
to_quit=False
) # Do not need to destroy the window frame if the target is NSFW
update_status("Processing ignored!")
self.update_status("Processing ignored!")
return True
else:
return False
def fit_image_to_size(image, width: int, height: int):
if width is None and height is None:
return image
h, w, _ = image.shape
ratio_h = 0.0
ratio_w = 0.0
if width > height:
ratio_h = height / h
else:
ratio_w = width / w
ratio = max(ratio_w, ratio_h)
new_size = (int(ratio * w), int(ratio * h))
return cv2.resize(image, dsize=new_size)
def render_image_preview(image_path: str, size: Tuple[int, int]) -> ctk.CTkImage:
def render_image_preview(self, image_path: str, size: Tuple[int, int]) -> ctk.CTkImage:
image = Image.open(image_path)
if size:
image = ImageOps.fit(image, size, Image.LANCZOS)
return ctk.CTkImage(image, size=image.size)
def render_video_preview(
video_path: str, size: Tuple[int, int], frame_number: int = 0
) -> ctk.CTkImage:
def render_video_preview(
self, video_path: str, size: Tuple[int, int], frame_number: int = 0
) -> None:
capture = cv2.VideoCapture(video_path)
if frame_number:
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
@ -702,29 +697,29 @@ def render_video_preview(
cv2.destroyAllWindows()
def toggle_preview() -> None:
if PREVIEW.state() == "normal":
PREVIEW.withdraw()
def toggle_preview(self) -> None:
if self.preview.state() == "normal":
self.preview.withdraw()
elif modules.globals.source_path and modules.globals.target_path:
init_preview()
update_preview()
self.init_preview()
self.update_preview()
def init_preview() -> None:
def init_preview(self) -> None:
if is_image(modules.globals.target_path):
preview_slider.pack_forget()
self.preview_slider.pack_forget()
if is_video(modules.globals.target_path):
video_frame_total = get_video_frame_total(modules.globals.target_path)
preview_slider.configure(to=video_frame_total)
preview_slider.pack(fill="x")
preview_slider.set(0)
self.preview_slider.configure(to=video_frame_total)
self.preview_slider.pack(fill="x")
self.preview_slider.set(0)
def update_preview(frame_number: int = 0) -> None:
def update_preview(self, frame_number: int = 0) -> None:
if modules.globals.source_path and modules.globals.target_path:
update_status("Processing...")
self.update_status("Processing...")
temp_frame = get_video_frame(modules.globals.target_path, frame_number)
if modules.globals.nsfw_filter and check_and_ignore_nsfw(temp_frame):
if modules.globals.nsfw_filter and self.check_and_ignore_nsfw(temp_frame):
return
for frame_processor in get_frame_processors_modules(
modules.globals.frame_processors
@ -737,49 +732,32 @@ def update_preview(frame_number: int = 0) -> None:
image, (PREVIEW_MAX_WIDTH, PREVIEW_MAX_HEIGHT), Image.LANCZOS
)
image = ctk.CTkImage(image, size=image.size)
preview_label.configure(image=image)
update_status("Processing succeed!")
PREVIEW.deiconify()
self.preview_label.configure(image=image)
self.update_status("Processing succeed!")
self.preview.deiconify()
def webcam_preview(root: ctk.CTk, camera_index: int):
def webcam_preview(self, root: ctk.CTk, camera_index: int):
if not modules.globals.map_faces:
if modules.globals.source_path is None:
# No image selected
return
create_webcam_preview(camera_index)
self.create_webcam_preview(camera_index)
else:
modules.globals.souce_target_map = []
create_source_target_popup_for_webcam(
self.create_source_target_popup_for_webcam(
root, modules.globals.souce_target_map, camera_index
)
def get_available_cameras():
"""Returns a list of available camera names and indices."""
camera_indices = []
camera_names = []
for camera in enumerate_cameras():
cap = cv2.VideoCapture(camera.index)
if cap.isOpened():
camera_indices.append(camera.index)
camera_names.append(camera.name)
cap.release()
return (camera_indices, camera_names)
def create_webcam_preview(camera_index: int):
global preview_label, PREVIEW
def create_webcam_preview(self, camera_index: int):
camera = cv2.VideoCapture(camera_index)
camera.set(cv2.CAP_PROP_FRAME_WIDTH, PREVIEW_DEFAULT_WIDTH)
camera.set(cv2.CAP_PROP_FRAME_HEIGHT, PREVIEW_DEFAULT_HEIGHT)
camera.set(cv2.CAP_PROP_FPS, 60)
preview_label.configure(width=PREVIEW_DEFAULT_WIDTH, height=PREVIEW_DEFAULT_HEIGHT)
self.preview_label.configure(width=PREVIEW_DEFAULT_WIDTH, height=PREVIEW_DEFAULT_HEIGHT)
PREVIEW.deiconify()
self.preview.deiconify()
frame_processors = get_frame_processors_modules(modules.globals.frame_processors)
@ -801,7 +779,7 @@ def create_webcam_preview(camera_index: int):
if modules.globals.live_resizable:
temp_frame = fit_image_to_size(
temp_frame, PREVIEW.winfo_width(), PREVIEW.winfo_height()
temp_frame, self.preview.winfo_width(), self.preview.winfo_height()
)
if not modules.globals.map_faces:
@ -849,64 +827,62 @@ def create_webcam_preview(camera_index: int):
image, (temp_frame.shape[1], temp_frame.shape[0]), Image.LANCZOS
)
image = ctk.CTkImage(image, size=image.size)
preview_label.configure(image=image)
ROOT.update()
self.preview_label.configure(image=image)
self.root.update()
if PREVIEW.state() == "withdrawn":
if self.preview.state() == "withdrawn":
break
camera.release()
PREVIEW.withdraw()
self.preview.withdraw()
def create_source_target_popup_for_webcam(
root: ctk.CTk, map: list, camera_index: int
) -> None:
global POPUP_LIVE, popup_status_label_live
POPUP_LIVE = ctk.CTkToplevel(root)
POPUP_LIVE.title("Source x Target Mapper")
POPUP_LIVE.geometry(f"{POPUP_LIVE_WIDTH}x{POPUP_LIVE_HEIGHT}")
POPUP_LIVE.focus()
def create_source_target_popup_for_webcam(
self, root: ctk.CTk, map: list, camera_index: int
) -> None:
self.popup_live = ctk.CTkToplevel(root)
self.popup_live.title("Source x Target Mapper")
self.popup_live.geometry(f"{POPUP_LIVE_WIDTH}x{POPUP_LIVE_HEIGHT}")
self.popup_live.focus()
def on_submit_click():
if has_valid_map():
POPUP_LIVE.destroy()
self.popup_live.destroy()
simplify_maps()
create_webcam_preview(camera_index)
self.create_webcam_preview(camera_index)
else:
update_pop_live_status("At least 1 source with target is required!")
self.update_pop_live_status("At least 1 source with target is required!")
def on_add_click():
add_blank_map()
refresh_data(map)
update_pop_live_status("Please provide mapping!")
self.refresh_data(map)
self.update_pop_live_status("Please provide mapping!")
popup_status_label_live = ctk.CTkLabel(POPUP_LIVE, text=None, justify="center")
popup_status_label_live = ctk.CTkLabel(self.popup_live, text=None, justify="center")
popup_status_label_live.grid(row=1, column=0, pady=15)
add_button = ctk.CTkButton(POPUP_LIVE, text="Add", command=lambda: on_add_click())
add_button = ctk.CTkButton(self.popup_live, text="Add", command=lambda: on_add_click())
add_button.place(relx=0.2, rely=0.92, relwidth=0.2, relheight=0.05)
close_button = ctk.CTkButton(
POPUP_LIVE, text="Submit", command=lambda: on_submit_click()
self.popup_live, text="Submit", command=lambda: on_submit_click()
)
close_button.place(relx=0.6, rely=0.92, relwidth=0.2, relheight=0.05)
self.popup_status_label_live = popup_status_label_live
def refresh_data(map: list):
global POPUP_LIVE
def refresh_data(self, map: list):
scrollable_frame = ctk.CTkScrollableFrame(
POPUP_LIVE, width=POPUP_LIVE_SCROLL_WIDTH, height=POPUP_LIVE_SCROLL_HEIGHT
self.popup_live, width=POPUP_LIVE_SCROLL_WIDTH, height=POPUP_LIVE_SCROLL_HEIGHT
)
scrollable_frame.grid(row=0, column=0, padx=0, pady=0, sticky="nsew")
def on_sbutton_click(map, button_num):
map = update_webcam_source(scrollable_frame, map, button_num)
map = self.update_webcam_source(scrollable_frame, map, button_num)
def on_tbutton_click(map, button_num):
map = update_webcam_target(scrollable_frame, map, button_num)
map = self.update_webcam_target(scrollable_frame, map, button_num)
for item in map:
id = item["id"]
@ -974,21 +950,19 @@ def refresh_data(map: list):
target_image.configure(image=tk_image)
def update_webcam_source(
scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int
) -> list:
global source_label_dict_live
def update_webcam_source(
self, scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int
) -> list:
source_path = ctk.filedialog.askopenfilename(
title="select an source image",
initialdir=RECENT_DIRECTORY_SOURCE,
initialdir=self.recent_directory_source,
filetypes=[img_ft],
)
if "source" in map[button_num]:
map[button_num].pop("source")
source_label_dict_live[button_num].destroy()
del source_label_dict_live[button_num]
self.source_label_dict_live[button_num].destroy()
del self.source_label_dict_live[button_num]
if source_path == "":
return map
@ -1020,27 +994,25 @@ def update_webcam_source(
)
source_image.grid(row=button_num, column=1, padx=10, pady=10)
source_image.configure(image=tk_image)
source_label_dict_live[button_num] = source_image
self.source_label_dict_live[button_num] = source_image
else:
update_pop_live_status("Face could not be detected in last upload!")
self.update_pop_live_status("Face could not be detected in last upload!")
return map
def update_webcam_target(
scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int
) -> list:
global target_label_dict_live
def update_webcam_target(
self, scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int
) -> list:
target_path = ctk.filedialog.askopenfilename(
title="select an target image",
initialdir=RECENT_DIRECTORY_SOURCE,
initialdir=self.recent_directory_source,
filetypes=[img_ft],
)
if "target" in map[button_num]:
map[button_num].pop("target")
target_label_dict_live[button_num].destroy()
del target_label_dict_live[button_num]
self.target_label_dict_live[button_num].destroy()
del self.target_label_dict_live[button_num]
if target_path == "":
return map
@ -1072,7 +1044,7 @@ def update_webcam_target(
)
target_image.grid(row=button_num, column=4, padx=20, pady=10)
target_image.configure(image=tk_image)
target_label_dict_live[button_num] = target_image
self.target_label_dict_live[button_num] = target_image
else:
update_pop_live_status("Face could not be detected in last upload!")
self.update_pop_live_status("Face could not be detected in last upload!")
return map

7
run.py
View File

@ -1,6 +1,9 @@
#!/usr/bin/env python3
from modules import core
from modules.core import DeepFakeApp
app = DeepFakeApp()
if __name__ == '__main__':
core.run()
app.run()