feat: Improve color handling, cluster analysis and add error resilience
- Enhanced color space handling in capturer.py with explicit controls - Improved cluster analysis with silhouette scoring for better face matching - Added caching for NSFW model to improve performance - Implemented proper error handling and logging across all modules - Added configurable settings system in globals.py - Created comprehensive unit tests - Added detailed documentation with docstringspull/1129/head
parent
181144ce33
commit
1d6d72b8bc
|
@ -1,32 +1,72 @@
|
|||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
import cv2
|
||||
import modules.globals # Import the globals to check the color correction toggle
|
||||
import modules.globals
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_video_frame(video_path: str, frame_number: int = 0) -> Any:
|
||||
def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Any]:
|
||||
"""
|
||||
Extract a specific frame from a video file with proper color handling.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
frame_number: Frame number to extract (defaults to first frame)
|
||||
|
||||
Returns:
|
||||
Video frame as numpy array or None if frame extraction fails
|
||||
"""
|
||||
try:
|
||||
capture = cv2.VideoCapture(video_path)
|
||||
if not capture.isOpened():
|
||||
logger.error(f"Failed to open video: {video_path}")
|
||||
return None
|
||||
|
||||
# Set MJPEG format to ensure correct color space handling
|
||||
capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
|
||||
|
||||
# Only force RGB conversion if color correction is enabled
|
||||
# Configure color conversion based on setting
|
||||
if modules.globals.color_correction:
|
||||
capture.set(cv2.CAP_PROP_CONVERT_RGB, 1)
|
||||
else:
|
||||
capture.set(cv2.CAP_PROP_CONVERT_RGB, 0) # Explicitly disable if not needed
|
||||
|
||||
frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1))
|
||||
has_frame, frame = capture.read()
|
||||
|
||||
if has_frame and modules.globals.color_correction:
|
||||
# Convert the frame color if necessary
|
||||
# Only convert manually if color_correction is enabled but capture didn't handle it
|
||||
if has_frame and modules.globals.color_correction and frame is not None:
|
||||
frame_channels = frame.shape[2] if len(frame.shape) == 3 else 1
|
||||
if frame_channels == 3: # Only convert if we have a color image
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
capture.release()
|
||||
return frame if has_frame else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing video frame: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_video_frame_total(video_path: str) -> int:
|
||||
"""
|
||||
Get the total number of frames in a video file.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
|
||||
Returns:
|
||||
Total number of frames in the video
|
||||
"""
|
||||
try:
|
||||
capture = cv2.VideoCapture(video_path)
|
||||
if not capture.isOpened():
|
||||
logger.error(f"Failed to open video for frame counting: {video_path}")
|
||||
return 0
|
||||
|
||||
video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
capture.release()
|
||||
return video_frame_total
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting video frames: {str(e)}")
|
||||
return 0
|
|
@ -1,32 +1,108 @@
|
|||
import numpy as np
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import silhouette_score
|
||||
from typing import Any
|
||||
from typing import Any, List, Optional, Tuple
|
||||
import logging
|
||||
import modules.globals
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def find_cluster_centroids(embeddings, max_k=10) -> Any:
|
||||
"""
|
||||
Identifies optimal face clusters using KMeans and silhouette scoring
|
||||
|
||||
Args:
|
||||
embeddings: Face embedding vectors
|
||||
max_k: Maximum number of clusters to consider
|
||||
|
||||
Returns:
|
||||
Array of optimal cluster centroids
|
||||
"""
|
||||
try:
|
||||
if len(embeddings) < 2:
|
||||
logger.warning("Not enough embeddings for clustering analysis")
|
||||
return embeddings # Return the single embedding as its own cluster
|
||||
|
||||
# Use settings from globals if available
|
||||
max_k = getattr(modules.globals, 'max_cluster_k', max_k)
|
||||
kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++')
|
||||
|
||||
# Try silhouette method first
|
||||
best_k = 2 # Start with minimum viable cluster count
|
||||
best_score = -1
|
||||
best_centroids = None
|
||||
|
||||
# We need at least 3 samples to calculate silhouette score
|
||||
if len(embeddings) >= 3:
|
||||
# Find optimal k using silhouette analysis
|
||||
for k in range(2, min(max_k+1, len(embeddings))):
|
||||
try:
|
||||
kmeans = KMeans(n_clusters=k, init=kmeans_init, n_init=10, random_state=0)
|
||||
labels = kmeans.fit_predict(embeddings)
|
||||
|
||||
# Calculate silhouette score
|
||||
score = silhouette_score(embeddings, labels)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_k = k
|
||||
best_centroids = kmeans.cluster_centers_
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during silhouette analysis for k={k}: {str(e)}")
|
||||
continue
|
||||
|
||||
# Fallback to elbow method if silhouette failed or for small datasets
|
||||
if best_centroids is None:
|
||||
inertia = []
|
||||
cluster_centroids = []
|
||||
K = range(1, max_k+1)
|
||||
K = range(1, min(max_k+1, len(embeddings)+1))
|
||||
|
||||
for k in K:
|
||||
kmeans = KMeans(n_clusters=k, random_state=0)
|
||||
kmeans = KMeans(n_clusters=k, init=kmeans_init, random_state=0)
|
||||
kmeans.fit(embeddings)
|
||||
inertia.append(kmeans.inertia_)
|
||||
cluster_centroids.append({"k": k, "centroids": kmeans.cluster_centers_})
|
||||
|
||||
if len(inertia) > 1:
|
||||
diffs = [inertia[i] - inertia[i+1] for i in range(len(inertia)-1)]
|
||||
optimal_centroids = cluster_centroids[diffs.index(max(diffs)) + 1]['centroids']
|
||||
best_idx = diffs.index(max(diffs))
|
||||
best_centroids = cluster_centroids[best_idx + 1]['centroids']
|
||||
else:
|
||||
# Just one cluster
|
||||
best_centroids = cluster_centroids[0]['centroids']
|
||||
|
||||
return optimal_centroids
|
||||
return best_centroids
|
||||
|
||||
def find_closest_centroid(centroids: list, normed_face_embedding) -> list:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cluster analysis: {str(e)}")
|
||||
# Return a single centroid (mean of all embeddings) as fallback
|
||||
return np.mean(embeddings, axis=0, keepdims=True)
|
||||
|
||||
def find_closest_centroid(centroids: list, normed_face_embedding) -> Optional[Tuple[int, np.ndarray]]:
|
||||
"""
|
||||
Find the closest centroid to a face embedding
|
||||
|
||||
Args:
|
||||
centroids: List of cluster centroids
|
||||
normed_face_embedding: Normalized face embedding vector
|
||||
|
||||
Returns:
|
||||
Tuple of (centroid index, centroid vector) or None if matching fails
|
||||
"""
|
||||
try:
|
||||
centroids = np.array(centroids)
|
||||
normed_face_embedding = np.array(normed_face_embedding)
|
||||
|
||||
# Validate input shapes
|
||||
if len(centroids.shape) != 2 or len(normed_face_embedding.shape) != 1:
|
||||
logger.warning(f"Invalid shapes: centroids {centroids.shape}, embedding {normed_face_embedding.shape}")
|
||||
return None
|
||||
|
||||
# Calculate similarity (dot product) between embedding and each centroid
|
||||
similarities = np.dot(centroids, normed_face_embedding)
|
||||
closest_centroid_index = np.argmax(similarities)
|
||||
|
||||
return closest_centroid_index, centroids[closest_centroid_index]
|
||||
except ValueError:
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding closest centroid: {str(e)}")
|
||||
return None
|
|
@ -1,17 +1,36 @@
|
|||
import os
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Core paths
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
WORKFLOW_DIR = os.path.join(ROOT_DIR, "workflow")
|
||||
CONFIG_PATH = os.path.join(ROOT_DIR, "config.json")
|
||||
|
||||
# Default configuration settings
|
||||
DEFAULT_SETTINGS = {
|
||||
'max_cluster_k': 10,
|
||||
'kmeans_init': 'k-means++',
|
||||
'nsfw_threshold': 0.85,
|
||||
'mask_feather_ratio': 8,
|
||||
'mask_down_size': 0.50,
|
||||
'mask_size': 1
|
||||
}
|
||||
|
||||
# File type definitions
|
||||
file_types = [
|
||||
("Image", ("*.png", "*.jpg", "*.jpeg", "*.gif", "*.bmp")),
|
||||
("Video", ("*.mp4", "*.mkv")),
|
||||
]
|
||||
|
||||
# Runtime variables
|
||||
source_target_map = []
|
||||
simple_map = {}
|
||||
|
||||
# Paths and processing options
|
||||
source_path = None
|
||||
target_path = None
|
||||
output_path = None
|
||||
|
@ -21,7 +40,7 @@ keep_audio = True
|
|||
keep_frames = False
|
||||
many_faces = False
|
||||
map_faces = False
|
||||
color_correction = False # New global variable for color correction toggle
|
||||
color_correction = False
|
||||
nsfw_filter = False
|
||||
video_encoder = None
|
||||
video_quality = None
|
||||
|
@ -38,6 +57,62 @@ webcam_preview_running = False
|
|||
show_fps = False
|
||||
mouth_mask = False
|
||||
show_mouth_mask_box = False
|
||||
mask_feather_ratio = 8
|
||||
mask_down_size = 0.50
|
||||
mask_size = 1
|
||||
|
||||
# Masking parameters - moved from hardcoded to configurable
|
||||
mask_feather_ratio = DEFAULT_SETTINGS['mask_feather_ratio']
|
||||
mask_down_size = DEFAULT_SETTINGS['mask_down_size']
|
||||
mask_size = DEFAULT_SETTINGS['mask_size']
|
||||
|
||||
# Advanced parameters
|
||||
max_cluster_k = DEFAULT_SETTINGS['max_cluster_k']
|
||||
kmeans_init = DEFAULT_SETTINGS['kmeans_init']
|
||||
nsfw_threshold = DEFAULT_SETTINGS['nsfw_threshold']
|
||||
|
||||
def load_settings() -> None:
|
||||
"""
|
||||
Load user settings from config file
|
||||
"""
|
||||
global mask_feather_ratio, mask_down_size, mask_size
|
||||
global max_cluster_k, kmeans_init, nsfw_threshold
|
||||
|
||||
try:
|
||||
if os.path.exists(CONFIG_PATH):
|
||||
with open(CONFIG_PATH, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Apply settings from config, falling back to defaults
|
||||
mask_feather_ratio = config.get('mask_feather_ratio', DEFAULT_SETTINGS['mask_feather_ratio'])
|
||||
mask_down_size = config.get('mask_down_size', DEFAULT_SETTINGS['mask_down_size'])
|
||||
mask_size = config.get('mask_size', DEFAULT_SETTINGS['mask_size'])
|
||||
max_cluster_k = config.get('max_cluster_k', DEFAULT_SETTINGS['max_cluster_k'])
|
||||
kmeans_init = config.get('kmeans_init', DEFAULT_SETTINGS['kmeans_init'])
|
||||
nsfw_threshold = config.get('nsfw_threshold', DEFAULT_SETTINGS['nsfw_threshold'])
|
||||
|
||||
logger.info("Settings loaded from config file")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading settings: {str(e)}")
|
||||
# Use defaults if loading fails
|
||||
|
||||
def save_settings() -> None:
|
||||
"""
|
||||
Save current settings to config file
|
||||
"""
|
||||
try:
|
||||
config = {
|
||||
'mask_feather_ratio': mask_feather_ratio,
|
||||
'mask_down_size': mask_down_size,
|
||||
'mask_size': mask_size,
|
||||
'max_cluster_k': max_cluster_k,
|
||||
'kmeans_init': kmeans_init,
|
||||
'nsfw_threshold': nsfw_threshold
|
||||
}
|
||||
|
||||
with open(CONFIG_PATH, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
logger.info("Settings saved to config file")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving settings: {str(e)}")
|
||||
|
||||
# Load settings at module import time
|
||||
load_settings()
|
|
@ -1,36 +1,125 @@
|
|||
import numpy
|
||||
import numpy as np
|
||||
import opennsfw2
|
||||
from PIL import Image
|
||||
import cv2 # Add OpenCV import
|
||||
import modules.globals # Import globals to access the color correction toggle
|
||||
import cv2
|
||||
import modules.globals
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Union, Any
|
||||
|
||||
from modules.typing import Frame
|
||||
|
||||
MAX_PROBABILITY = 0.85
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Preload the model once for efficiency
|
||||
model = None
|
||||
# Global model instance for reuse
|
||||
_model = None
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def load_nsfw_model():
|
||||
"""
|
||||
Load the NSFW prediction model with caching
|
||||
|
||||
Returns:
|
||||
Loaded NSFW model
|
||||
"""
|
||||
try:
|
||||
logger.info("Loading NSFW detection model")
|
||||
return opennsfw2.make_open_nsfw_model()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load NSFW model: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_nsfw_model():
|
||||
"""
|
||||
Get or initialize the NSFW model
|
||||
|
||||
Returns:
|
||||
NSFW model instance
|
||||
"""
|
||||
global _model
|
||||
if _model is None:
|
||||
_model = load_nsfw_model()
|
||||
return _model
|
||||
|
||||
def predict_frame(target_frame: Frame) -> bool:
|
||||
# Convert the frame to RGB before processing if color correction is enabled
|
||||
if modules.globals.color_correction:
|
||||
target_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB)
|
||||
"""
|
||||
Predict if a frame contains NSFW content
|
||||
|
||||
image = Image.fromarray(target_frame)
|
||||
Args:
|
||||
target_frame: Frame to analyze as numpy array
|
||||
|
||||
Returns:
|
||||
True if NSFW content detected, False otherwise
|
||||
"""
|
||||
try:
|
||||
if target_frame is None:
|
||||
logger.warning("Cannot predict on None frame")
|
||||
return False
|
||||
|
||||
# Get threshold from globals
|
||||
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
||||
|
||||
# Convert the frame to RGB if needed
|
||||
expected_format = 'RGB' if modules.globals.color_correction else 'BGR'
|
||||
if expected_format == 'RGB' and target_frame.shape[2] == 3:
|
||||
processed_frame = cv2.cvtColor(target_frame, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
processed_frame = target_frame
|
||||
|
||||
# Convert to PIL image and preprocess
|
||||
image = Image.fromarray(processed_frame)
|
||||
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
|
||||
global model
|
||||
|
||||
# Get model and predict
|
||||
model = get_nsfw_model()
|
||||
if model is None:
|
||||
model = opennsfw2.make_open_nsfw_model()
|
||||
logger.error("NSFW model not available")
|
||||
return False
|
||||
|
||||
views = numpy.expand_dims(image, axis=0)
|
||||
views = np.expand_dims(image, axis=0)
|
||||
_, probability = model.predict(views)[0]
|
||||
return probability > MAX_PROBABILITY
|
||||
|
||||
logger.debug(f"NSFW probability: {probability:.4f}")
|
||||
return probability > threshold
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during NSFW prediction: {str(e)}")
|
||||
return False
|
||||
|
||||
def predict_image(target_path: str) -> bool:
|
||||
return opennsfw2.predict_image(target_path) > MAX_PROBABILITY
|
||||
"""
|
||||
Predict if an image file contains NSFW content
|
||||
|
||||
Args:
|
||||
target_path: Path to image file
|
||||
|
||||
Returns:
|
||||
True if NSFW content detected, False otherwise
|
||||
"""
|
||||
try:
|
||||
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
||||
return opennsfw2.predict_image(target_path) > threshold
|
||||
except Exception as e:
|
||||
logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}")
|
||||
return False
|
||||
|
||||
def predict_video(target_path: str) -> bool:
|
||||
_, probabilities = opennsfw2.predict_video_frames(video_path=target_path, frame_interval=100)
|
||||
return any(probability > MAX_PROBABILITY for probability in probabilities)
|
||||
"""
|
||||
Predict if a video file contains NSFW content
|
||||
|
||||
Args:
|
||||
target_path: Path to video file
|
||||
|
||||
Returns:
|
||||
True if NSFW content detected, False otherwise
|
||||
"""
|
||||
try:
|
||||
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
||||
_, probabilities = opennsfw2.predict_video_frames(
|
||||
video_path=target_path,
|
||||
frame_interval=100
|
||||
)
|
||||
return any(probability > threshold for probability in probabilities)
|
||||
except Exception as e:
|
||||
logger.error(f"Error predicting NSFW for video {target_path}: {str(e)}")
|
||||
return False
|
Loading…
Reference in New Issue