Refactor: Address Sourcery Bot and reviewer feedback for stability and testability

- Fixed frame index calculation in capturer.py to prevent invalid -1 index
- Removed side-effect of automatic settings load in globals.py; added explicit init()
- Refactored functions to accept config params directly, reducing global settings reliance:
  - find_cluster_centroids() now takes max_k and kmeans_init
  - NSFW prediction functions now accept threshold
- Updated documentation for all affected functions
- Verified all changes against existing test suite — all tests pass
souvik03-136 2025-04-29 19:45:04 +05:30
parent 1d6d72b8bc
commit 532c8e57db
4 changed files with 34 additions and 15 deletions

View File

@ -32,7 +32,9 @@ def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Any]:
capture.set(cv2.CAP_PROP_CONVERT_RGB, 0) # Explicitly disable if not needed capture.set(cv2.CAP_PROP_CONVERT_RGB, 0) # Explicitly disable if not needed
frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT)
capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) # Ensure frame_number is valid (0-based index)
target_frame = max(0, min(frame_total - 1, frame_number))
capture.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
has_frame, frame = capture.read() has_frame, frame = capture.read()
# Only convert manually if color_correction is enabled but capture didn't handle it # Only convert manually if color_correction is enabled but capture didn't handle it

View File

@ -7,13 +7,14 @@ import modules.globals
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def find_cluster_centroids(embeddings, max_k=10) -> Any: def find_cluster_centroids(embeddings, max_k=None, kmeans_init=None) -> Any:
""" """
Identifies optimal face clusters using KMeans and silhouette scoring Identifies optimal face clusters using KMeans and silhouette scoring
Args: Args:
embeddings: Face embedding vectors embeddings: Face embedding vectors
max_k: Maximum number of clusters to consider max_k: Maximum number of clusters to consider (default: from globals)
kmeans_init: KMeans initialization method (default: from globals)
Returns: Returns:
Array of optimal cluster centroids Array of optimal cluster centroids
@ -23,9 +24,11 @@ def find_cluster_centroids(embeddings, max_k=10) -> Any:
logger.warning("Not enough embeddings for clustering analysis") logger.warning("Not enough embeddings for clustering analysis")
return embeddings # Return the single embedding as its own cluster return embeddings # Return the single embedding as its own cluster
# Use settings from globals if available # Use settings from globals if not explicitly provided
max_k = getattr(modules.globals, 'max_cluster_k', max_k) if max_k is None:
kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++') max_k = getattr(modules.globals, 'max_cluster_k', 10)
if kmeans_init is None:
kmeans_init = getattr(modules.globals, 'kmeans_init', 'k-means++')
# Try silhouette method first # Try silhouette method first
best_k = 2 # Start with minimum viable cluster count best_k = 2 # Start with minimum viable cluster count

View File

@ -68,6 +68,14 @@ max_cluster_k = DEFAULT_SETTINGS['max_cluster_k']
kmeans_init = DEFAULT_SETTINGS['kmeans_init'] kmeans_init = DEFAULT_SETTINGS['kmeans_init']
nsfw_threshold = DEFAULT_SETTINGS['nsfw_threshold'] nsfw_threshold = DEFAULT_SETTINGS['nsfw_threshold']
def init() -> None:
"""
Initialize the globals module and load settings
Should be called explicitly by the application during startup
"""
load_settings()
logger.info("Globals module initialized")
def load_settings() -> None: def load_settings() -> None:
""" """
Load user settings from config file Load user settings from config file
@ -114,5 +122,5 @@ def save_settings() -> None:
except Exception as e: except Exception as e:
logger.error(f"Error saving settings: {str(e)}") logger.error(f"Error saving settings: {str(e)}")
# Load settings at module import time # Don't load settings at import time to avoid side effects
load_settings() # Will be called explicitly by the application's initialization

View File

@ -41,12 +41,13 @@ def get_nsfw_model():
_model = load_nsfw_model() _model = load_nsfw_model()
return _model return _model
def predict_frame(target_frame: Frame) -> bool: def predict_frame(target_frame: Frame, threshold=None) -> bool:
""" """
Predict if a frame contains NSFW content Predict if a frame contains NSFW content
Args: Args:
target_frame: Frame to analyze as numpy array target_frame: Frame to analyze as numpy array
threshold: NSFW probability threshold (default: from globals)
Returns: Returns:
True if NSFW content detected, False otherwise True if NSFW content detected, False otherwise
@ -56,8 +57,9 @@ def predict_frame(target_frame: Frame) -> bool:
logger.warning("Cannot predict on None frame") logger.warning("Cannot predict on None frame")
return False return False
# Get threshold from globals # Get threshold from globals if not explicitly provided
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) if threshold is None:
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
# Convert the frame to RGB if needed # Convert the frame to RGB if needed
expected_format = 'RGB' if modules.globals.color_correction else 'BGR' expected_format = 'RGB' if modules.globals.color_correction else 'BGR'
@ -86,35 +88,39 @@ def predict_frame(target_frame: Frame) -> bool:
logger.error(f"Error during NSFW prediction: {str(e)}") logger.error(f"Error during NSFW prediction: {str(e)}")
return False return False
def predict_image(target_path: str) -> bool: def predict_image(target_path: str, threshold=None) -> bool:
""" """
Predict if an image file contains NSFW content Predict if an image file contains NSFW content
Args: Args:
target_path: Path to image file target_path: Path to image file
threshold: NSFW probability threshold (default: from globals)
Returns: Returns:
True if NSFW content detected, False otherwise True if NSFW content detected, False otherwise
""" """
try: try:
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) if threshold is None:
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
return opennsfw2.predict_image(target_path) > threshold return opennsfw2.predict_image(target_path) > threshold
except Exception as e: except Exception as e:
logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}") logger.error(f"Error predicting NSFW for image {target_path}: {str(e)}")
return False return False
def predict_video(target_path: str) -> bool: def predict_video(target_path: str, threshold=None) -> bool:
""" """
Predict if a video file contains NSFW content Predict if a video file contains NSFW content
Args: Args:
target_path: Path to video file target_path: Path to video file
threshold: NSFW probability threshold (default: from globals)
Returns: Returns:
True if NSFW content detected, False otherwise True if NSFW content detected, False otherwise
""" """
try: try:
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85) if threshold is None:
threshold = getattr(modules.globals, 'nsfw_threshold', 0.85)
_, probabilities = opennsfw2.predict_video_frames( _, probabilities = opennsfw2.predict_video_frames(
video_path=target_path, video_path=target_path,
frame_interval=100 frame_interval=100