"""
Image Quality Validation Service
Provides comprehensive validation for face registration images.
"""
import cv2
import numpy as np
from typing import Tuple, List, Dict, Optional


def detect_blur(image: np.ndarray, face_bbox: Optional[Tuple[int, int, int, int]] = None, threshold: float = 80.0) -> Tuple[bool, float]:
    """
    Detect blur using Laplacian variance method.
    
    Args:
        image: Input image (BGR format)
        face_bbox: Optional face bounding box (x, y, w, h). If None, uses entire image.
        threshold: Minimum variance threshold (default: 80 - more lenient, focuses on face clarity)
    
    Returns:
        Tuple of (is_blurry: bool, blur_score: float)
    """
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # Extract face region if bbox provided
    if face_bbox:
        x, y, w, h = face_bbox
        gray = gray[y:y+h, x:x+w]
    
    # Calculate Laplacian variance
    laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
    
    is_blurry = laplacian_var < threshold
    return is_blurry, laplacian_var


def validate_lighting(image: np.ndarray, face_bbox: Tuple[int, int, int, int]) -> Tuple[bool, List[str], Dict]:
    """
    Comprehensive lighting validation for both face region and overall environment.
    
    Args:
        image: Input image (BGR format)
        face_bbox: Face bounding box (x, y, w, h)
    
    Returns:
        Tuple of (is_valid: bool, issues: List[str], metrics: Dict)
    """
    # Convert entire image to grayscale for environment analysis
    gray_full = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # Extract face region
    x, y, w, h = face_bbox
    face_region = image[y:y+h, x:x+w]
    gray_face = cv2.cvtColor(face_region, cv2.COLOR_BGR2GRAY)
    
    # Calculate environment (overall image) metrics
    env_brightness = float(np.mean(gray_full))
    env_glare_pixels = np.sum((gray_full > 252).astype(int))
    env_total_pixels = gray_full.size
    env_glare_percentage = float(env_glare_pixels / env_total_pixels)
    
    # Calculate face region metrics
    face_mean_brightness = float(np.mean(gray_face))
    face_brightness_std = float(np.std(gray_face))
    face_min_brightness = int(np.min(gray_face))
    face_max_brightness = int(np.max(gray_face))
    face_contrast = int(face_max_brightness - face_min_brightness)
    face_clipped_dark = float(np.sum(gray_face == 0) / gray_face.size)
    face_clipped_bright = float(np.sum(gray_face == 255) / gray_face.size)
    
    # Calculate background brightness (everything except face)
    mask = np.ones(gray_full.shape, dtype=bool)
    mask[y:y+h, x:x+w] = False
    background_region = gray_full[mask]
    background_brightness = float(np.mean(background_region)) if len(background_region) > 0 else env_brightness
    
    # Check for backlighting (background much brighter than face)
    # Only flag severe backlighting that would actually affect face recognition
    # Increased threshold to 70 - only reject if background is significantly brighter
    backlight_threshold = 70
    is_backlit = background_brightness > face_mean_brightness + backlight_threshold
    
    # Calculate face very bright percentage for torch/flash detection
    face_very_bright_pixels = np.sum((gray_face > 240).astype(int))
    face_very_bright_percentage = float(face_very_bright_pixels / gray_face.size)
    
    # Calculate moderate bright pixels for early torch detection
    face_moderate_bright_pixels = np.sum((gray_face > 200).astype(int))
    face_moderate_bright_percentage = float(face_moderate_bright_pixels / gray_face.size)
    
    # Store all metrics
    metrics = {
        'face_mean_brightness': face_mean_brightness,
        'face_brightness_std': face_brightness_std,
        'face_min_brightness': face_min_brightness,
        'face_max_brightness': face_max_brightness,
        'face_contrast': face_contrast,
        'face_clipped_dark': face_clipped_dark,
        'face_clipped_bright': face_clipped_bright,
        'face_very_bright_percentage': face_very_bright_percentage,
        'face_moderate_bright_percentage': face_moderate_bright_percentage,
        'env_brightness': env_brightness,
        'env_glare_percentage': env_glare_percentage,
        'background_brightness': background_brightness,
        'is_backlit': is_backlit,
        'brightness_diff': face_mean_brightness - env_brightness  # How much brighter face is than environment
    }
    
    issues = []
    
    # ===== ENVIRONMENT CHECKS (Overall Image) =====
    # Check overall environment brightness (relaxed: 60-200 for more leniency)
    # Only reject if extremely dark or extremely bright
    if env_brightness < 60:
        issues.append("environment_too_dark")
    elif env_brightness > 200:
        issues.append("environment_too_bright")
    
    # Check for glare in overall image (relaxed: < 3% - less sensitive)
    # Only reject if there's excessive glare
    if env_glare_percentage > 0.03:
        issues.append("excessive_glare")
    
    # ===== FACE REGION CHECKS =====
    # Check face brightness - Focus on quality: only reject if too dark to see details or too bright (torch/flash)
    # Accept wider range (60-200) as long as face details are visible and quality is good
    if face_mean_brightness < 60:  # Only reject if extremely dark (can't see face details)
        issues.append("face_too_dark")
    elif face_mean_brightness > 200:  # Only reject if extremely bright (likely torch/flash)
        issues.append("face_too_bright")
    
    # Check for very bright face (torch/flash detection) - Only reject extreme cases
    # Only flag if max brightness is extremely high (likely torch/flash)
    if face_max_brightness > 255:  # Only reject if pixels are completely saturated
        issues.append("face_too_bright")
    
    # Check for excessive bright pixels in face region (torch/flash indicator) - More lenient
    # Only reject if a large portion of face is extremely bright (likely torch/flash)
    if face_very_bright_percentage > 0.30:  # More lenient - 30% instead of 20%
        issues.append("face_too_bright")
    
    # Check for moderate bright pixels (early torch/flash detection) - REMOVED
    # Too sensitive - removed this check to be more lenient
    
    # Check for backlighting - Only reject severe backlighting that affects face recognition
    # Increased threshold to 70 - only reject if background is significantly brighter
    if is_backlit:
        # Only reject if backlighting is severe AND face contrast is low (affects recognition)
        if face_contrast < 40:  # Only reject backlighting if it also causes low contrast
            issues.append("backlit")
    
    # Check contrast - Only reject if contrast is too low to see face details
    if face_contrast < 25:  # More lenient - only reject if contrast is very low
        issues.append("low_contrast")
    
    # Check clipping (underexposure/overexposure) - More lenient, focus on quality
    # Only reject if clipping affects a significant portion of the face
    if face_clipped_dark > 0.20:  # More lenient - 20% instead of 15%
        issues.append("underexposed")
    if face_clipped_bright > 0.15:  # More lenient - 15% instead of 10%
        issues.append("overexposed")
    
    # Additional check: if face is much brighter than environment, likely torch/flash
    # Only reject if face is extremely brighter than environment (likely torch/flash)
    if face_mean_brightness > env_brightness + 70:  # More lenient - 70 points instead of 60
        issues.append("face_too_bright")
    
    # REMOVED: Suspicious zone check - too sensitive, removed for more leniency
    
    # Check brightness variance (too uniform indicates poor quality)
    # Only reject if lighting is extremely uniform (affects recognition quality)
    if face_brightness_std < 10:  # More lenient - only reject if extremely uniform
        issues.append("uniform_lighting")
    
    return len(issues) == 0, issues, metrics


def calculate_face_size_score(face_bbox: Tuple[int, int, int, int], min_size: int = 80, ideal_size: int = 200) -> float:
    """
    Calculate score based on face size.
    Larger faces are generally better for recognition.
    
    Args:
        face_bbox: Face bounding box (x, y, w, h)
        min_size: Minimum acceptable face size
        ideal_size: Ideal face size
    
    Returns:
        Score between 0.0 and 1.0
    """
    _, _, w, h = face_bbox
    face_size = min(w, h)
    
    if face_size < min_size:
        return 0.0
    elif face_size >= ideal_size:
        return 1.0
    else:
        # Linear interpolation between min and ideal
        return (face_size - min_size) / (ideal_size - min_size)


def calculate_sharpness_score(blur_score: float, threshold: float = 80.0) -> float:
    """
    Convert blur score to sharpness score (0.0 to 1.0).
    
    Args:
        blur_score: Laplacian variance score
        threshold: Minimum acceptable threshold
    
    Returns:
        Score between 0.0 and 1.0
    """
    if blur_score >= threshold * 2:
        return 1.0
    elif blur_score >= threshold:
        # Linear interpolation from threshold to 2*threshold
        return (blur_score - threshold) / threshold
    else:
        # Below threshold, score decreases linearly
        return max(0.0, blur_score / threshold)


def calculate_lighting_quality_score(metrics: Dict) -> float:
    """
    Calculate lighting quality score from lighting metrics.
    
    Args:
        metrics: Dictionary from validate_lighting()
    
    Returns:
        Score between 0.0 and 1.0
    """
    score = 1.0
    
    # Penalize for environment being too dark or too bright (relaxed thresholds)
    env_brightness = metrics.get('env_brightness', 128)
    if env_brightness < 60 or env_brightness > 200:
        score *= 0.3  # Less severe penalty
    elif env_brightness < 80 or env_brightness > 180:
        score *= 0.6
    
    # Penalize for excessive glare (relaxed threshold)
    env_glare = metrics.get('env_glare_percentage', 0.0)
    if env_glare > 0.03:  # More lenient - 3% instead of 1.5%
        score *= 0.4  # Less severe penalty for glare
    elif env_glare > 0.02:
        score *= 0.7
    
    # Penalize for face being too dark or too bright (relaxed thresholds)
    face_brightness = metrics.get('face_mean_brightness', 128)
    if face_brightness < 80 or face_brightness > 160:  # More lenient range
        score *= 0.3  # Less severe penalty
    elif face_brightness < 90 or face_brightness > 150:
        score *= 0.6
    
    # Additional penalty for very bright face (torch/flash detection) - more lenient
    face_max_brightness = metrics.get('face_max_brightness', 128)
    if face_max_brightness > 250:
        score *= 0.4  # Less severe penalty
    
    # Penalize for excessive bright pixels in face (torch/flash indicator) - more lenient
    face_very_bright_pct = metrics.get('face_very_bright_percentage', 0.0)
    if face_very_bright_pct > 0.20:
        score *= 0.4  # Less severe penalty
    elif face_very_bright_pct > 0.15:
        score *= 0.7
    
    # Penalize if face is much brighter than environment (torch/flash) - more lenient
    if face_brightness > env_brightness + 60:
        score *= 0.5  # Less severe penalty
    
    # Penalize for backlighting
    if metrics.get('is_backlit', False):
        score *= 0.5
    
    # Penalize for low contrast
    contrast = metrics.get('face_contrast', 100)
    if contrast < 30:
        score *= 0.4
    elif contrast < 50:
        score *= 0.7
    
    # Penalize for clipping (relaxed thresholds)
    if metrics.get('face_clipped_dark', 0) > 0.15:
        score *= 0.6
    if metrics.get('face_clipped_bright', 0) > 0.10:  # More lenient - 10% instead of 5%
        score *= 0.5  # Less severe penalty for overexposure
    
    # Penalize for uniform lighting
    if metrics.get('face_brightness_std', 30) < 15:
        score *= 0.6
    
    return score


def calculate_quality_score(image: np.ndarray, face_bbox: Tuple[int, int, int, int], 
                           blur_score: float, lighting_metrics: Dict) -> Tuple[float, Dict]:
    """
    Calculate overall quality score for an image.
    
    Args:
        image: Input image (BGR format)
        face_bbox: Face bounding box (x, y, w, h)
        blur_score: Blur score from detect_blur()
        lighting_metrics: Metrics from validate_lighting()
    
    Returns:
        Tuple of (total_score: float, breakdown: Dict)
    """
    # Calculate individual scores
    sharpness = calculate_sharpness_score(blur_score)
    face_size = calculate_face_size_score(face_bbox)
    lighting = calculate_lighting_quality_score(lighting_metrics)
    
    # Weighted average
    weights = {
        'sharpness': 0.3,
        'face_size': 0.2,
        'lighting': 0.3,
        'face_angle': 0.2  # Placeholder for future face angle detection
    }
    
    # For now, face_angle is set to 1.0 (no penalty)
    face_angle = 1.0
    
    scores = {
        'sharpness': sharpness,
        'face_size': face_size,
        'lighting': lighting,
        'face_angle': face_angle
    }
    
    total_score = (
        scores['sharpness'] * weights['sharpness'] +
        scores['face_size'] * weights['face_size'] +
        scores['lighting'] * weights['lighting'] +
        scores['face_angle'] * weights['face_angle']
    )
    
    breakdown = {
        'total': total_score,
        'sharpness': sharpness,
        'face_size': face_size,
        'lighting': lighting,
        'face_angle': face_angle
    }
    
    return total_score, breakdown


def cosine_similarity(embedding1: np.ndarray, embedding2: np.ndarray) -> float:
    """
    Calculate cosine similarity between two embeddings.
    
    Args:
        embedding1: First embedding vector
        embedding2: Second embedding vector
    
    Returns:
        Cosine similarity score between -1.0 and 1.0 (typically 0.0 to 1.0 for face embeddings)
    """
    dot_product = np.dot(embedding1, embedding2)
    norm1 = np.linalg.norm(embedding1)
    norm2 = np.linalg.norm(embedding2)
    
    if norm1 == 0 or norm2 == 0:
        return 0.0
    
    return dot_product / (norm1 * norm2)


def validate_embedding_consistency(embeddings: List[np.ndarray], threshold: float = 0.6) -> Tuple[bool, float, List[int]]:
    """
    Validate that all embeddings are from the same person by comparing to first embedding.
    
    Args:
        embeddings: List of embedding vectors
        threshold: Minimum cosine similarity threshold (default: 0.6)
    
    Returns:
        Tuple of (is_valid: bool, min_similarity: float, failed_indices: List[int])
    """
    if len(embeddings) < 2:
        return True, 1.0, []
    
    reference = embeddings[0]
    failed_indices = []
    min_similarity = 1.0
    
    for i, embedding in enumerate(embeddings[1:], start=1):
        similarity = cosine_similarity(reference, embedding)
        min_similarity = min(min_similarity, similarity)
        if similarity < threshold:
            failed_indices.append(i)
    
    return len(failed_indices) == 0, min_similarity, failed_indices


def detect_face_in_image(image: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
    """
    Detect face in image using OpenCV Haar Cascade.
    
    Args:
        image: Input image (BGR format)
    
    Returns:
        Face bounding box (x, y, w, h) or None if no face detected
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
    faces = face_cascade.detectMultiScale(gray, 1.1, 4)
    
    if len(faces) == 0:
        return None
    
    # Return the largest face
    largest_face = max(faces, key=lambda f: f[2] * f[3])
    return tuple(largest_face)

