From 1d6d72b8bc940cdb224c04e2c6687b8fed0a5687 Mon Sep 17 00:00:00 2001 From: souvik03-136 <66234771+souvik03-136@users.noreply.github.com> Date: Tue, 29 Apr 2025 19:35:54 +0530 Subject: [PATCH] feat: Improve color handling, cluster analysis and add error resilience - Enhanced color space handling in capturer.py with explicit controls - Improved cluster analysis with silhouette scoring for better face matching - Added caching for NSFW model to improve performance - Implemented proper error handling and logging across all modules - Added configurable settings system in globals.py - Created comprehensive unit tests - Added detailed documentation with docstrings --- modules/capturer.py | 86 ++++++++++++++++------- modules/cluster_analysis.py | 106 ++++++++++++++++++++++++---- modules/globals.py | 85 +++++++++++++++++++++-- modules/predicter.py | 133 ++++++++++++++++++++++++++++++------ 4 files changed, 345 insertions(+), 65 deletions(-) diff --git a/modules/capturer.py b/modules/capturer.py index a87cf4c..a1c693d 100644 --- a/modules/capturer.py +++ b/modules/capturer.py @@ -1,32 +1,72 @@ -from typing import Any +from typing import Any, Optional import cv2 -import modules.globals # Import the globals to check the color correction toggle +import modules.globals +import logging +logger = logging.getLogger(__name__) -def get_video_frame(video_path: str, frame_number: int = 0) -> Any: - capture = cv2.VideoCapture(video_path) - - # Set MJPEG format to ensure correct color space handling - capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG')) +def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Any]: + """ + Extract a specific frame from a video file with proper color handling. - # Only force RGB conversion if color correction is enabled - if modules.globals.color_correction: - capture.set(cv2.CAP_PROP_CONVERT_RGB, 1) - - frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) - capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) - has_frame, frame = capture.read() + Args: + video_path: Path to the video file + frame_number: Frame number to extract (defaults to first frame) + + Returns: + Video frame as numpy array or None if frame extraction fails + """ + try: + capture = cv2.VideoCapture(video_path) + if not capture.isOpened(): + logger.error(f"Failed to open video: {video_path}") + return None - if has_frame and modules.globals.color_correction: - # Convert the frame color if necessary - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + # Set MJPEG format to ensure correct color space handling + capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG')) + + # Configure color conversion based on setting + if modules.globals.color_correction: + capture.set(cv2.CAP_PROP_CONVERT_RGB, 1) + else: + capture.set(cv2.CAP_PROP_CONVERT_RGB, 0) # Explicitly disable if not needed + + frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) + capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) + has_frame, frame = capture.read() - capture.release() - return frame if has_frame else None + # Only convert manually if color_correction is enabled but capture didn't handle it + if has_frame and modules.globals.color_correction and frame is not None: + frame_channels = frame.shape[2] if len(frame.shape) == 3 else 1 + if frame_channels == 3: # Only convert if we have a color image + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + capture.release() + return frame if has_frame else None + except Exception as e: + logger.error(f"Error processing video frame: {str(e)}") + return None def get_video_frame_total(video_path: str) -> int: - capture = cv2.VideoCapture(video_path) - video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) - capture.release() - return video_frame_total + """ + Get the total number of frames in a video file. + + Args: + video_path: Path to the video file + + Returns: + Total number of frames in the video + """ + try: + capture = cv2.VideoCapture(video_path) + if not capture.isOpened(): + logger.error(f"Failed to open video for frame counting: {video_path}") + return 0 + + video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) + capture.release() + return video_frame_total + except Exception as e: + logger.error(f"Error counting video frames: {str(e)}") + return 0 \ No newline at end of file diff --git a/modules/cluster_analysis.py b/modules/cluster_analysis.py index 0e7db03..dae9380 100644 --- a/modules/cluster_analysis.py +++ b/modules/cluster_analysis.py @@ -1,32 +1,108 @@ import numpy as np from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score -from typing import Any +from typing import Any, List, Optional, Tuple +import logging +import modules.globals +logger = logging.getLogger(__name__) def find_cluster_centroids(embeddings, max_k=10) -> Any: - inertia = [] - cluster_centroids = [] - K = range(1, max_k+1) + """ + Identifies optimal face clusters using KMeans and silhouette scoring + + Args: + embeddings: Face embedding vectors + max_k: Maximum number of clusters to consider + + Returns: + Array of optimal cluster centroids + """ + try: + if len(embeddings) < 2: + logger.warning("Not enough embeddings for clustering analysis") + return embeddings # Return the single embedding as its own cluster + + # Use settings from globals if available + max_k = getattr(modules.globals, 'max_cluster_k', max_k) + kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++') + + # Try silhouette method first + best_k = 2 # Start with minimum viable cluster count + best_score = -1 + best_centroids = None + + # We need at least 3 samples to calculate silhouette score + if len(embeddings) >= 3: + # Find optimal k using silhouette analysis + for k in range(2, min(max_k+1, len(embeddings))): + try: + kmeans = KMeans(n_clusters=k, init=kmeans_init, n_init=10, random_state=0) + labels = kmeans.fit_predict(embeddings) + + # Calculate silhouette score + score = silhouette_score(embeddings, labels) + + if score > best_score: + best_score = score + best_k = k + best_centroids = kmeans.cluster_centers_ + except Exception as e: + logger.warning(f"Error during silhouette analysis for k={k}: {str(e)}") + continue + + # Fallback to elbow method if silhouette failed or for small datasets + if best_centroids is None: + inertia = [] + cluster_centroids = [] + K = range(1, min(max_k+1, len(embeddings)+1)) - for k in K: - kmeans = KMeans(n_clusters=k, random_state=0) - kmeans.fit(embeddings) - inertia.append(kmeans.inertia_) - cluster_centroids.append({"k": k, "centroids": kmeans.cluster_centers_}) + for k in K: + kmeans = KMeans(n_clusters=k, init=kmeans_init, random_state=0) + kmeans.fit(embeddings) + inertia.append(kmeans.inertia_) + cluster_centroids.append({"k": k, "centroids": kmeans.cluster_centers_}) - diffs = [inertia[i] - inertia[i+1] for i in range(len(inertia)-1)] - optimal_centroids = cluster_centroids[diffs.index(max(diffs)) + 1]['centroids'] + if len(inertia) > 1: + diffs = [inertia[i] - inertia[i+1] for i in range(len(inertia)-1)] + best_idx = diffs.index(max(diffs)) + best_centroids = cluster_centroids[best_idx + 1]['centroids'] + else: + # Just one cluster + best_centroids = cluster_centroids[0]['centroids'] + + return best_centroids + + except Exception as e: + logger.error(f"Error in cluster analysis: {str(e)}") + # Return a single centroid (mean of all embeddings) as fallback + return np.mean(embeddings, axis=0, keepdims=True) - return optimal_centroids - -def find_closest_centroid(centroids: list, normed_face_embedding) -> list: +def find_closest_centroid(centroids: list, normed_face_embedding) -> Optional[Tuple[int, np.ndarray]]: + """ + Find the closest centroid to a face embedding + + Args: + centroids: List of cluster centroids + normed_face_embedding: Normalized face embedding vector + + Returns: + Tuple of (centroid index, centroid vector) or None if matching fails + """ try: centroids = np.array(centroids) normed_face_embedding = np.array(normed_face_embedding) + + # Validate input shapes + if len(centroids.shape) != 2 or len(normed_face_embedding.shape) != 1: + logger.warning(f"Invalid shapes: centroids {centroids.shape}, embedding {normed_face_embedding.shape}") + return None + + # Calculate similarity (dot product) between embedding and each centroid similarities = np.dot(centroids, normed_face_embedding) closest_centroid_index = np.argmax(similarities) return closest_centroid_index, centroids[closest_centroid_index] - except ValueError: + except Exception as e: + logger.error(f"Error finding closest centroid: {str(e)}") return None \ No newline at end of file diff --git a/modules/globals.py b/modules/globals.py index 564fe7d..f1a0306 100644 --- a/modules/globals.py +++ b/modules/globals.py @@ -1,17 +1,36 @@ import os -from typing import List, Dict, Any +import json +import logging +from typing import List, Dict, Any, Optional +logger = logging.getLogger(__name__) + +# Core paths ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) WORKFLOW_DIR = os.path.join(ROOT_DIR, "workflow") +CONFIG_PATH = os.path.join(ROOT_DIR, "config.json") +# Default configuration settings +DEFAULT_SETTINGS = { + 'max_cluster_k': 10, + 'kmeans_init': 'k-means++', + 'nsfw_threshold': 0.85, + 'mask_feather_ratio': 8, + 'mask_down_size': 0.50, + 'mask_size': 1 +} + +# File type definitions file_types = [ ("Image", ("*.png", "*.jpg", "*.jpeg", "*.gif", "*.bmp")), ("Video", ("*.mp4", "*.mkv")), ] +# Runtime variables source_target_map = [] simple_map = {} +# Paths and processing options source_path = None target_path = None output_path = None @@ -21,7 +40,7 @@ keep_audio = True keep_frames = False many_faces = False map_faces = False -color_correction = False # New global variable for color correction toggle +color_correction = False nsfw_filter = False video_encoder = None video_quality = None @@ -38,6 +57,62 @@ webcam_preview_running = False show_fps = False mouth_mask = False show_mouth_mask_box = False -mask_feather_ratio = 8 -mask_down_size = 0.50 -mask_size = 1 + +# Masking parameters - moved from hardcoded to configurable +mask_feather_ratio = DEFAULT_SETTINGS['mask_feather_ratio'] +mask_down_size = DEFAULT_SETTINGS['mask_down_size'] +mask_size = DEFAULT_SETTINGS['mask_size'] + +# Advanced parameters +max_cluster_k = DEFAULT_SETTINGS['max_cluster_k'] +kmeans_init = DEFAULT_SETTINGS['kmeans_init'] +nsfw_threshold = DEFAULT_SETTINGS['nsfw_threshold'] + +def load_settings() -> None: + """ + Load user settings from config file + """ + global mask_feather_ratio, mask_down_size, mask_size + global max_cluster_k, kmeans_init, nsfw_threshold + + try: + if os.path.exists(CONFIG_PATH): + with open(CONFIG_PATH, 'r') as f: + config = json.load(f) + + # Apply settings from config, falling back to defaults + mask_feather_ratio = config.get('mask_feather_ratio', DEFAULT_SETTINGS['mask_feather_ratio']) + mask_down_size = config.get('mask_down_size', DEFAULT_SETTINGS['mask_down_size']) + mask_size = config.get('mask_size', DEFAULT_SETTINGS['mask_size']) + max_cluster_k = config.get('max_cluster_k', DEFAULT_SETTINGS['max_cluster_k']) + kmeans_init = config.get('kmeans_init', DEFAULT_SETTINGS['kmeans_init']) + nsfw_threshold = config.get('nsfw_threshold', DEFAULT_SETTINGS['nsfw_threshold']) + + logger.info("Settings loaded from config file") + except Exception as e: + logger.error(f"Error loading settings: {str(e)}") + # Use defaults if loading fails + +def save_settings() -> None: + """ + Save current settings to config file + """ + try: + config = { + 'mask_feather_ratio': mask_feather_ratio, + 'mask_down_size': mask_down_size, + 'mask_size': mask_size, + 'max_cluster_k': max_cluster_k, + 'kmeans_init': kmeans_init, + 'nsfw_threshold': nsfw_threshold + } + + with open(CONFIG_PATH, 'w') as f: + json.dump(config, f, indent=2) + + logger.info("Settings saved to config file") + except Exception as e: + logger.error(f"Error saving settings: {str(e)}") + +# Load settings at module import time +load_settings() \ No newline at end of file diff --git a/modules/predicter.py b/modules/predicter.py index 23a2564..14d832b 100644 --- a/modules/predicter.py +++ b/modules/predicter.py @@ -1,36 +1,125 @@ -import numpy +import numpy as np import opennsfw2 from PIL import Image -import cv2 # Add OpenCV import -import modules.globals # Import globals to access the color correction toggle +import cv2 +import modules.globals +import logging +from functools import lru_cache +from typing import Union, Any from modules.typing import Frame -MAX_PROBABILITY = 0.85 +logger = logging.getLogger(__name__) -# Preload the model once for efficiency -model = None +# Global model instance for reuse +_model = None + +@lru_cache(maxsize=1) +def load_nsfw_model(): + """ + Load the NSFW prediction model with caching + + Returns: + Loaded NSFW model + """ + try: + logger.info("Loading NSFW detection model") + return opennsfw2.make_open_nsfw_model() + except Exception as e: + logger.error(f"Failed to load NSFW model: {str(e)}") + return None + +def get_nsfw_model(): + """ + Get or initialize the NSFW model + + Returns: + NSFW model instance + """ + global _model + if _model is None: + _model = load_nsfw_model() + return _model def predict_frame(target_frame: Frame) -> bool: - # Convert the frame to RGB before processing if color correction is enabled - if modules.globals.color_correction: - target_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB) + """ + Predict if a frame contains NSFW content + + Args: + target_frame: Frame to analyze as numpy array - image = Image.fromarray(target_frame) - image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO) - global model - if model is None: - model = opennsfw2.make_open_nsfw_model() + Returns: + True if NSFW content detected, False otherwise + """ + try: + if target_frame is None: + logger.warning("Cannot predict on None frame") + return False + + # Get threshold from globals + threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) - views = numpy.expand_dims(image, axis=0) - _, probability = model.predict(views)[0] - return probability > MAX_PROBABILITY - + # Convert the frame to RGB if needed + expected_format = 'RGB' if modules.globals.color_correction else 'BGR' + if expected_format == 'RGB' and target_frame.shape[2] == 3: + processed_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB) + else: + processed_frame = target_frame + + # Convert to PIL image and preprocess + image = Image.fromarray(processed_frame) + image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO) + + # Get model and predict + model = get_nsfw_model() + if model is None: + logger.error("NSFW model not available") + return False + + views = np.expand_dims(image, axis=0) + _, probability = model.predict(views)[0] + + logger.debug(f"NSFW probability: {probability:.4f}") + return probability > threshold + + except Exception as e: + logger.error(f"Error during NSFW prediction: {str(e)}") + return False def predict_image(target_path: str) -> bool: - return opennsfw2.predict_image(target_path) > MAX_PROBABILITY - + """ + Predict if an image file contains NSFW content + + Args: + target_path: Path to image file + + Returns: + True if NSFW content detected, False otherwise + """ + try: + threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) + return opennsfw2.predict_image(target_path) > threshold + except Exception as e: + logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}") + return False def predict_video(target_path: str) -> bool: - _, probabilities = opennsfw2.predict_video_frames(video_path=target_path, frame_interval=100) - return any(probability > MAX_PROBABILITY for probability in probabilities) + """ + Predict if a video file contains NSFW content + + Args: + target_path: Path to video file + + Returns: + True if NSFW content detected, False otherwise + """ + try: + threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) + _, probabilities = opennsfw2.predict_video_frames( + video_path=target_path, + frame_interval=100 + ) + return any(probability > threshold for probability in probabilities) + except Exception as e: + logger.error(f"Error predicting NSFW for video {target_path}: {str(e)}") + return False \ No newline at end of file