From 53d473164bdad16ef89cc33ac508bc5655caee68 Mon Sep 17 00:00:00 2001
From: KRSHH <136873090+KRSHH@users.noreply.github.com>
Date: Wed, 9 Oct 2024 19:51:04 +0530
Subject: [PATCH] remember/save switch states

---
 modules/ui.py | 88 ++++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 73 insertions(+), 15 deletions(-)

diff --git a/modules/ui.py b/modules/ui.py
index 428c99d..80a3a6d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -6,6 +6,8 @@ import cv2
 from cv2_enumerate_cameras import enumerate_cameras  # Add this import
 from PIL import Image, ImageOps
 import time
+import json
+
 import modules.globals
 import modules.metadata
 from modules.face_analyser import (
@@ -80,9 +82,49 @@ def init(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
     return ROOT
 
 
+def save_switch_states():
+    switch_states = {
+        "keep_fps": modules.globals.keep_fps,
+        "keep_audio": modules.globals.keep_audio,
+        "keep_frames": modules.globals.keep_frames,
+        "many_faces": modules.globals.many_faces,
+        "map_faces": modules.globals.map_faces,
+        "color_correction": modules.globals.color_correction,
+        "nsfw_filter": modules.globals.nsfw_filter,
+        "live_mirror": modules.globals.live_mirror,
+        "live_resizable": modules.globals.live_resizable,
+        "fp_ui": modules.globals.fp_ui,
+        "show_fps": modules.globals.show_fps,
+    }
+    with open("switch_states.json", "w") as f:
+        json.dump(switch_states, f)
+
+
+def load_switch_states():
+    try:
+        with open("switch_states.json", "r") as f:
+            switch_states = json.load(f)
+        modules.globals.keep_fps = switch_states.get("keep_fps", True)
+        modules.globals.keep_audio = switch_states.get("keep_audio", True)
+        modules.globals.keep_frames = switch_states.get("keep_frames", False)
+        modules.globals.many_faces = switch_states.get("many_faces", False)
+        modules.globals.map_faces = switch_states.get("map_faces", False)
+        modules.globals.color_correction = switch_states.get("color_correction", False)
+        modules.globals.nsfw_filter = switch_states.get("nsfw_filter", False)
+        modules.globals.live_mirror = switch_states.get("live_mirror", False)
+        modules.globals.live_resizable = switch_states.get("live_resizable", False)
+        modules.globals.fp_ui = switch_states.get("fp_ui", {"face_enhancer": False})
+        modules.globals.show_fps = switch_states.get("show_fps", False)
+    except FileNotFoundError:
+        # If the file doesn't exist, use default values
+        pass
+
+
 def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
     global source_label, target_label, status_label, show_fps_switch
 
+    load_switch_states()
+
     ctk.deactivate_automatic_dpi_awareness()
     ctk.set_appearance_mode("system")
     ctk.set_default_color_theme(resolve_relative_path("ui.json"))
@@ -125,8 +167,9 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
         text="Keep fps",
         variable=keep_fps_value,
         cursor="hand2",
-        command=lambda: setattr(
-            modules.globals, "keep_fps", not modules.globals.keep_fps
+        command=lambda: (
+            setattr(modules.globals, "keep_fps", keep_fps_value.get()),
+            save_switch_states(),
         ),
     )
     keep_fps_checkbox.place(relx=0.1, rely=0.6)
@@ -137,20 +180,23 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
         text="Keep frames",
         variable=keep_frames_value,
         cursor="hand2",
-        command=lambda: setattr(
-            modules.globals, "keep_frames", keep_frames_value.get()
+        command=lambda: (
+            setattr(modules.globals, "keep_frames", keep_frames_value.get()),
+            save_switch_states(),
         ),
     )
     keep_frames_switch.place(relx=0.1, rely=0.65)
 
-    # for FRAME PROCESSOR ENHANCER tumbler:
     enhancer_value = ctk.BooleanVar(value=modules.globals.fp_ui["face_enhancer"])
     enhancer_switch = ctk.CTkSwitch(
         root,
         text="Face Enhancer",
         variable=enhancer_value,
         cursor="hand2",
-        command=lambda: update_tumbler("face_enhancer", enhancer_value.get()),
+        command=lambda: (
+            update_tumbler("face_enhancer", enhancer_value.get()),
+            save_switch_states(),
+        ),
     )
     enhancer_switch.place(relx=0.1, rely=0.7)
 
@@ -160,7 +206,10 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
         text="Keep audio",
         variable=keep_audio_value,
         cursor="hand2",
-        command=lambda: setattr(modules.globals, "keep_audio", keep_audio_value.get()),
+        command=lambda: (
+            setattr(modules.globals, "keep_audio", keep_audio_value.get()),
+            save_switch_states(),
+        ),
     )
     keep_audio_switch.place(relx=0.6, rely=0.6)
 
@@ -170,19 +219,22 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
         text="Many faces",
         variable=many_faces_value,
         cursor="hand2",
-        command=lambda: setattr(modules.globals, "many_faces", many_faces_value.get()),
+        command=lambda: (
+            setattr(modules.globals, "many_faces", many_faces_value.get()),
+            save_switch_states(),
+        ),
     )
     many_faces_switch.place(relx=0.6, rely=0.65)
 
-    # Add color correction toggle button
     color_correction_value = ctk.BooleanVar(value=modules.globals.color_correction)
     color_correction_switch = ctk.CTkSwitch(
         root,
         text="Fix Blueish Cam\n(force cv2 to use RGB instead of BGR)",
         variable=color_correction_value,
         cursor="hand2",
-        command=lambda: setattr(
-            modules.globals, "color_correction", color_correction_value.get()
+        command=lambda: (
+            setattr(modules.globals, "color_correction", color_correction_value.get()),
+            save_switch_states(),
         ),
     )
     color_correction_switch.place(relx=0.6, rely=0.70)
@@ -197,18 +249,23 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
         text="Map faces",
         variable=map_faces,
         cursor="hand2",
-        command=lambda: setattr(modules.globals, "map_faces", map_faces.get()),
+        command=lambda: (
+            setattr(modules.globals, "map_faces", map_faces.get()),
+            save_switch_states(),
+        ),
     )
     map_faces_switch.place(relx=0.1, rely=0.75)
 
-    # Add Show FPS switch
-    show_fps_value = ctk.BooleanVar(value=False)
+    show_fps_value = ctk.BooleanVar(value=modules.globals.show_fps)
     show_fps_switch = ctk.CTkSwitch(
         root,
         text="Show FPS",
         variable=show_fps_value,
         cursor="hand2",
-        command=lambda: setattr(modules.globals, "show_fps", show_fps_value.get()),
+        command=lambda: (
+            setattr(modules.globals, "show_fps", show_fps_value.get()),
+            save_switch_states(),
+        ),
     )
     show_fps_switch.place(relx=0.6, rely=0.75)
 
@@ -456,6 +513,7 @@ def update_pop_live_status(text: str) -> None:
 
 def update_tumbler(var: str, value: bool) -> None:
     modules.globals.fp_ui[var] = value
+    save_switch_states()
 
 
 def select_source_path() -> None: