Refactor: Address Sourcery Bot and reviewer feedback for stability and testability
- Fixed frame index calculation in capturer.py to prevent invalid -1 index - Removed side-effect of automatic settings load in globals.py; added explicit init() - Refactored functions to accept config params directly, reducing global settings reliance: - find_cluster_centroids() now takes max_k and kmeans_init - NSFW prediction functions now accept threshold - Updated documentation for all affected functions - Verified all changes against existing test suite — all tests pass
parent
1d6d72b8bc
commit
532c8e57db
|
@ -32,7 +32,9 @@ def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Any]:
|
||||||
capture.set(cv2.CAP_PROP_CONVERT_RGB, 0) # Explicitly disable if not needed
|
capture.set(cv2.CAP_PROP_CONVERT_RGB, 0) # Explicitly disable if not needed
|
||||||
|
|
||||||
frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||||
capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1))
|
# Ensure frame_number is valid (0-based index)
|
||||||
|
target_frame = max(0, min(frame_total - 1, frame_number))
|
||||||
|
capture.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||||||
has_frame, frame = capture.read()
|
has_frame, frame = capture.read()
|
||||||
|
|
||||||
# Only convert manually if color_correction is enabled but capture didn't handle it
|
# Only convert manually if color_correction is enabled but capture didn't handle it
|
||||||
|
|
|
@ -7,13 +7,14 @@ import modules.globals
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def find_cluster_centroids(embeddings, max_k=10) -> Any:
|
def find_cluster_centroids(embeddings, max_k=None, kmeans_init=None) -> Any:
|
||||||
"""
|
"""
|
||||||
Identifies optimal face clusters using KMeans and silhouette scoring
|
Identifies optimal face clusters using KMeans and silhouette scoring
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embeddings: Face embedding vectors
|
embeddings: Face embedding vectors
|
||||||
max_k: Maximum number of clusters to consider
|
max_k: Maximum number of clusters to consider (default: from globals)
|
||||||
|
kmeans_init: KMeans initialization method (default: from globals)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Array of optimal cluster centroids
|
Array of optimal cluster centroids
|
||||||
|
@ -23,8 +24,10 @@ def find_cluster_centroids(embeddings, max_k=10) -> Any:
|
||||||
logger.warning("Not enough embeddings for clustering analysis")
|
logger.warning("Not enough embeddings for clustering analysis")
|
||||||
return embeddings # Return the single embedding as its own cluster
|
return embeddings # Return the single embedding as its own cluster
|
||||||
|
|
||||||
# Use settings from globals if available
|
# Use settings from globals if not explicitly provided
|
||||||
max_k = getattr(modules.globals, 'max_cluster_k', max_k)
|
if max_k is None:
|
||||||
|
max_k = getattr(modules.globals, 'max_cluster_k', 10)
|
||||||
|
if kmeans_init is None:
|
||||||
kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++')
|
kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++')
|
||||||
|
|
||||||
# Try silhouette method first
|
# Try silhouette method first
|
||||||
|
|
|
@ -68,6 +68,14 @@ max_cluster_k = DEFAULT_SETTINGS['max_cluster_k']
|
||||||
kmeans_init = DEFAULT_SETTINGS['kmeans_init']
|
kmeans_init = DEFAULT_SETTINGS['kmeans_init']
|
||||||
nsfw_threshold = DEFAULT_SETTINGS['nsfw_threshold']
|
nsfw_threshold = DEFAULT_SETTINGS['nsfw_threshold']
|
||||||
|
|
||||||
|
def init() -> None:
|
||||||
|
"""
|
||||||
|
Initialize the globals module and load settings
|
||||||
|
Should be called explicitly by the application during startup
|
||||||
|
"""
|
||||||
|
load_settings()
|
||||||
|
logger.info("Globals module initialized")
|
||||||
|
|
||||||
def load_settings() -> None:
|
def load_settings() -> None:
|
||||||
"""
|
"""
|
||||||
Load user settings from config file
|
Load user settings from config file
|
||||||
|
@ -114,5 +122,5 @@ def save_settings() -> None:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error saving settings: {str(e)}")
|
logger.error(f"Error saving settings: {str(e)}")
|
||||||
|
|
||||||
# Load settings at module import time
|
# Don't load settings at import time to avoid side effects
|
||||||
load_settings()
|
# Will be called explicitly by the application's initialization
|
|
@ -41,12 +41,13 @@ def get_nsfw_model():
|
||||||
_model = load_nsfw_model()
|
_model = load_nsfw_model()
|
||||||
return _model
|
return _model
|
||||||
|
|
||||||
def predict_frame(target_frame: Frame) -> bool:
|
def predict_frame(target_frame: Frame, threshold=None) -> bool:
|
||||||
"""
|
"""
|
||||||
Predict if a frame contains NSFW content
|
Predict if a frame contains NSFW content
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_frame: Frame to analyze as numpy array
|
target_frame: Frame to analyze as numpy array
|
||||||
|
threshold: NSFW probability threshold (default: from globals)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if NSFW content detected, False otherwise
|
True if NSFW content detected, False otherwise
|
||||||
|
@ -56,7 +57,8 @@ def predict_frame(target_frame: Frame) -> bool:
|
||||||
logger.warning("Cannot predict on None frame")
|
logger.warning("Cannot predict on None frame")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Get threshold from globals
|
# Get threshold from globals if not explicitly provided
|
||||||
|
if threshold is None:
|
||||||
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
||||||
|
|
||||||
# Convert the frame to RGB if needed
|
# Convert the frame to RGB if needed
|
||||||
|
@ -86,34 +88,38 @@ def predict_frame(target_frame: Frame) -> bool:
|
||||||
logger.error(f"Error during NSFW prediction: {str(e)}")
|
logger.error(f"Error during NSFW prediction: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def predict_image(target_path: str) -> bool:
|
def predict_image(target_path: str, threshold=None) -> bool:
|
||||||
"""
|
"""
|
||||||
Predict if an image file contains NSFW content
|
Predict if an image file contains NSFW content
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_path: Path to image file
|
target_path: Path to image file
|
||||||
|
threshold: NSFW probability threshold (default: from globals)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if NSFW content detected, False otherwise
|
True if NSFW content detected, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if threshold is None:
|
||||||
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
||||||
return opennsfw2.predict_image(target_path) > threshold
|
return opennsfw2.predict_image(target_path) > threshold
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}")
|
logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def predict_video(target_path: str) -> bool:
|
def predict_video(target_path: str, threshold=None) -> bool:
|
||||||
"""
|
"""
|
||||||
Predict if a video file contains NSFW content
|
Predict if a video file contains NSFW content
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_path: Path to video file
|
target_path: Path to video file
|
||||||
|
threshold: NSFW probability threshold (default: from globals)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if NSFW content detected, False otherwise
|
True if NSFW content detected, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if threshold is None:
|
||||||
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
|
||||||
_, probabilities = opennsfw2.predict_video_frames(
|
_, probabilities = opennsfw2.predict_video_frames(
|
||||||
video_path=target_path,
|
video_path=target_path,
|
||||||
|
|
Loading…
Reference in New Issue