diff --git a/modules/capturer.py b/modules/capturer.py index a1c693d..7db3159 100644 --- a/modules/capturer.py +++ b/modules/capturer.py @@ -32,7 +32,9 @@ def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Any]: capture.set(cv2.CAP_PROP_CONVERT_RGB, 0) # Explicitly disable if not needed frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) - capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) + # Ensure frame_number is valid (0-based index) + target_frame = max(0, min(frame_total - 1, frame_number)) + capture.set(cv2.CAP_PROP_POS_FRAMES, target_frame) has_frame, frame = capture.read() # Only convert manually if color_correction is enabled but capture didn't handle it diff --git a/modules/cluster_analysis.py b/modules/cluster_analysis.py index dae9380..ad79b47 100644 --- a/modules/cluster_analysis.py +++ b/modules/cluster_analysis.py @@ -7,13 +7,14 @@ import modules.globals logger = logging.getLogger(__name__) -def find_cluster_centroids(embeddings, max_k=10) -> Any: +def find_cluster_centroids(embeddings, max_k=None, kmeans_init=None) -> Any: """ Identifies optimal face clusters using KMeans and silhouette scoring Args: embeddings: Face embedding vectors - max_k: Maximum number of clusters to consider + max_k: Maximum number of clusters to consider (default: from globals) + kmeans_init: KMeans initialization method (default: from globals) Returns: Array of optimal cluster centroids @@ -23,9 +24,11 @@ def find_cluster_centroids(embeddings, max_k=10) -> Any: logger.warning("Not enough embeddings for clustering analysis") return embeddings # Return the single embedding as its own cluster - # Use settings from globals if available - max_k = getattr(modules.globals, 'max_cluster_k', max_k) - kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++') + # Use settings from globals if not explicitly provided + if max_k is None: + max_k = getattr(modules.globals, 'max_cluster_k', 10) + if kmeans_init is None: + kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++') # Try silhouette method first best_k = 2 # Start with minimum viable cluster count diff --git a/modules/globals.py b/modules/globals.py index f1a0306..652f6f8 100644 --- a/modules/globals.py +++ b/modules/globals.py @@ -68,6 +68,14 @@ max_cluster_k = DEFAULT_SETTINGS['max_cluster_k'] kmeans_init = DEFAULT_SETTINGS['kmeans_init'] nsfw_threshold = DEFAULT_SETTINGS['nsfw_threshold'] +def init() -> None: + """ + Initialize the globals module and load settings + Should be called explicitly by the application during startup + """ + load_settings() + logger.info("Globals module initialized") + def load_settings() -> None: """ Load user settings from config file @@ -114,5 +122,5 @@ def save_settings() -> None: except Exception as e: logger.error(f"Error saving settings: {str(e)}") -# Load settings at module import time -load_settings() \ No newline at end of file +# Don't load settings at import time to avoid side effects +# Will be called explicitly by the application's initialization \ No newline at end of file diff --git a/modules/predicter.py b/modules/predicter.py index 14d832b..fd33553 100644 --- a/modules/predicter.py +++ b/modules/predicter.py @@ -41,12 +41,13 @@ def get_nsfw_model(): _model = load_nsfw_model() return _model -def predict_frame(target_frame: Frame) -> bool: +def predict_frame(target_frame: Frame, threshold=None) -> bool: """ Predict if a frame contains NSFW content Args: target_frame: Frame to analyze as numpy array + threshold: NSFW probability threshold (default: from globals) Returns: True if NSFW content detected, False otherwise @@ -56,8 +57,9 @@ def predict_frame(target_frame: Frame) -> bool: logger.warning("Cannot predict on None frame") return False - # Get threshold from globals - threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) + # Get threshold from globals if not explicitly provided + if threshold is None: + threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) # Convert the frame to RGB if needed expected_format = 'RGB' if modules.globals.color_correction else 'BGR' @@ -86,35 +88,39 @@ def predict_frame(target_frame: Frame) -> bool: logger.error(f"Error during NSFW prediction: {str(e)}") return False -def predict_image(target_path: str) -> bool: +def predict_image(target_path: str, threshold=None) -> bool: """ Predict if an image file contains NSFW content Args: target_path: Path to image file + threshold: NSFW probability threshold (default: from globals) Returns: True if NSFW content detected, False otherwise """ try: - threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) + if threshold is None: + threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) return opennsfw2.predict_image(target_path) > threshold except Exception as e: logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}") return False -def predict_video(target_path: str) -> bool: +def predict_video(target_path: str, threshold=None) -> bool: """ Predict if a video file contains NSFW content Args: target_path: Path to video file + threshold: NSFW probability threshold (default: from globals) Returns: True if NSFW content detected, False otherwise """ try: - threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) + if threshold is None: + threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) _, probabilities = opennsfw2.predict_video_frames( video_path=target_path, frame_interval=100